Created
October 29, 2023 11:24
-
-
Save iamironz/411893138a58d54d7a054d65f4eee00f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import sys | |
import os | |
from PIL import Image | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
import torch | |
def log_message(message): | |
print(message) | |
with open("log.txt", "a") as log_file: | |
log_file.write(f"{message}\n") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
log_message(f"Device: {device}") | |
log_message("Loading processor...") | |
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
log_message("Loading model...") | |
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b") | |
log_message("Moving model to device...") | |
model.to(device) | |
log_message("Model loaded") | |
def process_images(image_paths): | |
log_message(f"Processing {len(image_paths)} images") | |
images = [Image.open(image_path) for image_path in image_paths] | |
inputs = processor(images=images, return_tensors="pt").to(device) | |
generated_ids = model.generate(**inputs) | |
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
return generated_texts | |
def main(input_path): | |
log_message(f"Input path: {input_path}") | |
image_paths = [] | |
if os.path.isdir(input_path): | |
log_message("Input path is a directory") | |
for root, dirs, files in os.walk(input_path): | |
for file in files: | |
if file.startswith('.') or file == '.DS_Store': | |
continue | |
image_paths.append(os.path.join(root, file)) | |
captions = process_images(image_paths) | |
for path, caption in zip(image_paths, captions): | |
log_message(f"{path}: {caption.strip()}") | |
elif os.path.isfile(input_path) and not input_path.endswith('.DS_Store'): | |
log_message("Input path is a file") | |
caption = process_images([input_path])[0] | |
log_message(f"{input_path}: {caption.strip()}") | |
else: | |
log_message("Invalid input path") | |
if __name__ == "__main__": | |
if len(sys.argv) != 2: | |
log_message("Usage: script.py <path_to_file_or_folder>") | |
else: | |
main(sys.argv[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment