Created
August 1, 2023 02:01
-
-
Save dblalock/6ba6f325363427d8aa4b533eb32f8352 to your computer and use it in GitHub Desktop.
block diagonal matmul
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import * # for convenience; also, this is valid for typing | |
import warnings | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor # shorten type signatures | |
# def try_compile(f: Callable): | |
# def try_compile(): | |
try: | |
do_compile = torch.compile | |
except AttributeError: | |
# do_compile = lambda: x: x # noqa | |
do_compile = torch.jit.script | |
def _diagonal_flat_idxs(W: Tensor) -> Tensor: | |
min_dim = min(W.shape[-2], W.shape[-1]) | |
eye = torch.eye(min_dim) | |
tmp = torch.zeros_like(W) | |
tmp[..., :min_dim, :min_dim] = eye | |
return torch.where(tmp.view(-1))[0] | |
class FlexibleMatmul(torch.autograd.Function): | |
"""Do {batched, single} matmuls {with, without} {bias, out, existing .grad}. | |
Also lets you choose whether output should be rowmajor or colmajor.""" | |
@staticmethod | |
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float16) | |
def forward(ctx, | |
bias: Optional[Tensor], | |
X: Tensor, | |
W: Tensor, | |
out: Optional[Tensor], | |
colmajor_output: bool = False, | |
# bias_grad: Optional[Tensor] = None, | |
Xgrad: Optional[Tensor] = None, | |
Wgrad: Optional[Tensor] = None, | |
X_identity_idxs: Optional[Tensor] = None, | |
W_identity_idxs: Optional[Tensor] = None, | |
) -> Tensor: | |
# TODO just save bias shape, rather than the actual tensor | |
# ctx.save_for_backward(bias.shape, bias_requires_grad, bias_grad, X, W, Wgrad) | |
# Wgrad = Wgrad or (W.grad if hasattr(W, 'grad') else None) | |
# ctx.save_for_backward(bias, X, W, Xgrad, Wgrad) | |
ctx.save_for_backward(bias, X, W, Xgrad, Wgrad, X_identity_idxs, W_identity_idxs) | |
# assert Wgrad is not None # TODO rm | |
# print("------------------------> Is wgrad None? ", Wgrad is None) | |
# TODO throw informative errors instead | |
assert X.ndim in (2, 3) | |
assert W.ndim in (2, 3) | |
assert out is None or out.ndim in (2, 3) | |
# add in identity mat if requested | |
if X_identity_idxs is not None: | |
X.view(-1)[X_identity_idxs] += 1 | |
if W_identity_idxs is not None: | |
W.view(-1)[W_identity_idxs] += 1 | |
# reduce regular matmuls to bmms with batch size 1 | |
orig_ndim = max(X.ndim, W.ndim) | |
X = X.view(1, *X.shape) if X.ndim == 2 else X | |
W = W.view(1, *W.shape) if W.ndim == 2 else W | |
out = out.view(1, *out.shape) if out is not None and out.ndim == 2 else out | |
# now we handle the different cases; this looks grosser than it is | |
# because we can't transpose None and can't supply None as the bias | |
# to baddbmm; but conceptually, we're always just doing either: | |
# ret = torch.baddbmm(bias, X, W, out=out) # rowmajor | |
# or | |
# ret = torch.baddbmm(bias.T, W.T, X.T, out=out).T # colmajor | |
if colmajor_output: | |
Xt = X.transpose(-2, -1) | |
Wt = W.transpose(-2, -1) | |
if out is not None: | |
out = out.reshape(X.shape[0], W.shape[-1], X.shape[-2]) | |
if bias is None: | |
ret = torch.bmm(Wt, Xt, out=out).transpose(-2, -1) | |
else: | |
bias = torch.atleast_2d(bias).transpose(-2, -1) | |
# print("bmm Xt shape", Xt.shape) | |
# print("bmm Wt shape", Wt.shape) | |
# print("bmm biasT shape", bias.shape) | |
ret = torch.baddbmm(bias, Wt, Xt, out=out).transpose(-2, -1) | |
else: # rowmajor | |
if bias is None: | |
ret = torch.bmm(X, W, out=out) | |
else: | |
ret = torch.baddbmm(bias, X, W, out=out) | |
if orig_ndim == 2: | |
ret = ret.view(ret.shape[1:]) | |
# undo addition of identity mat | |
if X_identity_idxs is not None: | |
X.view(-1)[X_identity_idxs] -= 1 | |
if W_identity_idxs is not None: | |
W.view(-1)[W_identity_idxs] -= 1 | |
return ret | |
@staticmethod | |
@torch.cuda.amp.custom_bwd | |
def backward(ctx, dOut: Tensor) -> Tuple[ | |
Tensor, Tensor, Optional[Tensor], None, None, None, None, None]: | |
bias, X, W, Xgrad, Wgrad, X_identity_idxs, W_identity_idxs = ctx.saved_tensors | |
dX = None | |
if ctx.needs_input_grad[2]: # dgrad | |
# add in identity mat if requested; we do this here instead of | |
# just passing the idxs into the apply call to avoid having to | |
# reason about how these idxs get transposed | |
if W_identity_idxs is not None: | |
W.view(-1)[W_identity_idxs] += 1 | |
colmajor_out = X.stride()[-2] < X.stride()[-1] | |
dX = FlexibleMatmul.apply(Xgrad, dOut, W.transpose(-2, -1), Xgrad, colmajor_out) | |
# return None since we already wrote to Xgrad | |
dX = dX if Xgrad is None else None | |
# undo addition of identity mat | |
if W_identity_idxs is not None: | |
W.view(-1)[W_identity_idxs] -= 1 | |
dW = None | |
if ctx.needs_input_grad[1]: # wgrad | |
# add in identity mat if requested | |
if X_identity_idxs is not None: | |
X.view(-1)[X_identity_idxs] += 1 | |
Xt = X.transpose(-2, -1) | |
dW_strides = Wgrad.stride() if Wgrad is not None else W.stride() | |
colmajor_out = dW_strides[-2] < dW_strides[-1] | |
dW = FlexibleMatmul.apply(Wgrad, Xt, dOut, Wgrad, colmajor_out) | |
# return None since we already wrote to Wgrad | |
dW = dW if Wgrad is None else None | |
# undo addition of identity mat | |
if X_identity_idxs is not None: | |
X.view(-1)[X_identity_idxs] -= 1 | |
dBias = None | |
if bias is not None and ctx.needs_input_grad[0]: | |
if dOut.ndim == 2: # always treat as baddbmm | |
dOut = dOut.reshape(1, *dOut.shape) | |
# figure out which dims we need to sum over; we just treat | |
# bias as a 3d tensor, prepending 1s as needed, and then check | |
# which dims are 1. | |
shape_as_3d = (1, 1, 1)[:-bias.ndim] + tuple(bias.shape) | |
contract_dims = tuple([i for i, dim in enumerate(shape_as_3d) if dim == 1]) | |
dBias = dOut.sum(dim=contract_dims) if contract_dims else dOut | |
dBias = dBias.view(bias.shape) | |
return dBias, dX, dW, None, None, None, None, None, None | |
def flexible_gemm(X: Tensor, | |
W: Tensor, | |
bias: Optional[Tensor] = None, | |
out: Optional[Tensor] = None, | |
colmajor_output: bool = False, | |
Xgrad: Optional[Tensor] = None, | |
Wgrad: Optional[Tensor] = None, | |
X_identity_idxs: Optional[Tensor] = None, | |
W_identity_idxs: Optional[Tensor] = None, | |
f_act: Optional[str] = None, | |
) -> Tensor: | |
# if Wgrad passed, backwards accumulates into it directly | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") # suppress autograd non-leaf .grad warning | |
if Xgrad is None and hasattr(X, 'grad'): | |
Xgrad = X.grad | |
if Wgrad is None and hasattr(W, 'grad'): | |
Wgrad = W.grad | |
# we can't just forward args because "torch.jit.frontend.NotSupportedError: | |
# Compiled functions can't take variable number of arguments or use | |
# keyword-only arguments with defaults" | |
# @do_compile # can't wrap autograd func apply(), so we're SOL for now | |
def _body(bias, X, W, out, colmajor_output, Xgrad, Wgrad, | |
X_identity_idxs, W_identity_idxs): | |
ret = FlexibleMatmul.apply(bias, X, W, out, colmajor_output, | |
Xgrad, Wgrad, | |
X_identity_idxs, W_identity_idxs) | |
if f_act == 'pow2': # TODO other options | |
ret = 2 ** ret | |
return ret | |
return _body(bias, X, W, out, colmajor_output, Xgrad, Wgrad, | |
X_identity_idxs, W_identity_idxs) | |
def block_diag_addmm(bias: Optional[Tensor], | |
X: Tensor, | |
W: Tensor, | |
out: Optional[Tensor] = None, | |
X_identity_idxs: Optional[Tensor] = None, | |
W_identity_idxs: Optional[Tensor] = None, | |
f_act: Optional[str] = None, | |
) -> Tensor: | |
# num_subspaces = W.shape[0] | |
# if (X.ndim != 2) or (X.stride()[0] < X.stride()[1]): | |
# raise NotImplementedError("Only rowmajor X supported") | |
# print("input X shape", X.shape) | |
# print("input W shape", W.shape) | |
# print("input bias shape", bias.shape) | |
M, K = X.shape | |
num_subspaces = W.shape[0] | |
assert K == num_subspaces * W.shape[1] | |
N = num_subspaces * W.shape[2] | |
if bias is not None: | |
# bias is now (nrows or 1) x nsubspaces x out_subspace_len | |
bias = bias.reshape(-1, num_subspaces, N // num_subspaces) | |
bias = bias.transpose(0, 1) # nsubspaces x (nrows or 1) x out_subspace_len | |
# TODO support col vector as bias, not just row vec and full mat | |
# view X as rowmajor (nrows x num_subspaces x subspace_len); strides | |
# are still descending | |
X = X.reshape(M, num_subspaces, K // num_subspaces) | |
# leading dim needs to be bmm batch dim, so turn X into batch of strided | |
# rowmajor mats of shape (num_subspaces x nrows x subspace_len) | |
X = X.transpose(0, 1) # strides now (middle, biggest, smallest) | |
# output has strides (biggest, smallest, middle) | |
# output has shape (num_subspaces, nrows, subspace_len) | |
ret = flexible_gemm(X, W, bias=bias, out=out, | |
colmajor_output=True, | |
X_identity_idxs=X_identity_idxs, | |
W_identity_idxs=W_identity_idxs, | |
f_act=f_act) | |
return ret.transpose(1, 0).view(M, N) # nrows x ncols_out | |
class FlexiLinear(nn.Module): | |
def __init__(self, | |
in_features: int, | |
out_features: int, | |
bias: bool = False, | |
dtype: Optional[torch.dtype] = None, | |
device: Optional[torch.device] = None, | |
add_identity: bool = False, | |
colmajor_output: bool = False, | |
# num_subspaces: int = -1): | |
num_subspaces: int = -1, | |
f_act: Optional[str] = None): | |
super().__init__() | |
self.add_identity = add_identity | |
self.colmajor_output = colmajor_output | |
self.num_subspaces = num_subspaces # for block diag linear | |
self.f_act = f_act | |
if num_subspaces > 0: # block diag matmul | |
if in_features % num_subspaces != 0: | |
raise ValueError(f'Subspace count {num_subspaces} does not ' + | |
f'evenly divide in_features {in_features}') | |
if out_features % num_subspaces != 0: | |
raise ValueError(f'Subspace count {num_subspaces} does not ' + | |
f'evenly divide out_features {out_features}') | |
self.weight = torch.nn.Parameter(torch.empty( | |
num_subspaces, | |
in_features // num_subspaces, | |
out_features // num_subspaces, | |
dtype=dtype, | |
device=device)) | |
else: | |
# NOTE: dims are transpose of vanilla linear | |
self.weight = torch.nn.Parameter(torch.empty( | |
in_features, out_features, dtype=dtype, device=device)) | |
if bias: | |
self.bias = torch.nn.Parameter(torch.empty( | |
out_features, dtype=dtype, device=device)) | |
else: | |
self.bias = None | |
if add_identity: | |
idxs = _diagonal_flat_idxs(self.weight).to(dtype=dtype, device=device) | |
self.register_buffer('identity_idxs', idxs) | |
else: | |
self.identity_idxs = None | |
def forward(self, X: Tensor, accum: Optional[Tensor] = None) -> Tensor: | |
if self.bias is not None and accum is not None: | |
bias = accum + self.bias | |
elif self.bias is not None and accum is None: | |
bias = self.bias | |
elif self.bias is None and accum is not None: | |
bias = accum | |
else: # no bias and no accum | |
bias = None | |
if self.num_subspaces < 1: | |
return flexible_gemm(X, self.weight, bias=bias, | |
colmajor_output=self.colmajor_output, | |
W_identity_idxs=self.identity_idxs, | |
f_act=self.f_act) | |
# XXX blockdiag output is (necessarily) always colmajor, and thus | |
# ignores our colmajor_output arg | |
return block_diag_addmm(X=X, W=self.weight, bias=bias, W_identity_idxs=self.identity_idxs, f_act=self.f_act) | |
if __name__ == '__main__': | |
M, K, N = 4, 6, 8 | |
num_subspaces = 2 | |
X = torch.randn(M, K, requires_grad=True) | |
W = torch.randn(num_subspaces, K // num_subspaces, N // num_subspaces, requires_grad=True) | |
accum = torch.randn(M, N, requires_grad=True) | |
Y = block_diag_addmm(accum, X, W) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment