Skip to content

Instantly share code, notes, and snippets.

@back2yes
Forked from guillefix/lc.py
Created March 27, 2018 08:32
Show Gist options
  • Save back2yes/a145ea911742ee3b18b3e1ddacd0d68c to your computer and use it in GitHub Desktop.
Save back2yes/a145ea911742ee3b18b3e1ddacd0d68c to your computer and use it in GitHub Desktop.
temporal workaround to get Conv2dLocal to work in PyTorch
# coding: utf-8
# In[1]:
import math
import torch
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import torch.nn as nn
Module = nn.Module
import collections
from itertools import repeat
# In[2]:
def _ntuple(n):
def parse(x):
if isinstance(x, collections.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
# In[3]:
class _ConvNd(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding, groups, bias):
super(_ConvNd, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
if transposed:
self.weight = Parameter(torch.Tensor(
in_channels, out_channels // groups, *kernel_size))
else:
self.weight = Parameter(torch.Tensor(
out_channels, in_channels // groups, *kernel_size))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
# In[4]:
class Conv2dLocal(Module):
def __init__(self, in_height, in_width, in_channels, out_channels,
kernel_size, stride=1, padding=0, bias=True, dilation=1):
super(Conv2dLocal, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.in_height = in_height
self.in_width = in_width
self.out_height = int(math.floor(
(in_height + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1))
self.out_width = int(math.floor(
(in_width + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1))
self.weight = Parameter(torch.Tensor(
self.out_height, self.out_width,
out_channels, in_channels, *self.kernel_size))
if bias:
self.bias = Parameter(torch.Tensor(
out_channels, self.out_height, self.out_width))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.bias is None:
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, input):
return conv2d_local(
input, self.weight, self.bias, stride=self.stride,
padding=self.padding, dilation=self.dilation)
# In[5]:
unfold = F.unfold
# In[6]:
def conv2d_local(input, weight, bias=None, padding=0, stride=1, dilation=1):
if input.dim() != 4:
raise NotImplementedError("Input Error: Only 4D input Tensors supported (got {}D)".format(input.dim()))
if weight.dim() != 6:
# outH x outW x outC x inC x kH x kW
raise NotImplementedError("Input Error: Only 6D weight Tensors supported (got {}D)".format(weight.dim()))
outH, outW, outC, inC, kH, kW = weight.size()
kernel_size = (kH, kW)
# N x [inC * kH * kW] x [outH * outW]
cols = unfold(input, kernel_size, dilation=dilation, padding=padding, stride=stride)
cols = cols.view(cols.size(0), cols.size(1), cols.size(2), 1).permute(0, 2, 3, 1)
out = torch.matmul(cols, weight.view(outH * outW, outC, inC * kH * kW).permute(0, 2, 1))
out = out.view(cols.size(0), outH, outW, outC).permute(0, 3, 1, 2)
if bias is not None:
out = out + bias.expand_as(out)
return out
# In[8]:
# lc = Conv2dLocal(3, 3, 64, 2,3)
# In[9]:
# lc(torch.autograd.Variable(torch.randn((1,64,3,3))))
# In[43]:
# x=torch.autograd.Variable(torch.randn((64,6,6)))
# In[47]:
# lc._backend.SpatialConvolutionLocal??
# In[58]:
# from torch.nn import Conv2dLocal
# In[59]:
# Conv2dLocal??
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment