Skip to content

Instantly share code, notes, and snippets.

@ParagEkbote
Last active November 4, 2025 14:15
Show Gist options
  • Select an option

  • Save ParagEkbote/6f22c6ab88763e2a84ff539e0e174830 to your computer and use it in GitHub Desktop.

Select an option

Save ParagEkbote/6f22c6ab88763e2a84ff539e0e174830 to your computer and use it in GitHub Desktop.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pruna import smash, SmashConfig
# Define per-model Smash + torch.compile configs
models = {
"HuggingFaceTB/SmolLM2-360M": {
"bits": 4, "group_size": 64,
"compiler": "torch_compile",
"torch_compile_mode": "max-autotune",
"torch_compile_backend": "inductor",
"torch_compile_fullgraph": True,
"torch_compile_dynamic": False,
"torch_compile_max_kv_cache_size": 400,
"torch_compile_seqlen_manual_cuda_graph": 200,
"torch_compile_make_portable": False,
},
"HuggingFaceTB/SmolLM2-1.7B": {
"bits": 4, "group_size": 128,
"compiler": "torch_compile",
"torch_compile_mode": "default",
"torch_compile_backend": "inductor",
"torch_compile_fullgraph": True,
"torch_compile_dynamic": False,
"torch_compile_max_kv_cache_size": 800,
"torch_compile_seqlen_manual_cuda_graph": 400,
"torch_compile_make_portable": False,
},
"HuggingFaceTB/SmolLM3-3B": {
"bits": 4, "group_size": 128,
"compiler": "torch_compile",
"torch_compile_mode": "reduce-overhead",
"torch_compile_backend": "cudagraphs",
"torch_compile_fullgraph": False,
"torch_compile_dynamic": True,
"torch_compile_max_kv_cache_size": 1600,
"torch_compile_seqlen_manual_cuda_graph": 800,
"torch_compile_make_portable": True,
},
}
device = "cuda" if torch.cuda.is_available() else "cpu"
for model_name, cfg in models.items():
print(f"\n🚀 Processing {model_name}")
# Load base model
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Build SmashConfig
smash_config = SmashConfig(device=device)
smash_config["quantizer"] = "hqq"
smash_config["hqq_weight_bits"] = cfg["bits"]
smash_config["hqq_compute_dtype"] = "torch.bfloat16"
smash_config["hqq_group_size"] = cfg["group_size"]
# Torch compile parameters
for k, v in cfg.items():
if k.startswith("torch_compile") or k == "compiler":
smash_config[k] = v
# Apply smashing
smashed_model = smash(model, smash_config)
# Push to Hub (replace with your HF username)
repo_id = f"AINovice2005/{model_name.split('/')[-1]}-smashed"
print(f"📤 Pushing {repo_id} to Hugging Face Hub...")
smashed_model.push_to_hub(repo_id)
print(f"✅ Done: {repo_id}")
torch==2.7.0
transformers>=4.53.0
accelerate>=1.0.0
pruna==0.2.8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment