Last active
April 18, 2025 12:11
-
-
Save Mazyod/9fd60205a1322019475bfb039602ba1a to your computer and use it in GitHub Desktop.
Script for chatting with Qwen2.5-VL. Supports CUDA, MPS (Apple Silicon), and CPU backends. Basic vision support as well.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
from qwen_vl_utils import process_vision_info | |
from rich.console import Console | |
from rich.markdown import Markdown | |
from rich.panel import Panel | |
from rich.progress import Progress, SpinnerColumn, TextColumn | |
from rich.prompt import Confirm, Prompt | |
from transformers import ( | |
AutoProcessor, | |
Qwen2_5_VLForConditionalGeneration, | |
) | |
# Initialize rich console as a global object for use throughout the program | |
console = Console() | |
class QwenInference: | |
def __init__(self): | |
self.device = self.detect_device() | |
self.model = None | |
self.processor = None | |
self.chat_history = [] | |
self.mps_device = None | |
def detect_device(self): | |
# Check for CUDA or MPS (Apple Silicon) availability | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif hasattr(torch, "mps") and torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
return device | |
def load_model(self, model_path): | |
# Load the model with appropriate settings for the device | |
torch_dtype = torch.float16 # bfloat16 is better but not all devices support it | |
# For MPS, we need to ensure compatibility | |
if self.device == "mps": | |
# MPS-specific settings | |
torch.mps.set_per_process_memory_fraction( | |
0.99 | |
) # Allow using 99% of available memory | |
if hasattr(torch.mps, "empty_cache"): | |
torch.mps.empty_cache() # Clear cache before loading | |
# Set the mps_device for future use | |
self.mps_device = torch.device("mps") | |
# Load the model with appropriate device configuration | |
if self.device == "mps": | |
# For MPS, we need to explicitly move the model after loading | |
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
model_path, | |
torch_dtype=torch_dtype, | |
) | |
# Explicitly move model to MPS device | |
self.model = self.model.to(self.mps_device) | |
else: | |
# For CUDA or CPU, use automatic device mapping | |
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
model_path, torch_dtype=torch_dtype, device_map="auto" | |
) | |
if self.device == "cuda": | |
self.mps_device = self.model.device | |
# Load processor | |
self.processor = AutoProcessor.from_pretrained(model_path) | |
return self.model, self.processor | |
def build_messages(self, user_input, use_history=True, custom_text=None): | |
"""Build messages for model input, handling both text and image inputs.""" | |
# Handle image input | |
if user_input.startswith("image:"): | |
image_path = user_input[6:].strip() | |
if not os.path.exists(image_path): | |
return None, f"Image file not found: {image_path}" | |
# Create messages with image | |
text_prompt = custom_text or "Describe this image." | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image", | |
"image": f"file://{image_path}", | |
}, | |
{"type": "text", "text": text_prompt}, | |
], | |
} | |
] | |
return messages, None | |
else: | |
# Text-only messages | |
messages = [{"role": "user", "content": user_input}] | |
# Add chat history if available and requested | |
if use_history and self.chat_history: | |
messages = self.chat_history + messages | |
return messages, None | |
def prepare_inference_inputs(self, messages): | |
# Prepare for inference | |
text = self.processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
# Process images if present | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = self.processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
# Ensure all tensors are on the same device as the model | |
inputs = { | |
k: v.to(self.mps_device) if isinstance(v, torch.Tensor) else v | |
for k, v in inputs.items() | |
} | |
return inputs | |
def process_model_output(self, generated_ids, inputs): | |
if "input_ids" not in inputs: | |
raise KeyError("Missing 'input_ids' in processor outputs") | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids) :] | |
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) | |
] | |
output_text = self.processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False, | |
)[0] | |
# Verify output is not empty | |
if not output_text or output_text.isspace(): | |
return None | |
return output_text | |
def generate_response(self, inputs): | |
# Configure generation parameters | |
generation_params = { | |
"max_new_tokens": 512, | |
"temperature": 0.7, | |
"top_p": 0.9, | |
"do_sample": True, | |
} | |
# For MPS, we'll use a more memory-efficient approach | |
if self.device == "mps": | |
# Clear cache before generation | |
if hasattr(torch.mps, "empty_cache"): | |
torch.mps.empty_cache() | |
# Use more conservative parameters for MPS | |
generation_params["max_new_tokens"] = 256 # Reduce for memory efficiency | |
# Additional settings for MPS stability | |
generation_params["num_beams"] = 1 # Disable beam search | |
generation_params["use_cache"] = True # Enable KV caching for speed | |
# Generate the response with error handling | |
try: | |
return self.model.generate(**inputs, **generation_params) | |
except torch.cuda.OutOfMemoryError: | |
console.print( | |
"[yellow]GPU ran out of memory. Trying with more conservative settings...[/yellow]" | |
) | |
# Retry with more conservative settings | |
if hasattr(torch.cuda, "empty_cache"): | |
torch.cuda.empty_cache() | |
generation_params["max_new_tokens"] = 128 | |
generation_params["num_beams"] = 1 | |
return self.model.generate(**inputs, **generation_params) | |
except RuntimeError as e: | |
if "MPS" in str(e) and "memory" in str(e).lower(): | |
console.print( | |
"[yellow]Apple Silicon GPU ran out of memory. Trying with more conservative settings...[/yellow]" | |
) | |
# Retry with more conservative settings | |
if hasattr(torch.mps, "empty_cache"): | |
torch.mps.empty_cache() | |
# Fall back to CPU if needed | |
inputs = { | |
k: v.to("cpu") if isinstance(v, torch.Tensor) else v | |
for k, v in inputs.items() | |
} | |
temp_model = self.model.to("cpu") | |
generation_params["max_new_tokens"] = 64 | |
generation_params["num_beams"] = 1 | |
generated_ids = temp_model.generate(**inputs, **generation_params) | |
# Move model back to original device afterward | |
self.model = self.model.to(self.mps_device) | |
# Return inputs to original device as well | |
inputs = { | |
k: v.to(self.mps_device) if isinstance(v, torch.Tensor) else v | |
for k, v in inputs.items() | |
} | |
return generated_ids | |
else: | |
raise | |
def add_to_chat_history(self, user_message, model_response): | |
"""Add a message pair to the chat history.""" | |
self.chat_history.append(user_message) | |
self.chat_history.append(model_response) | |
# Keep chat history limited to last 10 messages to avoid context length issues | |
if len(self.chat_history) > 10: | |
self.chat_history = self.chat_history[-10:] | |
def clear_chat_history(self): | |
"""Clear the chat history.""" | |
self.chat_history = [] | |
def run_inference(self, messages, user_input=None): | |
"""Process messages and return the model's response.""" | |
try: | |
# Prepare inputs for inference | |
inputs = self.prepare_inference_inputs(messages) | |
# Generate response with the model | |
generated_ids = self.generate_response(inputs) | |
# Process the output | |
output_text = self.process_model_output(generated_ids, inputs) | |
if not output_text: | |
return None, "I couldn't generate a proper response. Let's try again." | |
# Update chat history for context | |
if user_input is not None: | |
# Preserve the original message format (image or text) | |
if isinstance(messages[0]["content"], list): # Image content | |
user_message = {"role": "user", "content": messages[0]["content"]} | |
else: # Text content | |
user_message = {"role": "user", "content": user_input} | |
model_message = {"role": "assistant", "content": output_text} | |
self.add_to_chat_history(user_message, model_message) | |
return output_text, None | |
except Exception as e: | |
return None, f"Error occurred: {e}" | |
def get_available_models(): | |
"""Get available models from the __models__ directory.""" | |
models_path = "__models__" | |
# Check if __models__ directory exists | |
if not os.path.exists(models_path): | |
return [] | |
# List directories in __models__ that likely contain models | |
models = [] | |
for item in os.listdir(models_path): | |
item_path = os.path.join(models_path, item) | |
# Check if it's a directory and likely contains a model | |
if os.path.isdir(item_path) and ( | |
os.path.exists(os.path.join(item_path, "config.json")) | |
or os.path.exists(os.path.join(item_path, "model.safetensors")) | |
or os.path.exists(os.path.join(item_path, "pytorch_model.bin")) | |
): | |
models.append(item) | |
return models | |
def main(): | |
# Print welcome message | |
console.print( | |
Panel.fit( | |
"[bold cyan]Qwen Multimodal Chat Interface[/bold cyan]", | |
border_style="cyan", | |
padding=(1, 2), | |
) | |
) | |
# Create inference instance | |
qwen = QwenInference() | |
if qwen.device == "cpu": | |
console.print( | |
Panel.fit( | |
"[bold yellow]WARNING[/bold yellow]: Neither CUDA nor MPS detected. Running on CPU will be extremely slow for these models.\n" | |
"Consider using a machine with GPU support or Apple Silicon.", | |
border_style="yellow", | |
) | |
) | |
proceed = Confirm.ask("Do you want to continue anyway?") | |
if not proceed: | |
console.print("[yellow]Exiting...[/yellow]") | |
return | |
elif qwen.device == "mps": | |
console.print( | |
Panel.fit( | |
"[bold green]Using Apple Silicon GPU acceleration (MPS).[/bold green]\n" | |
"NOTE: These are large models which may exceed available memory on some Apple Silicon devices.\n" | |
"If you encounter memory errors, you might need to try a smaller model variant.", | |
border_style="green", | |
) | |
) | |
else: | |
console.print( | |
Panel.fit( | |
"[bold green]Using CUDA GPU acceleration.[/bold green]", | |
border_style="green", | |
) | |
) | |
# Get available models | |
available_models = get_available_models() | |
if not available_models: | |
console.print( | |
Panel.fit( | |
"[bold red]No models found in the __models__ directory.[/bold red]\n" | |
"Please download at least one model and place it in the __models__ directory.", | |
border_style="red", | |
) | |
) | |
return | |
# Let user choose a model | |
console.print("[bold]Available models:[/bold]") | |
for i, model_name in enumerate(available_models, 1): | |
console.print(f" {i}. [cyan]{model_name}[/cyan]") | |
# Get user choice | |
while True: | |
choice = Prompt.ask("\nSelect model by number", default="1") | |
try: | |
choice_idx = int(choice) - 1 | |
if 0 <= choice_idx < len(available_models): | |
selected_model = available_models[choice_idx] | |
break | |
else: | |
console.print( | |
"[bold red]Invalid selection. Please try again.[/bold red]" | |
) | |
except ValueError: | |
console.print("[bold red]Please enter a number.[/bold red]") | |
console.print(f"[bold green]Selected model:[/bold green] {selected_model}") | |
# Set model path | |
model_path = f"__models__/{selected_model}" | |
with Progress( | |
SpinnerColumn(), | |
TextColumn("[bold blue]Loading model...[/bold blue]"), | |
console=console, | |
) as progress: | |
task = progress.add_task("Loading", total=None) | |
try: | |
qwen.load_model(model_path) | |
progress.update(task, completed=True) | |
console.print("[bold green]Model loaded successfully![/bold green]") | |
except Exception as e: | |
progress.update(task, completed=True) | |
console.print(f"[bold red]Error loading model:[/bold red] {e}") | |
console.print( | |
"[yellow]Make sure you have downloaded the model correctly using huggingface-cli.[/yellow]" | |
) | |
return | |
# Show instructions | |
console.print( | |
Panel.fit( | |
"[bold]Chat Session Instructions:[/bold]\n" | |
"- Type [cyan]exit[/cyan] or [cyan]quit[/cyan] to end the conversation\n" | |
"- To use an image, type [cyan]image:[/cyan] followed by the path to the image\n" | |
"- Example: [cyan]image:/path/to/your/image.jpg[/cyan]", | |
title="Instructions", | |
border_style="cyan", | |
) | |
) | |
# Chat loop | |
while True: | |
# Get user input | |
user_input = Prompt.ask("\n[bold green]You[/bold green]") | |
if user_input.lower() in ["exit", "quit"]: | |
console.print("[bold cyan]Exiting chat. Goodbye![/bold cyan]") | |
break | |
# Build messages based on input type | |
if user_input.startswith("image:"): | |
# First build basic image messages | |
messages, error = qwen.build_messages(user_input, use_history=False) | |
if messages is None: | |
console.print(f"[bold red]{error}[/bold red]") | |
continue | |
# Ask for specific question about the image | |
text_input = Prompt.ask( | |
"[bold]What would you like to ask about this image?[/bold]" | |
) | |
# Rebuild messages with the custom text | |
messages, _ = qwen.build_messages( | |
user_input, use_history=False, custom_text=text_input | |
) | |
else: | |
# Build text messages | |
messages, error = qwen.build_messages(user_input) | |
if messages is None: | |
console.print(f"[bold red]{error}[/bold red]") | |
continue | |
# Common inference processing for all input types | |
with Progress( | |
SpinnerColumn(), | |
TextColumn("[bold blue]Qwen is thinking...[/bold blue]"), | |
console=console, | |
transient=True, | |
) as progress: | |
task = progress.add_task("Thinking", total=None) | |
# Process message through the inference pipeline | |
output_text, error = qwen.run_inference(messages, user_input) | |
progress.update(task, completed=True) | |
# Handle response | |
if output_text is None: | |
console.print(f"\n[bold yellow]Qwen:[/bold yellow] {error}") | |
continue | |
# Display response as markdown for better formatting | |
console.print("\n[bold blue]Qwen:[/bold blue]", end=" ") | |
console.print(Markdown(output_text)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment