Last active
September 21, 2024 10:33
-
-
Save decent-engineer-decent-datascientist/81e04ad86e102eb083416e28150aa2a1 to your computer and use it in GitHub Desktop.
Quick FastAPI wrapper for yolov5
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests as r | |
import json | |
from pprint import pprint | |
# Images | |
dir = 'https://github.com/ultralytics/yolov5/raw/master/data/images/' | |
imgs = [dir + f for f in ('zidane.jpg', 'bus.jpg')] # batched list of images | |
# Send images to endpoint | |
res = r.post("http://localhost:9999/inference", data=json.dumps({'img_list': imgs})) | |
pprint(json.loads(res.text)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from typing import Optional | |
from fastapi import FastAPI | |
from pydantic import BaseModel, Field | |
app = FastAPI() | |
class Image(BaseModel): | |
img_list: Optional[list] = Field(["https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg", | |
"https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg"], | |
title="List of image paths") | |
# Load model, this needs to be rethought but works for now. Maybe have an endpoint that selects and loads the model. | |
model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True) | |
def results_to_json(results): | |
return [ | |
[ | |
{ | |
"class": int(pred[5]), | |
"class_name": model.model.names[int(pred[5])], | |
"normalized_box": pred[:4].tolist(), | |
"confidence": float(pred[4]), | |
} | |
for pred in result | |
] | |
for result in results.xyxyn | |
] | |
@app.get("/") | |
def read_root(): | |
return {"Hello": "World"} | |
@app.post("/inference") | |
def inference_with_path(imgs: Image): | |
return results_to_json(model(imgs.img_list)) | |
if __name__ == '__main__': | |
import uvicorn | |
app_str = 'server:app' | |
uvicorn.run(app_str, host='localhost', port=9999, log_level='info', reload=True, workers=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment