Skip to content

Instantly share code, notes, and snippets.

@ParagEkbote
Last active November 4, 2025 17:59
Show Gist options
  • Select an option

  • Save ParagEkbote/1dd80f591e3f4b09e36beaaaf748738e to your computer and use it in GitHub Desktop.

Select an option

Save ParagEkbote/1dd80f591e3f4b09e36beaaaf748738e to your computer and use it in GitHub Desktop.
import os
import torch
import argparse
from datetime import datetime
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from pruna import smash, SmashConfig
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.task import Task
from pruna.evaluation.metrics import (
TotalTimeMetric,
LatencyMetric,
ThroughputMetric,
TotalParamsMetric,
TotalMACsMetric,
)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def parse_args():
parser = argparse.ArgumentParser(description='Enhanced Model Evaluation with Quantization')
# Model parameters
parser.add_argument(
'--models', type=str, nargs='+', required=True,
help='Model(s) to evaluate. Can pass space-separated list or comma-separated string.'
)
parser.add_argument('--quantization', type=str, default="hqq",
choices=['hqq', 'torchao', 'llm_int8', 'half', 'none'], help='Quantization method')
parser.add_argument('--bits', type=int, default=4, choices=[4, 8, 16])
parser.add_argument('--compile', action='store_true', default=False)
parser.add_argument('--compile-mode', type=str, default='max-autotune',
choices=['default', 'reduce-overhead', 'max-autotune'])
# Evaluation parameters
parser.add_argument('--iters', type=int, default=100)
parser.add_argument('--warmup', type=int, default=10)
parser.add_argument('--max-new-tokens', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--seq-len', type=int, default=512)
parser.add_argument('--samples', type=int, default=5)
# Output
parser.add_argument('--output-dir', type=str, default='reports')
# Hardware
parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cuda', 'cpu'])
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--verbose', action='store_true')
return parser.parse_args()
def format_number(num):
if num >= 1e12: return f"{num / 1e12:.2f}T"
if num >= 1e9: return f"{num / 1e9:.2f}B"
if num >= 1e6: return f"{num / 1e6:.2f}M"
if num >= 1e3: return f"{num / 1e3:.2f}K"
return f"{num:.2f}"
def format_time(microseconds):
if microseconds >= 1e6: return f"{microseconds / 1e6:.2f}s"
if microseconds >= 1e3: return f"{microseconds / 1e3:.2f}ms"
return f"{microseconds:.2f}μs"
def format_memory(bytes_val):
if bytes_val >= 1024**3: return f"{bytes_val / (1024**3):.2f} GB"
if bytes_val >= 1024**2: return f"{bytes_val / (1024**2):.2f} MB"
if bytes_val >= 1024: return f"{bytes_val / 1024:.2f} KB"
return f"{bytes_val} B"
def get_model_device(model):
"""Get the device of the model's parameters"""
try:
return next(model.parameters()).device
except StopIteration:
return torch.device('cpu')
def ensure_model_on_device(model, device):
"""Ensure model is on the specified device"""
model_device = get_model_device(model)
if str(model_device) != device:
print(f" ⚠️ Model device mismatch. Moving from {model_device} to {device}...")
model = model.to(device)
return model
def create_metrics():
return [
TotalTimeMetric(),
LatencyMetric(),
ThroughputMetric(),
TotalParamsMetric(),
TotalMACsMetric(),
]
def validate_metrics(results, verbose=False):
"""Validate that metrics were recorded correctly"""
validation_passed = True
validation_report = []
for metric_result in results:
metric_name = metric_result.name
metric_value = metric_result.result
# Check if metric has valid data
is_valid = False
error_msg = None
warning_msg = None
if isinstance(metric_value, dict):
if 'mean' in metric_value and metric_value['mean'] > 0:
is_valid = True
# Check for unrealistic values
if metric_name == 'latency':
latency_ms = metric_value['mean'] / 1000
if latency_ms < 1: # Less than 1ms is suspicious
warning_msg = f"⚠️ Suspiciously low ({latency_ms:.4f}ms)"
elif latency_ms > 10000: # More than 10s is suspicious
warning_msg = f"⚠️ Suspiciously high ({latency_ms:.2f}ms)"
else:
error_msg = f"Dict missing 'mean' or mean is zero/negative"
elif isinstance(metric_value, (int, float)):
if metric_value > 0:
is_valid = True
# Check for unrealistic values
if metric_name == 'latency':
latency_ms = metric_value / 1000
if latency_ms < 1:
warning_msg = f"⚠️ Suspiciously low ({latency_ms:.4f}ms)"
else:
error_msg = f"Value is zero or negative: {metric_value}"
elif isinstance(metric_value, list):
if len(metric_value) > 0 and any(v > 0 for v in metric_value):
is_valid = True
else:
error_msg = f"List is empty or all values are zero/negative"
else:
error_msg = f"Unexpected type: {type(metric_value)}"
status = "✓" if is_valid else "✗"
msg = f"{status} {metric_name}: "
if warning_msg:
msg += warning_msg
elif error_msg:
msg += error_msg
else:
msg += "OK"
validation_report.append(msg)
if not is_valid:
validation_passed = False
if verbose or not validation_passed:
print("\n📊 Metrics Validation:")
for line in validation_report:
print(f" {line}")
return validation_passed, validation_report
def generate_report(model_name, results, cuda_allocated, cuda_reserved, total_params, args, validation_report, manual_test_results=None):
# Extract metrics safely
metrics_dict = {}
for r in results:
try:
if isinstance(r.result, dict):
metrics_dict[r.name] = r.result.get("mean", 0)
elif isinstance(r.result, (int, float)):
metrics_dict[r.name] = r.result
elif isinstance(r.result, list):
metrics_dict[r.name] = sum(r.result)/len(r.result) if r.result else 0
else:
metrics_dict[r.name] = 0
except Exception:
metrics_dict[r.name] = 0
latency_ms = metrics_dict.get('latency', 0)/1000
throughput_samples_per_sec = 1 / (latency_ms/1000) if latency_ms>0 else 0
tokens_per_sec = throughput_samples_per_sec * args.max_new_tokens
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Sanity check flags
warnings = []
if latency_ms < 1:
warnings.append("⚠️ **WARNING**: Latency is unrealistically low. Metrics may not be measuring actual inference.")
if tokens_per_sec > 100000:
warnings.append("⚠️ **WARNING**: Token generation rate is impossibly high. Check metric configuration.")
if total_params < 1_000_000:
warnings.append("⚠️ **WARNING**: Parameter count seems too low for this model.")
markdown = f"""# Model Evaluation Report: {model_name}
**Date:** {timestamp}
**Device:** {args.device if args.device != 'auto' else 'CUDA' if torch.cuda.is_available() else 'CPU'}
**Quantization:** {args.quantization.upper()} {args.bits}-bit
**Torch Compile:** {"Yes" if args.compile else "No"}
"""
if warnings:
markdown += "## ⚠️ Sanity Check Warnings\n"
for warning in warnings:
markdown += f"{warning}\n\n"
markdown += f"""## Performance Metrics
- Total Parameters: {format_number(total_params)}
- Average Latency: {format_time(metrics_dict.get('latency', 0))}
- Throughput (samples/sec): {throughput_samples_per_sec:.2f}
- Token Generation Rate: {tokens_per_sec:.1f} tokens/sec
- GPU Memory Allocated: {format_memory(cuda_allocated * 1024**2)}
- GPU Memory Reserved: {format_memory(cuda_reserved * 1024**2)}
"""
if manual_test_results:
markdown += f"""
## Manual Verification Test
- Actual Generation Time: {manual_test_results['time_ms']:.2f}ms
- Tokens Generated: {manual_test_results['tokens']}
- Measured Tokens/sec: {manual_test_results['tokens_per_sec']:.1f}
- **Discrepancy**: {abs(tokens_per_sec - manual_test_results['tokens_per_sec']) / manual_test_results['tokens_per_sec'] * 100:.1f}% difference
"""
markdown += f"""
## Detailed Metrics
| Metric | Value |
|--------|-------|
"""
for name, value in metrics_dict.items():
markdown += f"| {name} | {format_number(value) if value > 1000 else f'{value:.4f}'} |\n"
markdown += f"""
## Validation Report
```
"""
for line in validation_report:
markdown += f"{line}\n"
markdown += "```\n"
return markdown
def main():
args = parse_args()
# Flatten comma-separated models
model_list = []
for m in args.models:
model_list.extend([x.strip() for x in m.split(',') if x.strip()])
args.models = model_list
os.makedirs(args.output_dir, exist_ok=True)
# Determine device
device = "cuda" if (args.device=="auto" and torch.cuda.is_available()) else args.device
for model_name in tqdm(args.models, desc="Evaluating models", unit="model"):
print(f"\n🚀 Evaluating {model_name} on {device.upper()}")
try:
# Load model & tokenizer
print(" Loading model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Apply quantization if needed
if args.quantization != 'none':
print(f" Applying {args.quantization} quantization ({args.bits}-bit)...")
cfg = SmashConfig(device=device)
cfg["quantizer"] = args.quantization
cfg[f"{args.quantization}_weight_bits"] = args.bits
cfg[f"{args.quantization}_compute_dtype"] = "torch.bfloat16"
if args.compile:
cfg["compiler"] = "torch_compile"
cfg["torch_compile_fullgraph"] = True
cfg["torch_compile_dynamic"] = True
cfg["torch_compile_mode"] = args.compile_mode
model = smash(model, cfg)
# Ensure model is on correct device after quantization
model = ensure_model_on_device(model, device)
# Double-check model device
if args.verbose:
print(f" Model is on device: {get_model_device(model)}")
# Prepare dataset
print(f" Preparing dataset ({args.samples} samples)...")
datamodule = PrunaDataModule.from_string(
dataset_name="WikiText",
tokenizer=tokenizer,
collate_fn_args={"max_seq_len": args.seq_len},
dataloader_args={"batch_size": args.batch_size, "num_workers": args.num_workers}
)
datamodule.limit_datasets(args.samples)
# Inspect batch format for debugging
if args.verbose:
test_loader = datamodule.train_dataloader()
test_batch = next(iter(test_loader))
print(f" Batch format: {type(test_batch)}")
if isinstance(test_batch, dict):
print(f" Batch keys: {test_batch.keys()}")
elif isinstance(test_batch, (tuple, list)):
print(f" Batch length: {len(test_batch)}")
print(f" Batch element shapes: {[b.shape if isinstance(b, torch.Tensor) else type(b) for b in test_batch]}")
del test_loader, test_batch
# Warm-up phase
if args.warmup > 0:
print(f" Running {args.warmup} warmup iterations...")
warmup_loader = datamodule.train_dataloader()
for i, batch in enumerate(tqdm(warmup_loader, desc=" Warmup", total=args.warmup, leave=False)):
if i >= args.warmup:
break
# Handle both dict and tuple batch formats
if isinstance(batch, dict):
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
with torch.no_grad():
_ = model(**batch)
elif isinstance(batch, (tuple, list)):
batch = tuple(b.to(device) if isinstance(b, torch.Tensor) else b for b in batch)
with torch.no_grad():
if len(batch) == 1:
_ = model(batch[0])
else:
_ = model(*batch)
else:
# Single tensor
batch = batch.to(device) if isinstance(batch, torch.Tensor) else batch
with torch.no_grad():
_ = model(batch)
if device == "cuda":
torch.cuda.synchronize()
# Evaluation
print(f" Running evaluation on {args.samples} samples...")
metrics = create_metrics()
task = Task(metrics, datamodule=datamodule)
eval_agent = EvaluationAgent(task)
# Force CUDA sync before evaluation for accurate timing
if device == "cuda":
torch.cuda.synchronize()
try:
results = eval_agent.evaluate(model)
except AttributeError as e:
if "'tuple' object has no attribute" in str(e):
print(" ⚠️ Detected batch format issue. Wrapping model to handle tuple inputs...")
# Create a wrapper that handles tuple inputs
class ModelWrapper(torch.nn.Module):
def __init__(self, base_model):
super().__init__()
self.model = base_model
def forward(self, *args, **kwargs):
if args and not kwargs:
# Tuple/positional args - convert to dict
if len(args) == 1:
return self.model(input_ids=args[0])
elif len(args) == 2:
return self.model(input_ids=args[0], attention_mask=args[1])
else:
return self.model(*args)
else:
return self.model(*args, **kwargs)
wrapped_model = ModelWrapper(model)
results = eval_agent.evaluate(wrapped_model)
else:
raise
# Force CUDA sync after evaluation
if device == "cuda":
torch.cuda.synchronize()
# Manual sanity check with actual inference
manual_test_results = None
if args.verbose:
print(" Running manual sanity check...")
import time
# Ensure model is on correct device
model = ensure_model_on_device(model, device)
model_device = get_model_device(model)
# Create test input on the same device as model
test_input = tokenizer("Hello world", return_tensors="pt")
test_input = {k: v.to(model_device) for k, v in test_input.items()}
# Warm up
for _ in range(3):
with torch.no_grad():
_ = model.generate(**test_input, max_new_tokens=10)
if device == "cuda":
torch.cuda.synchronize()
# Time actual generation
start = time.perf_counter()
with torch.no_grad():
output = model.generate(**test_input, max_new_tokens=args.max_new_tokens)
if device == "cuda":
torch.cuda.synchronize()
end = time.perf_counter()
actual_time_ms = (end - start) * 1000
actual_tokens = output.shape[1] - test_input['input_ids'].shape[1]
actual_tokens_per_sec = actual_tokens / (end - start)
manual_test_results = {
'time_ms': actual_time_ms,
'tokens': actual_tokens,
'tokens_per_sec': actual_tokens_per_sec
}
print(f" Manual test: {actual_time_ms:.2f}ms for {actual_tokens} tokens")
print(f" Manual tokens/sec: {actual_tokens_per_sec:.1f}")
# Validate metrics
validation_passed, validation_report = validate_metrics(results, verbose=args.verbose)
if not validation_passed:
print(" ⚠️ WARNING: Some metrics may not have been recorded correctly!")
# GPU memory info
if device=="cuda":
cuda_allocated = torch.cuda.memory_allocated()/1024**2
cuda_reserved = torch.cuda.memory_reserved()/1024**2
else:
cuda_allocated = cuda_reserved = 0
total_params = sum(p.numel() for p in model.parameters())
# Generate markdown report
report = generate_report(model_name, results, cuda_allocated, cuda_reserved,
total_params, args, validation_report, manual_test_results)
output_path = os.path.join(args.output_dir, f"{model_name.replace('/', '_')}_report.md")
with open(output_path, 'w') as f:
f.write(report)
print(f"✅ Report saved to {output_path}")
# Clean up
del model
if device == "cuda":
torch.cuda.empty_cache()
except Exception as e:
print(f"❌ Error evaluating {model_name}: {str(e)}")
if args.verbose:
import traceback
traceback.print_exc()
continue
if __name__ == "__main__":
main()
torch==2.7.0
transformers>=4.53.0
accelerate>=1.0.0
pruna==0.2.8
@ParagEkbote
Copy link
Author

ParagEkbote commented Oct 4, 2025

We can run the script as follows, for multi-model eval:

python eval_script.py \
  --models "HuggingFaceTB/SmolLM-135M,HuggingFaceTB/SmolLM-1.7B,HuggingFaceTB/SmolLM3-3B" \
  --quantization hqq --bits 4 --compile \
  --output-dir reports \
  --batch-size 8 --max-new-tokens 128 --samples 5

For single model eval:

python eval_script.py \
  --models "HuggingFaceTB/SmolLM2-1.7B" \
  --quantization hqq \
  --bits 4 \
  --samples 3 \
  --batch-size 2 \
  --max-new-tokens 64 \
  --warmup 5 \
  --verbose

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment