Last active
March 6, 2025 10:51
-
-
Save banach-space/1d09d5f06698fc2467616e311f52ae8d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class GatherModel(nn.Module): | |
def __init__(self): | |
super(GatherModel, self).__init__() | |
def forward(self, x): | |
print("Input Tensor:") | |
print(x) | |
# Create an index tensor to extract the last 3 elements from each row | |
# Trying to skip Row 2. | |
indices = torch.tensor([[2, 3, 4], # Row 0: get columns 2, 3, 4 | |
[2, 3, 4], # Row 1: get columns 2, 3, 4 | |
[0, 0, 0], # Row 2: trying to skip this one | |
[2, 3, 4]]) # Row 3: get columns 2, 3, 4 | |
print("\nIndex Tensor:") | |
print(indices) | |
# Apply torch.gather along dimension 1 (columns) | |
gathered = torch.gather(x, dim=1, index=indices) | |
print("\nGathered Tensor:") | |
print(gathered) | |
return gathered | |
# Create a 4x5 input tensor | |
input_tensor = torch.tensor([[10, 20, 30, 40, 50], | |
[60, 70, 80, 90, 100], | |
[110, 120, 130, 140, 150], # Trying to skip this row | |
[160, 170, 180, 190, 200]]) | |
# Initialize the model and perform a forward pass | |
model = GatherModel() | |
output = model(input_tensor) | |
model_c = torch.compile(model, backend="turbine_cpu") | |
import iree.turbine.aot as aot | |
export_output = aot.export(model, input_tensor) | |
mlir_file_path = "/tmp/gather_module_pytorch.mlir" | |
print("Exported .mlir:") | |
export_output.save_mlir(mlir_file_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment