Skip to content

Instantly share code, notes, and snippets.

@shashankprasanna
Created April 20, 2023 10:23
Show Gist options
  • Save shashankprasanna/be3b79db2ef00b188c0088c5773f3bec to your computer and use it in GitHub Desktop.
Save shashankprasanna/be3b79db2ef00b188c0088c5773f3bec to your computer and use it in GitHub Desktop.
def f(x):
return torch.sin(x)**2 + torch.cos(x)**2
torch._dynamo.reset()
compiled_f = torch.compile(f, backend='inductor',
options={'trace.enabled':True,
'trace.graph_diagram':True})
# device = 'cpu'
device = 'cuda'
torch.manual_seed(0)
x = torch.rand(1000, requires_grad=True).to(device)
y = torch.ones_like(x)
out = torch.nn.functional.mse_loss(compiled_f(x),y).backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment