Created
September 4, 2024 13:45
-
-
Save itsuwari/4ff80c549084310aedadd0d7b02e6777 to your computer and use it in GitHub Desktop.
ORB-MODELS to ONNX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from orb_models.forcefield import pretrained | |
from orb_models.forcefield.base import AtomGraphs | |
# 1. Load the Pretrained Model (choose one): | |
model = pretrained.orb_d3_v1() # Or another version like orb_d3_v1, etc. | |
# 2. Access the Core Model: | |
core_model = model.model # This is the MoleculeGNS instance | |
# 3. Construct a Dummy Input (AtomGraphs): | |
# Note how we assemble the data into a single AtomGraphs object | |
dummy_input = AtomGraphs( | |
senders=torch.tensor([0, 1, 2, 0]), | |
receivers=torch.tensor([1, 2, 0, 2]), | |
n_node=torch.tensor([3]), # 3 nodes in this example | |
n_edge=torch.tensor([4]), # 4 edges in this example | |
node_features={ | |
"atomic_numbers": torch.tensor([1, 8, 1]), # Atomic numbers: H, O, H | |
"atomic_numbers_embedding": torch.nn.functional.one_hot( | |
torch.tensor([1, 8, 1]), num_classes=118 | |
).type(torch.float32), # One-hot embeddings of atomic numbers | |
"positions": torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), # Positions of atoms | |
"feat": torch.randn(3, 256), # Initial node features - random for this example | |
}, | |
edge_features={ | |
"vectors": torch.tensor( | |
[[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, -1.0, 0.0], [0.0, 1.0, 0.0]] | |
), # Edge vectors calculated from positions | |
"feat": torch.randn(4, 53), # Initial edge features - random for this example | |
}, | |
system_features={ | |
"cell": torch.tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]]).unsqueeze( | |
0 | |
) # Unit cell information | |
}, | |
) | |
# 4. Prepare for ONNX Export: | |
core_model.eval() | |
# 3. Wrapper for ONNX Export (Crucial!) | |
class ORBCoreWrapper(torch.nn.Module): | |
def __init__(self, core_model): | |
super().__init__() | |
self.core_model = core_model | |
def forward( | |
self, | |
senders, | |
receivers, | |
n_node, | |
n_edge, | |
atomic_numbers, | |
atomic_numbers_embedding, | |
positions, | |
node_feat, | |
edge_vectors, | |
edge_feat, | |
cell | |
): | |
# Reconstruct the AtomGraphs object from individual tensors | |
atom_graphs = AtomGraphs( | |
senders=senders, | |
receivers=receivers, | |
n_node=n_node, | |
n_edge=n_edge, | |
node_features={ | |
"atomic_numbers": atomic_numbers, | |
"atomic_numbers_embedding": atomic_numbers_embedding, | |
"positions": positions, | |
"feat": node_feat, | |
}, | |
edge_features={ | |
"vectors": edge_vectors, | |
"feat": edge_feat, | |
}, | |
system_features={ | |
"cell": cell, | |
}, | |
) | |
# Call the original core model | |
return self.core_model(atom_graphs) | |
# 4. Prepare for ONNX Export | |
wrapped_model = ORBCoreWrapper(core_model) | |
wrapped_model.eval() | |
# 5. Export to ONNX (Corrected) | |
input_names = ["senders", "receivers", "n_node", "n_edge", | |
"atomic_numbers", "atomic_numbers_embedding", "positions", | |
"node_feat", "edge_vectors", "edge_feat", "cell"] | |
output_names = ["updated_node_features", "updated_edge_features"] | |
torch.onnx.export( | |
wrapped_model, | |
( | |
dummy_input.senders, | |
dummy_input.receivers, | |
dummy_input.n_node, | |
dummy_input.n_edge, | |
dummy_input.node_features["atomic_numbers"], | |
dummy_input.node_features["atomic_numbers_embedding"], | |
dummy_input.node_features["positions"], | |
dummy_input.node_features["feat"], | |
dummy_input.edge_features["vectors"], | |
dummy_input.edge_features["feat"], | |
dummy_input.system_features["cell"], | |
), | |
"orb_d3_v1.onnx", | |
opset_version=16, | |
input_names=input_names, | |
output_names=output_names, | |
dynamic_axes={ | |
"senders": {0: "num_edges"}, | |
"receivers": {0: "num_edges"}, | |
"atomic_numbers": {0: "num_nodes"}, | |
"atomic_numbers_embedding": {0: "num_nodes"}, | |
"positions": {0: "num_nodes"}, | |
"node_feat": {0: "num_nodes"}, | |
"edge_vectors": {0: "num_edges"}, | |
"edge_feat": {0: "num_edges"}, | |
"updated_node_features": {0: "num_nodes"}, | |
"updated_edge_features": {0: "num_edges"}, | |
}, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment