import os
import base64
import json
import re
from typing import List, Dict, Any, Optional, Union, Type, TypeVar
from pydantic import BaseModel, Field
from pathlib import Path
# Type variable for Pydantic models
T = TypeVar('T', bound=BaseModel)
# ===== MISTRAL OCR IMPLEMENTATION =====
def process_document_with_mistral_ocr(
file_path,
api_key=None,
model="mistral-ocr-latest",
output_format="markdown",
save_to_file=None
):
"""
Process a document using Mistral OCR API and return the results.
Args:
file_path (str): Path to the document file to process.
api_key (str, optional): Mistral API key. Defaults to MISTRAL_API_KEY environment variable.
model (str, optional): Mistral OCR model to use. Defaults to "mistral-ocr-latest".
output_format (str, optional): Output format - "markdown", "json", or "html". Defaults to "markdown".
save_to_file (str, optional): Path to save the output. If None, returns the result.
Returns:
The OCR results in the specified format.
"""
# Import here to make it optional
from mistralai import Mistral
# Get API key from environment if not provided
if not api_key:
api_key = os.environ.get("MISTRAL_API_KEY")
if not api_key:
raise ValueError("No API key provided and MISTRAL_API_KEY environment variable not set.")
# Initialize Mistral client
client = Mistral(api_key=api_key)
# Process the file
pdf_file = Path(file_path)
print(f"Uploading file {pdf_file.name}...")
# Upload the file
uploaded_file = client.files.upload(
file={
"file_name": pdf_file.stem,
"content": pdf_file.read_bytes(),
},
purpose="ocr",
)
# Get signed URL for the uploaded file
signed_url = client.files.get_signed_url(file_id=uploaded_file.id, expiry=1)
print(f"Processing with OCR model: {model}...")
# Process the document with OCR
ocr_response = client.ocr.process(
document={"type": "document_url", "document_url": signed_url.url},
model=model,
include_image_base64=True,
)
# Process the response based on the requested output format
if output_format == "json":
result = json.loads(ocr_response.model_dump_json())
else:
# Get markdown content
result = ocr_response.pages.markdown
# Convert to HTML if requested
if output_format == "html":
import markdown
result = markdown.markdown(result, extensions=['tables', 'fenced_code'])
result = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body {{ font-family: Arial, sans-serif; line-height: 1.6; max-width: 800px; margin: 0 auto; padding: 20px; }}
img {{ max-width: 100%; }}
pre {{ background-color: #f5f5f5; padding: 10px; overflow: auto; }}
table {{ border-collapse: collapse; width: 100%; }}
th, td {{ border: 1px solid #ddd; padding: 8px; }}
th {{ background-color: #f2f2f2; }}
</style>
</head>
<body>
{result}
</body>
</html>
"""
# Save to file if requested
if save_to_file:
with open(save_to_file, 'w', encoding='utf-8') as f:
f.write(result)
print(f"Results saved to {save_to_file}")
return None
return result
# ===== VISION LLM EXTRACTION IMPLEMENTATION =====
def encode_image_to_base64(image_path: str) -> str:
"""Convert an image file to base64 encoding."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def extract_json_from_text(text: str) -> Optional[str]:
"""Extract JSON object from text that might contain additional content."""
# Try to find JSON in markdown code blocks (using backticks)
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
if json_match:
json_str = json_match.group(1)
else:
# Try to find JSON based on braces
json_match = re.search(r'({[\s\S]*})', text)
if json_match:
json_str = json_match.group(1)
else:
json_str = text
json_str = json_str.strip()
try:
json.loads(json_str)
return json_str
except json.JSONDecodeError:
return None
def generate_extraction_prompt(model_class: Type[T]) -> str:
"""Generate a prompt based on a Pydantic model structure."""
schema = model_class.model_json_schema()
prompt = "Analyze this image and extract information in JSON format.\n\n"
prompt += f"Return the data according to this JSON schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n"
prompt += "Important guidelines:\n"
prompt += "1. Only return valid JSON that conforms to the schema.\n"
prompt += "2. If you're not sure about a field, use null instead of guessing.\n"
prompt += "3. Don't add any explanations outside the JSON structure.\n"
prompt += "4. Extract as much relevant information as possible.\n"
return prompt
def extract_structured_data_from_image(
image_path: str,
model_class: Type[T],
model_name: str = "ollama/llava-phi3",
api_base: str = "http://localhost:11434",
custom_prompt: Optional[str] = None,
max_retries: int = 2
) -> T:
"""
Extract structured data from an image based on a Pydantic model.
Args:
image_path (str): Path to the image file.
model_class (Type[T]): Pydantic model class defining the structure to extract.
model_name (str): Name of the vision model to use.
api_base (str): API endpoint for the model.
custom_prompt (Optional[str]): Optional custom prompt to use. If provided, this overrides the generated prompt.
max_retries (int): Maximum number of retry attempts if parsing fails.
Returns:
An instance of the provided model_class with extracted data.
"""
import litellm
try:
base64_image = encode_image_to_base64(image_path)
except Exception as e:
print(f"Error encoding image: {e}")
return model_class()
# Use a custom prompt if provided; otherwise, use generated prompt
prompt = custom_prompt if custom_prompt else generate_extraction_prompt(model_class)
# For a simplified use case (like the additional code snippet), you might want to override
# the prompt with a fixed instruction. Uncomment the following line to use a fixed prompt:
# prompt = "Extract all the relevant data from the image in JSON format"
for attempt in range(max_retries + 1):
try:
response = litellm.completion(
model=model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}"
}
}
]
}
],
api_base=api_base
)
response_text = response.choices[0].message.content
json_str = extract_json_from_text(response_text)
if json_str:
structured_data = model_class.model_validate_json(json_str)
return structured_data
else:
if attempt == max_retries:
print("Failed to extract JSON from response.")
return model_class()
except Exception as e:
print(f"Error in attempt {attempt + 1}: {e}")
if attempt == max_retries:
return model_class()
return model_class()
# ===== UTILITY FUNCTIONS =====
def save_structured_data(data: BaseModel, output_path: str):
"""Save structured data to a JSON file."""
with open(output_path, "w") as f:
f.write(data.model_dump_json(indent=2))
print(f"Data saved to {output_path}")
def read_structured_data(file_path: str, model_class: Type[T]) -> T:
"""Read structured data from a JSON file into a Pydantic model."""
with open(file_path, "r") as f:
json_data = f.read()
return model_class.model_validate_json(json_data)
# ===== PYDANTIC MODELS =====
class Person(BaseModel):
"""Data model for a person detected in an image."""
name: Optional[str] = None
gender: Optional[str] = None
approximate_age: Optional[str] = None
clothing_description: Optional[str] = None
position_in_image: Optional[str] = None
class Object(BaseModel):
"""Data model for an object detected in an image."""
name: str
color: Optional[str] = None
size: Optional[str] = None
position_in_image: Optional[str] = None
quantity: Optional[int] = 1
class SceneDescription(BaseModel):
"""Data model for an overall scene description."""
setting: Optional[str] = None
time_of_day: Optional[str] = None
weather: Optional[str] = None
general_mood: Optional[str] = None
key_activities: Optional[List[str]] = None
class ImageData(BaseModel):
"""Overall data model for information extracted from an image."""
persons: List[Person] = Field(default_factory=list)
objects: List[Object] = Field(default_factory=list)
scene_description: SceneDescription = Field(default_factory=SceneDescription)
text_in_image: Optional[List[str]] = None
additional_notes: Optional[str] = None
class DocumentData(BaseModel):
"""Data model for extracting information from document images."""
title: Optional[str] = None
date: Optional[str] = None
document_type: Optional[str] = None
content_summary: Optional[str] = None
key_points: List[str] = Field(default_factory=list)
entities_mentioned: List[str] = Field(default_factory=list)
class ProductData(BaseModel):
"""Data model for extracting product information from images."""
product_name: Optional[str] = None
brand: Optional[str] = None
category: Optional[str] = None
color: Optional[str] = None
features: List[str] = Field(default_factory=list)
condition: Optional[str] = None
estimated_price_range: Optional[str] = None
class ReceiptData(BaseModel):
"""Model for extracting information from receipts."""
store_name: Optional[str] = None
date: Optional[str] = None
total_amount: Optional[str] = None
items: List[Dict[str, Any]] = Field(default_factory=list)
payment_method: Optional[str] = None
tax_amount: Optional[str] = None
# ===== EXAMPLE USAGE =====
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Extract data from documents and images")
parser.add_argument("file_path", help="Path to the document or image file to process")
parser.add_argument("--method", choices=["ocr", "vision"], default="ocr",
help="Method to use: 'ocr' for Mistral OCR or 'vision' for vision LLM extraction")
parser.add_argument("--api-key", help="API key (defaults to environment variable)")
parser.add_argument("--model", help="Model to use (defaults based on method)")
parser.add_argument("--format", choices=["markdown", "json", "html"], default="markdown",
help="Output format for OCR method")
parser.add_argument("--output", "-o", help="Path to save the output")
parser.add_argument("--schema", choices=["image", "document", "product", "receipt"], default="document",
help="Schema to use for vision method")
parser.add_argument("--custom_prompt", help="Optional custom prompt for vision extraction", default=None)
args = parser.parse_args()
# Process based on method
if args.method == "ocr":
# Use Mistral OCR
result = process_document_with_mistral_ocr(
args.file_path,
api_key=args.api_key,
model=args.model or "mistral-ocr-latest",
output_format=args.format,
save_to_file=args.output
)
if result and not args.output:
print(result)
else:
# Use vision LLM extraction
model_classes = {
"image": ImageData,
"document": DocumentData,
"product": ProductData,
"receipt": ReceiptData
}
model_class = model_classes[args.schema]
result = extract_structured_data_from_image(
args.file_path,
model_class,
model_name=args.model or "ollama/llava-phi3",
custom_prompt=args.custom_prompt,
)
if args.output:
save_structured_data(result, args.output)
else:
print(result.model_dump_json(indent=2))
-
Custom vs. Generated Prompts:
You can now override the generated prompt by passing a--custom_prompt
parameter. (The code also includes a commented line that you can uncomment if you prefer a fixed prompt.) -
Consistent Base64 Encoding:
Theencode_image_to_base64
function is shared between the helper and extraction functions. -
Unified Vision Extraction:
The vision extraction function uses the samelitellm.completion
call as in your provided snippet, ensuring the image is sent in a proper base64 format. -
Command Line Flexibility:
Run the script from the command line using the provided examples, choosing between OCR and vision extraction modes.
This updated script now brings together both methods with enhanced flexibility for your image and document processing needs.