Created
March 12, 2025 18:01
-
-
Save felipemello1/e3f1b1c358e145c7a4d610cf44cca374 to your computer and use it in GitHub Desktop.
llama 3.2 torchtune vs HF
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchtune.models.llama3_1 import llama3_1_8b | |
from torchtune.models.llama3_2 import llama3_2_1b, llama3_2_3b | |
from transformers import AutoTokenizer, LlamaForCausalLM | |
import torch | |
####################################### | |
############## llama 8b ############### | |
####################################### | |
# load hf | |
tokenizer_hf = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") | |
model_hf = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") | |
#load tune | |
model_tune = llama3_1_8b() | |
model_tune.output_hidden_states=[0,1,2] | |
from torchtune.training import FullModelHFCheckpointer | |
checkpointer = FullModelHFCheckpointer( | |
checkpoint_dir="/tmp/Meta-Llama-3.1-8B-Instruct/", | |
checkpoint_files=[ | |
"model-00001-of-00004.safetensors", | |
"model-00002-of-00004.safetensors", | |
"model-00003-of-00004.safetensors", | |
"model-00004-of-00004.safetensors" | |
], | |
recipe_checkpoint=None, # No recipe checkpoint provided | |
output_dir="/tmp/Meta-Llama-3.1-8B-Instruct/", | |
model_type="LLAMA3" | |
) | |
model_state_dict = checkpointer.load_checkpoint()["model"] | |
model_tune.load_state_dict(model_state_dict) | |
# get tokens | |
inputs = tokenizer_hf("Hello world!", return_tensors="pt") | |
# run model | |
outputs_hf = model_hf(inputs["input_ids"], output_hidden_states=True) | |
outputs_tune = model_tune(inputs["input_ids"]) | |
# compare outputs | |
print(outputs_hf.hidden_states[0].shape) | |
print(outputs_tune[0].shape) | |
print(outputs_hf.hidden_states[0].mean()) | |
print(outputs_tune[0].mean()) | |
######################################## | |
############### llama 1b ############### | |
######################################## | |
# load hf | |
tokenizer_hf = AutoTokenizer.from_pretrained("nltpt/Llama-3.2-1B-Instruct") | |
model_hf = LlamaForCausalLM.from_pretrained("nltpt/Llama-3.2-1B-Instruct") | |
#load tune | |
model_tune = llama3_2_1b() | |
model_tune.output_hidden_states=[0,1,2] | |
from torchtune.training import FullModelHFCheckpointer | |
checkpointer = FullModelHFCheckpointer( | |
checkpoint_dir="/tmp/Meta-Llama-3.2-1B-Instruct", | |
checkpoint_files=[ | |
"model.safetensors" | |
], | |
recipe_checkpoint=None, # No recipe checkpoint provided | |
output_dir="/tmp/Meta-Llama-3.2-1B-Instruct", | |
model_type="LLAMA3" | |
) | |
model_state_dict = checkpointer.load_checkpoint()["model"] | |
model_tune.load_state_dict(model_state_dict) | |
# get tokens | |
inputs = tokenizer_hf("Hello world!", return_tensors="pt") | |
# run model | |
outputs_hf = model_hf(inputs["input_ids"], output_hidden_states=True) | |
outputs_tune = model_tune(inputs["input_ids"]) | |
# compare outputs | |
print(outputs_hf.hidden_states[0].mean()) | |
print(outputs_tune[0].mean()) | |
print(outputs_hf[0].mean()) | |
print(outputs_tune[-1].mean()) | |
assert torch.allclose(outputs_hf.hidden_states[0], outputs_tune[0], atol=1e-4) | |
assert torch.allclose(outputs_hf[0], outputs_tune[-1], atol=1e-4) | |
######################################## | |
############### llama 3b ############### | |
######################################## | |
# load hf | |
tokenizer_hf = AutoTokenizer.from_pretrained("nltpt/Llama-3.2-3B-Instruct") | |
model_hf = LlamaForCausalLM.from_pretrained("nltpt/Llama-3.2-3B-Instruct") | |
#load tune | |
model_tune = llama3_2_3b() | |
model_tune.output_hidden_states=[0,1,2] | |
from torchtune.training import FullModelHFCheckpointer | |
checkpointer = FullModelHFCheckpointer( | |
checkpoint_dir="/tmp/Meta-Llama-3.2-3B-Instruct", | |
checkpoint_files=[ | |
"model-00001-of-00002.safetensors", | |
"model-00002-of-00002.safetensors", | |
], | |
recipe_checkpoint=None, # No recipe checkpoint provided | |
output_dir="/tmp/Meta-Llama-3.2-3B-Instruct", | |
model_type="LLAMA3" | |
) | |
model_state_dict = checkpointer.load_checkpoint()["model"] | |
model_tune.load_state_dict(model_state_dict) | |
# get tokens | |
inputs = tokenizer_hf("Hello world!", return_tensors="pt") | |
# run model | |
outputs_hf = model_hf(inputs["input_ids"], output_hidden_states=True) | |
outputs_tune = model_tune(inputs["input_ids"]) | |
# compare outputs | |
print(outputs_hf.hidden_states[0].mean()) | |
print(outputs_tune[0].mean()) | |
print(outputs_hf[0].mean()) | |
print(outputs_tune[-1].mean()) | |
assert torch.allclose(outputs_hf.hidden_states[0], outputs_tune[0], atol=1e-4) | |
assert torch.allclose(outputs_hf[0], outputs_tune[-1], atol=1e-4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment