Created
January 21, 2025 00:11
-
-
Save sebaxakerhtc/5e7faa4ead6e2f4e0ea69634c3f624ba to your computer and use it in GitHub Desktop.
Guided script for Unsloth
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from unsloth import FastLanguageModel, is_bfloat16_supported | |
import torch | |
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only | |
from datasets import load_dataset | |
from huggingface_hub import login | |
from trl import SFTTrainer | |
from transformers import TrainingArguments, DataCollatorForSeq2Seq, TextStreamer | |
def formatting_prompts_func(examples): | |
convos = examples["conversations"] | |
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos] | |
return { "text" : texts, } | |
des_max_seq_length = int(input('max_seq_length (2048): ') or "2048") | |
max_seq_length = des_max_seq_length # Choose any! We auto support RoPE Scaling internally! | |
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ | |
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. | |
# 4bit pre quantized models we support for 4x faster downloading + no OOMs. | |
fourbit_models = [ | |
"############# Our popular models: #############", | |
"# unsloth/Meta-Llama-3.1-8B-bnb-4bit #", # Llama-3.1 2x faster | |
"# unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit #", | |
"# unsloth/Meta-Llama-3.1-70B-bnb-4bit #", | |
"# unsloth/Meta-Llama-3.1-405B-bnb-4bit #", # 4bit for 405b! | |
"# unsloth/Mistral-Small-Instruct-2409 #", # Mistral 22b 2x faster! | |
"# unsloth/mistral-7b-instruct-v0.3-bnb-4bit #", | |
"# unsloth/Phi-3.5-mini-instruct #", # Phi-3.5 2x faster! | |
"# unsloth/Phi-3-medium-4k-instruct #", | |
"# unsloth/Phi-4 #", # Phi-4 2x faster! | |
"# unsloth/Phi-4-unsloth-bnb-4bit #", # Phi-4 Unsloth Dynamic 4-bit Quant | |
"# unsloth/gemma-2-9b-bnb-4bit #", | |
"# unsloth/gemma-2-27b-bnb-4bit #", # Gemma 2x faster! | |
"# unsloth/Llama-3.2-1B-bnb-4bit #", # NEW! Llama 3.2 models | |
"# unsloth/Llama-3.2-1B-Instruct-bnb-4bit #", | |
"# unsloth/Llama-3.2-3B-bnb-4bit #", | |
"# unsloth/Llama-3.2-3B-Instruct-bnb-4bit #", | |
"# unsloth/Llama-3.3-70B-Instruct-bnb-4bit #", # NEW! Llama 3.3 70B! | |
"###############################################" | |
] # More models at https://huggingface.co/unsloth | |
print("\n".join(fourbit_models)) | |
des_model_name = str(input('Enter model_name (unsloth/Llama-3.2-1B-Instruct): ') or "unsloth/Llama-3.2-1B-Instruct") | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name = des_model_name, # or choose "unsloth/Llama-3.2-3B-Instruct" | |
max_seq_length = max_seq_length, | |
dtype = dtype, | |
load_in_4bit = load_in_4bit, | |
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf | |
) | |
des_lora_rank = int(input('LoRA rank (16): ') or "16") | |
des_lora_alpha = int(input('LoRA alpha (16): ') or "16") | |
des_lora_dropout = int(input('LoRA dropout (0): ') or "0") | |
model = FastLanguageModel.get_peft_model( | |
model, | |
r = des_lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 | |
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", | |
"gate_proj", "up_proj", "down_proj",], | |
lora_alpha = des_lora_alpha, | |
lora_dropout = des_lora_dropout, # Supports any, but = 0 is optimized | |
bias = "none", # Supports any, but = "none" is optimized | |
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! | |
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context | |
random_state = 3407, | |
use_rslora = False, # We support rank stabilized LoRA | |
loftq_config = None, # And LoftQ | |
) | |
print("###############################################################################") | |
print("### We use our get_chat_template function to get the correct chat template. ###") | |
print("### We support zephyr, chatml, mistral, llama, alpaca, vicuna, ###") | |
print("### vicuna_old, phi3, phi4, llama3 and more. ###") | |
print("###############################################################################") | |
des_chat_template = str(input('Enter chat_template (llama-3.1): ') or "llama-3.1") | |
des_sys_message = str(input('Enter default_system_message (Optional): ') or "Below are some instructions that describe some tasks. Write responses that appropriately complete each request.") | |
tokenizer = get_chat_template( | |
tokenizer, | |
chat_template = des_chat_template, | |
system_message = des_sys_message, | |
) | |
des_dataset = str(input('Enter dataset (mlabonne/FineTome-100k): ') or "mlabonne/FineTome-100k") | |
des_split = str(input('Enter split (train): ') or "train") | |
des_train_size = int(input('Dataset size in % (1): ') or "1") | |
des_train_size = des_train_size / 100 | |
# login("hf_...") if you want private datasets | |
dataset = load_dataset(des_dataset, split = des_split) | |
dataset = dataset.train_test_split(train_size = des_train_size)[des_split] | |
dataset = standardize_sharegpt(dataset) | |
dataset = dataset.map(formatting_prompts_func, batched = True,) | |
print("### Loaded dataset: ###") | |
print(des_dataset) | |
print("####################################################################################") | |
print(dataset[5]["text"]) | |
print("####################################################################################") | |
des_batch_size = int(input('batch_size (2): ') or "2") | |
des_gas = int(input('gradient_accumulation_steps (4): ') or "4") | |
des_train_on_responses_only = '' | |
while True: | |
des_train_on_responses_only = input('train_on_responses_only? True / False: ') | |
if des_train_on_responses_only.capitalize() == 'True': | |
print('The user typed in True') | |
break | |
elif des_train_on_responses_only.capitalize() == 'False': | |
print('The user typed in False') | |
break | |
else: | |
print('Enter True or False') | |
continue | |
trainer_args = { | |
"model": model, | |
"tokenizer": tokenizer, | |
"train_dataset": dataset, | |
"dataset_text_field": "text", | |
"max_seq_length": max_seq_length, | |
"dataset_num_proc": 1, | |
"packing": False, # Can make training 5x faster for short sequences. | |
"args": TrainingArguments( | |
per_device_train_batch_size=des_batch_size, | |
gradient_accumulation_steps=des_gas, | |
warmup_ratio=0.1, | |
num_train_epochs=1, # Set this for 1 full training run. | |
learning_rate=2e-4, | |
fp16=not is_bfloat16_supported(), | |
bf16=is_bfloat16_supported(), | |
logging_steps=1, | |
optim="adamw_8bit", | |
weight_decay=0.01, | |
lr_scheduler_type="linear", | |
seed=3407, | |
output_dir="outputs", | |
report_to="none", # Use this for WandB etc | |
), | |
} | |
if des_train_on_responses_only == 'True': | |
trainer_args["data_collator"] = DataCollatorForSeq2Seq(tokenizer=tokenizer) | |
trainer = SFTTrainer(**trainer_args) | |
if des_train_on_responses_only == 'True': | |
trainer = train_on_responses_only( | |
trainer, | |
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n", | |
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n", | |
) | |
print("############################# - Before masking - ####################################") | |
print(tokenizer.decode(trainer.train_dataset[5]["input_ids"])) | |
space = tokenizer(" ", add_special_tokens = False).input_ids[0] | |
print("############################# - After masking - ####################################") | |
print(tokenizer.decode([space if x == -100 else x for x in trainer.train_dataset[5]["labels"]])) | |
print("####################################################################################") | |
gpu_stats = torch.cuda.get_device_properties(0) | |
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) | |
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) | |
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") | |
print(f"{start_gpu_memory} GB of memory reserved.") | |
input("Let's train? Press Enter to start or Ctrl+C to cancel...") | |
trainer_stats = trainer.train() | |
print("##########################") | |
print("### Training complete! ###") | |
print("##########################") | |
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) | |
used_memory_for_lora = round(used_memory - start_gpu_memory, 3) | |
used_percentage = round(used_memory / max_memory * 100, 3) | |
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3) | |
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.") | |
print( | |
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training." | |
) | |
print(f"Peak reserved memory = {used_memory} GB.") | |
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.") | |
print(f"Peak reserved memory % of max memory = {used_percentage} %.") | |
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.") | |
print("##########################") | |
print("### Time for inference to test your trained model ###") | |
test_prompt = str(input('Enter prompt (2 + 2): ') or "2 + 2") | |
tokenizer = get_chat_template( | |
tokenizer, | |
chat_template = des_chat_template, | |
) | |
FastLanguageModel.for_inference(model) | |
messages = [ | |
{"role": "user", "content": f"{test_prompt}"}, | |
] | |
inputs = tokenizer.apply_chat_template( | |
messages, | |
tokenize = True, | |
add_generation_prompt = True, # Must add for generation | |
return_tensors = "pt", | |
).to("cuda") | |
text_streamer = TextStreamer(tokenizer, skip_prompt = True) | |
_ = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 128, | |
use_cache = True, temperature = 1.5, min_p = 0.1) | |
print("######################") | |
print("### Save the lora ###") | |
print("######################") | |
yes_choices = ['yes', 'y'] | |
want_save_lora = input('Save lora (yes/no): ') | |
if want_save_lora.lower() in yes_choices: | |
model.save_pretrained_merged("lora_model", tokenizer, save_method = "lora",) | |
pass | |
want_push_lora = input('Push lora to hub (yes/no): ') | |
if want_push_lora.lower() in yes_choices: | |
username_modelname = str(input('your_name/lora_model: ') or "lora_model") | |
write_token = str(input('Enter HF token ') or "") | |
model.push_to_hub_merged(username_modelname, tokenizer, save_method = "lora", token = write_token) | |
pass | |
print("######################") | |
print("### Save the model ###") | |
print("######################") | |
print("### Merge to 16bit ###") | |
want_save_model = input('Save model (yes/no): ') | |
if want_save_model.lower() in yes_choices: | |
model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",) | |
pass | |
want_push_model = input('Push model to hub (yes/no): ') | |
if want_push_model.lower() in yes_choices: | |
username_modelname = str(input('your_name/lora_model: ') or "lora_model") | |
write_token = str(input('Enter HF token ') or "") | |
model.push_to_hub_merged(username_modelname, tokenizer, save_method = "merged_16bit", token = write_token) | |
pass | |
print("######################") | |
print("### Save to GGUF ###") | |
print("######################") | |
# choose yours from list here https://docs.unsloth.ai/basics/running-and-saving-models/saving-to-gguf | |
quantization_variants = ["q4_k_m", "q5_k_m", "q6_k", "q8_0",] | |
print('Default variants: "q4_k_m", "q5_k_m", "q6_k", "q8_0"') | |
want_save_gguf = input('Save model (yes/no): ') | |
if want_save_gguf.lower() in yes_choices: | |
model.save_pretrained_gguf("model", tokenizer, quantization_method = quantization_variants) | |
pass | |
want_push_gguf = input('Push model to hub (yes/no): ') | |
if want_push_gguf.lower() in yes_choices: | |
username_modelname = str(input('your_name/lora_model: ') or "lora_model") | |
write_token = str(input('Enter HF token ') or "") | |
model.push_to_hub_gguf(username_modelname, tokenizer, quantization_method = quantization_variants, token = write_token) | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment