Created
December 5, 2023 02:15
-
-
Save jwatte/0ae2fab6303e68fddefc9d29d5c706db to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
import json | |
import sys | |
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.output_parsers import PydanticOutputParser | |
from langchain.schema.messages import HumanMessage, SystemMessage | |
from langchain.chains import ConversationChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain.output_parsers import CommaSeparatedListOutputParser | |
from observe.senders.langchain import ObserveTracer | |
debug = False | |
tracer = ObserveTracer(log_sends=debug) | |
temps = [] | |
initial_prompt = """You are a logic puzzle author, generating puzzles with three steps. | |
Given a name for the puzzle, you will generate three objects in the puzzle. | |
Each step will introduce the description of one object with a few states. | |
Further, you will generate a desired solution to the puzzle, in the form of object=state relationships. | |
Then, you will generate the description of the goal of the puzzle. | |
Then, you will generate and one constraint for each of the objects, for each of the other objects, where each constraint is fulfilled by the solution.""" | |
name_and_objects_prompt = """In this step, you will generate the name of the puzzle, and the three objects. | |
The output will be in the form of one comma and space separated line. For example: | |
input: Fruit Puzzle | |
output: | |
apple, orange, banana | |
You will generate the output in the theme of {name} without further explanation.""" | |
name_and_objects_input = """input: {name}""" | |
temps.append(ChatPromptTemplate.from_messages([("system", initial_prompt), ("system", name_and_objects_prompt), ("user", name_and_objects_input)])) | |
object_states_prompt = """In this step, you will generate the states for the object, {object}. | |
This object will have some number of states, each output in a column in a comma and space separated line. | |
For example: | |
input: states of apple | |
output: | |
red, green, yellow | |
You will generate the output in the theme of {name} without further explanation.""" | |
object_states_input = """input: states of {object}""" | |
temps.append(ChatPromptTemplate.from_messages([("system", object_states_prompt), ("user", object_states_input)])) | |
description_prompt = """In this step, you will generate the description and goal of the puzzle. | |
The goal should be in terms of the objects and states of objects above. For example: | |
input: Goals of Fruit Puzzle | |
output: | |
The goal is to find the color of each of the fruits. | |
You will generate the output in the theme of {name} without further explanation.""" | |
description_input = """Goals of {name}""" | |
temps.append(ChatPromptTemplate.from_messages([("system", description_prompt), ("user", description_input)])) | |
solution_prompt = """In this step, you will generate the solution of the objects, with the | |
states of the objects all adhering to the rules that will follow. The output is a JSON object. For example: | |
input: Solution for apple, orange, banana | |
output: | |
{{"apple":"red","orange":"yellow","banana":"yellow"}} | |
You will generate the output in the theme of {name} without further explanation.""" | |
solution_input = """Solution for {objects}""" | |
temps.append(ChatPromptTemplate.from_messages([("system", solution_prompt), ("user", solution_input)])) | |
object_requirement_prompt = """In this step, you will generate the requirements for the object, {object}, relative to the other object, {other}. | |
The requirement will qualify the relationship between the two objects, in terms of the states of the object. | |
For example: | |
input: requirement for apple, relative to banana | |
output: | |
If the banana is green, the apple is not green. | |
You will generate the output in the theme of {name} without further explanation.""" | |
object_requirement_input = """requirement for {object}, relative to {other}""" | |
temps.append(ChatPromptTemplate.from_messages([("system", object_requirement_prompt), ("user", object_requirement_input)])) | |
model = ChatOpenAI(model="gpt-4-1106-preview", temperature=0.5, callbacks=[tracer]) | |
conversation = ConversationChain(llm=model, memory=ConversationBufferMemory(), callbacks=[tracer]) | |
if __name__ == '__main__': | |
if len(sys.argv) == 2: | |
a = sys.argv[1] | |
elif len(sys.argv) == 1: | |
a = 'Ghost Story' | |
else: | |
print("Usage: python writer.py 'puzzle name'") | |
sys.exit(1) | |
parms = {'name':a} | |
b = temps[0].format_messages(**parms) | |
c = conversation(b) | |
res = c['response'].replace('output:', '').strip() | |
d = [x.strip() for x in CommaSeparatedListOutputParser().parse(res)] | |
parms['objects'] = ', '.join(d) | |
if debug: | |
print(d) | |
for o in d: | |
e = temps[1].format_messages(object=o, **parms) | |
f = conversation(e) | |
res = f['response'].replace('output:', '').strip() | |
g = CommaSeparatedListOutputParser().parse(res) | |
parms[o] = g | |
if debug: | |
print(o, g, flush=True) | |
h = temps[2].format_messages(**parms) | |
i = conversation(h) | |
j = i['response'].replace('output:', '').strip() | |
parms['description'] = j | |
if debug: | |
print(j) | |
k = temps[3].format_messages(**parms) | |
l = conversation(k) | |
m = l['response'].replace('output:', '').strip() | |
parms['solution'] = json.loads(m) | |
if debug: | |
print(m) | |
for i in range(len(d)): | |
for j in range(len(d)): | |
if i != j: | |
parms['object'] = d[i] | |
parms['other'] = d[j] | |
n = temps[4].format_messages(**parms) | |
o = conversation(n) | |
p = o['response'].replace('output:', '').strip() | |
parms[f"{d[i]}_{d[j]}"] = p | |
if debug: | |
print(p) | |
nl = "\n- " | |
print(f"""{parms['name']} | |
{parms['description']} | |
There are {len(d)} objects: {', '.join(d)}. | |
Each object has the following states: | |
- {nl.join([f"{o}: {', '.join(parms[o])}" for o in d])} | |
The rules are: | |
- {nl.join(parms[k] for k in parms if '_' in k)} | |
Solution: | |
- {nl.join([k + ": " + v for k, v in parms['solution'].items()])} | |
""", flush=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example: