Created
December 7, 2024 04:14
-
-
Save SaahilClaypool/c32b01026a6a15dbde891667350388c9 to your computer and use it in GitHub Desktop.
Go Function calling, but with reflection to auto create functions.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
> What is the difference in weather between San Fran and Chicago in C | |
calling open ai | |
Function call: getWeather({"Location": "San Francisco"}) -> Cold, 22 C | |
Function call: getWeather({"Location": "Chicago"}) -> Sunny, 25 C | |
calling open ai | |
Function call: add({"Left":25,"Right":-22}) -> 3 | |
calling open ai | |
The difference in weather between San Francisco and Chicago is 3°C, with Chicago being the warmer of the two. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"encoding/json" | |
"fmt" | |
"github.com/invopop/jsonschema" | |
"github.com/openai/openai-go" | |
) | |
func main() { | |
client := openai.NewClient() | |
ctx := context.Background() | |
system := "Use the add function for all adding and subtracting." | |
question := "What is the difference in weather between San Fran and Chicago in C" | |
print("> ") | |
println(question) | |
funcs := OAIFunctions{} | |
AddTool(&funcs, "getWeather", getWeather) | |
AddTool(&funcs, "add", add) | |
params := openai.ChatCompletionNewParams{ | |
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ | |
openai.SystemMessage(system), | |
openai.UserMessage(question), | |
}), | |
Tools: openai.F(funcs.OAISchema()), | |
Seed: openai.Int(0), | |
Model: openai.F(openai.ChatModelGPT4oMini), | |
} | |
for { | |
fmt.Printf("calling open ai\n\n") | |
completion, err := client.Chat.Completions.New(ctx, params) | |
if err != nil { | |
panic(err) | |
} | |
if len(completion.Choices[0].Message.ToolCalls) > 0 { | |
toolCalls := completion.Choices[0].Message.ToolCalls | |
params.Messages.Value = append(params.Messages.Value, completion.Choices[0].Message) | |
for _, toolCall := range toolCalls { | |
resp := funcs.Invoke(toolCall.Function.Name, toolCall.Function.Arguments) | |
fmt.Printf("Function call: %s(%s) -> %s\n", toolCall.Function.Name, toolCall.Function.Arguments, resp) | |
params.Messages.Value = append(params.Messages.Value, openai.ToolMessage(toolCall.ID, resp)) | |
} | |
} else { | |
println(completion.Choices[0].Message.Content) | |
break | |
} | |
} | |
} | |
// Mock function to simulate weather data retrieval | |
func getWeather(location struct{ Location string }) string { | |
if location.Location[0] == 'S' { | |
return "Cold, 22 C" | |
} | |
// In a real implementation, this function would call a weather API | |
return "Sunny, 25 C" | |
} | |
func add(input struct { | |
Left int | |
Right int | |
}) string { | |
// In a real implementation, this function would call a weather API | |
return fmt.Sprintf("%d", input.Left+input.Right) | |
} | |
type OAIFunc[T any] func(input T) string | |
type OAIFuncT struct { | |
Name string | |
InputPlaceholder any | |
Action func(inputJson string) string | |
} | |
type OAIFunctions struct { | |
funcs []OAIFuncT | |
} | |
func AddTool[T any](o *OAIFunctions, name string, tool OAIFunc[T]) { | |
var val T | |
o.funcs = append(o.funcs, OAIFuncT{ | |
InputPlaceholder: val, | |
Name: name, | |
Action: func(inputJson string) string { | |
var input T | |
err := json.Unmarshal([]byte(inputJson), &input) | |
if err != nil { | |
return fmt.Sprintf("Error unmarshaling input: %v", err) | |
} | |
return tool(input) | |
}, | |
}) | |
} | |
func (o *OAIFunctions) OAISchema() []openai.ChatCompletionToolParam { | |
schema := make([]openai.ChatCompletionToolParam, 0) | |
for _, f := range o.funcs { | |
fSchema := openai.ChatCompletionToolParam{ | |
Type: openai.F(openai.ChatCompletionToolTypeFunction), | |
Function: openai.F(openai.FunctionDefinitionParam{ | |
Name: openai.String(f.Name), | |
Description: openai.String(fmt.Sprintf("Calls %s", f.Name)), | |
Parameters: openai.F(GenerateSchema(f.InputPlaceholder)), | |
}), | |
} | |
schema = append(schema, fSchema) | |
} | |
return schema | |
} | |
func GenerateSchema(v any) openai.FunctionParameters { | |
reflector := jsonschema.Reflector{ | |
AllowAdditionalProperties: false, | |
DoNotReference: true, | |
} | |
s := reflector.Reflect(v) | |
props := make(map[string]interface{}) | |
pair := s.Properties.Oldest() | |
for pair != nil { | |
props[pair.Key] = map[string]interface{}{ | |
"type": pair.Value.Type, | |
} | |
pair = pair.Next() | |
} | |
params := openai.FunctionParameters{ | |
"type": "object", | |
"properties": props, | |
} | |
if len(s.Required) > 0 { | |
params["required"] = s.Required | |
} | |
return params | |
} | |
func (o *OAIFunctions) Invoke(name string, inputJson string) string { | |
for _, v := range o.funcs { | |
if v.Name == name { | |
return v.Action(inputJson) | |
} | |
} | |
panic(fmt.Sprintf("funtion %s not found", name)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment