Skip to content

Instantly share code, notes, and snippets.

@sxjscience
Last active July 12, 2022 14:59
Show Gist options
  • Save sxjscience/a86fc9c8e49b10af9495bc30322c3c26 to your computer and use it in GitHub Desktop.
Save sxjscience/a86fc9c8e49b10af9495bc30322c3c26 to your computer and use it in GitHub Desktop.
AutoMM Multi-GPU FSDP for mT5-XL
from autogluon.multimodal import MultiModalPredictor
from datasets import load_dataset
import json
import os
import time
import argparse
PAWS_X_LANGUAGE_L = ['en', 'fr', 'es', 'de', 'zh', 'ja', 'ko']
os.makedirs("data_cache", exist_ok=True)
train_data = load_dataset("paws-x", name="en", split="train").to_pandas().drop('id', axis=1)
tuning_data = load_dataset("paws-x", name="en", split="validation").to_pandas().drop('id', axis=1)
test_data_all_languages = [
[lang, load_dataset('paws-x', name=lang, split='test', cache_dir="data_cache").to_pandas().drop('id', axis=1)]
for lang in PAWS_X_LANGUAGE_L
]
label = 'label'
backbone = 'google/mt5-xl'
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--prompt',type=str, help='the prompt that may indicate the task.', default='')
parser.add_argument('--lr_decay',type=float, help='lr_decay', default=1.0)
parser.add_argument('--learning_rate',type=float, help='learning rate', default=1e-03)
parser.add_argument('--efficient_finetune',type=str, help='efficient finetuning type', default='lora_norm')
parser.add_argument('--pooling_mode', type=str, help='pooling mode', default='mean')
parser.add_argument('--seed', default=1)
args = parser.parse_args()
save_path = f'{backbone}_{args.pooling_mode}_{args.efficient_finetune}_lr{args.learning_rate}_{args.lr_decay}_pawsx_prompt_'
train_data['sentence1'] = train_data['sentence1'].apply(lambda ele: args.prompt + ' ' + ele)
tuning_data['sentence1'] = tuning_data['sentence1'].apply(lambda ele: args.prompt + ' ' + ele)
for i in range(len(test_data_all_languages)):
test_data_all_languages[i][1]['sentence1'] = test_data_all_languages[i][1]['sentence1'].apply(lambda ele: args.prompt + ' ' + ele)
train_start = time.time()
predictor = MultiModalPredictor(label=label, path=save_path).fit(
train_data,
hyperparameters={
"model.hf_text.checkpoint_name": backbone,
"model.hf_text.pooling_mode": args.pooling_mode,
"optimization.efficient_finetune": args.efficient_finetune,
"optimization.lr_decay": args.lr_decay,
"optimization.learning_rate": args.learning_rate,
"env.precision": "bf16",
"env.strategy": "fsdp",
"env.per_gpu_batch_size": 4,
})
train_end = time.time()
all_lang_scores = dict()
for lang, test_data in test_data_all_languages:
y_pred = predictor.predict(data=test_data)
scores = {'acc': (y_pred == test_data[label]).to_numpy().mean()}
y_pred.to_csv(os.path.join(predictor.path, f'prediction_{lang}.csv'))
print('lang=', lang, 'scores=', scores)
all_lang_scores[lang] = scores
with open(os.path.join(predictor.path, f'test_metrics.json'), 'w') as fp:
json.dump(all_lang_scores, fp)
set -x
for efficient_finetune in lora_norm lora_bias bit_fit norm_fit
do
for lr_decay in 0.9 1.0
do
for pooling_mode in mean cls
do
for prompt in "" "paraphrase:"
do
python3 automm_mt5_xl_multi_gpu_fsdp.py --efficient_finetune ${efficient_finetune} --lr_decay ${lr_decay} --pooling_mode ${pooling_mode} --prompt ${prompt}
done
done
done
done
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment