Last active
July 12, 2022 14:59
-
-
Save sxjscience/a86fc9c8e49b10af9495bc30322c3c26 to your computer and use it in GitHub Desktop.
AutoMM Multi-GPU FSDP for mT5-XL
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autogluon.multimodal import MultiModalPredictor | |
from datasets import load_dataset | |
import json | |
import os | |
import time | |
import argparse | |
PAWS_X_LANGUAGE_L = ['en', 'fr', 'es', 'de', 'zh', 'ja', 'ko'] | |
os.makedirs("data_cache", exist_ok=True) | |
train_data = load_dataset("paws-x", name="en", split="train").to_pandas().drop('id', axis=1) | |
tuning_data = load_dataset("paws-x", name="en", split="validation").to_pandas().drop('id', axis=1) | |
test_data_all_languages = [ | |
[lang, load_dataset('paws-x', name=lang, split='test', cache_dir="data_cache").to_pandas().drop('id', axis=1)] | |
for lang in PAWS_X_LANGUAGE_L | |
] | |
label = 'label' | |
backbone = 'google/mt5-xl' | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Process some integers.') | |
parser.add_argument('--prompt',type=str, help='the prompt that may indicate the task.', default='') | |
parser.add_argument('--lr_decay',type=float, help='lr_decay', default=1.0) | |
parser.add_argument('--learning_rate',type=float, help='learning rate', default=1e-03) | |
parser.add_argument('--efficient_finetune',type=str, help='efficient finetuning type', default='lora_norm') | |
parser.add_argument('--pooling_mode', type=str, help='pooling mode', default='mean') | |
parser.add_argument('--seed', default=1) | |
args = parser.parse_args() | |
save_path = f'{backbone}_{args.pooling_mode}_{args.efficient_finetune}_lr{args.learning_rate}_{args.lr_decay}_pawsx_prompt_' | |
train_data['sentence1'] = train_data['sentence1'].apply(lambda ele: args.prompt + ' ' + ele) | |
tuning_data['sentence1'] = tuning_data['sentence1'].apply(lambda ele: args.prompt + ' ' + ele) | |
for i in range(len(test_data_all_languages)): | |
test_data_all_languages[i][1]['sentence1'] = test_data_all_languages[i][1]['sentence1'].apply(lambda ele: args.prompt + ' ' + ele) | |
train_start = time.time() | |
predictor = MultiModalPredictor(label=label, path=save_path).fit( | |
train_data, | |
hyperparameters={ | |
"model.hf_text.checkpoint_name": backbone, | |
"model.hf_text.pooling_mode": args.pooling_mode, | |
"optimization.efficient_finetune": args.efficient_finetune, | |
"optimization.lr_decay": args.lr_decay, | |
"optimization.learning_rate": args.learning_rate, | |
"env.precision": "bf16", | |
"env.strategy": "fsdp", | |
"env.per_gpu_batch_size": 4, | |
}) | |
train_end = time.time() | |
all_lang_scores = dict() | |
for lang, test_data in test_data_all_languages: | |
y_pred = predictor.predict(data=test_data) | |
scores = {'acc': (y_pred == test_data[label]).to_numpy().mean()} | |
y_pred.to_csv(os.path.join(predictor.path, f'prediction_{lang}.csv')) | |
print('lang=', lang, 'scores=', scores) | |
all_lang_scores[lang] = scores | |
with open(os.path.join(predictor.path, f'test_metrics.json'), 'w') as fp: | |
json.dump(all_lang_scores, fp) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
set -x | |
for efficient_finetune in lora_norm lora_bias bit_fit norm_fit | |
do | |
for lr_decay in 0.9 1.0 | |
do | |
for pooling_mode in mean cls | |
do | |
for prompt in "" "paraphrase:" | |
do | |
python3 automm_mt5_xl_multi_gpu_fsdp.py --efficient_finetune ${efficient_finetune} --lr_decay ${lr_decay} --pooling_mode ${pooling_mode} --prompt ${prompt} | |
done | |
done | |
done | |
done |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment