Skip to content

Instantly share code, notes, and snippets.

@alexdremov
Created January 5, 2025 15:39
Show Gist options
  • Save alexdremov/abe53a31141e4fda767a05db40434746 to your computer and use it in GitHub Desktop.
Save alexdremov/abe53a31141e4fda767a05db40434746 to your computer and use it in GitHub Desktop.
Code snippet uploaded via Python script (py)
row_minus_max = row - tl.max(row, axis=1, keep_dims=True)
row_minus_max = tl.where(cols_mask, row_minus_max, -float('inf'))
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=1, keep_dims=True)
softmax_output = numerator / denominator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment