Last active
April 9, 2025 05:03
-
-
Save xenova/d48921875c8178de1dd72443cfb6f7c8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# !pip install --upgrade onnx==1.17.0 onnxruntime==1.20.1 onnxslim==0.1.48 optimum==1.24.0 transformers==4.48.3 | |
import torch | |
import torch.nn as nn | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
import os | |
import onnxslim | |
from optimum.onnx.graph_transformations import merge_decoders, check_and_save_model | |
model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Florence-2-base-ft", trust_remote_code=True | |
) | |
output_dir = "converted" | |
os.makedirs(output_dir, exist_ok=True) | |
# 1. Export vision encoder | |
class VisionEncoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.vision_tower = model.vision_tower | |
self.image_projection = model.image_projection | |
self.image_proj_norm = model.image_proj_norm | |
self.image_pos_embed = model.image_pos_embed | |
self.visual_temporal_embed = model.visual_temporal_embed | |
self.image_feature_source = model.image_feature_source | |
def forward(self, pixel_values): | |
if len(pixel_values.shape) == 4: | |
batch_size, C, H, W = pixel_values.shape | |
T = 1 | |
x = self.vision_tower.forward_features_unpool(pixel_values) | |
else: | |
raise ValueError(f"invalid image shape {pixel_values.shape}") | |
if self.image_pos_embed is not None: | |
x = x.view(batch_size * T, -1, x.shape[-1]) | |
num_tokens = x.shape[-2] | |
h, w = (num_tokens**0.5).to(torch.int64), (num_tokens**0.5).to( | |
torch.int64 | |
) | |
assert h * w == num_tokens, "only support square feature maps for now" | |
x = x.view(batch_size * T, h, w, x.shape[-1]) | |
pos_embed = self.image_pos_embed(x) | |
x = x + pos_embed | |
x = x.view(batch_size, T * h * w, x.shape[-1]) | |
if self.visual_temporal_embed is not None: | |
visual_temporal_embed = self.visual_temporal_embed( | |
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0] | |
) | |
x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view( | |
1, T, 1, x.shape[-1] | |
) | |
x_feat_dict = {} | |
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) | |
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x | |
temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) | |
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x | |
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] | |
x_feat_dict["last_frame"] = x | |
new_x = [] | |
for _image_feature_source in self.image_feature_source: | |
if _image_feature_source not in x_feat_dict: | |
raise ValueError( | |
"invalid image feature source: {}".format(_image_feature_source) | |
) | |
new_x.append(x_feat_dict[_image_feature_source]) | |
x = torch.cat(new_x, dim=1) | |
x = x @ self.image_projection | |
x = self.image_proj_norm(x) | |
return x | |
vision_model = VisionEncoder() | |
w, h = 768, 768 | |
x = torch.randn(2, 3, h, w, requires_grad=True) | |
torch.onnx.export( | |
vision_model, | |
x, | |
f"{output_dir}/vision_encoder.onnx", | |
export_params=True, | |
opset_version=13, | |
do_constant_folding=True, | |
input_names=["pixel_values"], | |
output_names=["image_features"], | |
dynamic_axes={ | |
"pixel_values": {0: "batch_size", 2: "height", 3: "width"}, | |
"image_features": {0: "batch_size", 1: "sequence_length"}, | |
}, | |
) | |
# 2. Export input embedding layer | |
x = torch.randint(0, 100, (2, 16)) | |
torch.onnx.export( | |
model.get_input_embeddings(), | |
x, | |
f"{output_dir}/embed_tokens.onnx", | |
export_params=True, | |
opset_version=14, | |
do_constant_folding=True, | |
input_names=["input_ids"], | |
output_names=["inputs_embeds"], | |
dynamic_axes={ | |
"input_ids": {0: "batch_size", 1: "sequence_length"}, | |
"inputs_embeds": {0: "batch_size", 1: "sequence_length"}, | |
}, | |
) | |
# 3. Export language model (encoder, decoder w/o past, decoder w/ past, and merged decoder) | |
text_config = model.config.text_config | |
num_attention_heads = text_config.decoder_attention_heads | |
num_layers = text_config.decoder_layers | |
hidden_size = text_config.d_model | |
head_dim = hidden_size // num_attention_heads | |
batch_size = 2 | |
past_decoder_sequence_length = 6 | |
decoder_sequence_length = 13 | |
encoder_sequence_length = 3 | |
encoder_inputs_embeds = torch.randn((batch_size, encoder_sequence_length, hidden_size)) | |
encoder_attention_mask = torch.ones( | |
(batch_size, encoder_sequence_length), dtype=torch.int64 | |
) | |
decoder_inputs_embeds = torch.randn((batch_size, decoder_sequence_length, hidden_size)) | |
dummy_past_key_values_kwargs = { | |
f"past_key_values.{i}.{module}.{key}": torch.zeros( | |
batch_size, | |
num_attention_heads, | |
past_decoder_sequence_length, | |
head_dim, | |
dtype=torch.float32, | |
) | |
for i in range(num_layers) | |
for module in ("decoder", "encoder") # (self, cross_attn) | |
for key in ["key", "value"] | |
} | |
encoder_outputs = model.language_model.model.encoder( | |
inputs_embeds=encoder_inputs_embeds, | |
attention_mask=encoder_attention_mask, | |
) | |
class Encoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.encoder = model.language_model.model.encoder | |
def forward(self, *args): | |
encoder_inputs_embeds, encoder_attention_mask = args | |
encoder_outputs = self.encoder( | |
inputs_embeds=encoder_inputs_embeds, | |
attention_mask=encoder_attention_mask, | |
) | |
return encoder_outputs.last_hidden_state | |
encoder_model = Encoder() | |
torch.onnx.export( | |
encoder_model, | |
(encoder_inputs_embeds, encoder_attention_mask), | |
f=f"{output_dir}/encoder_model.onnx", | |
export_params=True, | |
opset_version=14, | |
do_constant_folding=True, | |
input_names=["inputs_embeds", "attention_mask"], | |
output_names=["last_hidden_state"], | |
dynamic_axes={ | |
"inputs_embeds": {0: "batch_size", 1: "encoder_sequence_length"}, | |
"attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, | |
"last_hidden_state": {0: "batch_size", 1: "encoder_sequence_length"}, | |
}, | |
) | |
encoder_outputs = model.language_model.model.encoder.forward( | |
inputs_embeds=encoder_inputs_embeds, | |
attention_mask=encoder_attention_mask, | |
) | |
pkv_input_names = list(dummy_past_key_values_kwargs.keys()) | |
pkv_output_names = list( | |
x.replace("past_key_values", "present") for x in dummy_past_key_values_kwargs.keys() | |
) | |
class PatchedFlorence2DecoderWithoutPast(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.language_model = model.language_model | |
def forward(self, *args): | |
encoder_attention_mask, encoder_hidden_states, inputs_embeds = args | |
decoder_outputs = self.language_model.forward( | |
encoder_outputs=encoder_outputs, | |
decoder_inputs_embeds=inputs_embeds, | |
) | |
flattened_outputs = { | |
"logits": decoder_outputs.logits, | |
} | |
for i in range(num_layers): | |
for j, v in enumerate( | |
("decoder.key", "decoder.value", "encoder.key", "encoder.value") | |
): | |
flattened_outputs[f"present.{i}.{v}"] = decoder_outputs.past_key_values[ | |
i | |
][j] | |
return flattened_outputs | |
decoder_without_past = PatchedFlorence2DecoderWithoutPast() | |
torch.onnx.export( | |
decoder_without_past, | |
args=( | |
encoder_attention_mask, | |
encoder_outputs.last_hidden_state, | |
encoder_inputs_embeds, | |
), | |
f=f"{output_dir}/decoder_model.onnx", | |
export_params=True, | |
opset_version=14, | |
do_constant_folding=True, | |
input_names=["encoder_attention_mask", "encoder_hidden_states", "inputs_embeds"], | |
output_names=["logits"] + pkv_output_names, | |
dynamic_axes={ | |
"encoder_attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, | |
"encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"}, | |
"inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length"}, | |
"logits": {0: "batch_size", 1: "decoder_sequence_length"}, | |
**{ | |
k: { | |
0: "batch_size", | |
2: "past_decoder_sequence_length + decoder_sequence_length" | |
if "decoder" in k | |
else "encoder_sequence_length", | |
} | |
for k in pkv_output_names | |
}, | |
}, | |
) | |
class PatchedFlorence2DecoderWithPast(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.language_model = model.language_model | |
def forward(self, *args): | |
encoder_attention_mask, inputs_embeds, *past_key_values_args = args | |
pkv_iter = iter(past_key_values_args) | |
pkv_tuples = tuple( | |
tuple(next(pkv_iter) for i in range(4)) for _ in range(num_layers) | |
) | |
decoder_outputs = self.language_model.forward( | |
# NOTE: encoder_outputs isn't defined here, but we will reuse k,v, cross attentions from pkv tuples | |
encoder_outputs=[torch.zeros(0, past_key_values_args[0].shape[2], 0)], | |
decoder_inputs_embeds=inputs_embeds, | |
past_key_values=pkv_tuples, | |
# No need to pass `decoder_attention_mask` | |
) | |
flattened_outputs = { | |
"logits": decoder_outputs.logits, | |
} | |
for i in range(num_layers): | |
for j, v in enumerate( | |
("decoder.key", "decoder.value", "encoder.key", "encoder.value") | |
): | |
if "encoder" in v: | |
continue | |
flattened_outputs[f"present.{i}.{v}"] = decoder_outputs.past_key_values[ | |
i | |
][j] | |
return flattened_outputs | |
decoder_with_past = PatchedFlorence2DecoderWithPast() | |
torch.onnx.export( | |
decoder_with_past, | |
args=( | |
encoder_attention_mask, | |
encoder_inputs_embeds, | |
*dummy_past_key_values_kwargs.values(), | |
), | |
f=f"{output_dir}/decoder_with_past_model.onnx", | |
export_params=True, | |
opset_version=14, | |
do_constant_folding=True, | |
input_names=["encoder_attention_mask", "inputs_embeds"] + pkv_input_names, | |
output_names=["logits"] + [x for x in pkv_output_names if "decoder" in x], | |
dynamic_axes={ | |
"encoder_attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, | |
"encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"}, | |
"inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length"}, | |
**{ | |
k: { | |
0: "batch_size", | |
2: "past_decoder_sequence_length" | |
if "decoder" in k | |
else "encoder_sequence_length_out", | |
} | |
for k in pkv_input_names | |
}, | |
"logits": {0: "batch_size", 1: "decoder_sequence_length"}, | |
**{ | |
k: { | |
0: "batch_size", | |
2: "past_decoder_sequence_length + decoder_sequence_length", | |
} | |
for k in pkv_output_names | |
if "decoder" in k | |
}, | |
}, | |
) | |
# 4. Post-processing | |
for f in os.listdir(output_dir): | |
p = os.path.join(output_dir, f) | |
onnxslim.slim(p, p) | |
merged_decoder = merge_decoders( | |
f"{output_dir}/decoder_model.onnx", | |
f"{output_dir}/decoder_with_past_model.onnx", | |
strict=False, | |
) | |
check_and_save_model(merged_decoder, f"{output_dir}/decoder_model_merged.onnx") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
in regards to an issue I raised earlier in the transformers.js repo, I think the vision encoder here doesn't get correctly converted. The reason for it is the width and height choice being
224 x 224
. However, the preprocessor used for the model uses768 x 768
, as seen here.