Created
April 11, 2021 10:42
-
-
Save ishrikrishna/6b1f9dc0ba150f041bb3f9f195b9d472 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Testing simple 3 Hidden Layers Neural Network implementation in python | |
# Multiple examples are feed in single instance. | |
import numpy as np | |
from skimage.io import imread, imshow | |
def get_img(name): | |
im = imread(name) | |
im_clr = np.zeros((im.shape[0],im.shape[1])) | |
for i in range(im.shape[0]): | |
for j in range(im.shape[1]): | |
im_clr[i][j] = np.sum(im[i][j]) | |
im_flat = np.reshape(im_clr, (im.shape[0]*im.shape[1],1)); | |
return im_flat | |
image_flat = get_img("test_3.png") | |
image_flat = np.append(image_flat, get_img("test_3_1.png"), axis=1) | |
image_flat = np.append(image_flat, get_img("test_3_2.png"), axis=1) | |
image_flat = np.append(image_flat, get_img("test_3_3.png"), axis=1) | |
image_flat = np.append(image_flat, get_img("test_3_4.png"), axis=1) | |
image_flat = np.append(image_flat, get_img("test_3_5.png"), axis=1) | |
image_flat = np.append(image_flat, get_img("test_3_6.png"), axis=1) | |
image_flat = np.append(image_flat, get_img("d1.jpg"), axis=1) | |
image_flat = np.append(image_flat, get_img("d2.png"), axis=1) | |
image_flat = np.append(image_flat, get_img("d3.jpeg"), axis=1) | |
image_flat = np.append(image_flat, get_img("d4.jpeg"), axis=1) | |
image_flat = np.append(image_flat, get_img("d5.jpeg"), axis=1) | |
image_flat = np.append(image_flat, get_img("d6.jpeg"), axis=1) | |
n=13 | |
y=np.ones((1,n)) | |
lr=1 | |
epoch = 256; | |
w1 = np.random.randn(1024,image_flat.shape[0]) | |
b1 = np.random.randn(1024,1)*0 | |
z1 = np.dot(w1, image_flat) + b1 | |
a1 = 1/(1+np.exp(z1 * -1)) | |
w2 = np.random.randn(1024,a1.shape[0]) | |
b2 = np.random.randn(1024,1)*0 | |
z2 = np.dot(w2, a1) + b2 | |
a2 = 1/(1+np.exp(z2 * -1)) | |
w3 = np.random.randn(1024,a2.shape[0]) | |
b3 = np.random.randn(1024,1)*0 | |
z3 = np.dot(w3, a2) + b3 | |
a3 = 1/(1+np.exp(z3 * -1)) | |
w4 = np.random.randn(1,a3.shape[0]) | |
b4 = np.random.randn(1,1)*0 | |
z4 = np.dot(w4, a3) + b4 | |
a4 = 1/(1+np.exp(z4 * -1)) | |
#print("INIT Weights: \n", w2, w3, w4) | |
print("OUT: ", a4) | |
cost = np.sum((a4-y)**2/2, axis=1)/n | |
cost_prime = a4-y | |
#cost_2 = -(1 * np.log(a3) + (1-1)*np.log(1-a3)) | |
print("Cost: ", cost) | |
while epoch: | |
epoch -=1 | |
z4_prime = a4 * (1-a4) | |
d4 = cost_prime * z4_prime | |
pd4 = np.sum(a4 * d4, axis=1, keepdims=True)/n | |
z3_prime = a3 * (1-a3) | |
d3 = np.dot(np.transpose(w4),d4) * z3_prime | |
pd3 = np.sum(a3 * d3, axis=1, keepdims=True)/n | |
z2_prime = a2 * (1-a2) | |
d2 = np.dot(np.transpose(w3), d3) * z2_prime | |
pd2 = np.sum(a2 * d2, axis=1, keepdims=True)/n | |
z1_prime = a1 * (1-a1) | |
d1 = np.dot(np.transpose(w2), d2) * z1_prime | |
pd1 = np.sum(a1 * d1, axis=1, keepdims=True)/n | |
w4 -= lr*pd4 | |
w3 -= lr*pd3 | |
w2 -= lr*pd2 | |
w1 -= lr*pd1 | |
b4 = -lr*np.sum(z4_prime, axis=1, keepdims=True)/n | |
b3 = -lr*np.sum(z3_prime, axis=1, keepdims=True)/n | |
b2 = -lr*np.sum(z2_prime, axis=1, keepdims=True)/n | |
b1 = -lr*np.sum(z1_prime, axis=1, keepdims=True)/n | |
z1 = np.dot(w1, image_flat) + b1 | |
a1 = 1/(1+np.exp(z1 * -1)) | |
z2 = np.dot(w2, a1) + b2 | |
a2 = 1/(1+np.exp(z2 * -1)) | |
z3 = np.dot(w3, a2) + b3 | |
a3 = 1/(1+np.exp(z3 * -1)) | |
z4 = np.dot(w4, a3) + b4 | |
a4 = 1/(1+np.exp(z4 * -1)) | |
#print("OUT: ",a4) | |
cost = np.sum((a4-y)**2/2, axis=1)/n | |
cost_prime = a4-y | |
#cost_2 = -(1 * np.log(a3) + (1-1)*np.log(1-a3)) | |
print("Final OUT: ",a4) | |
print("Final Cost after n iter: ", cost) | |
def predict(img_name): | |
test = get_img(img_name) | |
z1 = np.dot(w1, test) | |
a1 = 1/(1+np.exp(z1 * -1)) | |
z2 = np.dot(w2, a1) | |
a2 = 1/(1+np.exp(z2 * -1)) | |
z3 = np.dot(w3, a2) | |
a3 = 1/(1+np.exp(z3 * -1)) | |
z4 = np.dot(w4, a3) | |
a4 = 1/(1+np.exp(z4 * -1)) | |
print("Predict OUT: ",a4) | |
#print("Final Weights: \n", w2, w3, w4) | |
predict("test_3174.png") | |
predict("test_3.png") | |
predict("ss.png") | |
predict("d7.jpeg") | |
predict("d8.jpeg") | |
predict("d9.jpeg") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment