Skip to content

Instantly share code, notes, and snippets.

@gin-melodic
Created December 21, 2023 02:30
Show Gist options
  • Save gin-melodic/8e0063be2a30784d5477d01761798e61 to your computer and use it in GitHub Desktop.
Save gin-melodic/8e0063be2a30784d5477d01761798e61 to your computer and use it in GitHub Desktop.
Tips of image processing with Sketchpad on Gradio 4.X
import gradio as gr
import torch
from PIL import Image, ImageOps
from torchvision import transforms
from model import NeuralNetwork
def predict_image(img_data):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load model
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load('./mnist.pth', map_location=device))
model.eval()
img = img_data['composite'] # now the image is RGBA
# convert to rgb
img_rgb = Image.new("RGB", img.size, (255, 255, 255))
img_rgb.paste(img, mask=img.split()[3])
img_rgb = img_rgb.resize((28, 28)) # resize to 28x28
img_rgb = ImageOps.invert(img_rgb)
trans = transforms.Compose([
transforms.Grayscale(1),
transforms.ToTensor()
])
img_tensor = trans(img_rgb).unsqueeze(0).to(device)
# predict
with torch.no_grad():
output = model(img_tensor)
probs = torch.nn.functional.softmax(output[0], 0)
predicted = torch.argmax(probs).item()
confidence = probs[predicted].item() # probability of the predicted label
return predicted, f"{confidence * 100:.2f}%"
pad = gr.Sketchpad(type="pil", image_mode="RGBA")
iface = gr.Interface(fn=predict_image, inputs=pad, outputs=["label", "text"])
iface.launch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment