Last active
April 2, 2019 15:18
-
-
Save ageron/04dd44734ac5c5193ccbb2a14f4fb0e7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from tensorflow import keras | |
import traceback | |
class ResidualBlock(keras.layers.Layer): | |
def __init__(self, n_layers, n_neurons, **kwargs): | |
super().__init__(**kwargs) | |
self.n_layers = n_layers | |
self.n_neurons = n_neurons | |
self.hidden = [keras.layers.Dense(n_neurons, activation="elu", | |
kernel_initializer="he_normal") | |
for _ in range(n_layers)] | |
def call(self, inputs): | |
Z = inputs | |
for layer in self.hidden: | |
Z = layer(Z) | |
return inputs + Z | |
def get_config(self): | |
base_config = super().get_config() | |
return {**base_config, | |
"n_layers": self.n_layers, "n_neurons": self.n_neurons} | |
class ResidualRegressor(keras.models.Model): | |
def __init__(self, output_dim, **kwargs): | |
super().__init__(**kwargs) | |
self.output_dim = output_dim | |
self.hidden1 = keras.layers.Dense(30, activation="elu", | |
kernel_initializer="he_normal") | |
self.block1 = ResidualBlock(2, 30) | |
self.block2 = ResidualBlock(2, 30) | |
self.out = keras.layers.Dense(output_dim) | |
def call(self, inputs): | |
Z = self.hidden1(inputs) | |
for _ in range(1 + 3): | |
Z = self.block1(Z) | |
Z = self.block2(Z) | |
return self.out(Z) | |
def get_config(self): | |
base_config = super().get_config() | |
return {**base_config, | |
"output_dim": self.output_dim} | |
class ResidualRegressorLayer(keras.layers.Layer): | |
def __init__(self, output_dim, **kwargs): | |
super().__init__(**kwargs) | |
self.output_dim = output_dim | |
self.hidden1 = keras.layers.Dense(30, activation="elu", | |
kernel_initializer="he_normal") | |
self.block1 = ResidualBlock(2, 30) | |
self.block2 = ResidualBlock(2, 30) | |
self.out = keras.layers.Dense(output_dim) | |
def call(self, inputs): | |
Z = self.hidden1(inputs) | |
for _ in range(1 + 3): | |
Z = self.block1(Z) | |
Z = self.block2(Z) | |
return self.out(Z) | |
def get_config(self): | |
base_config = super().get_config() | |
return {**base_config, | |
"output_dim": self.output_dim} | |
X_train, X_valid, X_test, X_new = np.random.randn(4, 1000, 10) | |
y_train, y_valid, y_test, y_new = np.random.randn(4, 1000, 1) | |
model1 = ResidualRegressor(1) | |
model2 = keras.models.Sequential([keras.layers.Input(shape=[10]), ResidualRegressor(1)]) | |
model3 = keras.models.Sequential([ResidualRegressorLayer(1, input_shape=[10])]) | |
for model in (model1, model2, model3): | |
print("-" * 80) | |
model.compile(loss="mse", optimizer="nadam") | |
history = model.fit(X_train, y_train, epochs=5, validation_data=[X_valid, y_valid]) | |
score = model.evaluate(X_test, y_test) | |
y_pred = model.predict(X_new) | |
try: | |
model.save("my_custom_model.h5") | |
except NotImplementedError as ex: | |
traceback.print_exc() | |
continue | |
print("Save worked fine!") | |
model = keras.models.load_model("my_custom_model.h5", | |
custom_objects={"ResidualRegressorLayer": ResidualRegressorLayer, | |
"ResidualBlock": ResidualBlock}) | |
print("Load worked fine!") | |
history = model.fit(X_train, y_train, epochs=5, validation_data=[X_valid, y_valid]) | |
print("I can even continue training!") | |
""" | |
Output: | |
-------------------------------------------------------------------------------- | |
Train on 1000 samples, validate on 1000 samples | |
Epoch 1/5 | |
1000/1000 [==============================] - 1s 877us/sample - loss: 79.3072 - val_loss: 11.6533 | |
Epoch 2/5 | |
1000/1000 [==============================] - 0s 247us/sample - loss: 8.3333 - val_loss: 5.6323 | |
Epoch 3/5 | |
1000/1000 [==============================] - 0s 239us/sample - loss: 4.6203 - val_loss: 3.9502 | |
Epoch 4/5 | |
1000/1000 [==============================] - 0s 193us/sample - loss: 3.3369 - val_loss: 3.2761 | |
Epoch 5/5 | |
1000/1000 [==============================] - 0s 162us/sample - loss: 2.7189 - val_loss: 2.8614 | |
1000/1000 [==============================] - 0s 24us/sample - loss: 2.9697 | |
-------------------------------------------------------------------------------- | |
Train on 1000 samples, validate on 1000 samples | |
Traceback (most recent call last): | |
File "<ipython-input-320-54bddbfdc178>", line 85, in <module> | |
model.save("my_custom_model.h5") | |
File "/Users/ageron/.virtualenvs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1300, in save | |
# returns a compiled model | |
NotImplementedError: The `save` method requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider using `save_weights`, in order to save the weights of the model. | |
Epoch 1/5 | |
1000/1000 [==============================] - 1s 787us/sample - loss: 432.0514 - val_loss: 40.2053 | |
Epoch 2/5 | |
1000/1000 [==============================] - 0s 194us/sample - loss: 24.1645 - val_loss: 15.6740 | |
Epoch 3/5 | |
1000/1000 [==============================] - 0s 200us/sample - loss: 12.5774 - val_loss: 11.2395 | |
Epoch 4/5 | |
1000/1000 [==============================] - 0s 192us/sample - loss: 9.0193 - val_loss: 8.9401 | |
Epoch 5/5 | |
1000/1000 [==============================] - 0s 162us/sample - loss: 6.8977 - val_loss: 7.4216 | |
1000/1000 [==============================] - 0s 26us/sample - loss: 7.7043 | |
-------------------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "<ipython-input-320-54bddbfdc178>", line 85, in <module> | |
model.save("my_custom_model.h5") | |
File "/Users/ageron/.virtualenvs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/sequential.py", line 321, in save | |
save_model(self, filepath, overwrite, include_optimizer) | |
File "/Users/ageron/.virtualenvs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/saving/hdf5_format.py", line 104, in save_model | |
default=serialization.get_json_type).encode('utf8') | |
File "/Users/ageron/.virtualenvs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/sequential.py", line 328, in get_config | |
'config': layer.get_config() | |
File "<ipython-input-320-54bddbfdc178>", line 43, in get_config | |
base_config = super().get_config() | |
File "/Users/ageron/.virtualenvs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1030, in get_config | |
self._nested_outputs, output_shapes) | |
NotImplementedError | |
Train on 1000 samples, validate on 1000 samples | |
Epoch 1/5 | |
1000/1000 [==============================] - 1s 849us/sample - loss: 591.1464 - val_loss: 33.6940 | |
Epoch 2/5 | |
1000/1000 [==============================] - 0s 226us/sample - loss: 22.1110 - val_loss: 15.7172 | |
Epoch 3/5 | |
1000/1000 [==============================] - 0s 226us/sample - loss: 11.7609 - val_loss: 11.2809 | |
Epoch 4/5 | |
1000/1000 [==============================] - 0s 190us/sample - loss: 8.5177 - val_loss: 9.1430 | |
Epoch 5/5 | |
1000/1000 [==============================] - 0s 151us/sample - loss: 6.8662 - val_loss: 7.7861 | |
1000/1000 [==============================] - 0s 24us/sample - loss: 7.6881 | |
Save worked fine! | |
Load worked fine! | |
Train on 1000 samples, validate on 1000 samples | |
Epoch 1/5 | |
1000/1000 [==============================] - 1s 861us/sample - loss: 5.8443 - val_loss: 6.9330 | |
Epoch 2/5 | |
1000/1000 [==============================] - 0s 226us/sample - loss: 5.0928 - val_loss: 6.2881 | |
Epoch 3/5 | |
1000/1000 [==============================] - 0s 246us/sample - loss: 4.5347 - val_loss: 5.7962 | |
Epoch 4/5 | |
1000/1000 [==============================] - 0s 206us/sample - loss: 4.1041 - val_loss: 5.4129 | |
Epoch 5/5 | |
1000/1000 [==============================] - 0s 158us/sample - loss: 3.7803 - val_loss: 5.1172 | |
I can even continue training! | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment