Skip to content

Instantly share code, notes, and snippets.

@xeniode
Last active September 20, 2024 17:56
Show Gist options
  • Save xeniode/a11706bad5ecc314cd5e4ee7c6b5e086 to your computer and use it in GitHub Desktop.
Save xeniode/a11706bad5ecc314cd5e4ee7c6b5e086 to your computer and use it in GitHub Desktop.
Gist Created in Script Kit
from langchain_openai import AzureChatOpenAI
import getpass
import os
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from typing import Annotated, List, Sequence
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
from typing_extensions import TypedDict
def _set_if_undefined(var: str) -> None:
if os.environ.get(var):
return
os.environ[var] = getpass.getpass(var)
_set_if_undefined("TAVILY_API_KEY")
azure_endpoint_env = os.environ.get("AZURE_OPENAI_API_ENDPOINT")
azure_openai_api_key_env = os.environ.get("AZURE_OPENAI_API_KEY")
azure_deployment_env = os.environ.get("AZURE_OPENAI_API_DEPLOYMENT_NAME")
azure_api_version_env = os.environ.get("AZURE_OPENAI_API_VERSION")
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an essay assistant tasked with writing excellent 5-paragraph essays."
" Generate the best essay possible for the user's request."
" If the user provides critique, respond with a revised version of your previous attempts.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
llm = AzureChatOpenAI(
azure_endpoint=azure_endpoint_env,
openai_api_key=azure_openai_api_key_env,
azure_deployment=azure_deployment_env,
api_version=azure_api_version_env,
)
generate = prompt | llm
reflection_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a teacher grading an essay submission. Generate critique and recommendations for the user's submission."
" Provide detailed recommendations, including requests for length, depth, style, etc.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
reflect = reflection_prompt | llm
class State(TypedDict):
messages: Annotated[list, add_messages]
async def generation_node(state: State) -> State:
return {"messages": [await generate.ainvoke(state["messages"])]}
async def reflection_node(state: State) -> State:
# Other messages we need to adjust
cls_map = {"ai": HumanMessage, "human": AIMessage}
# First message is the original user request. We hold it the same for all nodes
translated = [state["messages"][0]] + [
cls_map[msg.type](content=msg.content) for msg in state["messages"][1:]
]
res = await reflect.ainvoke(translated)
# We treat the output of this as human feedback for the generator
return {"messages": [HumanMessage(content=res.content)]}
builder = StateGraph(State)
builder.add_node("generate", generation_node)
builder.add_node("reflect", reflection_node)
builder.add_edge(START, "generate")
def should_continue(state: State):
if len(state["messages"]) > 6:
# End after 3 iterations
return END
return "reflect"
builder.add_conditional_edges("generate", should_continue)
builder.add_edge("reflect", "generate")
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)
config = {"configurable": {"thread_id": "1"}}
async def process_events():
async for event in graph.astream(
{
"messages": [
HumanMessage(
content="Write an essay on stoic philosophy in times of distress"
)
],
},
config,
):
print(event)
print("---")
if __name__ == "__main__":
import asyncio
asyncio.run(process_events())
state = graph.get_state(config)
ChatPromptTemplate.from_messages(state.values["messages"]).pretty_print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment