Skip to content

Instantly share code, notes, and snippets.

View a-r-r-o-w's full-sized avatar
:octocat:
wandering on a rock

Aryan a-r-r-o-w

:octocat:
wandering on a rock
View GitHub Profile
import torch
import torch.nn as nn
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class Model(nn.Module):
"""
Simple model that performs a single square matrix multiplication (C = A * B)
@a-r-r-o-w
a-r-r-o-w / attention_free_transformer.py
Created July 11, 2025 11:29
Attention-free transformer
"""
Implementation of "An Attention-Free Transformer": https://arxiv.org/abs/2105.14103
"""
import contextlib
import functools
import torch
import triton
import triton.language as tl
@a-r-r-o-w
a-r-r-o-w / fused_adaln_zero_triton.py
Created July 8, 2025 08:16
Can be faster than torch.compile if you don't use masks! Almost always possible in common transformer scenarios with aligned block sizes
import torch
import triton
import triton.language as tl
torch._dynamo.config.cache_size_limit = 10000
ENABLE_TRITON = True
ENABLE_DEEP_AUTOTUNE = True
@a-r-r-o-w
a-r-r-o-w / sequential_ring.py
Created June 30, 2025 11:36
sequential and templated ring/ulysses/unified attention implementation
import torch
torch.manual_seed(42)
def torch_sdpa(query, key, value):
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
import argparse
import contextlib
import math
import pathlib
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
@a-r-r-o-w
a-r-r-o-w / attempt_eager_layernorm_linear_activation.py
Created June 20, 2025 13:13
Attempt to make fused LayerNorm + Linear + Activation
import pathlib
import torch
import torch._dynamo.config
import triton
import triton.language as tl
torch._dynamo.config.cache_size_limit = 10000
import argparse
import contextlib
import math
import pathlib
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import argparse
import contextlib
import math
import pathlib
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.profiler._utils
@a-r-r-o-w
a-r-r-o-w / ring_attention_when_you_forget_to_do_the_rotations.py
Created June 18, 2025 05:17
ring attention when you forget to do the rotations
import argparse
import contextlib
import math
import pathlib
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
@a-r-r-o-w
a-r-r-o-w / benchmark_attention.py
Last active July 4, 2025 20:05
SDPA benchmark for torch, FA2, FA3, transformer engine, xformers, Sage Attention and HF kernels-lib
#!/usr/bin/env python3
# Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths
import functools
import os
import pathlib
import matplotlib.pyplot as plt
import torch
import torch._dynamo.config