Last active
September 20, 2024 17:56
-
-
Save xeniode/a11706bad5ecc314cd5e4ee7c6b5e086 to your computer and use it in GitHub Desktop.
Gist Created in Script Kit
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain_openai import AzureChatOpenAI | |
import getpass | |
import os | |
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from typing import Annotated, List, Sequence | |
from langgraph.graph import END, StateGraph, START | |
from langgraph.graph.message import add_messages | |
from langgraph.checkpoint.memory import MemorySaver | |
from typing_extensions import TypedDict | |
def _set_if_undefined(var: str) -> None: | |
if os.environ.get(var): | |
return | |
os.environ[var] = getpass.getpass(var) | |
_set_if_undefined("TAVILY_API_KEY") | |
azure_endpoint_env = os.environ.get("AZURE_OPENAI_API_ENDPOINT") | |
azure_openai_api_key_env = os.environ.get("AZURE_OPENAI_API_KEY") | |
azure_deployment_env = os.environ.get("AZURE_OPENAI_API_DEPLOYMENT_NAME") | |
azure_api_version_env = os.environ.get("AZURE_OPENAI_API_VERSION") | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"You are an essay assistant tasked with writing excellent 5-paragraph essays." | |
" Generate the best essay possible for the user's request." | |
" If the user provides critique, respond with a revised version of your previous attempts.", | |
), | |
MessagesPlaceholder(variable_name="messages"), | |
] | |
) | |
llm = AzureChatOpenAI( | |
azure_endpoint=azure_endpoint_env, | |
openai_api_key=azure_openai_api_key_env, | |
azure_deployment=azure_deployment_env, | |
api_version=azure_api_version_env, | |
) | |
generate = prompt | llm | |
reflection_prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"You are a teacher grading an essay submission. Generate critique and recommendations for the user's submission." | |
" Provide detailed recommendations, including requests for length, depth, style, etc.", | |
), | |
MessagesPlaceholder(variable_name="messages"), | |
] | |
) | |
reflect = reflection_prompt | llm | |
class State(TypedDict): | |
messages: Annotated[list, add_messages] | |
async def generation_node(state: State) -> State: | |
return {"messages": [await generate.ainvoke(state["messages"])]} | |
async def reflection_node(state: State) -> State: | |
# Other messages we need to adjust | |
cls_map = {"ai": HumanMessage, "human": AIMessage} | |
# First message is the original user request. We hold it the same for all nodes | |
translated = [state["messages"][0]] + [ | |
cls_map[msg.type](content=msg.content) for msg in state["messages"][1:] | |
] | |
res = await reflect.ainvoke(translated) | |
# We treat the output of this as human feedback for the generator | |
return {"messages": [HumanMessage(content=res.content)]} | |
builder = StateGraph(State) | |
builder.add_node("generate", generation_node) | |
builder.add_node("reflect", reflection_node) | |
builder.add_edge(START, "generate") | |
def should_continue(state: State): | |
if len(state["messages"]) > 6: | |
# End after 3 iterations | |
return END | |
return "reflect" | |
builder.add_conditional_edges("generate", should_continue) | |
builder.add_edge("reflect", "generate") | |
memory = MemorySaver() | |
graph = builder.compile(checkpointer=memory) | |
config = {"configurable": {"thread_id": "1"}} | |
async def process_events(): | |
async for event in graph.astream( | |
{ | |
"messages": [ | |
HumanMessage( | |
content="Write an essay on stoic philosophy in times of distress" | |
) | |
], | |
}, | |
config, | |
): | |
print(event) | |
print("---") | |
if __name__ == "__main__": | |
import asyncio | |
asyncio.run(process_events()) | |
state = graph.get_state(config) | |
ChatPromptTemplate.from_messages(state.values["messages"]).pretty_print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment