Skip to content

Instantly share code, notes, and snippets.

@tlrmchlsmth
Created February 11, 2025 19:49
Show Gist options
  • Save tlrmchlsmth/916607e15813e2f5be93fe994410230c to your computer and use it in GitHub Desktop.
Save tlrmchlsmth/916607e15813e2f5be93fe994410230c to your computer and use it in GitHub Desktop.
def decompress_2_4(metadata, values):
"""
Decompress 2:4 sparse tensor
Args:
metadata: Tensor of shape (N, K/8) uint8.
Each byte stores indexes for 2 blocks
values: Tensor of shape (N, K/2) bf16, storing non-zero values
Returns:
Dense tensor of shape (N, K) in bf16
"""
N, K_bytes = metadata.shape
K = K_bytes * 8 # Each byte handles 8 positions in the output
dense = torch.zeros((N, K), dtype=values.dtype, device=values.device)
# Extract 2-bit positions from packed metadata
positions_1a = metadata & 0x3
positions_1b = (metadata >> 2) & 0x3
positions_2a = (metadata >> 4) & 0x3
positions_2b = (metadata >> 6) & 0x3
# Stack all positions and reshape
# Shape: (N, K/8, 4) -> (N, K/2)
positions = torch.stack(
[positions_1a, positions_1b, positions_2a, positions_2b], dim=-1)
positions = positions.reshape(N, -1)
# Generate block offsets
# Each position needs to be offset by its block index * 4
block_ids = torch.arange(0, positions.shape[1], device=positions.device)
block_offsets = (block_ids // 2) * 4 # Each byte handles 2 blocks
# Add block offsets to local positions
global_positions = positions + block_offsets.unsqueeze(0)
# Scatter values into output tensor
dense.scatter_(1, global_positions, values)
return dense
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment