Skip to content

Instantly share code, notes, and snippets.

@felipemello1
Created January 14, 2025 17:30
Show Gist options
  • Save felipemello1/55ec8cdcf625b42c1542813c3f2ebf65 to your computer and use it in GitHub Desktop.
Save felipemello1/55ec8cdcf625b42c1542813c3f2ebf65 to your computer and use it in GitHub Desktop.
torchtune vs hf sanity check
from torchtune.models.llama3_1 import llama3_1_8b
from torchtune.models.llama3_2 import llama3_2_1b, llama3_2_3b
from transformers import AutoTokenizer, LlamaForCausalLM
import torch
#######################################
############## llama 8b ###############
#######################################
# load hf
tokenizer_hf = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
model_hf = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
#load tune
model_tune = llama3_1_8b()
model_tune.output_hidden_states=[0,1,2]
from torchtune.training import FullModelHFCheckpointer
checkpointer = FullModelHFCheckpointer(
checkpoint_dir="/tmp/Meta-Llama-3.1-8B-Instruct/",
checkpoint_files=[
"model-00001-of-00004.safetensors",
"model-00002-of-00004.safetensors",
"model-00003-of-00004.safetensors",
"model-00004-of-00004.safetensors"
],
recipe_checkpoint=None, # No recipe checkpoint provided
output_dir="/tmp/Meta-Llama-3.1-8B-Instruct/",
model_type="LLAMA3"
)
model_state_dict = checkpointer.load_checkpoint()["model"]
model_tune.load_state_dict(model_state_dict)
# get tokens
inputs = tokenizer_hf("Hello world!", return_tensors="pt")
# run model
outputs_hf = model_hf(inputs["input_ids"], output_hidden_states=True)
outputs_tune = model_tune(inputs["input_ids"])
# compare outputs
print(outputs_hf.hidden_states[0].shape)
print(outputs_tune[0].shape)
print(outputs_hf.hidden_states[0].mean())
print(outputs_tune[0].mean())
########################################
############### llama 1b ###############
########################################
# load hf
tokenizer_hf = AutoTokenizer.from_pretrained("nltpt/Llama-3.2-1B-Instruct")
model_hf = LlamaForCausalLM.from_pretrained("nltpt/Llama-3.2-1B-Instruct")
#load tune
model_tune = llama3_2_1b()
model_tune.output_hidden_states=[0,1,2]
from torchtune.training import FullModelHFCheckpointer
checkpointer = FullModelHFCheckpointer(
checkpoint_dir="/tmp/Meta-Llama-3.2-1B-Instruct",
checkpoint_files=[
"model.safetensors"
],
recipe_checkpoint=None, # No recipe checkpoint provided
output_dir="/tmp/Meta-Llama-3.2-1B-Instruct",
model_type="LLAMA3"
)
model_state_dict = checkpointer.load_checkpoint()["model"]
model_tune.load_state_dict(model_state_dict)
# get tokens
inputs = tokenizer_hf("Hello world!", return_tensors="pt")
# run model
outputs_hf = model_hf(inputs["input_ids"], output_hidden_states=True)
outputs_tune = model_tune(inputs["input_ids"])
# compare outputs
print(outputs_hf.hidden_states[0].mean())
print(outputs_tune[0].mean())
print(outputs_hf[0].mean())
print(outputs_tune[-1].mean())
assert torch.allclose(outputs_hf.hidden_states[0], outputs_tune[0], atol=1e-4)
assert torch.allclose(outputs_hf[0], outputs_tune[-1], atol=1e-4)
########################################
############### llama 3b ###############
########################################
# load hf
tokenizer_hf = AutoTokenizer.from_pretrained("nltpt/Llama-3.2-3B-Instruct")
model_hf = LlamaForCausalLM.from_pretrained("nltpt/Llama-3.2-3B-Instruct")
#load tune
model_tune = llama3_2_3b()
model_tune.output_hidden_states=[0,1,2]
from torchtune.training import FullModelHFCheckpointer
checkpointer = FullModelHFCheckpointer(
checkpoint_dir="/tmp/Meta-Llama-3.2-3B-Instruct",
checkpoint_files=[
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
],
recipe_checkpoint=None, # No recipe checkpoint provided
output_dir="/tmp/Meta-Llama-3.2-3B-Instruct",
model_type="LLAMA3"
)
model_state_dict = checkpointer.load_checkpoint()["model"]
model_tune.load_state_dict(model_state_dict)
# get tokens
inputs = tokenizer_hf("Hello world!", return_tensors="pt")
# run model
outputs_hf = model_hf(inputs["input_ids"], output_hidden_states=True)
outputs_tune = model_tune(inputs["input_ids"])
# compare outputs
print(outputs_hf.hidden_states[0].mean())
print(outputs_tune[0].mean())
print(outputs_hf[0].mean())
print(outputs_tune[-1].mean())
assert torch.allclose(outputs_hf.hidden_states[0], outputs_tune[0], atol=1e-4)
assert torch.allclose(outputs_hf[0], outputs_tune[-1], atol=1e-4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment