Created
June 21, 2023 17:40
-
-
Save gradjitta/7f9ec977f1687f0e53f8fb0970c0682d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import fastapi | |
import time | |
from typing import Dict | |
from modal import asgi_app | |
from modal.functions import FunctionCall | |
IMAGE_MODEL_DIR = "/model" | |
web_app = fastapi.FastAPI() | |
def download_mt0(): | |
""" | |
Load model from HF and save to local storage | |
""" | |
print("Downloading model...") | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-xxl-mt") | |
tokenizer.save_pretrained(save_directory=IMAGE_MODEL_DIR, safe_serialization=False) | |
model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/mt0-xxl-mt") | |
model.save_pretrained(save_directory=IMAGE_MODEL_DIR, safe_serialization=False) | |
image = ( | |
Image.micromamba() | |
.micromamba_install( | |
"cudatoolkit=11.7", | |
"cudnn=8.1.0", | |
"cuda-nvcc", | |
"scipy", | |
channels=["conda-forge", "nvidia"], | |
) | |
.apt_install("git") | |
.pip_install( | |
"bitsandbytes==0.39.0", | |
"bitsandbytes-cuda117==0.26.0.post2", | |
"transformers @ git+https://github.com/huggingface/transformers.git", | |
"accelerate @ git+https://github.com/huggingface/accelerate.git", | |
"torch==2.0.0", | |
"torchvision==0.15.1", | |
"sentencepiece==0.1.97", | |
"huggingface_hub==0.14.1", | |
"einops==0.6.1", | |
) | |
.run_function(download_mt0) | |
) | |
stub = Stub(image=image, name="example-mt0-bnb") | |
web_app = fastapi.FastAPI() | |
@stub.cls(gpu=gpu.A100(), timeout=6000 , container_idle_timeout=1200) | |
class MT0Model: | |
def __enter__(self): | |
from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig | |
self.tokenizer = AutoTokenizer.from_pretrained(IMAGE_MODEL_DIR) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(IMAGE_MODEL_DIR, device_map="auto", load_in_8bit=True) | |
print("Loaded model.") | |
@method() | |
def generate(self, prompt: str): | |
inputs = self.tokenizer(prompt, return_tensors="pt", padding= True).to("cuda") | |
start = time.time() | |
outputs = self.model.generate( | |
**inputs, | |
max_length=100, | |
do_sample=True, | |
temperature=0.7, | |
top_p=10, | |
repetition_penalty=1.0, | |
) | |
elapsed = time.time() - start | |
_output_tokens = [x.shape[0] for x in outputs] | |
throughput = sum(_output_tokens) / elapsed | |
out = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return (out, throughput) | |
@stub.function(timeout=1200) | |
@web_endpoint(method="POST") | |
def generate_mt0(input_text: Dict): | |
model = MT0Model() | |
result = model.generate.call(input_text['prompt_text']) | |
return {"results": result} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment