Last active
August 19, 2023 10:59
-
-
Save iakashpaul/e5ecf96a5ea7cf088668ad7f2e4f954b to your computer and use it in GitHub Desktop.
run llama2 with less than 8gb VRAM snippet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Below snippet lets you run Llama2 &B variants under 8GB VRAM comfortably, usually needing 5-7GB VRAM depending on configs | |
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM | |
# run hugginface-cli login once after generating a read token after applying for the model weights from the same HF email id | |
model_name = 'meta-llama/Llama-2-7b-chat-hf' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model_0 = AutoModelForCausalLM.from_pretrained(model_name, device_map=0,load_in_4bit=True,torch_dtype=torch.bfloat16) | |
MAX_LENGTH=100 | |
print("MAX_LENGTH",MAX_LENGTH) | |
for i in tqdm(range(10)): | |
prompt = "Hey, are you conscious? Can you talk to me?" # Change this to better match the reference prompt template of the chat model | |
inputs = tokenizer(prompt, return_tensors="pt") | |
generate_ids = model_0.generate(inputs.input_ids.to('cuda'), max_length=MAX_LENGTH) | |
print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment