Skip to content

Instantly share code, notes, and snippets.

View justinchuby's full-sized avatar
🌊
Better ML

Justin Chu justinchuby

🌊
Better ML
View GitHub Profile
@justinchuby
justinchuby / export_hf.py
Last active August 28, 2025 01:09
Export HF model to ONNX
"""Export to ONNX.
transformers_version == "4.52.0"
"""
import onnx_diagnostic.tasks.text_generation
import torch
from transformers import AutoConfig, AutoModel
import onnxscript
import onnx_ir as ir
import onnx
def create_model():
"""Create a model that has a unsorted node with subgraph that uses a value defined later."""
a = ir.Value(name="a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)))
b = ir.Value(name="b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((3, 4)))
b_out = ir.Value(name="b_out", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((3, 4)))
c = ir.Value(name="c", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((4, 5)))
@justinchuby
justinchuby / pt_export.md
Created August 13, 2025 18:44
PyTorch export non-strict vs strict modes

torch.export non strict mode uses "make_fx [which] uses __torch_dispatch__ to trace under the hood. this is where it creates the fx nodes. AOTAutograd also calls into make_fx, but before that it also does some things related to functionalization. Since export is now "Training IR", it no longer does functionalization, so we just directly call make_fx."

@justinchuby
justinchuby / cast_verification.py
Created July 2, 2025 21:50
Cast verification
"""Verify the cast values"""
import os
import onnx
import onnx_ir as ir
DIR = "onnx/backend/test/data/node/"
def verify_one_case(path: str):
test_name = os.path.basename(path)
input_path = os.path.join(path, "test_data_set_0", "input_0.pb")
@justinchuby
justinchuby / Exported program bundle.txt
Created May 24, 2025 01:58
Exported program bundle
https://github.com/iree-org/iree-turbine/blob/main/iree/turbine/aot/fx_programs.py
Also ai-edge torch exporter
@justinchuby
justinchuby / stable.py
Last active May 23, 2025 04:35
Stable HLO
from ai_edge_torch.odml_torch.export import exported_program_to_mlir
import torch
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x ** 0.5
model = PowModel()
print(model(torch.tensor(2)))
@justinchuby
justinchuby / export_hf.py
Created April 22, 2025 23:39
Export HF models with torch.onnx
import torch
from onnx_diagnostic import torch_export_patches
from onnxscript.ir.passes.common import clear_metadata_and_docstring
from transformers import AttentionInterface, AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
# Get position_ids from attention_mask
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
# Owner(s): ["module: onnx"]
"""Unit LLM tests for the onnx dynamo exporter."""
from __future__ import annotations
from typing import Any
import logging
import transformers
@justinchuby
justinchuby / torch_geometric_onnx_comp.py
Last active March 7, 2025 00:13
Code for figuring out where an onnx model is inaccurate and visualize with model explorer
import logging
import torch
from torch_geometric.nn import GAT
logger = logging.getLogger(__name__)
logging.getLogger('torch.onnx').setLevel(logging.INFO)
logger.info("Prepare model")
num_features = 23
num_classes = 12
@justinchuby
justinchuby / ops.py
Last active February 27, 2025 17:34
PyTorch ONNX exporter supported ops
"""Display all PyTorch ONNX exporter supported ops.
NOTE: This is using internal methods. Do not use it in production code.
NOTE: Ops implemented via decomp may not be supported because they may still be decomposed
into ops that are without native implementation. They include some backward ops,
svd, sq, and some others.
"""
from torch.onnx._internal.exporter import _decomp, _registration