Skip to content

Instantly share code, notes, and snippets.

@iakashpaul
Created August 18, 2023 13:12
Show Gist options
  • Save iakashpaul/bf35135726315b5a170bcc4b413a5c0c to your computer and use it in GitHub Desktop.
Save iakashpaul/bf35135726315b5a170bcc4b413a5c0c to your computer and use it in GitHub Desktop.
Run Flan-UL2(20B) instances on V100s w/ 26GB VRAM per instance, change device_map according to your hardware. All layers are on GPU, additional GPU & CPU RAM specified only for reference
from accelerate import load_checkpoint_and_dispatch
from accelerate import init_empty_weights, infer_auto_device_map
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Config
def load_model_sharded():
model_name="google/flan-ul2"
config = T5Config.from_pretrained(model_name)
tokenizer_1 = AutoTokenizer.from_pretrained(model_name)
max_memory_1={2: "30GiB",1: "10GiB", "cpu": "100GiB"}
with init_empty_weights():
model_1 = T5ForConditionalGeneration(config)
device_map_1 = infer_auto_device_map(model_1, no_split_module_classes=["T5Block"], dtype=torch.float16, max_memory=max_memory_1)
device_map_1['lm_head'] = device_map_1["decoder.embed_tokens"]
model_1 = T5ForConditionalGeneration.from_pretrained("./flan-ul2",cache_dir="./cache/", device_map=device_map_1, load_in_4bit=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment