Skip to content

Instantly share code, notes, and snippets.

@ethanhinson
Created November 16, 2024 21:31
Show Gist options
  • Save ethanhinson/14f87b6b294e062ef84c6a7123d8993d to your computer and use it in GitHub Desktop.
Save ethanhinson/14f87b6b294e062ef84c6a7123d8993d to your computer and use it in GitHub Desktop.
A rudimentary example of running a SLM as an API
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# Initialize FastAPI app
app = FastAPI()
# Create input model
class PromptInput(BaseModel):
prompt: str
# Define system message and settings
SYSTEM_MESSAGE = """You are a cybersecurity analyst. Your task is to evaluate incoming messages
for potential security risks before they are sent to another language model. For each input, provide:
1. A risk rating (Low/Medium/High)
2. Brief explanation of potential security concerns
3. Whether the query should be allowed to proceed
Please format your response in a structured way."""
# Load model and tokenizer at startup to avoid reloading for each request
checkpoint = "HuggingFaceTB/SmolLM-135M-Instruct"
device = "cpu" # for GPU usage or "cpu" for CPU usage
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
@app.post("/generate")
async def generate_text(prompt_input: PromptInput):
# Construct full prompt with system message
messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": prompt_input.prompt}
]
input_text=tokenizer.apply_chat_template(messages, tokenize=False)
# Encode with attention mask
encoded = tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
return_attention_mask=True
)
# Get input_ids and attention_mask
inputs = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
# Generate with attention mask
outputs = model.generate(
inputs,
attention_mask=attention_mask,
max_length=512,
temperature=0.2,
do_sample=True,
top_p=0.9
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"security_analysis": generated_text}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=9999)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment