Skip to content

Instantly share code, notes, and snippets.

@nestharus
Created August 17, 2025 16:07
Show Gist options
  • Select an option

  • Save nestharus/438a010593a53cb817276a9f34de6312 to your computer and use it in GitHub Desktop.

Select an option

Save nestharus/438a010593a53cb817276a9f34de6312 to your computer and use it in GitHub Desktop.
Reranker FastAPI Qwen3
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from vllm import LLM, SamplingParams
from vllm.inputs.data import TokensPrompt
from transformers import AutoTokenizer
import math
class RerankRequest(BaseModel):
query: str
documents: List[str]
instruction: Optional[str] = 'Given a web search query, retrieve relevant passages that answer the query'
top_n: Optional[int] = None
class RerankResult(BaseModel):
document: str
score: float
index: int
class RerankResponse(BaseModel):
results: List[RerankResult]
MODEL_NAME = "Qwen/Qwen3-Reranker-8B"
print(f"🚀 Loading model: {MODEL_NAME}")
llm = LLM(
model=MODEL_NAME,
tensor_parallel_size=1, # Adjust for your hardware
max_model_len=8192,
enable_prefix_caching=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
MAX_LENGTH = 8192
SUFFIX_TOKENS = tokenizer.encode(SUFFIX, add_special_tokens=False)
TRUE_TOKEN_ID = tokenizer("yes", add_special_tokens=False).input_ids[0]
FALSE_TOKEN_ID = tokenizer("no", add_special_tokens=False).input_ids[0]
sampling_params = SamplingParams(
temperature=0,
max_tokens=1,
logprobs=20, # Request logprobs for top tokens
allowed_token_ids=[TRUE_TOKEN_ID, FALSE_TOKEN_ID],
)
print("✅ Model loaded successfully.")
app = FastAPI(title="High-Performance Reranker API")
@app.post("/rerank", response_model=RerankResponse)
async def rerank(request: RerankRequest):
messages = []
for doc in request.documents:
text = [
{"role": "system", "content": "Judge whether the Document meets the requirements..."},
{"role": "user", "content": f"<Instruct>: {request.instruction}\n\n<Query>: {request.query}\n\n<Document>: {doc}"}
]
tokenized_message = tokenizer.apply_chat_template(
text, tokenize=True, add_generation_prompt=False, enable_thinking=False
)
final_tokens = tokenized_message[:MAX_LENGTH - len(SUFFIX_TOKENS)] + SUFFIX_TOKENS
messages.append(TokensPrompt(prompt_token_ids=final_tokens))
outputs = await llm.generate(messages, sampling_params, use_tqdm=False)
scores = []
for output in outputs:
final_logprobs = output.outputs[0].logprobs[-1]
true_logit = final_logprobs.get(TRUE_TOKEN_ID, -100)
false_logit = final_logprobs.get(FALSE_TOKEN_ID, -100)
true_score = math.exp(true_logit.logprob if hasattr(true_logit, 'logprob') else true_logit)
false_score = math.exp(false_logit.logprob if hasattr(false_logit, 'logprob') else false_logit)
score = true_score / (true_score + false_score)
scores.append(score)
results = [
RerankResult(document=doc, score=score, index=i)
for i, (doc, score) in enumerate(zip(request.documents, scores))
]
results.sort(key=lambda x: x.score, reverse=True)
if request.top_n:
results = results[:request.top_n]
return RerankResponse(results=results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment