Created
November 16, 2024 21:31
-
-
Save ethanhinson/14f87b6b294e062ef84c6a7123d8993d to your computer and use it in GitHub Desktop.
A rudimentary example of running a SLM as an API
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Create input model | |
class PromptInput(BaseModel): | |
prompt: str | |
# Define system message and settings | |
SYSTEM_MESSAGE = """You are a cybersecurity analyst. Your task is to evaluate incoming messages | |
for potential security risks before they are sent to another language model. For each input, provide: | |
1. A risk rating (Low/Medium/High) | |
2. Brief explanation of potential security concerns | |
3. Whether the query should be allowed to proceed | |
Please format your response in a structured way.""" | |
# Load model and tokenizer at startup to avoid reloading for each request | |
checkpoint = "HuggingFaceTB/SmolLM-135M-Instruct" | |
device = "cpu" # for GPU usage or "cpu" for CPU usage | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) | |
# Add padding token if it doesn't exist | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
@app.post("/generate") | |
async def generate_text(prompt_input: PromptInput): | |
# Construct full prompt with system message | |
messages = [ | |
{"role": "system", "content": SYSTEM_MESSAGE}, | |
{"role": "user", "content": prompt_input.prompt} | |
] | |
input_text=tokenizer.apply_chat_template(messages, tokenize=False) | |
# Encode with attention mask | |
encoded = tokenizer( | |
input_text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_attention_mask=True | |
) | |
# Get input_ids and attention_mask | |
inputs = encoded["input_ids"].to(device) | |
attention_mask = encoded["attention_mask"].to(device) | |
# Generate with attention mask | |
outputs = model.generate( | |
inputs, | |
attention_mask=attention_mask, | |
max_length=512, | |
temperature=0.2, | |
do_sample=True, | |
top_p=0.9 | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return {"security_analysis": generated_text} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=9999) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment