Skip to content

Instantly share code, notes, and snippets.

@mokshchadha
Created September 6, 2024 08:33
Show Gist options
  • Save mokshchadha/30bbfd1ac31738953f8ac97723bfcca2 to your computer and use it in GitHub Desktop.
Save mokshchadha/30bbfd1ac31738953f8ac97723bfcca2 to your computer and use it in GitHub Desktop.
Get structred output from a LLM
import os
import time
from dotenv import load_dotenv
load_dotenv()
from pydantic import BaseModel, Field
from typing import TypedDict, Literal
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
from langgraph.graph import MessagesState
from langchain.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
class WeatherResponse(BaseModel):
temperature: float = Field(description="The temperature in Fahrenheit")
wind_direction: str = Field(description="The direction of wind in abbreviated form")
wind_speed: float = Field(description="The speed of the wind in mph")
class AgentInput(MessagesState):
pass
class AgentOutput(TypedDict):
final_response: WeatherResponse
class AgentState(MessagesState):
final_response: WeatherResponse | None = None
@tool
def get_weather(city: Literal['nyc', 'sf']) -> str:
"""
Get the weather information for a specific city.
Args:
city (str): The city to get weather information for. Must be either 'nyc' or 'sf'.
Returns:
str: A string containing weather information for the specified city.
"""
if city == "nyc":
return "It is cloudy in nyc with 5mph, wind in north-east direction and a temperature of 40 degrees"
elif city == "sf":
return "It is 75 degrees and sunny in sf, with 3mph winds in the south-east direction"
else:
raise ValueError("Unknown city")
model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=os.getenv("GOOGLE_API_KEY"))
tools = [get_weather]
model_with_tools = model.bind_tools(tools)
def call_model(state: AgentState):
response = model_with_tools.invoke(state["messages"])
return {"messages": state["messages"] + [response]}
def respond(state: AgentState):
last_message = state["messages"][-1]
content = last_message.content
temp = float(content.split("temperature of ")[1].split(" degrees")[0])
wind_speed = float(content.split("with ")[1].split("mph")[0])
wind_direction = content.split("wind in ")[1].split(" direction")[0]
response = WeatherResponse(
temperature=temp,
wind_speed=wind_speed,
wind_direction=wind_direction
)
return {"final_response": response}
def should_continue(state: AgentState):
messages = state["messages"]
last_message = messages[-1]
if "temperature" in last_message.content.lower() and "wind" in last_message.content.lower():
return "respond"
else:
return "continue"
workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model)
workflow.add_node("respond", respond)
workflow.add_node("tools", ToolNode(tools))
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
should_continue,
{
"continue": "tools",
"respond": "respond"
},
)
workflow.add_edge("tools", "agent")
workflow.add_edge("respond", END)
graph = workflow.compile()
# Example usage
inputs = {"messages": [{"role": "user", "content": "What's the weather like in NYC?"}]}
start_time = time.time()
result = graph.invoke(inputs)
end_time = time.time()
execution_time = end_time - start_time
print(result["final_response"])
print(f"Execution time: {execution_time:.2f} seconds")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment