Last active
October 11, 2023 09:01
-
-
Save ndgnuh/c9891cf987c7ab3333e9d4353fff07a6 to your computer and use it in GitHub Desktop.
Custom torch layers, modules and utilities, ready to be copy-and-pasted
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.autograd import Function | |
from torch.nn import functional as F | |
class BackwardGradNormFn(Function): | |
""" | |
A normalization layer that does nothing to the input, but | |
normalize the gradient. | |
Reference: https://arxiv.org/abs/2106.09475 | |
Very cool idea, I tried applying this on the convolution stem | |
(the first two convs layer which scale down resolutions) and | |
it is quite good. | |
I'm not sure about applying it everywhere like the paper said though. | |
""" | |
@staticmethod | |
def forward(ctx, input): | |
return input | |
@staticmethod | |
def backward(ctx, grad_output): | |
norm = torch.norm(grad_output) | |
if norm > 0: | |
grad_output = grad_output / (norm) | |
# grad_output = torch.clamp(grad_output, -1000, 1000) | |
return grad_output | |
class BackwardGradNorm(nn.Module): | |
def forward(self, x): | |
if self.training: | |
return BackwardGradNormFn.apply(x) | |
else: | |
return x | |
class AccNorm(nn.Module): | |
"""Don't have 8 NVIDIA A100-s for the 100-batchsize? Gotcha! | |
This is a normalization hack to: | |
- gain the benefit from batch normalization without | |
having to crank up the batch size or having to have | |
the GPU to do so; normalization for everyone! | |
- deal with the annoying drawbacks from batch normalization, such as | |
train/validate performance difference, batch size dependent. | |
""" | |
def __init__( | |
self, | |
hidden_size: int, | |
virtual_batch_size: int = 75, | |
eps: float = 1e-5, | |
momentum: float = 0.1, | |
): | |
super().__init__() | |
self.momentum = momentum | |
self.eps = eps | |
self.T = virtual_batch_size | |
shape = (1, hidden_size, 1, 1) | |
self.weight = nn.Parameter(torch.ones(shape)) | |
self.bias = nn.Parameter(torch.zeros(shape)) | |
self.register_buffer("t", torch.tensor(0)) | |
self.register_buffer("t0", torch.tensor(1)) | |
self.register_buffer("mean", torch.zeros(shape)) | |
self.register_buffer("acc_mean", torch.zeros(shape)) | |
self.register_buffer("var", torch.ones(shape)) | |
self.register_buffer("acc_var", torch.ones(shape)) | |
self.register_buffer("std", torch.ones(shape)) | |
@torch.no_grad() | |
def update(self, x): | |
var, mean = torch.var_mean(x, (-2, -1), keepdim=True) | |
bsize = x.shape[0] | |
self.t = self.t + bsize | |
self.acc_mean = self.acc_mean + mean.sum(dim=0, keepdim=True) | |
self.acc_var = self.acc_var + var.sum(dim=0, keepdim=True) | |
if self.t >= self.t0: | |
# Calculate mean statistics | |
mean = self.acc_mean / self.t | |
var = self.acc_var / self.t | |
# Update running stats | |
mom = self.momentum | |
self.mean = self.mean * (1 - mom) + mom * mean | |
self.var = self.var * (1 - mom) + mom * var | |
self.std = torch.sqrt(self.var + self.eps) | |
# Reset accumulator | |
self.t.fill_(0) | |
self.acc_mean.fill_(0) | |
self.acc_var.fill_(1) | |
# Scale up the virtual batch size until reaching the limit | |
t1 = (self.t0 * 1.5).type(torch.long) | |
self.t0 = torch.clamp(t1, 1, self.T) | |
def forward(self, x): | |
if self.training: | |
self.update(x) | |
mean, std = self.mean, self.std | |
x = (x - mean) / std | |
x = x * self.weight + self.bias | |
return x | |
class AdaGreedNorm(nn.Module): | |
"""My goated normalization layer. Based on Adam and existing normalization layers""" | |
def __init__(self, num_channels: int, eps=1e-5, betas=(0.9, 0.999)): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1)) | |
self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) | |
self.register_buffer("need_init", torch.tensor(True)) | |
self.eps = eps | |
self.betas = betas | |
self.register_buffer("m_t", torch.zeros(1)) | |
self.register_buffer("v_t", torch.zeros(1)) | |
self.register_buffer("v_t_max", torch.zeros(1)) | |
self.register_buffer("t", torch.ones(1)) | |
self.register_buffer("m_t_hat", torch.zeros(1)) | |
self.register_buffer("v_t_hat", torch.zeros(1)) | |
self.register_buffer("mean", torch.zeros(1)) | |
self.register_buffer("std", torch.ones(1)) | |
@torch.no_grad() | |
def update_stats(self, x): | |
eps = self.eps | |
b1, b2 = self.betas | |
# This is why it is called greedy | |
v, m = torch.var_mean(x) | |
# Update running stats | |
self.m_t = self.m_t * b1 + (1 - b1) * m | |
self.v_t = self.v_t * b2 + (1 - b2) * v | |
self.m_t_hat = self.m_t / (1 - b1**self.t) | |
self.v_t_hat = self.v_t / (1 - b2**self.t) | |
self.v_t_max = torch.maximum(self.v_t_max, self.v_t_hat) | |
self.t = self.t + 1 | |
# Calculate shift and std | |
self.mean = self.m_t | |
self.std = torch.sqrt(self.v_t_max + eps) | |
def forward(self, x): | |
# Training | |
if self.training: | |
self.update_stats(x) | |
# Standardize | |
x = (x - self.mean) / self.std | |
x = x * self.weight + self.bias | |
return x | |
class WSConv2d(nn.Conv2d): | |
"""Weight standardized Convolution layer. | |
Ref: https://arxiv.org/abs/1903.10520v2 | |
""" | |
def __init__(self, *args, eps=1e-5, gain=True, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.eps = eps | |
if gain: | |
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) | |
else: | |
self.gain = 1 | |
ks = self.kernel_size | |
self.fan_in = ks[0] * ks[1] * self.in_channels | |
self.register_buffer("nweight", torch.ones_like(self.weight)) | |
def get_weight(self): | |
weight = self.weight | |
fan_in = self.fan_in | |
eps = self.eps | |
if self.training: | |
var, mean = torch.var_mean(weight, dim=(1, 2, 3), keepdim=True) | |
# Standardize | |
weight = (weight - mean) / torch.sqrt(var * fan_in + eps) | |
# Ha! Self, gain weight, get it? | |
weight = self.gain * weight | |
self.nweight = weight.clone().detach() | |
else: | |
weight = self.nweight | |
return weight | |
def forward(self, x): | |
weight = self.get_weight() | |
return F.conv2d( | |
x, | |
weight=weight, | |
bias=self.bias, | |
padding=self.padding, | |
dilation=self.dilation, | |
groups=self.groups, | |
stride=self.stride, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Union | |
import torch | |
from torch import nn | |
class GlobalResponseNorm(nn.Module): | |
def __init__(self, channels: int, eps: float = 1e-5): | |
super().__init__() | |
self.weight = nn.Parameter(torch.randn(1, 1, 1, channels)) | |
self.bias = nn.Parameter(torch.randn(1, 1, 1, channels)) | |
self.eps = eps | |
def forward(self, x): | |
# x dims: B H W C | |
Gx = torch.norm(x, dim=(-2, -3), p=2, keepdim=True) | |
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + self.eps) | |
x = x + (x * Nx) * self.weight + self.bias | |
return x | |
class PermuteDim(nn.Module): | |
def __init__(self, src: str, dst: str): | |
super().__init__() | |
self.perm = [src.index(s) for s in dst] | |
self.extra_repr = lambda: f"from='{src}', to='{dst}', perm={self.perm}" | |
def forward(self, x): | |
x = x.permute(self.perm) | |
return x | |
class ConvNextBlock(nn.Module): | |
def __init__(self, channels: int, expansion: int = 4): | |
super().__init__() | |
self.conv_mlp = nn.Sequential( | |
nn.Conv2d(channels, channels, 7, padding=3, groups=channels), | |
PermuteDim("bchw", "bhwc"), | |
nn.LayerNorm(channels), | |
nn.Linear(channels, channels * expansion), | |
nn.GELU(approximate="tanh"), | |
GlobalResponseNorm(channels * expansion), | |
nn.Linear(channels * expansion, channels), | |
PermuteDim("bhwc", "bchw"), | |
) | |
def forward(self, x): | |
return self.conv_mlp(x) + x | |
class DownSample(nn.Sequential): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
stride: int = 2, | |
prenorm: bool = True, | |
): | |
super().__init__() | |
down = nn.Conv2d(in_channels, out_channels, stride, stride) | |
if not prenorm: | |
self.down = down | |
self.ch_last = PermuteDim("bchw", "bhwc") | |
self.norm = nn.LayerNorm(in_channels if prenorm else out_channels) | |
self.ch_first = PermuteDim("bhwc", "bchw") | |
if prenorm: | |
self.down = down | |
class ConvNext(nn.Module): | |
def __init__( | |
self, | |
channels: List[int], | |
num_layers: List[int], | |
expansion: int = 4, | |
strides: Union[int, List[int]] = 2, | |
patch_size: int = 4, | |
): | |
super().__init__() | |
layers = [DownSample(3, channels[0], patch_size, prenorm=False)] | |
n = len(num_layers) | |
if isinstance(strides, int): | |
strides = [strides] * n | |
for i, nl in enumerate(num_layers): | |
c1 = channels[i] | |
c2 = channels[i + 1] | |
stride = strides[i] | |
for _ in range(nl): | |
layers.append(ConvNextBlock(c1, expansion)) | |
if i != n - 1: | |
layers.append(DownSample(c1, c2, stride=stride)) | |
self.layers = nn.ModuleList(layers) | |
def forward(self, x): | |
for layer in self.layers: | |
x = layer(x) | |
return x | |
if __name__ == "__main__": | |
model = ConvNext([20, 40, 60, 80, 80], [2, 2, 6, 2]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reference: https://arxiv.org/abs/1603.07285 | |
def get_conv_output_size_1(x: int, k: int, s: int = 1, p: int = 0, d: int = 0): | |
"""Get output resolution of the convolution operation. | |
For reference, see https://arxiv.org/abs/1603.07285. | |
Args: | |
x (int): The input resolution | |
k (int): Kernel size | |
s (int): Stride (Default: 1) | |
p (int): Padding (Default: 0) | |
d (int): Dilation (Default: 0) | |
Returns: | |
_ (int): The output resolution | |
""" | |
if d > 0: | |
k = k + (k - 1) * (d - 1) | |
return int((x + 2 * p - k) / s) + 1 | |
def get_conv_output_size(x, *configs): | |
""" | |
Get output resolution of the convolution operation. | |
This function uses `get_conv_output_size_1`. | |
Args: | |
x (int): The input resolution | |
configs (List[Tuple]): | |
List of tuples of (kernel size, stride, padding, dilation). | |
Returns: | |
_ (int): The output resolution | |
""" | |
for config in configs: | |
x = get_conv_output_size_1(x, *config) | |
return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import Tensor, nn | |
# Corner pooling: unbind-stack version | |
@torch.jit.script | |
def corner_pool(x: Tensor, dim: int, flip: bool): | |
sz = x.size(dim) | |
outputs = list(x.unbind(dim)) | |
for i in range(1, sz): | |
if flip: | |
i_in = sz - i | |
i_out = sz - i - 1 | |
else: | |
i_in = i - 1 | |
i_out = i | |
outputs[i_out] = torch.maximum(outputs[i_out], outputs[i_in]) | |
return torch.stack(outputs, dim=dim) | |
class TopPool(nn.Module): | |
def forward(self, x): | |
return corner_pool(x, -2, True) | |
class BottomPool(nn.Module): | |
def forward(self, x): | |
return corner_pool(x, -2, False) | |
class LeftPool(nn.Module): | |
def forward(self, x): | |
return corner_pool(x, -1, True) | |
class RightPool(nn.Module): | |
def forward(self, x): | |
return corner_pool(x, -1, False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
from typing import Optional | |
import torch | |
from torch import autograd, nn | |
class ReversibleFN(autograd.Function): | |
@staticmethod | |
def forward(ctx, Fm, Gm, x, *params): | |
x = x.detach() | |
with torch.no_grad(): | |
x1, x2 = torch.chunk(x, chunks=2, dim=1) | |
y1 = x1 + Fm(x2) | |
y2 = x2 + Gm(y1) | |
y = torch.cat((y1, y2), dim=1) | |
del x1, x2, y1, y2 | |
ctx.Fm = Fm | |
ctx.Gm = Gm | |
ctx.save_for_backward(x) | |
return y | |
@staticmethod | |
def backward(ctx, grad_output): | |
Fm = ctx.Fm | |
Gm = ctx.Gm | |
Fparams = tuple(Fm.parameters()) | |
Gparams = tuple(Gm.parameters()) | |
x = ctx.saved_tensors[0] | |
x1, x2 = torch.chunk(x, 2, dim=1) | |
# compute outputs building a sub-graph | |
with torch.set_grad_enabled(True): | |
x1.requires_grad = True | |
x2.requires_grad = True | |
y1 = x1 + Fm(x2) | |
y2 = x2 + Gm(y1) | |
y = torch.cat([y1, y2], dim=1) | |
inputs = (x1, x2) + Fparams + Gparams | |
grads = autograd.grad(y, inputs, grad_output) | |
grad_input = torch.cat([grads[0], grads[1]], dim=1) | |
return (None, None, grad_input) + tuple(grads[2:]) | |
class Reversible(nn.Module): | |
def __init__(self, Fm: nn.Module, Gm: Optional[nn.Module] = None): | |
super().__init__() | |
self.Fm = Fm | |
if Gm is None: | |
Gm = copy.deepcopy(Fm) | |
self.Gm = Gm | |
def forward(self, x): | |
if self.training: | |
params = list(self.Fm.parameters()) + list(self.Gm.parameters()) | |
y = ReversibleFN.apply(self.Fm, self.Gm, x, *params) | |
else: | |
x1, x2 = torch.chunk(x, chunks=2, dim=1) | |
y1 = x1 + self.Fm(x2) | |
y2 = x2 + self.Gm(y1) | |
y = torch.cat((y1, y2), dim=1) | |
return y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment