Skip to content

Instantly share code, notes, and snippets.

@Tagussan
Created December 3, 2021 09:23
Show Gist options
  • Save Tagussan/29027d77925c0a3f3cbceb36d4c011e1 to your computer and use it in GitHub Desktop.
Save Tagussan/29027d77925c0a3f3cbceb36d4c011e1 to your computer and use it in GitHub Desktop.
import torch
import timm
import torch.neuron
# original model
model = timm.create_model("fbnetc_100", pretrained=True)
model.eval()
# number of parameters
pytorch_total_params = sum(p.numel() for p in model.parameters())
print("total_param: " + str(pytorch_total_params))
# dummy data
data_in = torch.rand(1, 3, 512, 512)
print(model(data_in))
# trace the model
print("Trace model")
torch.neuron.analyze_model(model, example_inputs=[data_in])
model_neuron = torch.neuron.trace(model, example_inputs=[data_in])
model_neuron.save("test.pt")
# load the compiled model and run inference
print("Run inference on neuron")
model_neuron_loaded = torch.jit.load("test.pt")
print(model_neuron_loaded(data_in))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment