-
-
Save asehmi/9b6559e1663941cbe36f90db92d5815a to your computer and use it in GitHub Desktop.
basic captioning example using the lavis llibrary
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
caption_image.py - basic captioning example using lavis | |
usage: caption_image.py [-h] -i IMAGE_PATH [-m MODEL_TYPE] [-d DEVICE] [-v VERBOSE] | |
# lavis | |
https://github.com/salesforce/LAVIS | |
""" | |
# pip install git+https://github.com/salesforce/LAVIS.git -q | |
import argparse | |
import logging | |
import time | |
from pathlib import Path | |
import re | |
import requests | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
import PIL | |
import torch | |
from lavis.models import load_model_and_preprocess | |
from PIL import Image | |
def load_image(impath: str or Path) -> PIL.Image: | |
""" | |
load_image - load image from path | |
:param strorPath impath: path to image | |
:return PIL.Image: image object | |
""" | |
# check if impath is a URL | |
if re.match(r"^https?://", impath): | |
return Image.open(requests.get(impath, stream=True).raw).convert("RGB") | |
else: | |
impath = Path(impath) | |
return Image.open(impath).convert("RGB") | |
def load_and_caption_image( | |
impath: str or Path, | |
model_type: str = "base_coco", | |
device: str = None, | |
verbose: bool = False, | |
): | |
""" | |
load_and_caption_image - load image and caption it using lavis | |
:param strorPath impath: path to image | |
:param str model_type: _description_, defaults to "base_coco" | |
:param str device: _description_, defaults to None | |
:param bool verbose: _description_, defaults to False | |
:return _type_: _description_ | |
""" | |
logger = logging.getLogger(__name__) | |
raw_image = load_image(impath) | |
device = ( | |
(torch.device("cuda" if torch.cuda.is_available() else "cpu")) | |
if device is None | |
else device | |
) | |
logger.info(f"loading model {model_type} on device {device} ...") | |
model, vis_processors, _ = load_model_and_preprocess( | |
name="blip_caption", model_type=model_type, is_eval=True, device=device | |
) | |
if verbose: | |
logger.info(f"Loaded model:\t{model_type}") | |
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) | |
logger.info("running inference ...") | |
st = time.perf_counter() | |
caption = model.generate({"image": image})[0] | |
rt = round(time.perf_counter() - st, 2) | |
logger.info(f"Finished inference in {rt} seconds, caption:\t{caption}") | |
return caption | |
def get_parser(): | |
""" | |
get_parser - get parser for command line arguments | |
:return argparse.ArgumentParser: parser | |
""" | |
parser = argparse.ArgumentParser( | |
description="lavis_basic_captioning.py - basic captioning example using lavis" | |
) | |
parser.add_argument( | |
"-i", | |
"--image_path", | |
type=str, | |
required=True, | |
help="path to image to caption", | |
) | |
parser.add_argument( | |
"-m", | |
"--model_type", | |
type=str, | |
default="base_coco", | |
help="model type to use", | |
) | |
parser.add_argument( | |
"-d", | |
"--device", | |
type=str, | |
default=None, | |
help="device to use", | |
) | |
parser.add_argument( | |
"-v", | |
"--verbose", | |
type=bool, | |
default=False, | |
help="verbose", | |
) | |
return parser | |
if __name__ == "__main__": | |
parser = get_parser() | |
args = get_parser().parse_args() | |
caption = load_and_caption_image( | |
impath=args.image_path, | |
model_type=args.model_type, | |
device=args.device, | |
verbose=args.verbose, | |
) | |
print(caption) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment