Skip to content

Instantly share code, notes, and snippets.

@archydeberker
Created March 4, 2018 21:37
Show Gist options
  • Save archydeberker/73689d4721f7d4c77e1eec7f6dc404cc to your computer and use it in GitHub Desktop.
Save archydeberker/73689d4721f7d4c77e1eec7f6dc404cc to your computer and use it in GitHub Desktop.
A simple CNN in Pytorch
class Net(nn.Module):
""" A simple 5 layer CNN, configurable by passing a hyperparameter dictionary at initialization.
Based upon the one outlined in the Pytorch intro tutorial
(http://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html#define-the-network)
"""
def __init__(self, hyperparam_dict=None):
super(Net, self).__init__()
if not hyperparam_dict :
hyperparam_dict = self.standard_hyperparams()
self.hyperparam_dict = hyperparam_dict
self.conv1 = nn.Conv2d(3, hyperparam_dict['conv1_size'], 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(hyperparam_dict['conv1_size'], hyperparam_dict['conv2_size'], 5)
self.fc1 = nn.Linear(hyperparam_dict['conv2_size'] * 5 * 5, hyperparam_dict['fc1_size'])
self.fc2 = nn.Linear(hyperparam_dict['fc1_size'], hyperparam_dict['fc2_size'])
self.fc3 = nn.Linear(hyperparam_dict['fc2_size'], 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, self.hyperparam_dict['conv2_size'] * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def standard_hyperparams(self):
hyperparam_dict = {}
hyperparam_dict['conv1_size'] = 6
hyperparam_dict['conv2_size'] = 16
hyperparam_dict['fc1_size'] = 120
hyperparam_dict['fc2_size'] = 84
return hyperparam_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment