Created
October 11, 2017 00:49
-
-
Save keunwoochoi/6d9e7d200582384a3bdc2ca69b35d4f9 to your computer and use it in GitHub Desktop.
Keras-unet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras | |
from keras import backend as K | |
from keras.layers.convolutional import Conv2D, Conv2DTranspose | |
from keras.layers import Input, Dense, Activation | |
from keras.layers import concatenate # functional interface | |
from keras.models import Model | |
from keras.layers.advanced_activations import LeakyReLU | |
N_INPUT = 512 | |
def get_unet(): | |
n_ch_exps = [4, 5, 6, 6, 7, 7] | |
kernels = (5, 5) | |
if K.image_data_format() == 'channels_first': | |
ch_axis = 1 | |
input_shape = (1, N_INPUT, N_INPUT) | |
elif K.image_data_format() == 'channels_last': | |
ch_axis = 3 | |
input_shape = (N_INPUT, N_INPUT, 1) | |
inp = Input(shape=input_shape) | |
encodeds = [] | |
# encoder | |
enc = inp | |
for l_idx, n_ch in enumerate(n_ch_exps): | |
enc = Conv2D(2 ** n_ch, kernels, | |
strides=(2, 2), padding='same', | |
kernel_initializer='he_normal')(enc) | |
enc = LeakyReLU(name='encoded_{}'.format(l_idx), | |
alpha=0.2)(enc) | |
encodeds.append(enc) | |
# decoder | |
dec = enc | |
decoder_n_chs = n_ch_exps[::-1][1:] | |
for l_idx, n_ch in enumerate(decoder_n_chs): | |
l_idx_rev = len(n_ch_exps) - l_idx - 2 # | |
dec = Conv2DTranspose(2 ** n_ch, kernels, | |
strides=(2, 2), padding='same', | |
kernel_initializer='he_normal', | |
activation='relu', | |
name='decoded_{}'.format(l_idx))(dec) | |
dec = concatenate([dec, encodeds[l_idx_rev]], | |
axis=ch_axis) | |
outp = Conv2DTranspose(1, kernels, | |
strides=(2, 2), padding='same', | |
kernel_initializer='glorot_normal', | |
activation='sigmoid', | |
name='decoded_{}'.format(l_idx + 1))(dec) | |
unet = Model(inputs=inp, outputs=outp) | |
return unet | |
if __name__ == "__main__": | |
model = get_unet() | |
model.summary() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment