Created
February 24, 2017 20:56
-
-
Save lucasdavid/4cd49f68e5a4002a2611c18c876a1998 to your computer and use it in GitHub Desktop.
Show that we cannot call model.predict inside the generator that feeds model.fit.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from keras import backend as K | |
from keras.datasets import cifar10 | |
from keras.engine import Input, Model | |
from keras.models import Sequential | |
from keras.layers import Dense, Lambda | |
def build_model(x_shape): | |
b_net = Sequential([ | |
Dense(2048, activation='relu', name='fc1', input_shape=x_shape), | |
Dense(2048, activation='relu', name='fc2'), | |
]) | |
img_a = Input(shape=x_shape) | |
img_b = Input(shape=x_shape) | |
img_c = Input(shape=x_shape) | |
x = Lambda(lambda _x: K.concatenate(( | |
K.expand_dims(_x[0]), | |
K.expand_dims(_x[1]), | |
K.expand_dims(_x[1]) | |
)), name='concat')([b_net(img_a), b_net(img_b), b_net(img_c)]) | |
t_net = Model(input=[img_a, img_b, img_c], output=x) | |
return b_net, t_net | |
def triplet_loss(y_true, y_pred): | |
"""Triplet Loss used in https://arxiv.org/pdf/1503.03832.pdf. | |
""" | |
alpha = 1 | |
a, p, n = y_pred[:, :, 0], y_pred[:, :, 1], y_pred[:, :, 2] | |
return K.sum(K.sqrt(K.sum((a - p) ** 2, axis=-1)) - | |
K.sqrt(K.sum((a - n) ** 2, axis=-1)) + | |
alpha) | |
def triplets_gen(X, y, embedding_net, | |
batch_size=32, window_size=64, | |
anchor_label=1, shuffle=True): | |
indices = np.arange(y.shape[0]) | |
if shuffle: np.random.shuffle(indices) | |
window_offset = 0 | |
while True: | |
window_indices = indices[window_offset: window_offset + window_size] | |
window_offset = (window_offset + window_size) % ( | |
indices.shape[0] - window_size) | |
X_window, y_window = X[window_indices], y[window_indices] | |
positive_indices = np.where(y_window == anchor_label) | |
negative_indices = np.where(y_window != anchor_label) | |
positives = X_window[positive_indices] | |
negatives = X_window[negative_indices] | |
f_positives = embedding_net.predict(positives, batch_size=batch_size) | |
f_negatives = embedding_net.predict(negatives, batch_size=batch_size) | |
# f_positives = np.random.rand(positives.shape[0], 2048) | |
# f_negatives = np.random.rand(positives.shape[0], 2048) | |
# Select only hard-negatives triplets (p.4) | |
hard_negatives = np.array( | |
[[a, p, negatives[np.argmin(np.sum((f_negatives - f_a) ** 2, axis=-1))]] | |
for (p, f_p) in zip(positives, f_positives) | |
for a, f_a in zip(positives, f_positives)], copy=False) | |
batch_offset = 0 | |
while batch_offset < hard_negatives.shape[0]: | |
hard_negatives_batch = hard_negatives[ | |
batch_offset:batch_offset + batch_size] | |
# X is converted to list (i.e. multiple inputs), | |
# unused y is made out of dummy values. | |
yield ([hard_negatives_batch[:, 0], | |
hard_negatives_batch[:, 1], | |
hard_negatives_batch[:, 2]], | |
np.zeros((hard_negatives_batch.shape[0], 1, 1))) | |
batch_offset += batch_size | |
def show_model_throws_error(): | |
embedding_net, training_net = build_model(x_shape=(3072,)) | |
training_net.compile(optimizer='adam', loss=triplet_loss) | |
(X_train, y_train), (X_valid, y_valid) = cifar10.load_data() | |
X_train, X_valid = (x.reshape(x.shape[0], -1) for x in (X_train, X_valid)) | |
y_train, y_valid = (y.ravel() for y in (y_train, y_valid)) | |
train_data = triplets_gen(X_train, y_train, embedding_net=embedding_net) | |
valid_data = triplets_gen(X_valid, y_valid, embedding_net=embedding_net) | |
training_net.fit_generator(train_data, samples_per_epoch=100, nb_epoch=10, | |
validation_data=valid_data, nb_val_samples=100) | |
if __name__ == '__main__': | |
show_model_throws_error() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment