Skip to content

Instantly share code, notes, and snippets.

@dcci
Last active January 4, 2025 00:30
Show Gist options
  • Save dcci/9ae605dfc79130869df275db562eb146 to your computer and use it in GitHub Desktop.
Save dcci/9ae605dfc79130869df275db562eb146 to your computer and use it in GitHub Desktop.
pytorch MPS example
import torch._dynamo.config as config
import sys
config.recompile_limit = float('inf')
torch.set_default_device("mps")
torch.set_printoptions(precision=16)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def apply_op(self, op, v):
return op(v)
def forward(self, x):
return torch.exp(x)
torch._dynamo.reset()
x = torch.randn(10).to('mps')
y = x.clone().to('mps')
print(x)
func = Model()
mps_device = torch.device("mps")
func.to(mps_device)
start_time = time.time()
res1 = func(x)
print("--- %s seconds ---" % (time.time() - start_time))
start_time = time.time()
func1 = torch.compile(func, fullgraph=True)
res2 = func1(y)
print("--- %s seconds ---" % (time.time() - start_time))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment