Created
March 4, 2018 21:37
-
-
Save archydeberker/73689d4721f7d4c77e1eec7f6dc404cc to your computer and use it in GitHub Desktop.
A simple CNN in Pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Net(nn.Module): | |
""" A simple 5 layer CNN, configurable by passing a hyperparameter dictionary at initialization. | |
Based upon the one outlined in the Pytorch intro tutorial | |
(http://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html#define-the-network) | |
""" | |
def __init__(self, hyperparam_dict=None): | |
super(Net, self).__init__() | |
if not hyperparam_dict : | |
hyperparam_dict = self.standard_hyperparams() | |
self.hyperparam_dict = hyperparam_dict | |
self.conv1 = nn.Conv2d(3, hyperparam_dict['conv1_size'], 5) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.conv2 = nn.Conv2d(hyperparam_dict['conv1_size'], hyperparam_dict['conv2_size'], 5) | |
self.fc1 = nn.Linear(hyperparam_dict['conv2_size'] * 5 * 5, hyperparam_dict['fc1_size']) | |
self.fc2 = nn.Linear(hyperparam_dict['fc1_size'], hyperparam_dict['fc2_size']) | |
self.fc3 = nn.Linear(hyperparam_dict['fc2_size'], 10) | |
def forward(self, x): | |
x = self.pool(F.relu(self.conv1(x))) | |
x = self.pool(F.relu(self.conv2(x))) | |
x = x.view(-1, self.hyperparam_dict['conv2_size'] * 5 * 5) | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
x = self.fc3(x) | |
return x | |
def standard_hyperparams(self): | |
hyperparam_dict = {} | |
hyperparam_dict['conv1_size'] = 6 | |
hyperparam_dict['conv2_size'] = 16 | |
hyperparam_dict['fc1_size'] = 120 | |
hyperparam_dict['fc2_size'] = 84 | |
return hyperparam_dict | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment