Created
April 20, 2023 10:22
-
-
Save shashankprasanna/69dd22d2ec46d7b74e5ae5a4bc247758 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch._dynamo | |
from torch.fx.passes.graph_drawer import FxGraphDrawer | |
from functorch.compile import make_boxed_func | |
from torch._functorch.aot_autograd import aot_module_simplified | |
def f(x): | |
return torch.sin(x)**2 + torch.cos(x)**2 | |
def inspect_backend(gm, sample_inputs): | |
# Forward compiler capture | |
def fw(gm, sample_inputs): | |
gm.print_readable() | |
g = FxGraphDrawer(gm, 'fn') | |
with open("forward_aot.svg", "wb") as file: | |
file.write(g.get_dot_graph().create_svg()) | |
return make_boxed_func(gm.forward) | |
# Backward compiler capture | |
def bw(gm, sample_inputs): | |
gm.print_readable() | |
g = FxGraphDrawer(gm, 'fn') | |
with open("backward_aot.svg", "wb") as file: | |
file.write(g.get_dot_graph().create_svg()) | |
return make_boxed_func(gm.forward) | |
# Call AOTAutograd | |
gm_forward = aot_module_simplified(gm,sample_inputs, | |
fw_compiler=fw, | |
bw_compiler=bw) | |
return gm_forward | |
torch.manual_seed(0) | |
x = torch.rand(1000, requires_grad=True).to(device) | |
y = torch.ones_like(x) | |
torch._dynamo.reset() | |
compiled_f = torch.compile(f, backend=inspect_backend) | |
out = torch.nn.functional.mse_loss(compiled_f(x), y).backward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment