-
-
Save back2yes/a145ea911742ee3b18b3e1ddacd0d68c to your computer and use it in GitHub Desktop.
temporal workaround to get Conv2dLocal to work in PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# In[1]: | |
import math | |
import torch | |
from torch.nn.parameter import Parameter | |
import torch.nn.functional as F | |
import torch.nn as nn | |
Module = nn.Module | |
import collections | |
from itertools import repeat | |
# In[2]: | |
def _ntuple(n): | |
def parse(x): | |
if isinstance(x, collections.Iterable): | |
return x | |
return tuple(repeat(x, n)) | |
return parse | |
_single = _ntuple(1) | |
_pair = _ntuple(2) | |
_triple = _ntuple(3) | |
_quadruple = _ntuple(4) | |
# In[3]: | |
class _ConvNd(Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride, | |
padding, dilation, transposed, output_padding, groups, bias): | |
super(_ConvNd, self).__init__() | |
if in_channels % groups != 0: | |
raise ValueError('in_channels must be divisible by groups') | |
if out_channels % groups != 0: | |
raise ValueError('out_channels must be divisible by groups') | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.padding = padding | |
self.dilation = dilation | |
self.transposed = transposed | |
self.output_padding = output_padding | |
self.groups = groups | |
if transposed: | |
self.weight = Parameter(torch.Tensor( | |
in_channels, out_channels // groups, *kernel_size)) | |
else: | |
self.weight = Parameter(torch.Tensor( | |
out_channels, in_channels // groups, *kernel_size)) | |
if bias: | |
self.bias = Parameter(torch.Tensor(out_channels)) | |
else: | |
self.register_parameter('bias', None) | |
self.reset_parameters() | |
def reset_parameters(self): | |
n = self.in_channels | |
for k in self.kernel_size: | |
n *= k | |
stdv = 1. / math.sqrt(n) | |
self.weight.data.uniform_(-stdv, stdv) | |
if self.bias is not None: | |
self.bias.data.uniform_(-stdv, stdv) | |
def __repr__(self): | |
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' | |
', stride={stride}') | |
if self.padding != (0,) * len(self.padding): | |
s += ', padding={padding}' | |
if self.dilation != (1,) * len(self.dilation): | |
s += ', dilation={dilation}' | |
if self.output_padding != (0,) * len(self.output_padding): | |
s += ', output_padding={output_padding}' | |
if self.groups != 1: | |
s += ', groups={groups}' | |
if self.bias is None: | |
s += ', bias=False' | |
s += ')' | |
return s.format(name=self.__class__.__name__, **self.__dict__) | |
# In[4]: | |
class Conv2dLocal(Module): | |
def __init__(self, in_height, in_width, in_channels, out_channels, | |
kernel_size, stride=1, padding=0, bias=True, dilation=1): | |
super(Conv2dLocal, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = _pair(kernel_size) | |
self.stride = _pair(stride) | |
self.padding = _pair(padding) | |
self.dilation = _pair(dilation) | |
self.in_height = in_height | |
self.in_width = in_width | |
self.out_height = int(math.floor( | |
(in_height + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)) | |
self.out_width = int(math.floor( | |
(in_width + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)) | |
self.weight = Parameter(torch.Tensor( | |
self.out_height, self.out_width, | |
out_channels, in_channels, *self.kernel_size)) | |
if bias: | |
self.bias = Parameter(torch.Tensor( | |
out_channels, self.out_height, self.out_width)) | |
else: | |
self.register_parameter('bias', None) | |
self.reset_parameters() | |
def reset_parameters(self): | |
n = self.in_channels | |
for k in self.kernel_size: | |
n *= k | |
stdv = 1. / math.sqrt(n) | |
self.weight.data.uniform_(-stdv, stdv) | |
if self.bias is not None: | |
self.bias.data.uniform_(-stdv, stdv) | |
def __repr__(self): | |
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' | |
', stride={stride}') | |
if self.padding != (0,) * len(self.padding): | |
s += ', padding={padding}' | |
if self.dilation != (1,) * len(self.dilation): | |
s += ', dilation={dilation}' | |
if self.bias is None: | |
s += ', bias=False' | |
s += ')' | |
return s.format(name=self.__class__.__name__, **self.__dict__) | |
def forward(self, input): | |
return conv2d_local( | |
input, self.weight, self.bias, stride=self.stride, | |
padding=self.padding, dilation=self.dilation) | |
# In[5]: | |
unfold = F.unfold | |
# In[6]: | |
def conv2d_local(input, weight, bias=None, padding=0, stride=1, dilation=1): | |
if input.dim() != 4: | |
raise NotImplementedError("Input Error: Only 4D input Tensors supported (got {}D)".format(input.dim())) | |
if weight.dim() != 6: | |
# outH x outW x outC x inC x kH x kW | |
raise NotImplementedError("Input Error: Only 6D weight Tensors supported (got {}D)".format(weight.dim())) | |
outH, outW, outC, inC, kH, kW = weight.size() | |
kernel_size = (kH, kW) | |
# N x [inC * kH * kW] x [outH * outW] | |
cols = unfold(input, kernel_size, dilation=dilation, padding=padding, stride=stride) | |
cols = cols.view(cols.size(0), cols.size(1), cols.size(2), 1).permute(0, 2, 3, 1) | |
out = torch.matmul(cols, weight.view(outH * outW, outC, inC * kH * kW).permute(0, 2, 1)) | |
out = out.view(cols.size(0), outH, outW, outC).permute(0, 3, 1, 2) | |
if bias is not None: | |
out = out + bias.expand_as(out) | |
return out | |
# In[8]: | |
# lc = Conv2dLocal(3, 3, 64, 2,3) | |
# In[9]: | |
# lc(torch.autograd.Variable(torch.randn((1,64,3,3)))) | |
# In[43]: | |
# x=torch.autograd.Variable(torch.randn((64,6,6))) | |
# In[47]: | |
# lc._backend.SpatialConvolutionLocal?? | |
# In[58]: | |
# from torch.nn import Conv2dLocal | |
# In[59]: | |
# Conv2dLocal?? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment