Skip to content

Instantly share code, notes, and snippets.

@jackdeadman
Created May 28, 2017 08:30
Show Gist options
  • Save jackdeadman/038ef25d09ad116a0a70826e642facc6 to your computer and use it in GitHub Desktop.
Save jackdeadman/038ef25d09ad116a0a70826e642facc6 to your computer and use it in GitHub Desktop.
Code to visualise Oja's rule finding the first Principal Component
import matplotlib.pyplot as plt
import numpy as np
plt.title('Using Oja\'s rule to find principal component')
# l2 norm
def normalise(v):
return v / np.linalg.norm(v)
# Plots the direction of a vector with a line
def plot_line(v, **args):
# Negation of eigenvector may be found
# which will shift the line
if (v.sum() < 0):
v = np.abs(v)
line = np.arange(-10, 10).reshape(-1, 1)
line = v * line
plt.plot(line[:, 0], line[:, 1], **args)
# Setup data
x = np.random.normal(10,5,100)
y = 2 + .3*x + np.random.normal(0,1,100)
X = np.array([x,y])
# Center data
X = X - np.mean(X, axis=1).reshape(-2, 1)
# Get PC traditional way
cov = np.cov(X)
vals, vectors = np.linalg.eigh(cov)
# Principal component is always index 1 in this case
principal_component = vectors[:, 1]
# initial weights
w = np.random.rand(2) - 0.5
# Activation function
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# Learning params
lr = 0.03
epochs = 2
plt.ion()
# Do learning
for i in range(epochs):
for pre in X.T:
plot_line(principal_component, c='y', linewidth=2, label='Principal Component')
plt.plot(X[0,],X[1,],'ro')
post = sigmoid(np.dot(pre, w))
dw = lr*(pre*post-(post**2)*w)
w += dw
plot_line(normalise(w), c='b', label='Weight')
plt.legend(loc='lower right')
plt.pause(0.001)
plt.cla()
plt.ioff()
plot_line(vectors[1], c='y', linewidth=2, label='Principal Component')
plt.plot(X[0,], X[1,],'ro')
plot_line(normalise(w), c='b', label='Weight')
plt.legend(loc='lower right')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment