Created
September 6, 2024 08:33
-
-
Save mokshchadha/30bbfd1ac31738953f8ac97723bfcca2 to your computer and use it in GitHub Desktop.
Get structred output from a LLM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
from dotenv import load_dotenv | |
load_dotenv() | |
from pydantic import BaseModel, Field | |
from typing import TypedDict, Literal | |
from langgraph.graph import StateGraph, END | |
from langgraph.prebuilt import ToolNode | |
from langgraph.graph import MessagesState | |
from langchain.tools import tool | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
class WeatherResponse(BaseModel): | |
temperature: float = Field(description="The temperature in Fahrenheit") | |
wind_direction: str = Field(description="The direction of wind in abbreviated form") | |
wind_speed: float = Field(description="The speed of the wind in mph") | |
class AgentInput(MessagesState): | |
pass | |
class AgentOutput(TypedDict): | |
final_response: WeatherResponse | |
class AgentState(MessagesState): | |
final_response: WeatherResponse | None = None | |
@tool | |
def get_weather(city: Literal['nyc', 'sf']) -> str: | |
""" | |
Get the weather information for a specific city. | |
Args: | |
city (str): The city to get weather information for. Must be either 'nyc' or 'sf'. | |
Returns: | |
str: A string containing weather information for the specified city. | |
""" | |
if city == "nyc": | |
return "It is cloudy in nyc with 5mph, wind in north-east direction and a temperature of 40 degrees" | |
elif city == "sf": | |
return "It is 75 degrees and sunny in sf, with 3mph winds in the south-east direction" | |
else: | |
raise ValueError("Unknown city") | |
model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=os.getenv("GOOGLE_API_KEY")) | |
tools = [get_weather] | |
model_with_tools = model.bind_tools(tools) | |
def call_model(state: AgentState): | |
response = model_with_tools.invoke(state["messages"]) | |
return {"messages": state["messages"] + [response]} | |
def respond(state: AgentState): | |
last_message = state["messages"][-1] | |
content = last_message.content | |
temp = float(content.split("temperature of ")[1].split(" degrees")[0]) | |
wind_speed = float(content.split("with ")[1].split("mph")[0]) | |
wind_direction = content.split("wind in ")[1].split(" direction")[0] | |
response = WeatherResponse( | |
temperature=temp, | |
wind_speed=wind_speed, | |
wind_direction=wind_direction | |
) | |
return {"final_response": response} | |
def should_continue(state: AgentState): | |
messages = state["messages"] | |
last_message = messages[-1] | |
if "temperature" in last_message.content.lower() and "wind" in last_message.content.lower(): | |
return "respond" | |
else: | |
return "continue" | |
workflow = StateGraph(AgentState) | |
workflow.add_node("agent", call_model) | |
workflow.add_node("respond", respond) | |
workflow.add_node("tools", ToolNode(tools)) | |
workflow.set_entry_point("agent") | |
workflow.add_conditional_edges( | |
"agent", | |
should_continue, | |
{ | |
"continue": "tools", | |
"respond": "respond" | |
}, | |
) | |
workflow.add_edge("tools", "agent") | |
workflow.add_edge("respond", END) | |
graph = workflow.compile() | |
# Example usage | |
inputs = {"messages": [{"role": "user", "content": "What's the weather like in NYC?"}]} | |
start_time = time.time() | |
result = graph.invoke(inputs) | |
end_time = time.time() | |
execution_time = end_time - start_time | |
print(result["final_response"]) | |
print(f"Execution time: {execution_time:.2f} seconds") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment