Created
September 3, 2024 06:19
-
-
Save bewestphal/5c9c4313df33cebb37ea9baa6a7fc339 to your computer and use it in GitHub Desktop.
FastAPI Huggingface LLM Server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import aiohttp | |
import asyncio | |
import json | |
# URL of your FastAPI server | |
url = "http://localhost:8000/generate" | |
# Parameters for the text generation | |
params = { | |
"prompt": "Explain quantum physics in simple terms" | |
} | |
async def generate_text(): | |
async with aiohttp.ClientSession() as session: | |
try: | |
async with session.post(url, json=params) as response: | |
if response.status == 200: | |
print("Received response. Generated text:\n\n") | |
async for chunk in response.content.iter_any(): | |
try: | |
data = json.loads(chunk) | |
if 'generated_text' in data: | |
print(data['generated_text'], end='', flush=True) | |
except json.JSONDecodeError: | |
print(chunk.decode(), end='', flush=True) | |
except TypeError as e: | |
print("ERROR", e) | |
print("\n\nGeneration complete.") | |
else: | |
print(f"Error: {response.status}") | |
print(await response.text()) | |
except aiohttp.ClientError as e: | |
print(f"Request failed: {e}") | |
# Run the async function | |
asyncio.run(generate_text()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI | |
import asyncio | |
from pydantic import BaseModel | |
from starlette.responses import StreamingResponse | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from threading import Thread | |
app = FastAPI() | |
# Load AWQ Mistral model and tokenizer | |
model_name = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generator = pipeline("text-generation", model=model_name, tokenizer=tokenizer, streamer=streamer, device_map="auto") | |
class GenerationRequest(BaseModel): | |
prompt: str | |
max_new_tokens: int = 10_000 | |
temperature: float = 0.7 | |
top_p: float = 0.95 | |
top_k: int = 40 | |
repetition_penalty: float = 1.1 | |
async def generate_text(request: GenerationRequest): | |
# Huggingface is not native-async so spawn a thread | |
thread = Thread(target=generator, args=(request.prompt,), kwargs=dict( | |
temperature=request.temperature, | |
top_p=request.top_p, | |
top_k=request.top_k, | |
repetition_penalty=request.repetition_penalty, | |
max_new_tokens=request.max_new_tokens, | |
num_return_sequences=1, | |
)) | |
thread.start() | |
for token in streamer: | |
yield token.encode('utf-8') | |
await asyncio.sleep(0) # yield control to the event loop | |
thread.join() | |
@app.post("/generate") | |
async def generate(request: GenerationRequest): | |
return StreamingResponse(generate_text(request)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment