Last active
November 26, 2024 00:43
-
-
Save aaronvg/00b74e27f9653df9e15511e160059973 to your computer and use it in GitHub Desktop.
instructor-pydantic-1b-param-example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from pydantic import BaseModel, Field | |
| from typing import Optional, List, Literal | |
| from datetime import datetime | |
| class LogMetadata(BaseModel): | |
| environment: str = Field(description="Production, staging, development") | |
| resource_id: str = Field(description="ID of the resource generating the log") | |
| duration: Optional[int] = Field( | |
| None, description="Duration of the operation in milliseconds" | |
| ) | |
| user_id: Optional[str] = Field(None, description="User ID if available") | |
| request_id: Optional[str] = Field(None, description="Request ID for tracing") | |
| class SecurityInfo(BaseModel): | |
| threat_level: Literal["LOW", "MEDIUM", "HIGH"] | |
| ip_address: Optional[str] = Field( | |
| None, description="Source IP address if available" | |
| ) | |
| suspicious_activity: bool = Field( | |
| description="Indicates if this is suspicious activity" | |
| ) | |
| related_events: List[str] = Field(description="Related security event IDs") | |
| class LogEntry(BaseModel): | |
| timestamp: str = Field(description="The timestamp from the log entry") | |
| severity: Literal["INFO", "WARN", "ERROR", "CRITICAL", "DEBUG"] | |
| service_name: str = Field(description="Name of the service that generated the log") | |
| message: str = Field(description="The main log message") | |
| error_type: Optional[str] = Field(None, description="Type of error if present") | |
| stack_trace: Optional[str] = Field(None, description="Stack trace if available") | |
| metadata: LogMetadata | |
| security: Optional[SecurityInfo] = None | |
| # class Config: | |
| # This makes the model's schema_json() output more detailed | |
| # json_schema_extra = { | |
| # "example": { | |
| # "timestamp": "2024-03-15T10:23:45Z", | |
| # "severity": "ERROR", | |
| # "service_name": "AuthService", | |
| # "message": "Failed login attempt", | |
| # "metadata": { | |
| # "environment": "production", | |
| # "resource_id": "auth-1", | |
| # "request_id": "req-123abc", | |
| # "duration": 145, | |
| # }, | |
| # } | |
| # } | |
| import instructor | |
| from openai import OpenAI | |
| import os | |
| import pytest | |
| import logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| client = instructor.from_openai( | |
| OpenAI( | |
| api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1" | |
| ), | |
| mode=instructor.Mode.JSON, | |
| ) | |
| # client.chat.completions.create( | |
| # model="meta-llama/llama-3.2-1b-instruct", | |
| # messages=[{"role": "user", "content": "Hello, how are you?"}], | |
| # response_model=LogEntry, | |
| # ) | |
| @pytest.fixture | |
| def llm_client(): | |
| return client | |
| def test_simple_log_parse(llm_client): | |
| log_text = """ | |
| 2024-03-15T10:23:45Z [ERROR] AuthService: Failed login attempt - Invalid credentials | |
| Environment: production | |
| RequestId: req-123abc | |
| User: user-789 | |
| Duration: 145ms | |
| Client IP: 192.168.1.100 | |
| """ | |
| response = llm_client.chat.completions.create( | |
| model="meta-llama/llama-3.2-1b-instruct", | |
| messages=[{"role": "user", "content": f"Parse the following log:\n{log_text}"}], | |
| response_model=LogEntry, | |
| max_retries=0, | |
| ) | |
| print(response) | |
| assert isinstance(response, LogEntry) | |
| # assert response.severity == "ERROR" | |
| # assert response.service_name == "AuthService" | |
| # assert response.metadata.environment == "production" | |
| # assert response.metadata.request_id == "req-123abc" | |
| # assert response.metadata.duration == 145 | |
| def test_complex_log_parse(llm_client): | |
| log_text = """ | |
| 2024-03-15T11:45:12Z [CRITICAL] PaymentService: Database connection timeout | |
| Stack Trace: | |
| at Connection.connect (/src/db.js:45) | |
| at PaymentProcessor.process (/src/payments.js:102) | |
| Environment: production | |
| ResourceId: svc-payments-01 | |
| RequestId: req-456def | |
| Duration: 5000ms | |
| """ | |
| response = llm_client.chat.completions.create( | |
| model="meta-llama/llama-3.2-1b-instruct", | |
| messages=[{"role": "user", "content": f"Parse the following log:\n{log_text}"}], | |
| response_model=LogEntry, | |
| max_retries=0, | |
| ) | |
| print(response) | |
| assert isinstance(response, LogEntry) | |
| # assert response.severity == "CRITICAL" | |
| # assert response.service_name == "PaymentService" | |
| # assert response.stack_trace is not None | |
| # assert response.metadata.resource_id == "svc-payments-01" | |
| # assert response.metadata.duration == 5000 | |
| if __name__ == "__main__": | |
| pytest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment