Created
June 3, 2024 21:42
-
-
Save ibuildthecloud/ec96353c27502a99371387769594277d to your computer and use it in GitHub Desktop.
Simple function calling proxy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bufio" | |
"bytes" | |
"encoding/json" | |
"fmt" | |
"io" | |
"log" | |
"net/http" | |
"slices" | |
"strings" | |
openai "github.com/gptscript-ai/chat-completion-client" | |
) | |
func writeTool(buf *strings.Builder, tool openai.Tool) { | |
buf.WriteString("// ") | |
buf.WriteString(tool.Function.Description) | |
buf.WriteString("\n") | |
buf.WriteString("type ") | |
buf.WriteString(tool.Function.Name) | |
buf.WriteString(" = (_: {\n") | |
schema := tool.Function.Parameters.(map[string]any) | |
props, _ := schema["properties"].(map[string]any) | |
for name, prop := range props { | |
buf.WriteString("// ") | |
buf.WriteString(prop.(map[string]any)["description"].(string)) | |
buf.WriteString("\n") | |
buf.WriteString(name) | |
buf.WriteString("?: string,\n") | |
} | |
buf.WriteString("}) => any;\n\n") | |
} | |
func translate(req openai.ChatCompletionRequest) openai.ChatCompletionRequest { | |
if len(req.Tools) == 0 { | |
return req | |
} | |
buf := strings.Builder{} | |
buf.WriteString("You can have the user call the following functions.\n") | |
buf.WriteString("```typescript\n") | |
buf.WriteString("## functions\n") | |
buf.WriteString("\n") | |
buf.WriteString("namespace functions {\n\n") | |
for _, tool := range req.Tools { | |
writeTool(&buf, tool) | |
} | |
buf.WriteString("} // namespace functions \n") | |
buf.WriteString("```\n\n") | |
buf.WriteString("To tell the user to call a function use the following format\n") | |
buf.WriteString("<CALL>functionToCall({\"arg1\":\"value1\",\"arg2\":\"value2\"})</CALL>\n\n") | |
req.Stream = true | |
req.Stop = []string{"</CALL>"} | |
var systemSet bool | |
for i, msg := range req.Messages { | |
if msg.Role == openai.ChatMessageRoleSystem && !systemSet { | |
msg.Content = buf.String() + msg.Content | |
req.Messages[i] = msg | |
systemSet = true | |
} else if msg.Role == openai.ChatMessageRoleAssistant && len(msg.ToolCalls) > 0 { | |
msg.Content = "<CALL>" + msg.ToolCalls[0].Function.Name + "(" + msg.ToolCalls[0].Function.Arguments + ")</CALL>" | |
msg.ToolCalls = nil | |
req.Messages[i] = msg | |
} else if msg.Role == openai.ChatMessageRoleTool { | |
msg.Role = openai.ChatMessageRoleUser | |
msg.Content = "<CALL_RESULT>" + msg.Content + "</CALL_RESULT>" | |
msg.FunctionCall = nil | |
req.Messages[i] = msg | |
} | |
} | |
if !systemSet { | |
req.Messages = slices.Insert(req.Messages, 0, openai.ChatCompletionMessage{ | |
Role: openai.ChatMessageRoleSystem, | |
Content: buf.String(), | |
}) | |
} | |
req.Tools = nil | |
return req | |
} | |
func toContent(resp *http.Response) io.Reader { | |
r, w := io.Pipe() | |
go func() { | |
defer w.Close() | |
reader := bufio.NewScanner(resp.Body) | |
for reader.Scan() { | |
var delta openai.ChatCompletionStreamResponse | |
fmt.Println("LINE: ", reader.Text()) | |
line := strings.TrimPrefix(reader.Text(), "data: ") | |
if line == "" { | |
continue | |
} | |
if line == "[DONE]" { | |
break | |
} | |
if err := json.Unmarshal([]byte(line), &delta); err == nil { | |
if len(delta.Choices) > 0 { | |
fmt.Println("CONTENT: ", delta.Choices[0].Delta.Content) | |
_, _ = w.Write([]byte(delta.Choices[0].Delta.Content)) | |
} | |
} else { | |
fmt.Print("ERR:", err.Error()) | |
} | |
} | |
}() | |
return r | |
} | |
type scanner struct { | |
tokens [][]byte | |
} | |
func newSplitter() bufio.SplitFunc { | |
s := scanner{ | |
tokens: [][]byte{ | |
[]byte("<CALL>"), | |
[]byte("("), | |
[]byte(")!"), | |
}, | |
} | |
return s.split | |
} | |
func (s *scanner) split(data []byte, atEOF bool) (advance int, token []byte, err error) { | |
if len(data) == 0 && atEOF { | |
return 0, nil, io.EOF | |
} | |
i := bytes.Index(data, s.tokens[0][:1]) | |
if i == -1 { | |
return len(data), data, nil | |
} else if i > 0 { | |
return len(data[:i]), data[:i], nil | |
} | |
// Must start with < | |
if !bytes.HasPrefix(data, s.tokens[0]) { | |
if atEOF { | |
if string(data) == ")" { | |
return len(data), nil, io.EOF | |
} | |
return len(data), data, nil | |
} else if len(data) >= len(s.tokens[0]) { | |
return len(s.tokens[0]), data[:len(s.tokens[0])], nil | |
} | |
return 0, nil, nil | |
} | |
defer func() { | |
s.tokens = s.tokens[1:] | |
}() | |
return len(s.tokens[0]), data[:len(s.tokens[0])], nil | |
} | |
func main() { | |
var target = "http://station76.local:11434/v1/chat/completions" | |
//var target = "http://station76.local:8081/v1/chat/completions" | |
log.Print("Listening on :8081") | |
http.ListenAndServe(":8081", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
var req openai.ChatCompletionRequest | |
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { | |
w.WriteHeader(http.StatusBadRequest) | |
_, _ = w.Write([]byte(err.Error())) | |
return | |
} | |
req = translate(req) | |
data, err := json.Marshal(req) | |
if err != nil { | |
w.WriteHeader(http.StatusBadRequest) | |
_, _ = w.Write([]byte(err.Error())) | |
return | |
} | |
fmt.Println("REQUEST:", string(data)) | |
newReq, err := http.NewRequest(http.MethodPost, target, bytes.NewReader(data)) | |
if err != nil { | |
w.WriteHeader(http.StatusBadRequest) | |
_, _ = w.Write([]byte(err.Error())) | |
return | |
} | |
newReq.Header.Set("Content-Type", "text/event-stream") | |
httpResp, err := http.DefaultClient.Do(newReq) | |
if err != nil { | |
w.WriteHeader(http.StatusBadRequest) | |
_, _ = w.Write([]byte(err.Error())) | |
return | |
} | |
defer httpResp.Body.Close() | |
reader := bufio.NewScanner(toContent(httpResp)) | |
reader.Split(newSplitter()) | |
for reader.Scan() { | |
token := reader.Text() | |
if token == "<CALL>" { | |
var funcName string | |
for reader.Scan() { | |
token = reader.Text() | |
if token == "(" { | |
for reader.Scan() { | |
token = reader.Text() | |
_, _ = w.Write([]byte("data: ")) | |
_ = json.NewEncoder(w).Encode(openai.ChatCompletionStreamResponse{ | |
Choices: []openai.ChatCompletionStreamChoice{ | |
{ | |
Delta: openai.ChatCompletionStreamChoiceDelta{ | |
Role: openai.ChatMessageRoleAssistant, | |
ToolCalls: []openai.ToolCall{ | |
{ | |
Index: new(int), | |
Type: openai.ToolTypeFunction, | |
Function: openai.FunctionCall{ | |
Name: strings.TrimPrefix(funcName, "functions."), | |
Arguments: token, | |
}, | |
}, | |
}, | |
}, | |
}, | |
}, | |
}) | |
_, _ = w.Write([]byte("\n\n")) | |
} | |
} else { | |
fmt.Println("FUNC: ", token) | |
funcName += token | |
} | |
} | |
} else { | |
_, _ = w.Write([]byte("data: ")) | |
_ = json.NewEncoder(w).Encode(openai.ChatCompletionStreamResponse{ | |
Choices: []openai.ChatCompletionStreamChoice{ | |
{ | |
Delta: openai.ChatCompletionStreamChoiceDelta{ | |
Role: openai.ChatMessageRoleAssistant, | |
Content: token, | |
}, | |
}, | |
}, | |
}) | |
_, _ = w.Write([]byte("\n\n")) | |
} | |
} | |
})) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment