Last active
September 27, 2024 13:21
-
-
Save couragecowardlydog/67c2ea6e2576df70f7a6989c4abe0dda to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io.micrometer.observation.ObservationRegistry; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import org.springframework.ai.document.Document; | |
import org.springframework.ai.document.MetadataMode; | |
import org.springframework.ai.embedding.*; | |
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; | |
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; | |
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; | |
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; | |
import org.springframework.ai.retry.RetryUtils; | |
import org.springframework.retry.support.RetryTemplate; | |
import org.springframework.util.Assert; | |
import java.util.List; | |
import static com.gitrebase.rag.ai.gemini.GeminiApiConstants.*; | |
public class GeminiEmbeddingModel extends AbstractEmbeddingModel { | |
// Logger | |
private static final Logger logger = LoggerFactory.getLogger(GeminiEmbeddingModel.class); | |
// Observability | |
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); | |
// Retry Template | |
private final RetryTemplate retryTemplate; | |
// Observability | |
private final ObservationRegistry observationRegistry; | |
// API - Class that wraps API | |
private final GeminiApi geminiApi; | |
private EmbeddingModelObservationConvention observationConvention; | |
public GeminiEmbeddingModel(GeminiApiConfig config) { | |
this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; | |
this.observationRegistry = ObservationRegistry.NOOP; | |
this.geminiApi = new GeminiApi(config); | |
} | |
@Override | |
public EmbeddingResponse call(EmbeddingRequest request) { | |
EmbeddingOptions embeddingOptions = EmbeddingOptionsBuilder.builder() | |
.withModel(TEXT_EMBEDDING_MODEL) | |
.withDimensions(TEXT_EMBEDDING_MODEL_DIMENSIONS) | |
.build(); | |
EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() | |
.embeddingRequest(request) | |
.provider(GEMINI_PROVIDER) | |
.requestOptions(embeddingOptions).build(); | |
var input = request.getInstructions(); | |
var rqx = new GeminiApi.GeminiEmbeddingRequest(input); | |
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION | |
.observation( | |
this.observationConvention, | |
DEFAULT_OBSERVATION_CONVENTION, | |
() -> observationContext, this.observationRegistry | |
).observe(() -> { | |
var apiEmbeddingResponse = this.retryTemplate.execute((ctx) -> this.geminiApi.embeddings(rqx).getBody()); | |
if (apiEmbeddingResponse == null) { | |
logger.warn("No embeddings returned for request: {}", request); | |
return new EmbeddingResponse(List.of()); | |
} else { | |
Embedding embedding = new Embedding(apiEmbeddingResponse.getEmbedding().values, 0); | |
EmbeddingResponse embeddingResponse = new EmbeddingResponse(List.of(embedding)); | |
observationContext.setResponse(embeddingResponse); | |
return embeddingResponse; | |
} | |
}); | |
} | |
@Override | |
public float[] embed(Document document) { | |
Assert.notNull(document, "Document must not be null"); | |
return this.embed(document.getFormattedContent(MetadataMode.EMBED)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment