Created
March 6, 2023 14:46
-
-
Save vivekhaldar/fef84a7239b9e5667b8a0cdb59d7e281 to your computer and use it in GitHub Desktop.
Simple Python script to invoke ChatGPT API.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# | |
# Takes a chat transcript (for ChatGPT) on stdin, calls the OpenAI | |
# ChatGPT API, and prints the response on stdout. | |
# | |
# Your OpenAI API key must be set in the environment variable | |
# OPENAI_API_KEY. | |
# | |
# Logs are written to ~/chat.log. | |
import sys | |
import os | |
import re | |
import openai | |
import datetime | |
import logging | |
# All input/output to the API is logged here. | |
# Expand filename to full path. | |
LOG_FILE = os.path.expanduser('~/chat.log') | |
# Read stdin and save into a string. | |
def read_stdin(): | |
return sys.stdin.read() | |
def set_openai_key(): | |
# Read key from env var OPENAI_API_KEY. | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
def parse_input(input): | |
logging.debug(input) | |
# Split input into lines and iterate over them. | |
s = '' | |
cur_role = '' | |
messages = [] | |
for line in input.splitlines(): | |
# Check if line matches a regex like '%user%. | |
r = re.match(r"^%(.+)%", line) | |
if r: | |
messages.append({"role": cur_role, "content": s}) | |
s = '' | |
cur_role = r.group(1) | |
else: | |
s += line | |
messages.append({"role": cur_role, "content": s}) | |
return messages[1:] | |
def get_response(messages): | |
completion = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
) | |
logging.debug(completion) | |
reply = completion["choices"][0]["message"]["content"] | |
role = completion["choices"][0]["message"]["role"] | |
logging.debug('==== OUTPUT⇟\n') | |
logging.debug(reply) | |
logging.debug(role) | |
return reply | |
def main(): | |
# Send logging messages to a file. | |
logging.basicConfig(filename=LOG_FILE, level=logging.DEBUG) | |
# Get date and time in standard format. | |
logging.debug(str(datetime.datetime.now())) | |
messages = parse_input(read_stdin()) | |
logging.debug('==== INPUT⇟\n') | |
logging.debug(messages) | |
set_openai_key() | |
reply = get_response(messages) | |
print(reply) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment