Skip to content

Instantly share code, notes, and snippets.

@alexdremov
Created January 5, 2025 15:39
Show Gist options
  • Save alexdremov/441cea5ff86372f838d68c3e24bd1e52 to your computer and use it in GitHub Desktop.
Save alexdremov/441cea5ff86372f838d68c3e24bd1e52 to your computer and use it in GitHub Desktop.
Code snippet uploaded via Python script (py)
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config(
kwargs=dict(
BLOCK_SIZE_ROWS=BLOCK_SIZE_ROWS,
num_stages=num_stages,
),
num_warps=num_warps,
num_stages=num_stages,
)
for BLOCK_SIZE_ROWS in (16, 32, 64, 128)
for num_stages in (2, 3, 4)
for num_warps in (2, 4, 8)
],
key=['N_COLS'],
)
@triton.heuristics(
values=dict(
BLOCK_SIZE_COLS=lambda args: triton.next_power_of_2(args['N_COLS'])
)
)
@triton.jit
def softmax_kernel(
input_ptr: tl.tensor,
output_ptr: tl.tensor,
input_row_stride: int,
output_row_stride: int,
n_rows: int,
N_COLS: tl.constexpr,
BLOCK_SIZE_ROWS: tl.constexpr,
BLOCK_SIZE_COLS: tl.constexpr,
num_stages: tl.constexpr
):
input_ptr = tl.make_block_ptr(
base=input_ptr,
shape=(n_rows, N_COLS),
strides=(input_row_stride, 1),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_ROWS, BLOCK_SIZE_COLS),
order=(1, 0),
)
output_ptr = tl.make_block_ptr(
base=output_ptr,
shape=(n_rows, N_COLS),
strides=(output_row_stride, 1),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_ROWS, BLOCK_SIZE_COLS),
order=(1, 0),
)
cols_mask = tl.arange(0, BLOCK_SIZE_COLS) < N_COLS
row_idx = tl.program_id(0) * BLOCK_SIZE_ROWS
in_tile_ptr = tl.advance(input_ptr, (row_idx, 0))
row = tl.load(pointer=in_tile_ptr, boundary_check=(0, 1))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=1, keep_dims=True)
row_minus_max = tl.where(cols_mask, row_minus_max, -float('inf'))
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=1, keep_dims=True)
softmax_output = numerator / denominator
out_tile_ptr = tl.advance(output_ptr, (row_idx, 0))
tl.store(out_tile_ptr, softmax_output, boundary_check=(0, 1))
def softmax(x: torch.Tensor):
x_orig_shape = x.shape
x = x.view(-1, x_orig_shape[-1])
n_rows, n_cols = x.shape
y = torch.empty_like(x, memory_format=torch.contiguous_format)
grid = lambda args: (
triton.cdiv(n_rows, args['BLOCK_SIZE_ROWS']),
1,
1
)
softmax_kernel[grid](
input_ptr=x,
output_ptr=y,
input_row_stride=x.stride(0),
output_row_stride=y.stride(0),
n_rows=n_rows,
N_COLS=n_cols,
)
return y.view(*x_orig_shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment