Last active
September 30, 2016 05:12
-
-
Save standarderror/27129036f98d8e478987fba93b28f0f2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# choose images & plot the first one | |
im = allX[102:103] | |
plt.axis('off') | |
plt.imshow(im[0].astype('uint8')) | |
plt.gcf().set_size_inches(2, 2) | |
# run images through 1st conv layer | |
m2 = tflearn.DNN(conv_1, session=model.session) | |
yhat = m2.predict(im) | |
# slice off outputs for first image and plot | |
yhat_1 = array(yhat[0]) | |
def vis_conv(v,ix,iy,ch,cy,cx, p = 0) : | |
v = np.reshape(v,(iy,ix,ch)) | |
ix += 2 | |
iy += 2 | |
npad = ((1,1), (1,1), (0,0)) | |
v = np.pad(v, pad_width=npad, mode='constant', constant_values=p) | |
v = np.reshape(v,(iy,ix,cy,cx)) | |
v = np.transpose(v,(2,0,3,1)) #cy,iy,cx,ix | |
v = np.reshape(v,(cy*iy,cx*ix)) | |
return v | |
# h_conv1 - processed image | |
ix = 64 # img size | |
iy = 64 | |
ch = 32 | |
cy = 4 # grid from channels: 32 = 4x8 | |
cx = 8 | |
v = vis_conv(yhat_1,ix,iy,ch,cy,cx) | |
plt.figure(figsize = (12,12)) | |
plt.imshow(v,cmap="Greys_r",interpolation='nearest') | |
plt.axis('off'); | |
## Acknowledgements @rgr on Stackoverflow, http://stackoverflow.com/a/35247876 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment