-
-
Save suxue/7dd408a1a5f312f5acdd02abb3d9a9ef to your computer and use it in GitHub Desktop.
`transformers` + `torchao` quantization + `torch.compile` on Llama3.1 8B
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# REQUIRES torchao, torch nightly (or torch 2.5) and transformers | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TorchAoConfig | |
from transformers import TextStreamer | |
import torch | |
from tqdm import tqdm | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) | |
torch.set_float32_matmul_precision('high') | |
# Other configuration options | |
DEVICE = "cuda:0" | |
NUM_RUNS = 10 | |
MAX_NEW_TOKENS = 500 | |
# Load the model and prepare generate args | |
repo_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
# Set the quantization config | |
# You can choose between int4_weight_only (4-bit), int8_weight_only (8-bit) and int8_dynamic_activation_int8_weight (8-bit) | |
# group_size is only for int4_weight_only and needs to be one of [32,64,128,256] | |
# quantization_config = TorchAoConfig(quant_type="int4_weight_only", group_size=128) | |
# Loading the quantized model takes 6218 MB | |
model = AutoModelForCausalLM.from_pretrained(repo_id, | |
torch_dtype=torch.bfloat16, | |
device_map=DEVICE | |
) | |
model.generation_config.cache_implementation = "static" | |
# model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) | |
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True) | |
messages = [ | |
{"role": "user", "content": "Write a story: "}, | |
] | |
model_inputs = tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt", | |
return_dict=True, | |
).to(DEVICE) | |
generate_kwargs = { | |
"max_new_tokens": MAX_NEW_TOKENS, | |
"do_sample": True, | |
"temperature": 0.2, | |
"eos_token_id": -1 # forces the generation of `max_new_tokens` | |
} | |
# Warmup | |
print("Warming up...") | |
for _ in range(2): | |
gen_out = model.generate(**model_inputs, **generate_kwargs) | |
print("Done!") | |
# Measure OR Stream | |
def measure_generate(model, model_inputs, generate_kwargs): | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
torch.cuda.reset_peak_memory_stats(DEVICE) | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
start_event.record() | |
for _ in tqdm(range(NUM_RUNS)): | |
gen_out = model.generate(**model_inputs, **generate_kwargs) | |
end_event.record() | |
torch.cuda.synchronize() | |
max_memory = torch.cuda.max_memory_allocated(DEVICE) | |
print("Max memory (MB): ", max_memory * 1e-6) | |
print("Throughput (tokens/sec): ", (NUM_RUNS * MAX_NEW_TOKENS) / (start_event.elapsed_time(end_event) * 1.0e-3)) | |
def stream_generate(model, model_inputs, generate_kwargs): | |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
gen_out = model.generate(**model_inputs,streamer=streamer, **generate_kwargs) | |
stream_generate(model, model_inputs, generate_kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment