Skip to content

Instantly share code, notes, and snippets.

@xenova
Last active April 9, 2025 05:03
Show Gist options
  • Save xenova/d48921875c8178de1dd72443cfb6f7c8 to your computer and use it in GitHub Desktop.
Save xenova/d48921875c8178de1dd72443cfb6f7c8 to your computer and use it in GitHub Desktop.
# !pip install --upgrade onnx==1.17.0 onnxruntime==1.20.1 onnxslim==0.1.48 optimum==1.24.0 transformers==4.48.3
import torch
import torch.nn as nn
from transformers import AutoProcessor, AutoModelForCausalLM
import os
import onnxslim
from optimum.onnx.graph_transformations import merge_decoders, check_and_save_model
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True
)
output_dir = "converted"
os.makedirs(output_dir, exist_ok=True)
# 1. Export vision encoder
class VisionEncoder(nn.Module):
def __init__(self):
super().__init__()
self.vision_tower = model.vision_tower
self.image_projection = model.image_projection
self.image_proj_norm = model.image_proj_norm
self.image_pos_embed = model.image_pos_embed
self.visual_temporal_embed = model.visual_temporal_embed
self.image_feature_source = model.image_feature_source
def forward(self, pixel_values):
if len(pixel_values.shape) == 4:
batch_size, C, H, W = pixel_values.shape
T = 1
x = self.vision_tower.forward_features_unpool(pixel_values)
else:
raise ValueError(f"invalid image shape {pixel_values.shape}")
if self.image_pos_embed is not None:
x = x.view(batch_size * T, -1, x.shape[-1])
num_tokens = x.shape[-2]
h, w = (num_tokens**0.5).to(torch.int64), (num_tokens**0.5).to(
torch.int64
)
assert h * w == num_tokens, "only support square feature maps for now"
x = x.view(batch_size * T, h, w, x.shape[-1])
pos_embed = self.image_pos_embed(x)
x = x + pos_embed
x = x.view(batch_size, T * h * w, x.shape[-1])
if self.visual_temporal_embed is not None:
visual_temporal_embed = self.visual_temporal_embed(
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]
)
x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(
1, T, 1, x.shape[-1]
)
x_feat_dict = {}
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
x_feat_dict["last_frame"] = x
new_x = []
for _image_feature_source in self.image_feature_source:
if _image_feature_source not in x_feat_dict:
raise ValueError(
"invalid image feature source: {}".format(_image_feature_source)
)
new_x.append(x_feat_dict[_image_feature_source])
x = torch.cat(new_x, dim=1)
x = x @ self.image_projection
x = self.image_proj_norm(x)
return x
vision_model = VisionEncoder()
w, h = 768, 768
x = torch.randn(2, 3, h, w, requires_grad=True)
torch.onnx.export(
vision_model,
x,
f"{output_dir}/vision_encoder.onnx",
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=["pixel_values"],
output_names=["image_features"],
dynamic_axes={
"pixel_values": {0: "batch_size", 2: "height", 3: "width"},
"image_features": {0: "batch_size", 1: "sequence_length"},
},
)
# 2. Export input embedding layer
x = torch.randint(0, 100, (2, 16))
torch.onnx.export(
model.get_input_embeddings(),
x,
f"{output_dir}/embed_tokens.onnx",
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=["input_ids"],
output_names=["inputs_embeds"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"inputs_embeds": {0: "batch_size", 1: "sequence_length"},
},
)
# 3. Export language model (encoder, decoder w/o past, decoder w/ past, and merged decoder)
text_config = model.config.text_config
num_attention_heads = text_config.decoder_attention_heads
num_layers = text_config.decoder_layers
hidden_size = text_config.d_model
head_dim = hidden_size // num_attention_heads
batch_size = 2
past_decoder_sequence_length = 6
decoder_sequence_length = 13
encoder_sequence_length = 3
encoder_inputs_embeds = torch.randn((batch_size, encoder_sequence_length, hidden_size))
encoder_attention_mask = torch.ones(
(batch_size, encoder_sequence_length), dtype=torch.int64
)
decoder_inputs_embeds = torch.randn((batch_size, decoder_sequence_length, hidden_size))
dummy_past_key_values_kwargs = {
f"past_key_values.{i}.{module}.{key}": torch.zeros(
batch_size,
num_attention_heads,
past_decoder_sequence_length,
head_dim,
dtype=torch.float32,
)
for i in range(num_layers)
for module in ("decoder", "encoder") # (self, cross_attn)
for key in ["key", "value"]
}
encoder_outputs = model.language_model.model.encoder(
inputs_embeds=encoder_inputs_embeds,
attention_mask=encoder_attention_mask,
)
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = model.language_model.model.encoder
def forward(self, *args):
encoder_inputs_embeds, encoder_attention_mask = args
encoder_outputs = self.encoder(
inputs_embeds=encoder_inputs_embeds,
attention_mask=encoder_attention_mask,
)
return encoder_outputs.last_hidden_state
encoder_model = Encoder()
torch.onnx.export(
encoder_model,
(encoder_inputs_embeds, encoder_attention_mask),
f=f"{output_dir}/encoder_model.onnx",
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=["inputs_embeds", "attention_mask"],
output_names=["last_hidden_state"],
dynamic_axes={
"inputs_embeds": {0: "batch_size", 1: "encoder_sequence_length"},
"attention_mask": {0: "batch_size", 1: "encoder_sequence_length"},
"last_hidden_state": {0: "batch_size", 1: "encoder_sequence_length"},
},
)
encoder_outputs = model.language_model.model.encoder.forward(
inputs_embeds=encoder_inputs_embeds,
attention_mask=encoder_attention_mask,
)
pkv_input_names = list(dummy_past_key_values_kwargs.keys())
pkv_output_names = list(
x.replace("past_key_values", "present") for x in dummy_past_key_values_kwargs.keys()
)
class PatchedFlorence2DecoderWithoutPast(nn.Module):
def __init__(self):
super().__init__()
self.language_model = model.language_model
def forward(self, *args):
encoder_attention_mask, encoder_hidden_states, inputs_embeds = args
decoder_outputs = self.language_model.forward(
encoder_outputs=encoder_outputs,
decoder_inputs_embeds=inputs_embeds,
)
flattened_outputs = {
"logits": decoder_outputs.logits,
}
for i in range(num_layers):
for j, v in enumerate(
("decoder.key", "decoder.value", "encoder.key", "encoder.value")
):
flattened_outputs[f"present.{i}.{v}"] = decoder_outputs.past_key_values[
i
][j]
return flattened_outputs
decoder_without_past = PatchedFlorence2DecoderWithoutPast()
torch.onnx.export(
decoder_without_past,
args=(
encoder_attention_mask,
encoder_outputs.last_hidden_state,
encoder_inputs_embeds,
),
f=f"{output_dir}/decoder_model.onnx",
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=["encoder_attention_mask", "encoder_hidden_states", "inputs_embeds"],
output_names=["logits"] + pkv_output_names,
dynamic_axes={
"encoder_attention_mask": {0: "batch_size", 1: "encoder_sequence_length"},
"encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"},
"inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length"},
"logits": {0: "batch_size", 1: "decoder_sequence_length"},
**{
k: {
0: "batch_size",
2: "past_decoder_sequence_length + decoder_sequence_length"
if "decoder" in k
else "encoder_sequence_length",
}
for k in pkv_output_names
},
},
)
class PatchedFlorence2DecoderWithPast(nn.Module):
def __init__(self):
super().__init__()
self.language_model = model.language_model
def forward(self, *args):
encoder_attention_mask, inputs_embeds, *past_key_values_args = args
pkv_iter = iter(past_key_values_args)
pkv_tuples = tuple(
tuple(next(pkv_iter) for i in range(4)) for _ in range(num_layers)
)
decoder_outputs = self.language_model.forward(
# NOTE: encoder_outputs isn't defined here, but we will reuse k,v, cross attentions from pkv tuples
encoder_outputs=[torch.zeros(0, past_key_values_args[0].shape[2], 0)],
decoder_inputs_embeds=inputs_embeds,
past_key_values=pkv_tuples,
# No need to pass `decoder_attention_mask`
)
flattened_outputs = {
"logits": decoder_outputs.logits,
}
for i in range(num_layers):
for j, v in enumerate(
("decoder.key", "decoder.value", "encoder.key", "encoder.value")
):
if "encoder" in v:
continue
flattened_outputs[f"present.{i}.{v}"] = decoder_outputs.past_key_values[
i
][j]
return flattened_outputs
decoder_with_past = PatchedFlorence2DecoderWithPast()
torch.onnx.export(
decoder_with_past,
args=(
encoder_attention_mask,
encoder_inputs_embeds,
*dummy_past_key_values_kwargs.values(),
),
f=f"{output_dir}/decoder_with_past_model.onnx",
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=["encoder_attention_mask", "inputs_embeds"] + pkv_input_names,
output_names=["logits"] + [x for x in pkv_output_names if "decoder" in x],
dynamic_axes={
"encoder_attention_mask": {0: "batch_size", 1: "encoder_sequence_length"},
"encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"},
"inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length"},
**{
k: {
0: "batch_size",
2: "past_decoder_sequence_length"
if "decoder" in k
else "encoder_sequence_length_out",
}
for k in pkv_input_names
},
"logits": {0: "batch_size", 1: "decoder_sequence_length"},
**{
k: {
0: "batch_size",
2: "past_decoder_sequence_length + decoder_sequence_length",
}
for k in pkv_output_names
if "decoder" in k
},
},
)
# 4. Post-processing
for f in os.listdir(output_dir):
p = os.path.join(output_dir, f)
onnxslim.slim(p, p)
merged_decoder = merge_decoders(
f"{output_dir}/decoder_model.onnx",
f"{output_dir}/decoder_with_past_model.onnx",
strict=False,
)
check_and_save_model(merged_decoder, f"{output_dir}/decoder_model_merged.onnx")
@ir2718
Copy link

ir2718 commented Feb 17, 2025

Hi,

in regards to an issue I raised earlier in the transformers.js repo, I think the vision encoder here doesn't get correctly converted. The reason for it is the width and height choice being 224 x 224. However, the preprocessor used for the model uses 768 x 768, as seen here.

@Md-Sayeed-Khan
Copy link

Md-Sayeed-Khan commented Mar 12, 2025

Hello @ir2718 @xenova how do i use these converted onnx files for python inference , please help
what shall be order of onnx models

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment