Last active
November 4, 2025 17:59
-
-
Save ParagEkbote/1dd80f591e3f4b09e36beaaaf748738e to your computer and use it in GitHub Desktop.
Pruna Eval Script for SmollM Collection: https://huggingface.co/collections/AINovice2005/smollm-smashed-tiny-giants-optimized-for-speed-68e16113d90eab12bae42a34
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| import argparse | |
| from datetime import datetime | |
| from tqdm import tqdm | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from pruna import smash, SmashConfig | |
| from pruna.data.pruna_datamodule import PrunaDataModule | |
| from pruna.evaluation.evaluation_agent import EvaluationAgent | |
| from pruna.evaluation.task import Task | |
| from pruna.evaluation.metrics import ( | |
| TotalTimeMetric, | |
| LatencyMetric, | |
| ThroughputMetric, | |
| TotalParamsMetric, | |
| TotalMACsMetric, | |
| ) | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Enhanced Model Evaluation with Quantization') | |
| # Model parameters | |
| parser.add_argument( | |
| '--models', type=str, nargs='+', required=True, | |
| help='Model(s) to evaluate. Can pass space-separated list or comma-separated string.' | |
| ) | |
| parser.add_argument('--quantization', type=str, default="hqq", | |
| choices=['hqq', 'torchao', 'llm_int8', 'half', 'none'], help='Quantization method') | |
| parser.add_argument('--bits', type=int, default=4, choices=[4, 8, 16]) | |
| parser.add_argument('--compile', action='store_true', default=False) | |
| parser.add_argument('--compile-mode', type=str, default='max-autotune', | |
| choices=['default', 'reduce-overhead', 'max-autotune']) | |
| # Evaluation parameters | |
| parser.add_argument('--iters', type=int, default=100) | |
| parser.add_argument('--warmup', type=int, default=10) | |
| parser.add_argument('--max-new-tokens', type=int, default=128) | |
| parser.add_argument('--batch-size', type=int, default=8) | |
| parser.add_argument('--seq-len', type=int, default=512) | |
| parser.add_argument('--samples', type=int, default=5) | |
| # Output | |
| parser.add_argument('--output-dir', type=str, default='reports') | |
| # Hardware | |
| parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cuda', 'cpu']) | |
| parser.add_argument('--num-workers', type=int, default=0) | |
| parser.add_argument('--verbose', action='store_true') | |
| return parser.parse_args() | |
| def format_number(num): | |
| if num >= 1e12: return f"{num / 1e12:.2f}T" | |
| if num >= 1e9: return f"{num / 1e9:.2f}B" | |
| if num >= 1e6: return f"{num / 1e6:.2f}M" | |
| if num >= 1e3: return f"{num / 1e3:.2f}K" | |
| return f"{num:.2f}" | |
| def format_time(microseconds): | |
| if microseconds >= 1e6: return f"{microseconds / 1e6:.2f}s" | |
| if microseconds >= 1e3: return f"{microseconds / 1e3:.2f}ms" | |
| return f"{microseconds:.2f}μs" | |
| def format_memory(bytes_val): | |
| if bytes_val >= 1024**3: return f"{bytes_val / (1024**3):.2f} GB" | |
| if bytes_val >= 1024**2: return f"{bytes_val / (1024**2):.2f} MB" | |
| if bytes_val >= 1024: return f"{bytes_val / 1024:.2f} KB" | |
| return f"{bytes_val} B" | |
| def get_model_device(model): | |
| """Get the device of the model's parameters""" | |
| try: | |
| return next(model.parameters()).device | |
| except StopIteration: | |
| return torch.device('cpu') | |
| def ensure_model_on_device(model, device): | |
| """Ensure model is on the specified device""" | |
| model_device = get_model_device(model) | |
| if str(model_device) != device: | |
| print(f" ⚠️ Model device mismatch. Moving from {model_device} to {device}...") | |
| model = model.to(device) | |
| return model | |
| def create_metrics(): | |
| return [ | |
| TotalTimeMetric(), | |
| LatencyMetric(), | |
| ThroughputMetric(), | |
| TotalParamsMetric(), | |
| TotalMACsMetric(), | |
| ] | |
| def validate_metrics(results, verbose=False): | |
| """Validate that metrics were recorded correctly""" | |
| validation_passed = True | |
| validation_report = [] | |
| for metric_result in results: | |
| metric_name = metric_result.name | |
| metric_value = metric_result.result | |
| # Check if metric has valid data | |
| is_valid = False | |
| error_msg = None | |
| warning_msg = None | |
| if isinstance(metric_value, dict): | |
| if 'mean' in metric_value and metric_value['mean'] > 0: | |
| is_valid = True | |
| # Check for unrealistic values | |
| if metric_name == 'latency': | |
| latency_ms = metric_value['mean'] / 1000 | |
| if latency_ms < 1: # Less than 1ms is suspicious | |
| warning_msg = f"⚠️ Suspiciously low ({latency_ms:.4f}ms)" | |
| elif latency_ms > 10000: # More than 10s is suspicious | |
| warning_msg = f"⚠️ Suspiciously high ({latency_ms:.2f}ms)" | |
| else: | |
| error_msg = f"Dict missing 'mean' or mean is zero/negative" | |
| elif isinstance(metric_value, (int, float)): | |
| if metric_value > 0: | |
| is_valid = True | |
| # Check for unrealistic values | |
| if metric_name == 'latency': | |
| latency_ms = metric_value / 1000 | |
| if latency_ms < 1: | |
| warning_msg = f"⚠️ Suspiciously low ({latency_ms:.4f}ms)" | |
| else: | |
| error_msg = f"Value is zero or negative: {metric_value}" | |
| elif isinstance(metric_value, list): | |
| if len(metric_value) > 0 and any(v > 0 for v in metric_value): | |
| is_valid = True | |
| else: | |
| error_msg = f"List is empty or all values are zero/negative" | |
| else: | |
| error_msg = f"Unexpected type: {type(metric_value)}" | |
| status = "✓" if is_valid else "✗" | |
| msg = f"{status} {metric_name}: " | |
| if warning_msg: | |
| msg += warning_msg | |
| elif error_msg: | |
| msg += error_msg | |
| else: | |
| msg += "OK" | |
| validation_report.append(msg) | |
| if not is_valid: | |
| validation_passed = False | |
| if verbose or not validation_passed: | |
| print("\n📊 Metrics Validation:") | |
| for line in validation_report: | |
| print(f" {line}") | |
| return validation_passed, validation_report | |
| def generate_report(model_name, results, cuda_allocated, cuda_reserved, total_params, args, validation_report, manual_test_results=None): | |
| # Extract metrics safely | |
| metrics_dict = {} | |
| for r in results: | |
| try: | |
| if isinstance(r.result, dict): | |
| metrics_dict[r.name] = r.result.get("mean", 0) | |
| elif isinstance(r.result, (int, float)): | |
| metrics_dict[r.name] = r.result | |
| elif isinstance(r.result, list): | |
| metrics_dict[r.name] = sum(r.result)/len(r.result) if r.result else 0 | |
| else: | |
| metrics_dict[r.name] = 0 | |
| except Exception: | |
| metrics_dict[r.name] = 0 | |
| latency_ms = metrics_dict.get('latency', 0)/1000 | |
| throughput_samples_per_sec = 1 / (latency_ms/1000) if latency_ms>0 else 0 | |
| tokens_per_sec = throughput_samples_per_sec * args.max_new_tokens | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| # Sanity check flags | |
| warnings = [] | |
| if latency_ms < 1: | |
| warnings.append("⚠️ **WARNING**: Latency is unrealistically low. Metrics may not be measuring actual inference.") | |
| if tokens_per_sec > 100000: | |
| warnings.append("⚠️ **WARNING**: Token generation rate is impossibly high. Check metric configuration.") | |
| if total_params < 1_000_000: | |
| warnings.append("⚠️ **WARNING**: Parameter count seems too low for this model.") | |
| markdown = f"""# Model Evaluation Report: {model_name} | |
| **Date:** {timestamp} | |
| **Device:** {args.device if args.device != 'auto' else 'CUDA' if torch.cuda.is_available() else 'CPU'} | |
| **Quantization:** {args.quantization.upper()} {args.bits}-bit | |
| **Torch Compile:** {"Yes" if args.compile else "No"} | |
| """ | |
| if warnings: | |
| markdown += "## ⚠️ Sanity Check Warnings\n" | |
| for warning in warnings: | |
| markdown += f"{warning}\n\n" | |
| markdown += f"""## Performance Metrics | |
| - Total Parameters: {format_number(total_params)} | |
| - Average Latency: {format_time(metrics_dict.get('latency', 0))} | |
| - Throughput (samples/sec): {throughput_samples_per_sec:.2f} | |
| - Token Generation Rate: {tokens_per_sec:.1f} tokens/sec | |
| - GPU Memory Allocated: {format_memory(cuda_allocated * 1024**2)} | |
| - GPU Memory Reserved: {format_memory(cuda_reserved * 1024**2)} | |
| """ | |
| if manual_test_results: | |
| markdown += f""" | |
| ## Manual Verification Test | |
| - Actual Generation Time: {manual_test_results['time_ms']:.2f}ms | |
| - Tokens Generated: {manual_test_results['tokens']} | |
| - Measured Tokens/sec: {manual_test_results['tokens_per_sec']:.1f} | |
| - **Discrepancy**: {abs(tokens_per_sec - manual_test_results['tokens_per_sec']) / manual_test_results['tokens_per_sec'] * 100:.1f}% difference | |
| """ | |
| markdown += f""" | |
| ## Detailed Metrics | |
| | Metric | Value | | |
| |--------|-------| | |
| """ | |
| for name, value in metrics_dict.items(): | |
| markdown += f"| {name} | {format_number(value) if value > 1000 else f'{value:.4f}'} |\n" | |
| markdown += f""" | |
| ## Validation Report | |
| ``` | |
| """ | |
| for line in validation_report: | |
| markdown += f"{line}\n" | |
| markdown += "```\n" | |
| return markdown | |
| def main(): | |
| args = parse_args() | |
| # Flatten comma-separated models | |
| model_list = [] | |
| for m in args.models: | |
| model_list.extend([x.strip() for x in m.split(',') if x.strip()]) | |
| args.models = model_list | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Determine device | |
| device = "cuda" if (args.device=="auto" and torch.cuda.is_available()) else args.device | |
| for model_name in tqdm(args.models, desc="Evaluating models", unit="model"): | |
| print(f"\n🚀 Evaluating {model_name} on {device.upper()}") | |
| try: | |
| # Load model & tokenizer | |
| print(" Loading model and tokenizer...") | |
| model = AutoModelForCausalLM.from_pretrained(model_name).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Apply quantization if needed | |
| if args.quantization != 'none': | |
| print(f" Applying {args.quantization} quantization ({args.bits}-bit)...") | |
| cfg = SmashConfig(device=device) | |
| cfg["quantizer"] = args.quantization | |
| cfg[f"{args.quantization}_weight_bits"] = args.bits | |
| cfg[f"{args.quantization}_compute_dtype"] = "torch.bfloat16" | |
| if args.compile: | |
| cfg["compiler"] = "torch_compile" | |
| cfg["torch_compile_fullgraph"] = True | |
| cfg["torch_compile_dynamic"] = True | |
| cfg["torch_compile_mode"] = args.compile_mode | |
| model = smash(model, cfg) | |
| # Ensure model is on correct device after quantization | |
| model = ensure_model_on_device(model, device) | |
| # Double-check model device | |
| if args.verbose: | |
| print(f" Model is on device: {get_model_device(model)}") | |
| # Prepare dataset | |
| print(f" Preparing dataset ({args.samples} samples)...") | |
| datamodule = PrunaDataModule.from_string( | |
| dataset_name="WikiText", | |
| tokenizer=tokenizer, | |
| collate_fn_args={"max_seq_len": args.seq_len}, | |
| dataloader_args={"batch_size": args.batch_size, "num_workers": args.num_workers} | |
| ) | |
| datamodule.limit_datasets(args.samples) | |
| # Inspect batch format for debugging | |
| if args.verbose: | |
| test_loader = datamodule.train_dataloader() | |
| test_batch = next(iter(test_loader)) | |
| print(f" Batch format: {type(test_batch)}") | |
| if isinstance(test_batch, dict): | |
| print(f" Batch keys: {test_batch.keys()}") | |
| elif isinstance(test_batch, (tuple, list)): | |
| print(f" Batch length: {len(test_batch)}") | |
| print(f" Batch element shapes: {[b.shape if isinstance(b, torch.Tensor) else type(b) for b in test_batch]}") | |
| del test_loader, test_batch | |
| # Warm-up phase | |
| if args.warmup > 0: | |
| print(f" Running {args.warmup} warmup iterations...") | |
| warmup_loader = datamodule.train_dataloader() | |
| for i, batch in enumerate(tqdm(warmup_loader, desc=" Warmup", total=args.warmup, leave=False)): | |
| if i >= args.warmup: | |
| break | |
| # Handle both dict and tuple batch formats | |
| if isinstance(batch, dict): | |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} | |
| with torch.no_grad(): | |
| _ = model(**batch) | |
| elif isinstance(batch, (tuple, list)): | |
| batch = tuple(b.to(device) if isinstance(b, torch.Tensor) else b for b in batch) | |
| with torch.no_grad(): | |
| if len(batch) == 1: | |
| _ = model(batch[0]) | |
| else: | |
| _ = model(*batch) | |
| else: | |
| # Single tensor | |
| batch = batch.to(device) if isinstance(batch, torch.Tensor) else batch | |
| with torch.no_grad(): | |
| _ = model(batch) | |
| if device == "cuda": | |
| torch.cuda.synchronize() | |
| # Evaluation | |
| print(f" Running evaluation on {args.samples} samples...") | |
| metrics = create_metrics() | |
| task = Task(metrics, datamodule=datamodule) | |
| eval_agent = EvaluationAgent(task) | |
| # Force CUDA sync before evaluation for accurate timing | |
| if device == "cuda": | |
| torch.cuda.synchronize() | |
| try: | |
| results = eval_agent.evaluate(model) | |
| except AttributeError as e: | |
| if "'tuple' object has no attribute" in str(e): | |
| print(" ⚠️ Detected batch format issue. Wrapping model to handle tuple inputs...") | |
| # Create a wrapper that handles tuple inputs | |
| class ModelWrapper(torch.nn.Module): | |
| def __init__(self, base_model): | |
| super().__init__() | |
| self.model = base_model | |
| def forward(self, *args, **kwargs): | |
| if args and not kwargs: | |
| # Tuple/positional args - convert to dict | |
| if len(args) == 1: | |
| return self.model(input_ids=args[0]) | |
| elif len(args) == 2: | |
| return self.model(input_ids=args[0], attention_mask=args[1]) | |
| else: | |
| return self.model(*args) | |
| else: | |
| return self.model(*args, **kwargs) | |
| wrapped_model = ModelWrapper(model) | |
| results = eval_agent.evaluate(wrapped_model) | |
| else: | |
| raise | |
| # Force CUDA sync after evaluation | |
| if device == "cuda": | |
| torch.cuda.synchronize() | |
| # Manual sanity check with actual inference | |
| manual_test_results = None | |
| if args.verbose: | |
| print(" Running manual sanity check...") | |
| import time | |
| # Ensure model is on correct device | |
| model = ensure_model_on_device(model, device) | |
| model_device = get_model_device(model) | |
| # Create test input on the same device as model | |
| test_input = tokenizer("Hello world", return_tensors="pt") | |
| test_input = {k: v.to(model_device) for k, v in test_input.items()} | |
| # Warm up | |
| for _ in range(3): | |
| with torch.no_grad(): | |
| _ = model.generate(**test_input, max_new_tokens=10) | |
| if device == "cuda": | |
| torch.cuda.synchronize() | |
| # Time actual generation | |
| start = time.perf_counter() | |
| with torch.no_grad(): | |
| output = model.generate(**test_input, max_new_tokens=args.max_new_tokens) | |
| if device == "cuda": | |
| torch.cuda.synchronize() | |
| end = time.perf_counter() | |
| actual_time_ms = (end - start) * 1000 | |
| actual_tokens = output.shape[1] - test_input['input_ids'].shape[1] | |
| actual_tokens_per_sec = actual_tokens / (end - start) | |
| manual_test_results = { | |
| 'time_ms': actual_time_ms, | |
| 'tokens': actual_tokens, | |
| 'tokens_per_sec': actual_tokens_per_sec | |
| } | |
| print(f" Manual test: {actual_time_ms:.2f}ms for {actual_tokens} tokens") | |
| print(f" Manual tokens/sec: {actual_tokens_per_sec:.1f}") | |
| # Validate metrics | |
| validation_passed, validation_report = validate_metrics(results, verbose=args.verbose) | |
| if not validation_passed: | |
| print(" ⚠️ WARNING: Some metrics may not have been recorded correctly!") | |
| # GPU memory info | |
| if device=="cuda": | |
| cuda_allocated = torch.cuda.memory_allocated()/1024**2 | |
| cuda_reserved = torch.cuda.memory_reserved()/1024**2 | |
| else: | |
| cuda_allocated = cuda_reserved = 0 | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| # Generate markdown report | |
| report = generate_report(model_name, results, cuda_allocated, cuda_reserved, | |
| total_params, args, validation_report, manual_test_results) | |
| output_path = os.path.join(args.output_dir, f"{model_name.replace('/', '_')}_report.md") | |
| with open(output_path, 'w') as f: | |
| f.write(report) | |
| print(f"✅ Report saved to {output_path}") | |
| # Clean up | |
| del model | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| except Exception as e: | |
| print(f"❌ Error evaluating {model_name}: {str(e)}") | |
| if args.verbose: | |
| import traceback | |
| traceback.print_exc() | |
| continue | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| torch==2.7.0 | |
| transformers>=4.53.0 | |
| accelerate>=1.0.0 | |
| pruna==0.2.8 |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
We can run the script as follows, for multi-model eval:
python eval_script.py \ --models "HuggingFaceTB/SmolLM-135M,HuggingFaceTB/SmolLM-1.7B,HuggingFaceTB/SmolLM3-3B" \ --quantization hqq --bits 4 --compile \ --output-dir reports \ --batch-size 8 --max-new-tokens 128 --samples 5For single model eval:
python eval_script.py \ --models "HuggingFaceTB/SmolLM2-1.7B" \ --quantization hqq \ --bits 4 \ --samples 3 \ --batch-size 2 \ --max-new-tokens 64 \ --warmup 5 \ --verbose