Created
June 7, 2023 07:51
-
-
Save p208p2002/e582202bf1f379ad198deab3c0e14f31 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# $ pip install deepspeed>=0.9.3 | |
# $ deepspeed deepspeed_inference.py | |
import os | |
import deepspeed | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
local_rank = int(os.getenv("LOCAL_RANK", "0")) | |
world_size = int(os.getenv("WORLD_SIZE", "1")) | |
model_name_or_id = "/home/ubuntu/Ziya-LLaMA-13B-v1" | |
model = AutoModelForCausalLM.from_pretrained(model_name_or_id) | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_id) | |
model = deepspeed.init_inference( | |
model, mp_size=world_size, dtype=torch.float16, replace_with_kernel_inject=False | |
) | |
def generate(inputs): | |
"""returns a list of zipped inputs, outputs and number of new tokens""" | |
input_ids = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=False)[ | |
"input_ids" | |
].to(torch.cuda.current_device()) | |
outputs = model.generate(input_ids, max_new_tokens=100, do_sample=False) | |
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return outputs | |
outputs = generate( | |
[ | |
"问:地球环境日益严峻,我们如何减少污染?" | |
"答:" | |
] | |
) | |
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: | |
print("-" * 50) | |
print(outputs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment