Created
October 15, 2023 20:28
LightGlue for Zero shot image classification
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# streamlit_app.py | |
import os | |
import streamlit as st | |
import glob | |
import torch | |
from pathlib import Path | |
from lightglue import viz2d | |
from lightglue import LightGlue, SuperPoint | |
from lightglue.utils import load_image, rbd | |
from PIL import Image | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
import pickle | |
st.title("Image Matcher using LightGlue") | |
# Setup | |
torch.set_grad_enabled(False) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
extractor = SuperPoint(max_num_keypoints=2048).eval().to(device) | |
matcher = LightGlue(features="superpoint").eval().to(device) | |
parentfolder = "../all_cropped_icons" | |
all_img_files = glob.glob(parentfolder+"/*") | |
# Load all images and compute features | |
feat_per_img = [] | |
feat_per_img_dict = dict() | |
img2ix = dict() | |
# pickle the features | |
print("Computing features for all images...") | |
featurefile = 'feat_per_img.pickle' | |
if os.path.exists(featurefile): | |
with open(featurefile, 'rb') as handle: | |
feat_per_img = pickle.load(handle) | |
print("Loaded features from pickle file") | |
else: | |
for idx, imgpath in tqdm(enumerate(all_img_files)): | |
img = load_image(imgpath) | |
feats = extractor.extract(img.to(device)) | |
feat_per_img.append([feats, imgpath, img]) | |
feat_per_img_dict[Path(imgpath).stem] = feats | |
img2ix[Path(imgpath).name] = idx | |
with open(featurefile, 'wb') as handle: | |
pickle.dump(feat_per_img, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
print("Saved features to pickle file") | |
# Function to find top-k matches for the query image | |
def find_top_k_matches(queryimg, k=5): | |
queryfeats = extractor.extract(queryimg.to(device)) | |
all_matches = [] | |
for parseix, (feats, imgpath, img) in enumerate(feat_per_img): | |
matches01 = matcher({"image0": queryfeats, "image1": feats}) | |
feats0, feats, matches01 = [ | |
rbd(x) for x in [queryfeats, feats, matches01] | |
] # remove batch dimension | |
num_matches = list(matches01["matches"].shape)[0] | |
all_matches.append((num_matches, imgpath, img, feats0, feats, matches01)) | |
# Sort by number of matches and return top-k | |
all_matches.sort(key=lambda x: x[0], reverse=True) | |
return all_matches[:k] | |
# Display all image names for selection | |
image_names = [Path(imgpath).name for imgpath in all_img_files] | |
# selected_image_name = st.slider("Select an image:", image_names) | |
selected_image_index = st.slider("Select an image:", 0, len(image_names) - 1) | |
selected_image_name = image_names[selected_image_index] | |
output_dir = Path("output_images/") | |
# create output dir and ignore if exists | |
output_dir.mkdir(parents=True, exist_ok=True) | |
# Load the selected image | |
queryimg = load_image(os.path.join(parentfolder, selected_image_name)) | |
st.image(Image.open(os.path.join(parentfolder, selected_image_name)).convert("RGB"), | |
caption=f'Selected Image: {selected_image_name}') | |
# Find top-k matches and display them | |
top_k_matches = find_top_k_matches(queryimg, k=5) | |
for ix, (num_matches, matched_imgpath, matched_img, feats0, feats, matches01) in enumerate(top_k_matches): | |
kpts0, kpts1, matches = feats0["keypoints"], feats["keypoints"], matches01["matches"] | |
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]] | |
# Plotting and displaying | |
axes = viz2d.plot_images([queryimg, matched_img]) | |
viz2d.plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2) | |
viz2d.add_text(0, f'Stop after {matches01["stop"]} layers', fs=20) | |
kpc0, kpc1 = viz2d.cm_prune(matches01["prune0"]), viz2d.cm_prune(matches01["prune1"]) | |
viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=10) | |
# Save the figure | |
output_path = os.path.join("output_images/", f"viz2d_{Path(matched_imgpath).stem}_{ix}.png") | |
plt.savefig(output_path) | |
plt.close() | |
# Display the saved image in Streamlit or provide a link to download | |
st.image(output_path, caption=f'Matched Image with {num_matches} Keypoints from {matched_imgpath}', use_column_width=True) | |
st.write(f"[Download the visualization]({output_path})") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment