Skip to content

Instantly share code, notes, and snippets.

@alexdremov
Created January 5, 2025 15:38
Show Gist options
  • Save alexdremov/15d8a35f1f047d53e7d91775862a2c4a to your computer and use it in GitHub Desktop.
Save alexdremov/15d8a35f1f047d53e7d91775862a2c4a to your computer and use it in GitHub Desktop.
Code snippet uploaded via Python script (py)
import torch
# Our softmax function in PyTorch land
def softmax_pytorch(x):
# Avoid numerical instability by subtracting max
x_max = torch.max(x, dim=-1, keepdim=True).values
x_exp = torch.exp(x - x_max)
return x_exp / torch.sum(x_exp, dim=-1, keepdim=True)
# Let's compile it with torch.compile
@torch.compile
def compiled_softmax(x):
return softmax_pytorch(x)
if __name__ == "__main__":
# Example usage:
input_tensor = torch.randn((2, 4), device="cuda")
output = compiled_softmax(input_tensor)
print("Input:", input_tensor)
print("Compiled Softmax Output:", output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment