Created
February 1, 2025 16:14
-
-
Save hartmamt/544e5428c519e71700e90254fe522d80 to your computer and use it in GitHub Desktop.
Open Source Vertex AI processor for Red Panda Connect
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package processor | |
import ( | |
"context" | |
"fmt" | |
"io" | |
"log" | |
"mime" | |
"net/http" | |
"os" | |
"path/filepath" | |
"strconv" | |
"strings" | |
"time" | |
"cloud.google.com/go/storage" | |
"cloud.google.com/go/vertexai/genai" | |
"github.com/redpanda-data/benthos/v4/public/service" | |
"google.golang.org/api/option" | |
) | |
// ---------------------------------------------------------------------------- | |
// Configuration Specification | |
// ---------------------------------------------------------------------------- | |
var vertexAiProcessorConfigSpec = service.NewConfigSpec(). | |
Summary("Creates a processor that sends requests to the Vertex AI API for text generation, image-to-text (Gemini), or audio transcription."). | |
Field(service.NewStringField("source_field"). | |
Description("The field in the message to use as input (for example, a file path or raw image bytes). If absent, the entire message is used."). | |
Example("input.content")). | |
Field(service.NewStringField("target_field"). | |
Description("The field where the Vertex AI response will be stored."). | |
Example("output.response")). | |
Field(service.NewStringField("project"). | |
Description("GCP Project ID."). | |
Example("my-gcp-project")). | |
// Credentials are optional; if left empty, Application Default Credentials are used. | |
Field(service.NewStringField("credentials_json"). | |
Description("GCP Credentials JSON. Leave empty to use Application Default Credentials."). | |
Optional().Secret()). | |
Field(service.NewStringField("model"). | |
Description("The Vertex AI model to use. For Gemini image-to-text, for example use 'gemini-1.5-flash-001'."). | |
Example("gemini-1.5-flash-001")). | |
Field(service.NewStringField("location"). | |
Description("GCP region where the Vertex AI model is hosted."). | |
Default("us-central1")). | |
Field(service.NewStringField("prompt"). | |
Description("A default prompt to use if the source field is missing (for text generation)."). | |
Example("Default text prompt.")). | |
Field(service.NewStringField("system_prompt"). | |
Description("A system prompt to guide the Vertex AI model. For image-to-text, you might use something like 'describe this image.'"). | |
Example("describe this image.")). | |
Field(service.NewFloatField("temperature"). | |
Description("Controls randomness in responses."). | |
Default(0.7)). | |
Field(service.NewStringField("input_type"). | |
Description("The type of input to process. Supported values are 'text', 'image', or 'audio'."). | |
Default("text"). | |
Example("image")). | |
// New field for specifying the GCS bucket to upload files. | |
Field(service.NewStringField("gcs_bucket"). | |
Description("The GCS bucket to use for file uploads (required for image or audio input)."). | |
Optional(). | |
Example("my-gcs-bucket")) | |
// ---------------------------------------------------------------------------- | |
// Vertex AI Client Interface and Implementation | |
// ---------------------------------------------------------------------------- | |
// VertexAIProcessor defines an interface for sending requests to Vertex AI. | |
type VertexAIProcessor interface { | |
// Ask is used for text generation. | |
Ask(prompt string, systemPrompt string, temperature float64) (string, error) | |
// ProcessImage uses Gemini to generate a text description from an image. | |
// The imageInput parameter may be either a string (a file path) or []byte containing raw image data. | |
ProcessImage(imageInput interface{}, systemPrompt string, temperature float64) (string, error) | |
// TranscribeAudio transcribes audio input. | |
TranscribeAudio(audioPath string, systemPrompt string, temperature float64) (string, error) | |
// Close releases any allocated resources. | |
Close() error | |
} | |
// vertexAiProcessorClient is a concrete implementation of VertexAIProcessor. | |
type vertexAiProcessorClient struct { | |
client *genai.Client | |
model string | |
// bucket is the GCS bucket used for file uploads. | |
bucket string | |
} | |
// NewVertexAIProcessor creates a new Vertex AI client instance using the provided parameters. | |
// If credentials is an empty string, Application Default Credentials will be used. | |
func NewVertexAIProcessor(project, credentials, model, location, bucket string) (VertexAIProcessor, error) { | |
ctx := context.Background() | |
opts := []option.ClientOption{} | |
if credentials != "" { | |
opts = append(opts, option.WithCredentialsJSON([]byte(credentials))) | |
} | |
// Otherwise, ADC is used. | |
client, err := genai.NewClient(ctx, project, location, opts...) | |
if err != nil { | |
return nil, fmt.Errorf("failed to create Vertex AI client: %w", err) | |
} | |
return &vertexAiProcessorClient{ | |
client: client, | |
model: model, | |
bucket: bucket, | |
}, nil | |
} | |
// Ask sends the text prompt and system instructions to Vertex AI and returns the generated response. | |
func (v *vertexAiProcessorClient) Ask(prompt string, systemPrompt string, temperature float64) (string, error) { | |
ctx := context.Background() | |
modelInst := v.client.GenerativeModel(v.model) | |
modelInst.SetTemperature(float32(temperature)) | |
// For text generation, send a system prompt and a user prompt. | |
res, err := modelInst.GenerateContent(ctx, genai.Text(systemPrompt), genai.Text(prompt)) | |
if err != nil { | |
log.Printf("Text generation error: %v", err) | |
return "", fmt.Errorf("failed to generate content: %w", err) | |
} | |
if len(res.Candidates) == 0 || len(res.Candidates[0].Content.Parts) == 0 { | |
return "", fmt.Errorf("no valid text response from Vertex AI") | |
} | |
return fmt.Sprintf("%v", res.Candidates[0].Content.Parts[0]), nil | |
} | |
// ProcessImage implements image-to-text using Gemini. | |
// If the input is a string, it is interpreted as a file path, which is uploaded | |
// to GCS using the Cloud Storage SDK; the returned file URI and MIME type are then used for content generation. | |
// If the input is []byte, the code falls back to using content detection. | |
func (v *vertexAiProcessorClient) ProcessImage(imageInput interface{}, systemPrompt string, temperature float64) (string, error) { | |
ctx := context.Background() | |
modelInst := v.client.GenerativeModel(v.model) | |
modelInst.SetTemperature(float32(temperature)) | |
// Configure safety settings. | |
modelInst.SafetySettings = []*genai.SafetySetting{ | |
{ | |
Category: genai.HarmCategoryHarassment, | |
Threshold: genai.HarmBlockLowAndAbove, | |
}, | |
{ | |
Category: genai.HarmCategoryDangerousContent, | |
Threshold: genai.HarmBlockLowAndAbove, | |
}, | |
} | |
var imagePart genai.Part | |
switch input := imageInput.(type) { | |
case string: | |
// Interpret the string as a file path; upload it to GCS. | |
uri, err := uploadFileToGCS(ctx, v.bucket, input) | |
if err != nil { | |
return "", fmt.Errorf("failed to upload file to GCS: %w", err) | |
} | |
// Determine MIME type from the file extension. | |
mt := mime.TypeByExtension(filepath.Ext(input)) | |
if mt == "" { | |
mt = "image/jpeg" | |
} | |
imagePart = genai.FileData{ | |
MIMEType: mt, | |
FileURI: uri, | |
} | |
case []byte: | |
// Fallback: if raw bytes are provided, use content detection. | |
fileType := "jpeg" | |
detected := http.DetectContentType(input) | |
if strings.HasPrefix(detected, "image/") { | |
fileType = strings.TrimPrefix(detected, "image/") | |
} | |
imagePart = genai.ImageData(fileType, input) | |
default: | |
return "", fmt.Errorf("unsupported image input type %T", input) | |
} | |
parts := []genai.Part{ | |
imagePart, | |
genai.Text(systemPrompt), | |
} | |
res, err := modelInst.GenerateContent(ctx, parts...) | |
if err != nil { | |
log.Printf("Image-to-text generation error: %v", err) | |
return "", fmt.Errorf("unable to generate content from image: %w", err) | |
} | |
if len(res.Candidates) == 0 || len(res.Candidates[0].Content.Parts) == 0 { | |
return "", fmt.Errorf("no valid image response from Vertex AI") | |
} | |
return fmt.Sprintf("%v", res.Candidates[0].Content.Parts[0]), nil | |
} | |
// TranscribeAudio uploads the audio file from a path to GCS and then transcribes it | |
// using the model's GenerateContent method. | |
func (v *vertexAiProcessorClient) TranscribeAudio(audioPath string, systemPrompt string, temperature float64) (string, error) { | |
ctx := context.Background() | |
// Upload the audio file to GCS. | |
uri, err := uploadFileToGCS(ctx, v.bucket, audioPath) | |
if err != nil { | |
return "", fmt.Errorf("failed to upload audio file to GCS: %w", err) | |
} | |
// Determine MIME type from the file extension. | |
mt := mime.TypeByExtension(filepath.Ext(audioPath)) | |
if mt == "" { | |
mt = "audio/mpeg" | |
} | |
// Get the generative model and set the temperature. | |
modelInst := v.client.GenerativeModel(v.model) | |
modelInst.SetTemperature(float32(temperature)) // You may adjust this default if needed. | |
// Build the file data part using the uploaded file's URI. | |
audioPart := genai.FileData{ | |
MIMEType: mt, | |
FileURI: uri, | |
} | |
// Call GenerateContent with the audio file and the transcription prompt. | |
res, err := modelInst.GenerateContent(ctx, audioPart, genai.Text(systemPrompt)) | |
if err != nil { | |
log.Printf("Audio transcription error: %v", err) | |
return "", fmt.Errorf("failed to transcribe audio: %w", err) | |
} | |
// Validate the response. | |
if len(res.Candidates) == 0 || len(res.Candidates[0].Content.Parts) == 0 { | |
return "", fmt.Errorf("empty transcription returned from Vertex AI") | |
} | |
// Return the transcription. | |
transcript := fmt.Sprintf("%v", res.Candidates[0].Content.Parts[0]) | |
log.Printf("Transcription: %s", transcript) | |
return transcript, nil | |
} | |
// Close releases resources held by the Vertex AI client. | |
func (v *vertexAiProcessorClient) Close() error { | |
return v.client.Close() | |
} | |
// ---------------------------------------------------------------------------- | |
// GCS Upload Helper | |
// ---------------------------------------------------------------------------- | |
// uploadFileToGCS uploads a local file (at localPath) to the specified GCS bucket | |
// using the official Cloud Storage SDK. It returns the GCS URI of the uploaded file. | |
func uploadFileToGCS(ctx context.Context, bucketName, localPath string) (string, error) { | |
// Create a Cloud Storage client. | |
storageClient, err := storage.NewClient(ctx) | |
if err != nil { | |
return "", fmt.Errorf("failed to create storage client: %w", err) | |
} | |
defer storageClient.Close() | |
bucket := storageClient.Bucket(bucketName) | |
// Create a unique object name (using a timestamp and the base filename). | |
base := filepath.Base(localPath) | |
objectName := fmt.Sprintf("%d-%s", time.Now().UnixNano(), base) | |
wc := bucket.Object(objectName).NewWriter(ctx) | |
// Optionally, set the content type: | |
wc.ContentType = mime.TypeByExtension(filepath.Ext(localPath)) | |
// Open the local file. | |
f, err := os.Open(localPath) | |
if err != nil { | |
return "", fmt.Errorf("failed to open local file: %w", err) | |
} | |
defer f.Close() | |
// Copy file contents to the GCS writer. | |
if _, err = io.Copy(wc, f); err != nil { | |
return "", fmt.Errorf("failed to copy file to GCS: %w", err) | |
} | |
if err := wc.Close(); err != nil { | |
return "", fmt.Errorf("failed to close GCS writer: %w", err) | |
} | |
// Return the GCS URI. | |
return fmt.Sprintf("gs://%s/%s", bucketName, objectName), nil | |
} | |
// ---------------------------------------------------------------------------- | |
// Processor Implementation | |
// ---------------------------------------------------------------------------- | |
type vertexAiProcessor struct { | |
sourceField string | |
targetField string | |
client VertexAIProcessor | |
prompt string | |
systemPrompt string | |
temperature float64 | |
inputType string | |
} | |
func init() { | |
err := service.RegisterProcessor( | |
"vertex_ai_chat", | |
vertexAiProcessorConfigSpec, | |
func(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) { | |
return newVertexAiProcessor(conf, mgr) | |
}, | |
) | |
if err != nil { | |
panic(err) | |
} | |
} | |
func newVertexAiProcessor(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) { | |
sourceField, err := conf.FieldString("source_field") | |
if err != nil { | |
return nil, err | |
} | |
targetField, err := conf.FieldString("target_field") | |
if err != nil { | |
return nil, err | |
} | |
project, err := conf.FieldString("project") | |
if err != nil { | |
return nil, err | |
} | |
credentials, err := conf.FieldString("credentials_json") | |
if err != nil { | |
return nil, err | |
} | |
model, err := conf.FieldString("model") | |
if err != nil { | |
return nil, err | |
} | |
location, err := conf.FieldString("location") | |
if err != nil { | |
return nil, err | |
} | |
prompt, err := conf.FieldString("prompt") | |
if err != nil { | |
return nil, err | |
} | |
systemPrompt, err := conf.FieldString("system_prompt") | |
if err != nil { | |
return nil, err | |
} | |
temperature, err := conf.FieldFloat("temperature") | |
if err != nil { | |
return nil, err | |
} | |
inputType, err := conf.FieldString("input_type") | |
if err != nil { | |
inputType = "text" // default if not provided | |
} | |
// For image or audio input, we require a GCS bucket. | |
gcsBucket, _ := conf.FieldString("gcs_bucket") | |
if (inputType == "image" || inputType == "audio") && gcsBucket == "" { | |
return nil, fmt.Errorf("gcs_bucket is required for input_type %s", inputType) | |
} | |
client, err := NewVertexAIProcessor(project, credentials, model, location, gcsBucket) | |
if err != nil { | |
return nil, err | |
} | |
return &vertexAiProcessor{ | |
sourceField: sourceField, | |
targetField: targetField, | |
client: client, | |
prompt: prompt, | |
systemPrompt: systemPrompt, | |
temperature: temperature, | |
inputType: strings.ToLower(inputType), | |
}, nil | |
} | |
// Process extracts the input from the message, calls Vertex AI, and writes the response. | |
func (v *vertexAiProcessor) Process(ctx context.Context, m *service.Message) (service.MessageBatch, error) { | |
content, err := m.AsStructuredMut() | |
if err != nil { | |
// Not structured; fallback to raw bytes. | |
content = nil | |
} | |
var response string | |
switch v.inputType { | |
case "text": | |
var finalPrompt string | |
if content != nil { | |
if val, ok := getByKey(content, v.sourceField); ok { | |
finalPrompt = fmt.Sprintf("%s\nUser: %v", v.prompt, val) | |
} else { | |
finalPrompt = v.prompt | |
} | |
} else { | |
b, err := m.AsBytes() | |
if err != nil { | |
return nil, err | |
} | |
finalPrompt = string(b) | |
} | |
response, err = v.client.Ask(finalPrompt, v.systemPrompt, v.temperature) | |
if err != nil { | |
return []*service.Message{m}, nil | |
} | |
case "image": | |
var imageInput interface{} | |
if content != nil { | |
if val, ok := getByKey(content, v.sourceField); ok { | |
switch v := val.(type) { | |
case string: | |
imageInput = v | |
case []byte: | |
imageInput = v | |
default: | |
imageInput = fmt.Sprintf("%v", v) | |
} | |
} else { | |
b, err := m.AsBytes() | |
if err != nil { | |
return nil, err | |
} | |
imageInput = b | |
} | |
} else { | |
b, err := m.AsBytes() | |
if err != nil { | |
return nil, err | |
} | |
imageInput = b | |
} | |
response, err = v.client.ProcessImage(imageInput, v.systemPrompt, v.temperature) | |
if err != nil { | |
log.Printf("Image processing error: %v", err) | |
return []*service.Message{m}, nil | |
} | |
log.Printf("Image response: %s", response) | |
case "audio": | |
var audioPath string | |
if content != nil { | |
if val, ok := getByKey(content, v.sourceField); ok { | |
if s, ok := val.(string); ok { | |
audioPath = s | |
} else { | |
audioPath = fmt.Sprintf("%v", val) | |
} | |
} else { | |
b, err := m.AsBytes() | |
if err != nil { | |
return nil, err | |
} | |
audioPath = string(b) | |
} | |
} else { | |
b, err := m.AsBytes() | |
if err != nil { | |
return nil, err | |
} | |
audioPath = string(b) | |
} | |
response, err = v.client.TranscribeAudio(audioPath, v.systemPrompt, v.temperature) | |
if err != nil { | |
return []*service.Message{m}, nil | |
} | |
default: | |
return []*service.Message{m}, fmt.Errorf("unsupported input type: %s", v.inputType) | |
} | |
payload := make(map[string]interface{}) | |
if contentMap, ok := content.(map[string]interface{}); ok { | |
for k, val := range contentMap { | |
payload[k] = val | |
} | |
} | |
payload[v.targetField] = response | |
m.SetStructured(payload) | |
return []*service.Message{m}, nil | |
} | |
func (v *vertexAiProcessor) Close(ctx context.Context) error { | |
return v.client.Close() | |
} | |
// ---------------------------------------------------------------------------- | |
// Helper Function | |
// ---------------------------------------------------------------------------- | |
func getByKey(m interface{}, key string) (any, bool) { | |
parts := strings.Split(key, ".") | |
current := m | |
for _, part := range parts { | |
switch cur := current.(type) { | |
case map[string]interface{}: | |
val, exists := cur[part] | |
if !exists { | |
return nil, false | |
} | |
current = val | |
case []interface{}: | |
idx, err := strconv.Atoi(part) | |
if err != nil || idx < 0 || idx >= len(cur) { | |
return nil, false | |
} | |
current = cur[idx] | |
default: | |
return nil, false | |
} | |
} | |
return current, true | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment