Skip to content

Instantly share code, notes, and snippets.

@soulitzer
soulitzer / ac.py
Created April 9, 2025 22:10
graph-based AC
import torch
import functools
import contextlib
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakTensorKeyDictionary
from torch.utils.checkpoint import CheckpointPolicy, _policy_from_bool
from collections import namedtuple
import weakref
@soulitzer
soulitzer / priority_cache.py
Last active October 14, 2024 17:59
Priority Cache
from torch.utils.weak import WeakTensorKeyDictionary
import weakref
from dataclasses import dataclass
import dataclasses
from typing import *
import sys
@dataclass
class CacheEntry:
one: Optional[Union[torch.Tensor, weakref.ReferenceType]] = None
@soulitzer
soulitzer / sac2.py
Created June 27, 2024 16:19
A new way to do AC
import torch
import functools
from torch.utils._python_dispatch import TorchDispatchMode
import torch.utils._pytree as pytree
from torch.utils.weak import WeakTensorKeyDictionary
class RecomputableTensor(torch.Tensor):
@staticmethod
def __new__(cls, t, func, args):
from torch.nested._internal.nested_tensor import jagged_from_list
a = torch.randn(2, 7, 256, requires_grad=True, dtype=torch.float32)
b = torch.randn(3, 7, 256, requires_grad=True, dtype=torch.float32)
c = torch.randn(4, 7, 256, requires_grad=True, dtype=torch.float32)
d = torch.randn(5, 7, 256, requires_grad=True, dtype=torch.float32)
nt1 = jagged_from_list([a, b, c, d], None)[0]
nt2 = jagged_from_list([a, b, c, d], None)[0]
nt1_view = nt1.select(2, 1)
@soulitzer
soulitzer / inference_mode_propagation.py
Created November 3, 2023 00:51
Edge case if we try to patch inference-ness in ADInplaceOrView
import torch
class T(torch.Tensor):
def __new__(cls, elem):
return torch.Tensor._make_wrapper_subclass(cls, elem.shape, dtype=elem.dtype)
def __init__(self, elem):
self.elem = elem
@classmethod
# Technically even in the "easy case" of t._base.requires_grad == t.requires_grad
# I need to perform two views to recreate that view authentically. why?
# There are actually two things I need to recreate, (1) the autograd
# graph relationship and (2) the view relationship.
# The reason we don't handle this today is because this autograd connectivity information
# is not accessible during tracing and hence not relevant to compile in part because dynam
# doesn't support grad_fn access.
import torch
from torch.nested._internal.nested_tensor import NestedTensor, jagged_from_list
from torch.profiler import profile, record_function, ProfilerActivity
device="cuda:5"
for nb_unit in (10, 1, 2, 5, 20):
lin = torch.nn.functional.linear
def sin(x):
@soulitzer
soulitzer / nested_tensor.py
Last active July 19, 2023 15:48
NestedTensor __torch_dispatch__ wrapper tensor subclass POC with automatic dispatching to custom kernel
import torch
from torch.library import Library
from typing import List
from torch.utils._pytree import tree_map
import torch.nn.functional as F
import functools
import numpy as np
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC
@soulitzer
soulitzer / nested_tensor.py
Created July 17, 2023 15:24
NestedTensor python torch dispatch wrapper subclass
import torch
from torch.library import Library
from typing import List
from torch.utils._pytree import tree_map
import functools
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC
#
# 1) The __torch_dispatch__ handles pointwise ops entirely in python
@soulitzer
soulitzer / test.py
Created July 6, 2023 23:03
output_nr issue when requires_grad=False
from torch.library import Library
test_ns = "abc"
lib = Library(test_ns, "FRAGMENT")
lib.define("foo(Tensor(a!) a, Tensor(b!) b) -> (Tensor(a!), Tensor(b!))")
def get_op(name):
return getattr(getattr(torch.ops, test_ns), name).default
op = get_op("foo")