Created
February 11, 2025 19:49
-
-
Save tlrmchlsmth/916607e15813e2f5be93fe994410230c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def decompress_2_4(metadata, values): | |
""" | |
Decompress 2:4 sparse tensor | |
Args: | |
metadata: Tensor of shape (N, K/8) uint8. | |
Each byte stores indexes for 2 blocks | |
values: Tensor of shape (N, K/2) bf16, storing non-zero values | |
Returns: | |
Dense tensor of shape (N, K) in bf16 | |
""" | |
N, K_bytes = metadata.shape | |
K = K_bytes * 8 # Each byte handles 8 positions in the output | |
dense = torch.zeros((N, K), dtype=values.dtype, device=values.device) | |
# Extract 2-bit positions from packed metadata | |
positions_1a = metadata & 0x3 | |
positions_1b = (metadata >> 2) & 0x3 | |
positions_2a = (metadata >> 4) & 0x3 | |
positions_2b = (metadata >> 6) & 0x3 | |
# Stack all positions and reshape | |
# Shape: (N, K/8, 4) -> (N, K/2) | |
positions = torch.stack( | |
[positions_1a, positions_1b, positions_2a, positions_2b], dim=-1) | |
positions = positions.reshape(N, -1) | |
# Generate block offsets | |
# Each position needs to be offset by its block index * 4 | |
block_ids = torch.arange(0, positions.shape[1], device=positions.device) | |
block_offsets = (block_ids // 2) * 4 # Each byte handles 2 blocks | |
# Add block offsets to local positions | |
global_positions = positions + block_offsets.unsqueeze(0) | |
# Scatter values into output tensor | |
dense.scatter_(1, global_positions, values) | |
return dense |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment