Created
January 14, 2025 17:30
-
-
Save felipemello1/55ec8cdcf625b42c1542813c3f2ebf65 to your computer and use it in GitHub Desktop.
torchtune vs hf sanity check
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchtune.models.llama3_1 import llama3_1_8b | |
from torchtune.models.llama3_2 import llama3_2_1b, llama3_2_3b | |
from transformers import AutoTokenizer, LlamaForCausalLM | |
import torch | |
####################################### | |
############## llama 8b ############### | |
####################################### | |
# load hf | |
tokenizer_hf = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") | |
model_hf = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") | |
#load tune | |
model_tune = llama3_1_8b() | |
model_tune.output_hidden_states=[0,1,2] | |
from torchtune.training import FullModelHFCheckpointer | |
checkpointer = FullModelHFCheckpointer( | |
checkpoint_dir="/tmp/Meta-Llama-3.1-8B-Instruct/", | |
checkpoint_files=[ | |
"model-00001-of-00004.safetensors", | |
"model-00002-of-00004.safetensors", | |
"model-00003-of-00004.safetensors", | |
"model-00004-of-00004.safetensors" | |
], | |
recipe_checkpoint=None, # No recipe checkpoint provided | |
output_dir="/tmp/Meta-Llama-3.1-8B-Instruct/", | |
model_type="LLAMA3" | |
) | |
model_state_dict = checkpointer.load_checkpoint()["model"] | |
model_tune.load_state_dict(model_state_dict) | |
# get tokens | |
inputs = tokenizer_hf("Hello world!", return_tensors="pt") | |
# run model | |
outputs_hf = model_hf(inputs["input_ids"], output_hidden_states=True) | |
outputs_tune = model_tune(inputs["input_ids"]) | |
# compare outputs | |
print(outputs_hf.hidden_states[0].shape) | |
print(outputs_tune[0].shape) | |
print(outputs_hf.hidden_states[0].mean()) | |
print(outputs_tune[0].mean()) | |
######################################## | |
############### llama 1b ############### | |
######################################## | |
# load hf | |
tokenizer_hf = AutoTokenizer.from_pretrained("nltpt/Llama-3.2-1B-Instruct") | |
model_hf = LlamaForCausalLM.from_pretrained("nltpt/Llama-3.2-1B-Instruct") | |
#load tune | |
model_tune = llama3_2_1b() | |
model_tune.output_hidden_states=[0,1,2] | |
from torchtune.training import FullModelHFCheckpointer | |
checkpointer = FullModelHFCheckpointer( | |
checkpoint_dir="/tmp/Meta-Llama-3.2-1B-Instruct", | |
checkpoint_files=[ | |
"model.safetensors" | |
], | |
recipe_checkpoint=None, # No recipe checkpoint provided | |
output_dir="/tmp/Meta-Llama-3.2-1B-Instruct", | |
model_type="LLAMA3" | |
) | |
model_state_dict = checkpointer.load_checkpoint()["model"] | |
model_tune.load_state_dict(model_state_dict) | |
# get tokens | |
inputs = tokenizer_hf("Hello world!", return_tensors="pt") | |
# run model | |
outputs_hf = model_hf(inputs["input_ids"], output_hidden_states=True) | |
outputs_tune = model_tune(inputs["input_ids"]) | |
# compare outputs | |
print(outputs_hf.hidden_states[0].mean()) | |
print(outputs_tune[0].mean()) | |
print(outputs_hf[0].mean()) | |
print(outputs_tune[-1].mean()) | |
assert torch.allclose(outputs_hf.hidden_states[0], outputs_tune[0], atol=1e-4) | |
assert torch.allclose(outputs_hf[0], outputs_tune[-1], atol=1e-4) | |
######################################## | |
############### llama 3b ############### | |
######################################## | |
# load hf | |
tokenizer_hf = AutoTokenizer.from_pretrained("nltpt/Llama-3.2-3B-Instruct") | |
model_hf = LlamaForCausalLM.from_pretrained("nltpt/Llama-3.2-3B-Instruct") | |
#load tune | |
model_tune = llama3_2_3b() | |
model_tune.output_hidden_states=[0,1,2] | |
from torchtune.training import FullModelHFCheckpointer | |
checkpointer = FullModelHFCheckpointer( | |
checkpoint_dir="/tmp/Meta-Llama-3.2-3B-Instruct", | |
checkpoint_files=[ | |
"model-00001-of-00002.safetensors", | |
"model-00002-of-00002.safetensors", | |
], | |
recipe_checkpoint=None, # No recipe checkpoint provided | |
output_dir="/tmp/Meta-Llama-3.2-3B-Instruct", | |
model_type="LLAMA3" | |
) | |
model_state_dict = checkpointer.load_checkpoint()["model"] | |
model_tune.load_state_dict(model_state_dict) | |
# get tokens | |
inputs = tokenizer_hf("Hello world!", return_tensors="pt") | |
# run model | |
outputs_hf = model_hf(inputs["input_ids"], output_hidden_states=True) | |
outputs_tune = model_tune(inputs["input_ids"]) | |
# compare outputs | |
print(outputs_hf.hidden_states[0].mean()) | |
print(outputs_tune[0].mean()) | |
print(outputs_hf[0].mean()) | |
print(outputs_tune[-1].mean()) | |
assert torch.allclose(outputs_hf.hidden_states[0], outputs_tune[0], atol=1e-4) | |
assert torch.allclose(outputs_hf[0], outputs_tune[-1], atol=1e-4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment