Skip to content

Instantly share code, notes, and snippets.

@ikamensh
Last active November 19, 2017 20:52
Show Gist options
  • Save ikamensh/14d14d70b4517a297c0429ca003e644e to your computer and use it in GitHub Desktop.
Save ikamensh/14d14d70b4517a297c0429ca003e644e to your computer and use it in GitHub Desktop.

Since we can build a NN using just matrix multiplication (for the case of no activation function), the outputs are just linear combinations of the inputs. Adding something like ReLU adds a non-linearity, but can never increase the power of an argument.

For example, below is the output of a network that was trained to predict f(x) = x^2, trained on the interval [-1, 1]. Test data lies in the interval [-2, 2], and obviously network perfectly predicts the data it has seen, without making useful generalisation.

This is the same for many architectures, including CNNs, RNNs, and others. Is that a significant limitation in practice? what kind of dataset would demonstrate this limitation (not a toy one like I used below)

#Code used to train the network:
from keras.layers import Dense
from keras.activations import relu
from keras.models import Sequential
from keras.optimizers import SGD
from keras.losses import mean_squared_error
import numpy as np
from matplotlib import pyplot as plt
model = Sequential()
model.add(Dense(30, activation=relu, input_shape=(1,)))
model.add(Dense(1))
model.compile(SGD(), mean_squared_error)
X_train = np.linspace(-1, 1, 10000)
X_train = np.reshape(X_train, (10000, 1))
Y_train = np.power(X_train, 2)
X_test = np.linspace(-2, 2, 10000)
X_test = np.reshape(X_test, (10000, 1))
Y_test = np.power(X_test, 2)
model.fit(X_train, Y_train, epochs=40)
y_pred = model.predict(X_test)
print(y_pred)
plt.plot(X_test, y_pred, 'r', label = 'predictions')
plt.plot(X_test, Y_test, 'b', label = 'f(x) = x^2 (test data)')
plt.legend(loc='upper right')
plt.savefig("foo.png")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment