This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import functools | |
import contextlib | |
from torch.utils._python_dispatch import TorchDispatchMode | |
from torch.utils._pytree import tree_map_only | |
from torch.utils.weak import WeakTensorKeyDictionary | |
from torch.utils.checkpoint import CheckpointPolicy, _policy_from_bool | |
from collections import namedtuple | |
import weakref |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.weak import WeakTensorKeyDictionary | |
import weakref | |
from dataclasses import dataclass | |
import dataclasses | |
from typing import * | |
import sys | |
@dataclass | |
class CacheEntry: | |
one: Optional[Union[torch.Tensor, weakref.ReferenceType]] = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import functools | |
from torch.utils._python_dispatch import TorchDispatchMode | |
import torch.utils._pytree as pytree | |
from torch.utils.weak import WeakTensorKeyDictionary | |
class RecomputableTensor(torch.Tensor): | |
@staticmethod | |
def __new__(cls, t, func, args): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.nested._internal.nested_tensor import jagged_from_list | |
a = torch.randn(2, 7, 256, requires_grad=True, dtype=torch.float32) | |
b = torch.randn(3, 7, 256, requires_grad=True, dtype=torch.float32) | |
c = torch.randn(4, 7, 256, requires_grad=True, dtype=torch.float32) | |
d = torch.randn(5, 7, 256, requires_grad=True, dtype=torch.float32) | |
nt1 = jagged_from_list([a, b, c, d], None)[0] | |
nt2 = jagged_from_list([a, b, c, d], None)[0] | |
nt1_view = nt1.select(2, 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class T(torch.Tensor): | |
def __new__(cls, elem): | |
return torch.Tensor._make_wrapper_subclass(cls, elem.shape, dtype=elem.dtype) | |
def __init__(self, elem): | |
self.elem = elem | |
@classmethod |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Technically even in the "easy case" of t._base.requires_grad == t.requires_grad | |
# I need to perform two views to recreate that view authentically. why? | |
# There are actually two things I need to recreate, (1) the autograd | |
# graph relationship and (2) the view relationship. | |
# The reason we don't handle this today is because this autograd connectivity information | |
# is not accessible during tracing and hence not relevant to compile in part because dynam | |
# doesn't support grad_fn access. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nested._internal.nested_tensor import NestedTensor, jagged_from_list | |
from torch.profiler import profile, record_function, ProfilerActivity | |
device="cuda:5" | |
for nb_unit in (10, 1, 2, 5, 20): | |
lin = torch.nn.functional.linear | |
def sin(x): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.library import Library | |
from typing import List | |
from torch.utils._pytree import tree_map | |
import torch.nn.functional as F | |
import functools | |
import numpy as np | |
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.library import Library | |
from typing import List | |
from torch.utils._pytree import tree_map | |
import functools | |
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC | |
# | |
# 1) The __torch_dispatch__ handles pointwise ops entirely in python |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.library import Library | |
test_ns = "abc" | |
lib = Library(test_ns, "FRAGMENT") | |
lib.define("foo(Tensor(a!) a, Tensor(b!) b) -> (Tensor(a!), Tensor(b!))") | |
def get_op(name): | |
return getattr(getattr(torch.ops, test_ns), name).default | |
op = get_op("foo") |
NewerOlder