Created
November 17, 2019 16:33
-
-
Save koshian2/a6382f92ee0e7ff2af1f523dab73384c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras import backend as K | |
import tensorflow.keras.layers as layers | |
# https://github.com/IShengFang/SpectralNormalizationKeras/blob/master/SpectralNormalizationKeras.py | |
class ConvSN2D(layers.Conv2D): | |
def build(self, input_shape): | |
if self.data_format == 'channels_first': | |
channel_axis = 1 | |
else: | |
channel_axis = -1 | |
if input_shape[channel_axis] is None: | |
raise ValueError('The channel dimension of the inputs ' | |
'should be defined. Found `None`.') | |
input_dim = input_shape[channel_axis] | |
kernel_shape = self.kernel_size + (input_dim, self.filters) | |
self.kernel = self.add_weight(shape=kernel_shape, | |
initializer=self.kernel_initializer, | |
name='kernel', | |
regularizer=self.kernel_regularizer, | |
constraint=self.kernel_constraint) | |
if self.use_bias: | |
self.bias = self.add_weight(shape=(self.filters,), | |
initializer=self.bias_initializer, | |
name='bias', | |
regularizer=self.bias_regularizer, | |
constraint=self.bias_constraint) | |
else: | |
self.bias = None | |
#self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]), | |
# initializer=tf.keras.initializers.RandomNormal(0, 1), | |
# name='sn', | |
# trainable=False) | |
self.u = tf.Variable( | |
tf.random.normal((tuple([1, self.kernel.shape.as_list()[-1]])), dtype=tf.float32) | |
, aggregation=tf.VariableAggregation.MEAN, trainable=False) | |
# Set input spec. | |
self.input_spec = layers.InputSpec(ndim=self.rank + 2, | |
axes={channel_axis: input_dim}) | |
self.built = True | |
def call(self, inputs, training=None): | |
def _l2normalize(v, eps=1e-12): | |
return v / (K.sum(v ** 2) ** 0.5 + eps) | |
def power_iteration(W, u): | |
#Accroding the paper, we only need to do power iteration one time. | |
_u = u | |
_v = _l2normalize(K.dot(_u, K.transpose(W))) | |
_u = _l2normalize(K.dot(_v, W)) | |
return _u, _v | |
#Spectral Normalization | |
W_shape = self.kernel.shape.as_list() | |
#Flatten the Tensor | |
W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]]) | |
_u, _v = power_iteration(W_reshaped, self.u) | |
#Calculate Sigma | |
sigma=K.dot(_v, W_reshaped) | |
sigma=K.dot(sigma, K.transpose(_u)) | |
#normalize it | |
W_bar = W_reshaped / sigma | |
#reshape weight tensor | |
if training == False: | |
W_bar = K.reshape(W_bar, W_shape) | |
else: | |
with tf.control_dependencies([self.u.assign(_u)]): | |
W_bar = K.reshape(W_bar, W_shape) | |
outputs = K.conv2d( | |
inputs, | |
W_bar, | |
strides=self.strides, | |
padding=self.padding, | |
data_format=self.data_format, | |
dilation_rate=self.dilation_rate) | |
if self.use_bias: | |
outputs = K.bias_add( | |
outputs, | |
self.bias, | |
data_format=self.data_format) | |
if self.activation is not None: | |
return self.activation(outputs) | |
return outputs | |
def upsampling2d_tpu(inputs, scale=2): | |
x = K.repeat_elements(inputs, scale, axis=1) | |
x = K.repeat_elements(x, scale, axis=2) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment