Skip to content

Instantly share code, notes, and snippets.

@banach-space
Last active March 6, 2025 10:51
Show Gist options
  • Save banach-space/1d09d5f06698fc2467616e311f52ae8d to your computer and use it in GitHub Desktop.
Save banach-space/1d09d5f06698fc2467616e311f52ae8d to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class GatherModel(nn.Module):
def __init__(self):
super(GatherModel, self).__init__()
def forward(self, x):
print("Input Tensor:")
print(x)
# Create an index tensor to extract the last 3 elements from each row
# Trying to skip Row 2.
indices = torch.tensor([[2, 3, 4], # Row 0: get columns 2, 3, 4
[2, 3, 4], # Row 1: get columns 2, 3, 4
[0, 0, 0], # Row 2: trying to skip this one
[2, 3, 4]]) # Row 3: get columns 2, 3, 4
print("\nIndex Tensor:")
print(indices)
# Apply torch.gather along dimension 1 (columns)
gathered = torch.gather(x, dim=1, index=indices)
print("\nGathered Tensor:")
print(gathered)
return gathered
# Create a 4x5 input tensor
input_tensor = torch.tensor([[10, 20, 30, 40, 50],
[60, 70, 80, 90, 100],
[110, 120, 130, 140, 150], # Trying to skip this row
[160, 170, 180, 190, 200]])
# Initialize the model and perform a forward pass
model = GatherModel()
output = model(input_tensor)
model_c = torch.compile(model, backend="turbine_cpu")
import iree.turbine.aot as aot
export_output = aot.export(model, input_tensor)
mlir_file_path = "/tmp/gather_module_pytorch.mlir"
print("Exported .mlir:")
export_output.save_mlir(mlir_file_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment