Last active
August 21, 2023 09:37
-
-
Save xinsblog/e0926e2bbc3696a9145f83a88358dcf8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
import json | |
import sys | |
import openai | |
class SimChatGPT: | |
def __init__(self, api_key: str, messages: List = None): | |
openai.api_key = api_key | |
if messages: | |
self.messages = messages | |
else: | |
self.messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "接下来我会提供给你两句短文本,如果这两个文本的语义匹配,则回复'匹配',反之则回复'不匹配'," | |
"不要回复额外的文字或者标点符号"}, | |
] | |
def ask_chat_gpt(self) -> str: | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=self.messages | |
) | |
response_content = response['choices'][0]['message']['content'] | |
return response_content | |
def train(self, x1: str, x2: str, y: str): | |
self.messages.append({"role": "user", "content": f"'{x1}'和'{x2}'匹配还是不匹配"}) | |
response_content = self.ask_chat_gpt() | |
self.messages.append({"role": "assistant", "content": response_content}) | |
if response_content not in {'匹配', '不匹配'}: | |
feedback = "你回答的格式不对,你只能回复'匹配'或者'不匹配',不能回复额外的文字或者标点符号" | |
elif response_content == y: | |
feedback = "你回答的很对,棒棒哒" | |
else: | |
feedback = f"你回答的不对,你回答的是'{response_content}',正确答案是'{y}'" | |
self.messages.append({"role": "user", "content": feedback}) | |
print(f"\n当前训练样本x1={x1}, x2={x2}, y={y}") | |
print(f"self.messages=") | |
for msg in self.messages: | |
print(msg) | |
def predict(self, x1: str, x2: str) -> str: | |
self.messages.append({"role": "user", "content": f"'{x1}'和'{x2}'匹配还是不匹配"}) | |
response_content = self.ask_chat_gpt() | |
self.messages.pop() | |
return response_content | |
def save(self, model_path: str): | |
model_dict = { | |
'messages': self.messages | |
} | |
with open(model_path, "w", encoding='utf-8') as f: | |
json.dump(model_dict, f, ensure_ascii=False, indent=2) | |
@classmethod | |
def load(self, model_path: str, api_key: str) -> 'SimChatGPT': | |
with open(model_path, "r", encoding='utf-8') as f: | |
model_dict = json.load(f) | |
model = SimChatGPT(api_key=api_key, messages=model_dict['messages']) | |
return model | |
if __name__ == '__main__': | |
train_data = [ | |
("小张比小王更高吗", "小王比小张更矮吗", "匹配"), | |
("小张比小王更高吗", "小王比小张更高吗", "不匹配"), | |
("上海比北京更远吗", "北京比上海更远吗", "不匹配"), | |
("鱼和鸡蛋能一起吃吗", "鸡蛋和鱼能同时吃吗", "匹配"), | |
] | |
test_data = [ | |
("苹果8比苹果9更贵吗", "苹果9比苹果8更贵吗", "不匹配"), | |
("iphone8比iphone9更贵吗", "iphone9比8更便宜吗", "匹配"), | |
("上海和北京一样远吗", "北京和上海同样远吗", "匹配"), | |
("杭州比深圳更热吗", "深圳比杭州更热吗", "不匹配"), | |
] | |
if len(sys.argv) < 2: | |
raise RuntimeError("命令行参数缺少api_key") | |
api_key = sys.argv[1] | |
sim_chatgpt = SimChatGPT(api_key=api_key) | |
print(sim_chatgpt.ask_chat_gpt()) | |
for x1, x2, y in train_data: | |
sim_chatgpt.train(x1, x2, y) | |
sim_chatgpt.save('sim_chatgpt.json') | |
sim_chatgpt2 = SimChatGPT.load('sim_chatgpt.json', api_key=api_key) | |
for x1, x2, y in test_data: | |
print(x1, x2, y, sim_chatgpt2.predict(x1, x2)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment