Created
June 19, 2023 05:13
-
-
Save cyberfox/566948b2c59a8f91b8c2bb888ff0e38c to your computer and use it in GitHub Desktop.
Patch to support training LoRA for StarCoder-based models on Oobabooga's text-generation-webui
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/modules/training.py b/modules/training.py | |
index 75ba82c..c90d823 100644 | |
--- a/modules/training.py | |
+++ b/modules/training.py | |
@@ -30,12 +30,14 @@ try: | |
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES} | |
except: | |
standard_modules = ["q_proj", "v_proj"] | |
- model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"]} | |
+ model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"], | |
+ "gpt_bigcode": ["c_attn", "c_proj", "c_fc"]} | |
MODEL_CLASSES = { | |
"LlamaForCausalLM": "llama", | |
"OPTForCausalLM": "opt", | |
"GPTJForCausalLM": "gptj", | |
- "GPTNeoXForCausalLM": "gpt_neox" | |
+ "GPTNeoXForCausalLM": "gpt_neox", | |
+ "GPTBigCodeForCausalLM": "gpt_bigcode" | |
} | |
@@ -423,7 +425,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch | |
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps), | |
num_train_epochs=epochs, | |
learning_rate=actual_lr, | |
- fp16=False if shared.args.cpu else True, | |
+ fp16=False,# if shared.args.cpu else True, | |
optim=optimizer, | |
logging_steps=5, | |
evaluation_strategy="steps" if eval_data is not None else "no", | |
@@ -434,7 +436,8 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch | |
load_best_model_at_end=eval_data is not None, | |
# TODO: Enable multi-device support | |
ddp_find_unused_parameters=None, | |
- no_cuda=shared.args.cpu | |
+ no_cuda=shared.args.cpu, | |
+ report_to="wandb" | |
), | |
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False), | |
callbacks=list([Callbacks()]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment