Skip to content

Instantly share code, notes, and snippets.

@Mazyod
Last active April 18, 2025 12:11
Show Gist options
  • Save Mazyod/9fd60205a1322019475bfb039602ba1a to your computer and use it in GitHub Desktop.
Save Mazyod/9fd60205a1322019475bfb039602ba1a to your computer and use it in GitHub Desktop.
Script for chatting with Qwen2.5-VL. Supports CUDA, MPS (Apple Silicon), and CPU backends. Basic vision support as well.
import os
import torch
from qwen_vl_utils import process_vision_info
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.prompt import Confirm, Prompt
from transformers import (
AutoProcessor,
Qwen2_5_VLForConditionalGeneration,
)
# Initialize rich console as a global object for use throughout the program
console = Console()
class QwenInference:
def __init__(self):
self.device = self.detect_device()
self.model = None
self.processor = None
self.chat_history = []
self.mps_device = None
def detect_device(self):
# Check for CUDA or MPS (Apple Silicon) availability
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
return device
def load_model(self, model_path):
# Load the model with appropriate settings for the device
torch_dtype = torch.float16 # bfloat16 is better but not all devices support it
# For MPS, we need to ensure compatibility
if self.device == "mps":
# MPS-specific settings
torch.mps.set_per_process_memory_fraction(
0.99
) # Allow using 99% of available memory
if hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache() # Clear cache before loading
# Set the mps_device for future use
self.mps_device = torch.device("mps")
# Load the model with appropriate device configuration
if self.device == "mps":
# For MPS, we need to explicitly move the model after loading
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch_dtype,
)
# Explicitly move model to MPS device
self.model = self.model.to(self.mps_device)
else:
# For CUDA or CPU, use automatic device mapping
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype=torch_dtype, device_map="auto"
)
if self.device == "cuda":
self.mps_device = self.model.device
# Load processor
self.processor = AutoProcessor.from_pretrained(model_path)
return self.model, self.processor
def build_messages(self, user_input, use_history=True, custom_text=None):
"""Build messages for model input, handling both text and image inputs."""
# Handle image input
if user_input.startswith("image:"):
image_path = user_input[6:].strip()
if not os.path.exists(image_path):
return None, f"Image file not found: {image_path}"
# Create messages with image
text_prompt = custom_text or "Describe this image."
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": f"file://{image_path}",
},
{"type": "text", "text": text_prompt},
],
}
]
return messages, None
else:
# Text-only messages
messages = [{"role": "user", "content": user_input}]
# Add chat history if available and requested
if use_history and self.chat_history:
messages = self.chat_history + messages
return messages, None
def prepare_inference_inputs(self, messages):
# Prepare for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Process images if present
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Ensure all tensors are on the same device as the model
inputs = {
k: v.to(self.mps_device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()
}
return inputs
def process_model_output(self, generated_ids, inputs):
if "input_ids" not in inputs:
raise KeyError("Missing 'input_ids' in processor outputs")
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
# Verify output is not empty
if not output_text or output_text.isspace():
return None
return output_text
def generate_response(self, inputs):
# Configure generation parameters
generation_params = {
"max_new_tokens": 512,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": True,
}
# For MPS, we'll use a more memory-efficient approach
if self.device == "mps":
# Clear cache before generation
if hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
# Use more conservative parameters for MPS
generation_params["max_new_tokens"] = 256 # Reduce for memory efficiency
# Additional settings for MPS stability
generation_params["num_beams"] = 1 # Disable beam search
generation_params["use_cache"] = True # Enable KV caching for speed
# Generate the response with error handling
try:
return self.model.generate(**inputs, **generation_params)
except torch.cuda.OutOfMemoryError:
console.print(
"[yellow]GPU ran out of memory. Trying with more conservative settings...[/yellow]"
)
# Retry with more conservative settings
if hasattr(torch.cuda, "empty_cache"):
torch.cuda.empty_cache()
generation_params["max_new_tokens"] = 128
generation_params["num_beams"] = 1
return self.model.generate(**inputs, **generation_params)
except RuntimeError as e:
if "MPS" in str(e) and "memory" in str(e).lower():
console.print(
"[yellow]Apple Silicon GPU ran out of memory. Trying with more conservative settings...[/yellow]"
)
# Retry with more conservative settings
if hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
# Fall back to CPU if needed
inputs = {
k: v.to("cpu") if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()
}
temp_model = self.model.to("cpu")
generation_params["max_new_tokens"] = 64
generation_params["num_beams"] = 1
generated_ids = temp_model.generate(**inputs, **generation_params)
# Move model back to original device afterward
self.model = self.model.to(self.mps_device)
# Return inputs to original device as well
inputs = {
k: v.to(self.mps_device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()
}
return generated_ids
else:
raise
def add_to_chat_history(self, user_message, model_response):
"""Add a message pair to the chat history."""
self.chat_history.append(user_message)
self.chat_history.append(model_response)
# Keep chat history limited to last 10 messages to avoid context length issues
if len(self.chat_history) > 10:
self.chat_history = self.chat_history[-10:]
def clear_chat_history(self):
"""Clear the chat history."""
self.chat_history = []
def run_inference(self, messages, user_input=None):
"""Process messages and return the model's response."""
try:
# Prepare inputs for inference
inputs = self.prepare_inference_inputs(messages)
# Generate response with the model
generated_ids = self.generate_response(inputs)
# Process the output
output_text = self.process_model_output(generated_ids, inputs)
if not output_text:
return None, "I couldn't generate a proper response. Let's try again."
# Update chat history for context
if user_input is not None:
# Preserve the original message format (image or text)
if isinstance(messages[0]["content"], list): # Image content
user_message = {"role": "user", "content": messages[0]["content"]}
else: # Text content
user_message = {"role": "user", "content": user_input}
model_message = {"role": "assistant", "content": output_text}
self.add_to_chat_history(user_message, model_message)
return output_text, None
except Exception as e:
return None, f"Error occurred: {e}"
def get_available_models():
"""Get available models from the __models__ directory."""
models_path = "__models__"
# Check if __models__ directory exists
if not os.path.exists(models_path):
return []
# List directories in __models__ that likely contain models
models = []
for item in os.listdir(models_path):
item_path = os.path.join(models_path, item)
# Check if it's a directory and likely contains a model
if os.path.isdir(item_path) and (
os.path.exists(os.path.join(item_path, "config.json"))
or os.path.exists(os.path.join(item_path, "model.safetensors"))
or os.path.exists(os.path.join(item_path, "pytorch_model.bin"))
):
models.append(item)
return models
def main():
# Print welcome message
console.print(
Panel.fit(
"[bold cyan]Qwen Multimodal Chat Interface[/bold cyan]",
border_style="cyan",
padding=(1, 2),
)
)
# Create inference instance
qwen = QwenInference()
if qwen.device == "cpu":
console.print(
Panel.fit(
"[bold yellow]WARNING[/bold yellow]: Neither CUDA nor MPS detected. Running on CPU will be extremely slow for these models.\n"
"Consider using a machine with GPU support or Apple Silicon.",
border_style="yellow",
)
)
proceed = Confirm.ask("Do you want to continue anyway?")
if not proceed:
console.print("[yellow]Exiting...[/yellow]")
return
elif qwen.device == "mps":
console.print(
Panel.fit(
"[bold green]Using Apple Silicon GPU acceleration (MPS).[/bold green]\n"
"NOTE: These are large models which may exceed available memory on some Apple Silicon devices.\n"
"If you encounter memory errors, you might need to try a smaller model variant.",
border_style="green",
)
)
else:
console.print(
Panel.fit(
"[bold green]Using CUDA GPU acceleration.[/bold green]",
border_style="green",
)
)
# Get available models
available_models = get_available_models()
if not available_models:
console.print(
Panel.fit(
"[bold red]No models found in the __models__ directory.[/bold red]\n"
"Please download at least one model and place it in the __models__ directory.",
border_style="red",
)
)
return
# Let user choose a model
console.print("[bold]Available models:[/bold]")
for i, model_name in enumerate(available_models, 1):
console.print(f" {i}. [cyan]{model_name}[/cyan]")
# Get user choice
while True:
choice = Prompt.ask("\nSelect model by number", default="1")
try:
choice_idx = int(choice) - 1
if 0 <= choice_idx < len(available_models):
selected_model = available_models[choice_idx]
break
else:
console.print(
"[bold red]Invalid selection. Please try again.[/bold red]"
)
except ValueError:
console.print("[bold red]Please enter a number.[/bold red]")
console.print(f"[bold green]Selected model:[/bold green] {selected_model}")
# Set model path
model_path = f"__models__/{selected_model}"
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]Loading model...[/bold blue]"),
console=console,
) as progress:
task = progress.add_task("Loading", total=None)
try:
qwen.load_model(model_path)
progress.update(task, completed=True)
console.print("[bold green]Model loaded successfully![/bold green]")
except Exception as e:
progress.update(task, completed=True)
console.print(f"[bold red]Error loading model:[/bold red] {e}")
console.print(
"[yellow]Make sure you have downloaded the model correctly using huggingface-cli.[/yellow]"
)
return
# Show instructions
console.print(
Panel.fit(
"[bold]Chat Session Instructions:[/bold]\n"
"- Type [cyan]exit[/cyan] or [cyan]quit[/cyan] to end the conversation\n"
"- To use an image, type [cyan]image:[/cyan] followed by the path to the image\n"
"- Example: [cyan]image:/path/to/your/image.jpg[/cyan]",
title="Instructions",
border_style="cyan",
)
)
# Chat loop
while True:
# Get user input
user_input = Prompt.ask("\n[bold green]You[/bold green]")
if user_input.lower() in ["exit", "quit"]:
console.print("[bold cyan]Exiting chat. Goodbye![/bold cyan]")
break
# Build messages based on input type
if user_input.startswith("image:"):
# First build basic image messages
messages, error = qwen.build_messages(user_input, use_history=False)
if messages is None:
console.print(f"[bold red]{error}[/bold red]")
continue
# Ask for specific question about the image
text_input = Prompt.ask(
"[bold]What would you like to ask about this image?[/bold]"
)
# Rebuild messages with the custom text
messages, _ = qwen.build_messages(
user_input, use_history=False, custom_text=text_input
)
else:
# Build text messages
messages, error = qwen.build_messages(user_input)
if messages is None:
console.print(f"[bold red]{error}[/bold red]")
continue
# Common inference processing for all input types
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]Qwen is thinking...[/bold blue]"),
console=console,
transient=True,
) as progress:
task = progress.add_task("Thinking", total=None)
# Process message through the inference pipeline
output_text, error = qwen.run_inference(messages, user_input)
progress.update(task, completed=True)
# Handle response
if output_text is None:
console.print(f"\n[bold yellow]Qwen:[/bold yellow] {error}")
continue
# Display response as markdown for better formatting
console.print("\n[bold blue]Qwen:[/bold blue]", end=" ")
console.print(Markdown(output_text))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment