Skip to content

Instantly share code, notes, and snippets.

@hartmamt
Created February 1, 2025 16:14
Show Gist options
  • Save hartmamt/544e5428c519e71700e90254fe522d80 to your computer and use it in GitHub Desktop.
Save hartmamt/544e5428c519e71700e90254fe522d80 to your computer and use it in GitHub Desktop.
Open Source Vertex AI processor for Red Panda Connect
package processor
import (
"context"
"fmt"
"io"
"log"
"mime"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"cloud.google.com/go/storage"
"cloud.google.com/go/vertexai/genai"
"github.com/redpanda-data/benthos/v4/public/service"
"google.golang.org/api/option"
)
// ----------------------------------------------------------------------------
// Configuration Specification
// ----------------------------------------------------------------------------
var vertexAiProcessorConfigSpec = service.NewConfigSpec().
Summary("Creates a processor that sends requests to the Vertex AI API for text generation, image-to-text (Gemini), or audio transcription.").
Field(service.NewStringField("source_field").
Description("The field in the message to use as input (for example, a file path or raw image bytes). If absent, the entire message is used.").
Example("input.content")).
Field(service.NewStringField("target_field").
Description("The field where the Vertex AI response will be stored.").
Example("output.response")).
Field(service.NewStringField("project").
Description("GCP Project ID.").
Example("my-gcp-project")).
// Credentials are optional; if left empty, Application Default Credentials are used.
Field(service.NewStringField("credentials_json").
Description("GCP Credentials JSON. Leave empty to use Application Default Credentials.").
Optional().Secret()).
Field(service.NewStringField("model").
Description("The Vertex AI model to use. For Gemini image-to-text, for example use 'gemini-1.5-flash-001'.").
Example("gemini-1.5-flash-001")).
Field(service.NewStringField("location").
Description("GCP region where the Vertex AI model is hosted.").
Default("us-central1")).
Field(service.NewStringField("prompt").
Description("A default prompt to use if the source field is missing (for text generation).").
Example("Default text prompt.")).
Field(service.NewStringField("system_prompt").
Description("A system prompt to guide the Vertex AI model. For image-to-text, you might use something like 'describe this image.'").
Example("describe this image.")).
Field(service.NewFloatField("temperature").
Description("Controls randomness in responses.").
Default(0.7)).
Field(service.NewStringField("input_type").
Description("The type of input to process. Supported values are 'text', 'image', or 'audio'.").
Default("text").
Example("image")).
// New field for specifying the GCS bucket to upload files.
Field(service.NewStringField("gcs_bucket").
Description("The GCS bucket to use for file uploads (required for image or audio input).").
Optional().
Example("my-gcs-bucket"))
// ----------------------------------------------------------------------------
// Vertex AI Client Interface and Implementation
// ----------------------------------------------------------------------------
// VertexAIProcessor defines an interface for sending requests to Vertex AI.
type VertexAIProcessor interface {
// Ask is used for text generation.
Ask(prompt string, systemPrompt string, temperature float64) (string, error)
// ProcessImage uses Gemini to generate a text description from an image.
// The imageInput parameter may be either a string (a file path) or []byte containing raw image data.
ProcessImage(imageInput interface{}, systemPrompt string, temperature float64) (string, error)
// TranscribeAudio transcribes audio input.
TranscribeAudio(audioPath string, systemPrompt string, temperature float64) (string, error)
// Close releases any allocated resources.
Close() error
}
// vertexAiProcessorClient is a concrete implementation of VertexAIProcessor.
type vertexAiProcessorClient struct {
client *genai.Client
model string
// bucket is the GCS bucket used for file uploads.
bucket string
}
// NewVertexAIProcessor creates a new Vertex AI client instance using the provided parameters.
// If credentials is an empty string, Application Default Credentials will be used.
func NewVertexAIProcessor(project, credentials, model, location, bucket string) (VertexAIProcessor, error) {
ctx := context.Background()
opts := []option.ClientOption{}
if credentials != "" {
opts = append(opts, option.WithCredentialsJSON([]byte(credentials)))
}
// Otherwise, ADC is used.
client, err := genai.NewClient(ctx, project, location, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Vertex AI client: %w", err)
}
return &vertexAiProcessorClient{
client: client,
model: model,
bucket: bucket,
}, nil
}
// Ask sends the text prompt and system instructions to Vertex AI and returns the generated response.
func (v *vertexAiProcessorClient) Ask(prompt string, systemPrompt string, temperature float64) (string, error) {
ctx := context.Background()
modelInst := v.client.GenerativeModel(v.model)
modelInst.SetTemperature(float32(temperature))
// For text generation, send a system prompt and a user prompt.
res, err := modelInst.GenerateContent(ctx, genai.Text(systemPrompt), genai.Text(prompt))
if err != nil {
log.Printf("Text generation error: %v", err)
return "", fmt.Errorf("failed to generate content: %w", err)
}
if len(res.Candidates) == 0 || len(res.Candidates[0].Content.Parts) == 0 {
return "", fmt.Errorf("no valid text response from Vertex AI")
}
return fmt.Sprintf("%v", res.Candidates[0].Content.Parts[0]), nil
}
// ProcessImage implements image-to-text using Gemini.
// If the input is a string, it is interpreted as a file path, which is uploaded
// to GCS using the Cloud Storage SDK; the returned file URI and MIME type are then used for content generation.
// If the input is []byte, the code falls back to using content detection.
func (v *vertexAiProcessorClient) ProcessImage(imageInput interface{}, systemPrompt string, temperature float64) (string, error) {
ctx := context.Background()
modelInst := v.client.GenerativeModel(v.model)
modelInst.SetTemperature(float32(temperature))
// Configure safety settings.
modelInst.SafetySettings = []*genai.SafetySetting{
{
Category: genai.HarmCategoryHarassment,
Threshold: genai.HarmBlockLowAndAbove,
},
{
Category: genai.HarmCategoryDangerousContent,
Threshold: genai.HarmBlockLowAndAbove,
},
}
var imagePart genai.Part
switch input := imageInput.(type) {
case string:
// Interpret the string as a file path; upload it to GCS.
uri, err := uploadFileToGCS(ctx, v.bucket, input)
if err != nil {
return "", fmt.Errorf("failed to upload file to GCS: %w", err)
}
// Determine MIME type from the file extension.
mt := mime.TypeByExtension(filepath.Ext(input))
if mt == "" {
mt = "image/jpeg"
}
imagePart = genai.FileData{
MIMEType: mt,
FileURI: uri,
}
case []byte:
// Fallback: if raw bytes are provided, use content detection.
fileType := "jpeg"
detected := http.DetectContentType(input)
if strings.HasPrefix(detected, "image/") {
fileType = strings.TrimPrefix(detected, "image/")
}
imagePart = genai.ImageData(fileType, input)
default:
return "", fmt.Errorf("unsupported image input type %T", input)
}
parts := []genai.Part{
imagePart,
genai.Text(systemPrompt),
}
res, err := modelInst.GenerateContent(ctx, parts...)
if err != nil {
log.Printf("Image-to-text generation error: %v", err)
return "", fmt.Errorf("unable to generate content from image: %w", err)
}
if len(res.Candidates) == 0 || len(res.Candidates[0].Content.Parts) == 0 {
return "", fmt.Errorf("no valid image response from Vertex AI")
}
return fmt.Sprintf("%v", res.Candidates[0].Content.Parts[0]), nil
}
// TranscribeAudio uploads the audio file from a path to GCS and then transcribes it
// using the model's GenerateContent method.
func (v *vertexAiProcessorClient) TranscribeAudio(audioPath string, systemPrompt string, temperature float64) (string, error) {
ctx := context.Background()
// Upload the audio file to GCS.
uri, err := uploadFileToGCS(ctx, v.bucket, audioPath)
if err != nil {
return "", fmt.Errorf("failed to upload audio file to GCS: %w", err)
}
// Determine MIME type from the file extension.
mt := mime.TypeByExtension(filepath.Ext(audioPath))
if mt == "" {
mt = "audio/mpeg"
}
// Get the generative model and set the temperature.
modelInst := v.client.GenerativeModel(v.model)
modelInst.SetTemperature(float32(temperature)) // You may adjust this default if needed.
// Build the file data part using the uploaded file's URI.
audioPart := genai.FileData{
MIMEType: mt,
FileURI: uri,
}
// Call GenerateContent with the audio file and the transcription prompt.
res, err := modelInst.GenerateContent(ctx, audioPart, genai.Text(systemPrompt))
if err != nil {
log.Printf("Audio transcription error: %v", err)
return "", fmt.Errorf("failed to transcribe audio: %w", err)
}
// Validate the response.
if len(res.Candidates) == 0 || len(res.Candidates[0].Content.Parts) == 0 {
return "", fmt.Errorf("empty transcription returned from Vertex AI")
}
// Return the transcription.
transcript := fmt.Sprintf("%v", res.Candidates[0].Content.Parts[0])
log.Printf("Transcription: %s", transcript)
return transcript, nil
}
// Close releases resources held by the Vertex AI client.
func (v *vertexAiProcessorClient) Close() error {
return v.client.Close()
}
// ----------------------------------------------------------------------------
// GCS Upload Helper
// ----------------------------------------------------------------------------
// uploadFileToGCS uploads a local file (at localPath) to the specified GCS bucket
// using the official Cloud Storage SDK. It returns the GCS URI of the uploaded file.
func uploadFileToGCS(ctx context.Context, bucketName, localPath string) (string, error) {
// Create a Cloud Storage client.
storageClient, err := storage.NewClient(ctx)
if err != nil {
return "", fmt.Errorf("failed to create storage client: %w", err)
}
defer storageClient.Close()
bucket := storageClient.Bucket(bucketName)
// Create a unique object name (using a timestamp and the base filename).
base := filepath.Base(localPath)
objectName := fmt.Sprintf("%d-%s", time.Now().UnixNano(), base)
wc := bucket.Object(objectName).NewWriter(ctx)
// Optionally, set the content type:
wc.ContentType = mime.TypeByExtension(filepath.Ext(localPath))
// Open the local file.
f, err := os.Open(localPath)
if err != nil {
return "", fmt.Errorf("failed to open local file: %w", err)
}
defer f.Close()
// Copy file contents to the GCS writer.
if _, err = io.Copy(wc, f); err != nil {
return "", fmt.Errorf("failed to copy file to GCS: %w", err)
}
if err := wc.Close(); err != nil {
return "", fmt.Errorf("failed to close GCS writer: %w", err)
}
// Return the GCS URI.
return fmt.Sprintf("gs://%s/%s", bucketName, objectName), nil
}
// ----------------------------------------------------------------------------
// Processor Implementation
// ----------------------------------------------------------------------------
type vertexAiProcessor struct {
sourceField string
targetField string
client VertexAIProcessor
prompt string
systemPrompt string
temperature float64
inputType string
}
func init() {
err := service.RegisterProcessor(
"vertex_ai_chat",
vertexAiProcessorConfigSpec,
func(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) {
return newVertexAiProcessor(conf, mgr)
},
)
if err != nil {
panic(err)
}
}
func newVertexAiProcessor(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) {
sourceField, err := conf.FieldString("source_field")
if err != nil {
return nil, err
}
targetField, err := conf.FieldString("target_field")
if err != nil {
return nil, err
}
project, err := conf.FieldString("project")
if err != nil {
return nil, err
}
credentials, err := conf.FieldString("credentials_json")
if err != nil {
return nil, err
}
model, err := conf.FieldString("model")
if err != nil {
return nil, err
}
location, err := conf.FieldString("location")
if err != nil {
return nil, err
}
prompt, err := conf.FieldString("prompt")
if err != nil {
return nil, err
}
systemPrompt, err := conf.FieldString("system_prompt")
if err != nil {
return nil, err
}
temperature, err := conf.FieldFloat("temperature")
if err != nil {
return nil, err
}
inputType, err := conf.FieldString("input_type")
if err != nil {
inputType = "text" // default if not provided
}
// For image or audio input, we require a GCS bucket.
gcsBucket, _ := conf.FieldString("gcs_bucket")
if (inputType == "image" || inputType == "audio") && gcsBucket == "" {
return nil, fmt.Errorf("gcs_bucket is required for input_type %s", inputType)
}
client, err := NewVertexAIProcessor(project, credentials, model, location, gcsBucket)
if err != nil {
return nil, err
}
return &vertexAiProcessor{
sourceField: sourceField,
targetField: targetField,
client: client,
prompt: prompt,
systemPrompt: systemPrompt,
temperature: temperature,
inputType: strings.ToLower(inputType),
}, nil
}
// Process extracts the input from the message, calls Vertex AI, and writes the response.
func (v *vertexAiProcessor) Process(ctx context.Context, m *service.Message) (service.MessageBatch, error) {
content, err := m.AsStructuredMut()
if err != nil {
// Not structured; fallback to raw bytes.
content = nil
}
var response string
switch v.inputType {
case "text":
var finalPrompt string
if content != nil {
if val, ok := getByKey(content, v.sourceField); ok {
finalPrompt = fmt.Sprintf("%s\nUser: %v", v.prompt, val)
} else {
finalPrompt = v.prompt
}
} else {
b, err := m.AsBytes()
if err != nil {
return nil, err
}
finalPrompt = string(b)
}
response, err = v.client.Ask(finalPrompt, v.systemPrompt, v.temperature)
if err != nil {
return []*service.Message{m}, nil
}
case "image":
var imageInput interface{}
if content != nil {
if val, ok := getByKey(content, v.sourceField); ok {
switch v := val.(type) {
case string:
imageInput = v
case []byte:
imageInput = v
default:
imageInput = fmt.Sprintf("%v", v)
}
} else {
b, err := m.AsBytes()
if err != nil {
return nil, err
}
imageInput = b
}
} else {
b, err := m.AsBytes()
if err != nil {
return nil, err
}
imageInput = b
}
response, err = v.client.ProcessImage(imageInput, v.systemPrompt, v.temperature)
if err != nil {
log.Printf("Image processing error: %v", err)
return []*service.Message{m}, nil
}
log.Printf("Image response: %s", response)
case "audio":
var audioPath string
if content != nil {
if val, ok := getByKey(content, v.sourceField); ok {
if s, ok := val.(string); ok {
audioPath = s
} else {
audioPath = fmt.Sprintf("%v", val)
}
} else {
b, err := m.AsBytes()
if err != nil {
return nil, err
}
audioPath = string(b)
}
} else {
b, err := m.AsBytes()
if err != nil {
return nil, err
}
audioPath = string(b)
}
response, err = v.client.TranscribeAudio(audioPath, v.systemPrompt, v.temperature)
if err != nil {
return []*service.Message{m}, nil
}
default:
return []*service.Message{m}, fmt.Errorf("unsupported input type: %s", v.inputType)
}
payload := make(map[string]interface{})
if contentMap, ok := content.(map[string]interface{}); ok {
for k, val := range contentMap {
payload[k] = val
}
}
payload[v.targetField] = response
m.SetStructured(payload)
return []*service.Message{m}, nil
}
func (v *vertexAiProcessor) Close(ctx context.Context) error {
return v.client.Close()
}
// ----------------------------------------------------------------------------
// Helper Function
// ----------------------------------------------------------------------------
func getByKey(m interface{}, key string) (any, bool) {
parts := strings.Split(key, ".")
current := m
for _, part := range parts {
switch cur := current.(type) {
case map[string]interface{}:
val, exists := cur[part]
if !exists {
return nil, false
}
current = val
case []interface{}:
idx, err := strconv.Atoi(part)
if err != nil || idx < 0 || idx >= len(cur) {
return nil, false
}
current = cur[idx]
default:
return nil, false
}
}
return current, true
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment