Created
August 4, 2022 14:29
-
-
Save malteos/c194368594e16439c101b7bf27195fd1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import torch | |
from transformers.models.auto import AutoModelForCausalLM | |
LAYER_FILE_PREFIX = 'layer_' | |
MODEL_FILE_PREFIX = 'model_' | |
EMBEDDING_LAYER_INDEX = 1 | |
TRANSFORMER_LAYER_OFFSET = 3 | |
IGNORE_HF_KEYS = {'lm_head.weight'} # tied weights in megatron | |
# Bloom ########### | |
HF_BLOOM_STATE_DICT_MAPPINGS = { | |
# ds state dict key => HF state dict key + convert operation | |
'word_embeddings.weight': { | |
'hf_k': 'transformer.word_embeddings.weight', | |
}, | |
'word_embeddings.norm.weight': { | |
'hf_k': 'transformer.word_embeddings_layernorm.weight', | |
}, | |
'word_embeddings.norm.bias': { | |
'hf_k': 'transformer.word_embeddings_layernorm.bias', | |
}, | |
'input_layernorm.weight': { | |
'hf_k': 'transformer.h.<LAYER>.input_layernorm.weight' | |
}, | |
'input_layernorm.bias': { | |
'hf_k': 'transformer.h.<LAYER>.input_layernorm.bias' | |
}, | |
'self_attention.query_key_value.weight': { | |
'hf_k': 'transformer.h.<LAYER>.self_attention.query_key_value.weight', | |
}, | |
'self_attention.query_key_value.bias': { | |
'hf_k': 'transformer.h.<LAYER>.self_attention.query_key_value.bias', | |
}, | |
'self_attention.dense.weight': { | |
'hf_k': 'transformer.h.<LAYER>.self_attention.dense.weight', | |
'row_parallel': True, | |
}, | |
'self_attention.dense.bias': { | |
'hf_k': 'transformer.h.<LAYER>.self_attention.dense.bias', | |
}, | |
'post_attention_layernorm.weight': { | |
'hf_k': 'transformer.h.<LAYER>.post_attention_layernorm.weight', | |
}, | |
'post_attention_layernorm.bias': { | |
'hf_k': 'transformer.h.<LAYER>.post_attention_layernorm.bias', | |
}, | |
'mlp.dense_h_to_4h.weight': { | |
'hf_k': 'transformer.h.<LAYER>.mlp.dense_h_to_4h.weight', | |
}, | |
'mlp.dense_h_to_4h.bias': { | |
'hf_k': 'transformer.h.<LAYER>.mlp.dense_h_to_4h.bias', | |
}, | |
'mlp.dense_4h_to_h.weight': { | |
'hf_k': 'transformer.h.<LAYER>.mlp.dense_4h_to_h.weight', | |
'row_parallel': True, | |
}, | |
'mlp.dense_4h_to_h.bias': { | |
'hf_k': 'transformer.h.<LAYER>.mlp.dense_4h_to_h.bias', | |
}, | |
'bias': { | |
'hf_k': 'transformer.ln_f.bias' | |
}, | |
'weight': { | |
'hf_k': 'transformer.ln_f.weight' | |
}, | |
} | |
def main(): | |
""" | |
Override an existing deepspeed checkpoint with weights from a pretrained HF model. | |
Example usage: | |
python convert_hf_to_deepspeed.py ${DATASETS_DIR}/huggingface_transformers/pytorch/bloom-1b3 \ | |
${EXP_DIR}/tr1/dummy_checkpoints/global_step0 --bf16 | |
Supported model types: bloom | |
:return: | |
""" | |
# Create the argument parser. | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"hf_model_name_or_path", | |
type=str, | |
help="Path to the pretrained HuggingFace model", | |
) | |
parser.add_argument( | |
"checkpoint_dir", | |
type=str, | |
help="Path to the DeepSpeed checkpoint directory", | |
) | |
parser.add_argument( | |
"tp", | |
type=int, | |
default=1, | |
help="Tensor parallelism", | |
) | |
parser.add_argument("--bf16", action="store_true") | |
parser.add_argument("--dry-run", action="store_true") | |
args = parser.parse_args() | |
checkpoint_dir = args.checkpoint_dir | |
hf_model_name_or_path = args.hf_model_name_or_path | |
bf16 = args.bf16 | |
dry_run = args.dry_run | |
tp = args.tp | |
print(f'Loading pretrained HF model from {hf_model_name_or_path} into {checkpoint_dir} ...') | |
if not os.path.exists(checkpoint_dir): | |
raise FileNotFoundError(f'Checkpoint dir does not exists: {checkpoint_dir}') | |
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_name_or_path) | |
hf_config = hf_model.config | |
if hf_config.model_type == 'bloom': | |
original_to_hf_mapping = HF_BLOOM_STATE_DICT_MAPPINGS | |
else: | |
raise ValueError(f'Unsupported model type: {hf_config.model_typ}') | |
FINAL_LAYER_NORM_INDEX = hf_config.n_layer + TRANSFORMER_LAYER_OFFSET + 1 | |
if bf16: | |
print('Converting HF model to bf16') | |
hf_model = hf_model.bfloat16() | |
hf_sd = hf_model.state_dict() | |
matched_hf_keys = set() | |
# Iterate over files in checkpoint_dir | |
for fn in sorted(os.listdir(checkpoint_dir)): | |
fp = os.path.join(checkpoint_dir, fn) | |
if os.path.isfile(fp) and fn.endswith('model_states.pt') and fn.startswith(LAYER_FILE_PREFIX): | |
fn_split = fn.split('-') | |
layer_idx = int(fn_split[0][len(LAYER_FILE_PREFIX):]) | |
model_idx = int(fn_split[1][len(MODEL_FILE_PREFIX):]) | |
hf_layer_idx = None | |
# Determine layer type | |
if layer_idx == EMBEDDING_LAYER_INDEX: | |
layer_type = 'embedding' | |
elif layer_idx == FINAL_LAYER_NORM_INDEX: | |
layer_type = 'final_layer_norm' | |
else: | |
# transformer layer | |
hf_layer_idx = layer_idx - TRANSFORMER_LAYER_OFFSET | |
layer_type = 'transformer' | |
print(f'{layer_type=} {layer_idx} => {hf_layer_idx} {model_idx=} ') | |
# Load state dict from disk to CPU | |
sd = torch.load(fp, map_location="cpu") | |
for original_k, original_v in sd.items(): | |
if original_k not in original_to_hf_mapping: | |
raise ValueError(f'There is not mapping for {original_k=}') | |
hf_mapping = original_to_hf_mapping[original_k] | |
hf_k = hf_mapping['hf_k'] | |
# replace layer index | |
hf_k = hf_k.replace('<LAYER>', str(hf_layer_idx)) | |
# get value | |
hf_v = hf_sd[hf_k] | |
if tp > 1: | |
# Tensor parallelism enabled | |
if original_v.shape != hf_v.shape: # no partition when needed | |
hf_shape = hf_v.shape | |
if 'row_parallel' in hf_mapping and hf_mapping['row_parallel']: | |
# row parallel | |
single_partition_size = int(hf_shape[1] / tp) | |
partition_v = hf_v[:, model_idx * single_partition_size:(model_idx + 1) * single_partition_size] | |
else: | |
# column parallel | |
single_partition_size = int(hf_shape[0] / tp) | |
partition_v = hf_v[model_idx * single_partition_size:(model_idx + 1) * single_partition_size] | |
print(f' - partitioned from {hf_shape} to {partition_v.shape} ({tp=})') | |
hf_v = partition_v | |
# check if value shapes match | |
if original_v.shape != hf_v.shape: | |
raise ValueError(f'Shapes do not match: {original_k} = {original_v.shape}; {hf_k} = {hf_v.shape}') | |
# check if types are matching | |
if original_v.dtype != hf_v.dtype: | |
raise ValueError( | |
f'Data types do not match: {original_k} = {original_v.dtype}; {hf_k} = {hf_v.dtype}') | |
print('matched ', original_k, ' = ', hf_k, '; ') | |
matched_hf_keys.add(hf_k) | |
# replace in state dict | |
sd[original_k] = hf_v | |
# save to disk | |
if dry_run: | |
print('skip saving') | |
else: | |
torch.save(sd, fp) | |
print('saved to ', fp) | |
print() | |
# Check for not matched keys | |
not_matched_hf_keys = set(hf_sd.keys()) - matched_hf_keys - IGNORE_HF_KEYS | |
if len(not_matched_hf_keys) > 0: | |
raise ValueError('Not matched HF keys: %s' % not_matched_hf_keys) | |
print('done') | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment