Skip to content

Instantly share code, notes, and snippets.

@seanghay
Created February 6, 2023 11:39
Show Gist options
  • Save seanghay/4a627342be9d14e08b09ef95078575bf to your computer and use it in GitHub Desktop.
Save seanghay/4a627342be9d14e08b09ef95078575bf to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# encoding: utf-8
import time
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import json
from flask import Flask, request
app = Flask(__name__)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
def predict(image):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item();
return model.config.id2label[predicted_class_idx];
@app.route('/', methods = ['POST'])
def index():
f = request.files['file']
image = Image.open(f)
return json.dumps({ 'result': predict(image) })
app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment