Last active
October 13, 2023 11:39
-
-
Save thecharlieblake/82f1b54bbf608d8d339043ed8852cf91 to your computer and use it in GitHub Desktop.
Given a numpy function, prints equivalent PyTorch code (as canonical ATen ops) and returns it as a new function.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable, List | |
import numpy as np | |
import torch | |
from torch._dynamo.backends.common import aot_autograd | |
from torch.fx.graph_module import GraphModule | |
# NOTE: requires torch >= 2.1.0 | |
def np2torch(fn: Callable) -> Callable: | |
""" | |
Given a numpy function, prints equivalent PyTorch code | |
(as canonical ATen ops) and returns it as a new function. | |
""" | |
def aot_compile_backend(gm: GraphModule, _) -> Callable: | |
print(gm.code) | |
return gm | |
torch._dynamo.reset() | |
compile_backend = aot_autograd(fw_compiler=aot_compile_backend) | |
return torch.compile(fn, backend=compile_backend) | |
def example_fn(a, b): | |
c = a + b | |
d = np.tan(np.matmul(a, b)) | |
e = c - d | |
return np.sum(e, axis=-1) | |
a, b = np.random.randn(2**10, 2**10), np.random.randn(2**10, 2**10) | |
print("Numpy:", example_fn(a, b)) | |
torch_fn = np2torch(example_fn) | |
a, b = torch.from_numpy(a), torch.from_numpy(b) | |
print("Torch:", torch_fn(a, b)) |
Just to prove this is indeed valid PyTorch, if you then run forward(None, a, b)
you get:
(tensor([-5.4606e+04, -5.2097e+02, 4.1413e+02, ..., -2.6194e+02,
-3.5358e+01, 1.4673e+02], dtype=torch.float64),)
Magic!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output: