torch.export
non strict mode uses "make_fx [which] uses __torch_dispatch__
to trace under the hood. this is where it creates the fx nodes.
AOTAutograd also calls into make_fx, but before that it also does some things related to functionalization.
Since export is now "Training IR", it no longer does functionalization, so we just directly call make_fx."
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Export to ONNX. | |
transformers_version == "4.52.0" | |
""" | |
import onnx_diagnostic.tasks.text_generation | |
import torch | |
from transformers import AutoConfig, AutoModel | |
import onnxscript |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import onnx_ir as ir | |
import onnx | |
def create_model(): | |
"""Create a model that has a unsorted node with subgraph that uses a value defined later.""" | |
a = ir.Value(name="a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))) | |
b = ir.Value(name="b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((3, 4))) | |
b_out = ir.Value(name="b_out", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((3, 4))) | |
c = ir.Value(name="c", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((4, 5))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Verify the cast values""" | |
import os | |
import onnx | |
import onnx_ir as ir | |
DIR = "onnx/backend/test/data/node/" | |
def verify_one_case(path: str): | |
test_name = os.path.basename(path) | |
input_path = os.path.join(path, "test_data_set_0", "input_0.pb") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
https://github.com/iree-org/iree-turbine/blob/main/iree/turbine/aot/fx_programs.py | |
Also ai-edge torch exporter |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ai_edge_torch.odml_torch.export import exported_program_to_mlir | |
import torch | |
class PowModel(torch.nn.Module): | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return x ** 0.5 | |
model = PowModel() | |
print(model(torch.tensor(2))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from onnx_diagnostic import torch_export_patches | |
from onnxscript.ir.passes.common import clear_metadata_and_docstring | |
from transformers import AttentionInterface, AutoModelForCausalLM, AutoTokenizer | |
from transformers.cache_utils import DynamicCache | |
# Get position_ids from attention_mask | |
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Owner(s): ["module: onnx"] | |
"""Unit LLM tests for the onnx dynamo exporter.""" | |
from __future__ import annotations | |
from typing import Any | |
import logging | |
import transformers |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import torch | |
from torch_geometric.nn import GAT | |
logger = logging.getLogger(__name__) | |
logging.getLogger('torch.onnx').setLevel(logging.INFO) | |
logger.info("Prepare model") | |
num_features = 23 | |
num_classes = 12 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Display all PyTorch ONNX exporter supported ops. | |
NOTE: This is using internal methods. Do not use it in production code. | |
NOTE: Ops implemented via decomp may not be supported because they may still be decomposed | |
into ops that are without native implementation. They include some backward ops, | |
svd, sq, and some others. | |
""" | |
from torch.onnx._internal.exporter import _decomp, _registration |
NewerOlder