Skip to content

Instantly share code, notes, and snippets.

@davidmezzetti
Created December 24, 2025 20:45
Show Gist options
  • Select an option

  • Save davidmezzetti/2265e308f9100336698103e9c2523c80 to your computer and use it in GitHub Desktop.

Select an option

Save davidmezzetti/2265e308f9100336698103e9c2523c80 to your computer and use it in GitHub Desktop.
#
# pip install ai-edge-torch txtai
#
# See https://github.com/google-ai-edge/ai-edge-torch
import torch
import ai_edge_torch
import numpy as np
from ai_edge_torch.generative.quantize import quant_attrs, quant_recipes
from torch import nn
from transformers import AutoTokenizer
from txtai.models import PoolingFactory
class Pooling(nn.Module):
def __init__(self, path, device, **kwargs):
super().__init__()
# Create pooling method based on configuration
self.model = PoolingFactory.create({"path": path, "device": device, "modelargs": kwargs})
# pylint: disable=W0221
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
# Build list of arguments dynamically since some models take token_type_ids
# and others don't
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
return self.model.forward(**inputs)
def export():
path = "neuml/biomedbert-hash-nano-embeddings"
model = Pooling(path, "cpu", trust_remote_code=True).eval()
# Sample inputs
maxlength = 64
inputs = (
torch.ones(1, maxlength, dtype=torch.long), # tokens
torch.ones(1, maxlength, dtype=torch.long), # attention_mask
torch.ones(1, maxlength, dtype=torch.long), # token_type_ids
)
# INT8 Quantization
config = quant_recipes.full_dynamic_recipe(
mcfg=model.model.model.config,
weight_dtype=quant_attrs.Dtype.INT8,
)
# Convert to tflite
model = ai_edge_torch.convert(
model, inputs, quant_config=config
)
model.export("biomedbert-hash-nano-embeddings.tflite")
def test():
def tokenize(text):
inputs = tokenizer(text, return_tensors="np", padding="max_length", max_length=64)
return [inputs[key] for key in ["input_ids", "attention_mask", "token_type_ids"]]
# Load tflite model
tokenizer = AutoTokenizer.from_pretrained("neuml/biomedbert-hash-nano-embeddings")
model = ai_edge_torch.load("biomedbert-hash-nano-embeddings.tflite")
# Embed query
data = model(*tokenize("cancer"))
# Embed doc
data2 = model(*tokenize("tumor"))
# Normalize and compute dot product
data /= np.linalg.norm(data, axis=1)[:, np.newaxis]
data2 /= np.linalg.norm(data2, axis=1)[:, np.newaxis]
print(np.dot(data[0], data2.T))
# [0.8580084]
export()
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment