Skip to content

Instantly share code, notes, and snippets.

@essevan
Forked from CraftsMan-Labs/IDP.md
Created March 8, 2025 14:06
Show Gist options
  • Save essevan/f8c685be596d0eed7fe377336d1ff3a5 to your computer and use it in GitHub Desktop.
Save essevan/f8c685be596d0eed7fe377336d1ff3a5 to your computer and use it in GitHub Desktop.
import os
import base64
import json
import re
from typing import List, Dict, Any, Optional, Union, Type, TypeVar
from pydantic import BaseModel, Field
from pathlib import Path

# Type variable for Pydantic models
T = TypeVar('T', bound=BaseModel)

# ===== MISTRAL OCR IMPLEMENTATION =====

def process_document_with_mistral_ocr(
    file_path, 
    api_key=None, 
    model="mistral-ocr-latest", 
    output_format="markdown",
    save_to_file=None
):
    """
    Process a document using Mistral OCR API and return the results.
    
    Args:
        file_path (str): Path to the document file to process.
        api_key (str, optional): Mistral API key. Defaults to MISTRAL_API_KEY environment variable.
        model (str, optional): Mistral OCR model to use. Defaults to "mistral-ocr-latest".
        output_format (str, optional): Output format - "markdown", "json", or "html". Defaults to "markdown".
        save_to_file (str, optional): Path to save the output. If None, returns the result.
        
    Returns:
        The OCR results in the specified format.
    """
    # Import here to make it optional
    from mistralai import Mistral

    # Get API key from environment if not provided
    if not api_key:
        api_key = os.environ.get("MISTRAL_API_KEY")
        if not api_key:
            raise ValueError("No API key provided and MISTRAL_API_KEY environment variable not set.")

    # Initialize Mistral client
    client = Mistral(api_key=api_key)

    # Process the file
    pdf_file = Path(file_path)
    print(f"Uploading file {pdf_file.name}...")

    # Upload the file
    uploaded_file = client.files.upload(
        file={
            "file_name": pdf_file.stem,
            "content": pdf_file.read_bytes(),
        },
        purpose="ocr",
    )

    # Get signed URL for the uploaded file
    signed_url = client.files.get_signed_url(file_id=uploaded_file.id, expiry=1)

    print(f"Processing with OCR model: {model}...")

    # Process the document with OCR
    ocr_response = client.ocr.process(
        document={"type": "document_url", "document_url": signed_url.url},
        model=model,
        include_image_base64=True,
    )

    # Process the response based on the requested output format
    if output_format == "json":
        result = json.loads(ocr_response.model_dump_json())
    else:
        # Get markdown content
        result = ocr_response.pages.markdown

        # Convert to HTML if requested
        if output_format == "html":
            import markdown
            result = markdown.markdown(result, extensions=['tables', 'fenced_code'])
            result = f"""
<!DOCTYPE html>
<html>
<head>
  <meta charset="UTF-8">
  <style>
    body {{ font-family: Arial, sans-serif; line-height: 1.6; max-width: 800px; margin: 0 auto; padding: 20px; }}
    img {{ max-width: 100%; }}
    pre {{ background-color: #f5f5f5; padding: 10px; overflow: auto; }}
    table {{ border-collapse: collapse; width: 100%; }}
    th, td {{ border: 1px solid #ddd; padding: 8px; }}
    th {{ background-color: #f2f2f2; }}
  </style>
</head>
<body>
{result}
</body>
</html>
"""

    # Save to file if requested
    if save_to_file:
        with open(save_to_file, 'w', encoding='utf-8') as f:
            f.write(result)
        print(f"Results saved to {save_to_file}")
        return None

    return result

# ===== VISION LLM EXTRACTION IMPLEMENTATION =====

def encode_image_to_base64(image_path: str) -> str:
    """Convert an image file to base64 encoding."""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def extract_json_from_text(text: str) -> Optional[str]:
    """Extract JSON object from text that might contain additional content."""
    # Try to find JSON in markdown code blocks (using backticks)
    json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
    if json_match:
        json_str = json_match.group(1)
    else:
        # Try to find JSON based on braces
        json_match = re.search(r'({[\s\S]*})', text)
        if json_match:
            json_str = json_match.group(1)
        else:
            json_str = text

    json_str = json_str.strip()
    try:
        json.loads(json_str)
        return json_str
    except json.JSONDecodeError:
        return None

def generate_extraction_prompt(model_class: Type[T]) -> str:
    """Generate a prompt based on a Pydantic model structure."""
    schema = model_class.model_json_schema()
    prompt = "Analyze this image and extract information in JSON format.\n\n"
    prompt += f"Return the data according to this JSON schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n"
    prompt += "Important guidelines:\n"
    prompt += "1. Only return valid JSON that conforms to the schema.\n"
    prompt += "2. If you're not sure about a field, use null instead of guessing.\n"
    prompt += "3. Don't add any explanations outside the JSON structure.\n"
    prompt += "4. Extract as much relevant information as possible.\n"
    return prompt

def extract_structured_data_from_image(
    image_path: str, 
    model_class: Type[T],
    model_name: str = "ollama/llava-phi3", 
    api_base: str = "http://localhost:11434",
    custom_prompt: Optional[str] = None,
    max_retries: int = 2
) -> T:
    """
    Extract structured data from an image based on a Pydantic model.
    
    Args:
        image_path (str): Path to the image file.
        model_class (Type[T]): Pydantic model class defining the structure to extract.
        model_name (str): Name of the vision model to use.
        api_base (str): API endpoint for the model.
        custom_prompt (Optional[str]): Optional custom prompt to use. If provided, this overrides the generated prompt.
        max_retries (int): Maximum number of retry attempts if parsing fails.
        
    Returns:
        An instance of the provided model_class with extracted data.
    """
    import litellm

    try:
        base64_image = encode_image_to_base64(image_path)
    except Exception as e:
        print(f"Error encoding image: {e}")
        return model_class()

    # Use a custom prompt if provided; otherwise, use generated prompt
    prompt = custom_prompt if custom_prompt else generate_extraction_prompt(model_class)
    
    # For a simplified use case (like the additional code snippet), you might want to override
    # the prompt with a fixed instruction. Uncomment the following line to use a fixed prompt:
    # prompt = "Extract all the relevant data from the image in JSON format"

    for attempt in range(max_retries + 1):
        try:
            response = litellm.completion(
                model=model_name,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": prompt
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/png;base64,{base64_image}"
                                }
                            }
                        ]
                    }
                ],
                api_base=api_base
            )
            
            response_text = response.choices[0].message.content
            json_str = extract_json_from_text(response_text)
            
            if json_str:
                structured_data = model_class.model_validate_json(json_str)
                return structured_data
            else:
                if attempt == max_retries:
                    print("Failed to extract JSON from response.")
                    return model_class()
        except Exception as e:
            print(f"Error in attempt {attempt + 1}: {e}")
            if attempt == max_retries:
                return model_class()
    
    return model_class()

# ===== UTILITY FUNCTIONS =====

def save_structured_data(data: BaseModel, output_path: str):
    """Save structured data to a JSON file."""
    with open(output_path, "w") as f:
        f.write(data.model_dump_json(indent=2))
    print(f"Data saved to {output_path}")

def read_structured_data(file_path: str, model_class: Type[T]) -> T:
    """Read structured data from a JSON file into a Pydantic model."""
    with open(file_path, "r") as f:
        json_data = f.read()
    return model_class.model_validate_json(json_data)

# ===== PYDANTIC MODELS =====

class Person(BaseModel):
    """Data model for a person detected in an image."""
    name: Optional[str] = None
    gender: Optional[str] = None
    approximate_age: Optional[str] = None
    clothing_description: Optional[str] = None
    position_in_image: Optional[str] = None

class Object(BaseModel):
    """Data model for an object detected in an image."""
    name: str
    color: Optional[str] = None
    size: Optional[str] = None
    position_in_image: Optional[str] = None
    quantity: Optional[int] = 1

class SceneDescription(BaseModel):
    """Data model for an overall scene description."""
    setting: Optional[str] = None
    time_of_day: Optional[str] = None
    weather: Optional[str] = None
    general_mood: Optional[str] = None
    key_activities: Optional[List[str]] = None

class ImageData(BaseModel):
    """Overall data model for information extracted from an image."""
    persons: List[Person] = Field(default_factory=list)
    objects: List[Object] = Field(default_factory=list)
    scene_description: SceneDescription = Field(default_factory=SceneDescription)
    text_in_image: Optional[List[str]] = None
    additional_notes: Optional[str] = None

class DocumentData(BaseModel):
    """Data model for extracting information from document images."""
    title: Optional[str] = None
    date: Optional[str] = None
    document_type: Optional[str] = None
    content_summary: Optional[str] = None
    key_points: List[str] = Field(default_factory=list)
    entities_mentioned: List[str] = Field(default_factory=list)

class ProductData(BaseModel):
    """Data model for extracting product information from images."""
    product_name: Optional[str] = None
    brand: Optional[str] = None
    category: Optional[str] = None
    color: Optional[str] = None
    features: List[str] = Field(default_factory=list)
    condition: Optional[str] = None
    estimated_price_range: Optional[str] = None

class ReceiptData(BaseModel):
    """Model for extracting information from receipts."""
    store_name: Optional[str] = None
    date: Optional[str] = None
    total_amount: Optional[str] = None
    items: List[Dict[str, Any]] = Field(default_factory=list)
    payment_method: Optional[str] = None
    tax_amount: Optional[str] = None

# ===== EXAMPLE USAGE =====

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Extract data from documents and images")
    parser.add_argument("file_path", help="Path to the document or image file to process")
    parser.add_argument("--method", choices=["ocr", "vision"], default="ocr", 
                        help="Method to use: 'ocr' for Mistral OCR or 'vision' for vision LLM extraction")
    parser.add_argument("--api-key", help="API key (defaults to environment variable)")
    parser.add_argument("--model", help="Model to use (defaults based on method)")
    parser.add_argument("--format", choices=["markdown", "json", "html"], default="markdown", 
                        help="Output format for OCR method")
    parser.add_argument("--output", "-o", help="Path to save the output")
    parser.add_argument("--schema", choices=["image", "document", "product", "receipt"], default="document",
                        help="Schema to use for vision method")
    parser.add_argument("--custom_prompt", help="Optional custom prompt for vision extraction", default=None)

    args = parser.parse_args()

    # Process based on method
    if args.method == "ocr":
        # Use Mistral OCR
        result = process_document_with_mistral_ocr(
            args.file_path,
            api_key=args.api_key,
            model=args.model or "mistral-ocr-latest",
            output_format=args.format,
            save_to_file=args.output
        )

        if result and not args.output:
            print(result)
    else:
        # Use vision LLM extraction
        model_classes = {
            "image": ImageData,
            "document": DocumentData,
            "product": ProductData,
            "receipt": ReceiptData
        }
        model_class = model_classes[args.schema]

        result = extract_structured_data_from_image(
            args.file_path,
            model_class,
            model_name=args.model or "ollama/llava-phi3",
            custom_prompt=args.custom_prompt,
        )

        if args.output:
            save_structured_data(result, args.output)
        else:
            print(result.model_dump_json(indent=2))

Key Updates

  • Custom vs. Generated Prompts:
    You can now override the generated prompt by passing a --custom_prompt parameter. (The code also includes a commented line that you can uncomment if you prefer a fixed prompt.)

  • Consistent Base64 Encoding:
    The encode_image_to_base64 function is shared between the helper and extraction functions.

  • Unified Vision Extraction:
    The vision extraction function uses the same litellm.completion call as in your provided snippet, ensuring the image is sent in a proper base64 format.

  • Command Line Flexibility:
    Run the script from the command line using the provided examples, choosing between OCR and vision extraction modes.

This updated script now brings together both methods with enhanced flexibility for your image and document processing needs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment