Last active
September 28, 2016 14:02
-
-
Save libfun/dadd3b0208bfe53249fecb7a29c7c906 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import theano.tensor as T | |
from lasagne.layers import Layer | |
class Unpool3DLayer(Layer): | |
""" | |
3D Unpooling layer | |
This layer performs unpooling over the last two dimensions | |
of a 5D tensor. | |
Parameters | |
---------- | |
incoming : a :class:`Layer` instance or tuple | |
The layer feeding into this layer, or the expected input shape. | |
pool_size : integer or iterable | |
The length of the pooling region in each dimension. If an integer, it | |
is promoted to a square pooling region. If an iterable, it should have | |
two elements. | |
mode : {'repeat', 'bed_of_nails'} | |
Unpooling mode: unpool repeating tensor values or using bed of nails. | |
Default is 'repeat'. | |
**kwargs | |
Any additional keyword arguments are passed to the :class:`Layer` | |
superclass. | |
""" | |
def __init__(self, incoming, ds, mode='repeat', **kwargs): | |
super(Unpool3DLayer, self).__init__(incoming, **kwargs) | |
self.mode = mode | |
if (isinstance(ds, int)): | |
raise ValueError('ds must be int') | |
else: | |
ds = tuple(ds) | |
if len(ds) != 3: | |
raise ValueError('ds must have len == 3') | |
self.ds = ds | |
def get_output_shape_for(self, input_shape): | |
output_shape = list(input_shape) | |
output_shape[2] = input_shape[2] * self.ds[0] | |
output_shape[3] = input_shape[3] * self.ds[1] | |
output_shape[4] = input_shape[4] * self.ds[2] | |
return tuple(output_shape) | |
def get_output_for(self, input, **kwargs): | |
ds = self.ds | |
input_shape = input.shape | |
output_shape = self.get_output_shape_for(input_shape) | |
if self.mode == 'bed_of_nails': | |
_, _, i, j, k = T.nonzero(input) | |
res = T.zeros(shape=output_shape, dtype='float32') | |
res = T.set_subtensor(res[:, :, i * self.ds[0], j * self.ds[1], k * self.ds[2]], | |
input[:, :, i, j, k]) | |
return res | |
else: | |
return input.repeat(self.ds[0], axis=2).repeat(self.ds[1], axis=3).repeat(self.ds[2], axis=4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment