Skip to content

Instantly share code, notes, and snippets.

@irhum
Created January 19, 2019 16:07
Show Gist options
  • Save irhum/b9fec41bc29983b3c718fb4b27f74435 to your computer and use it in GitHub Desktop.
Save irhum/b9fec41bc29983b3c718fb4b27f74435 to your computer and use it in GitHub Desktop.
A basic DCGAN implementation
class Generator(nn.Module):
def __init__(self, n_hidden, bottom_width=4, channels=512):
super().__init__()
self.channels = channels
self.bottom_width = bottom_width
self.linear = nn.Linear(n_hidden, bottom_width*bottom_width*channels)
self.dconv1 = nn.ConvTranspose2d(channels, channels // 2, 4, 2, 1)
self.dconv2 = nn.ConvTranspose2d(channels // 2, channels // 4, 4, 2, 1)
self.dconv3 = nn.ConvTranspose2d(channels // 4, channels // 8, 4, 2, 1)
self.dconv4 = nn.ConvTranspose2d(channels // 8, 3, 4, 2, 1)
self.bn0 = nn.BatchNorm1d(bottom_width*bottom_width*channels)
self.bn1 = nn.BatchNorm2d(channels // 2)
self.bn2 = nn.BatchNorm2d(channels // 4)
self.bn3 = nn.BatchNorm2d(channels // 8)
def forward(self, x):
x = F.relu(self.bn0(self.linear(x))).view(-1, self.channels,
self.bottom_width, self.bottom_width)
x = F.relu(self.bn1(self.dconv1(x)))
x = F.relu(self.bn2(self.dconv2(x)))
x = F.relu(self.bn3(self.dconv3(x)))
x = torch.sigmoid(self.dconv4(x))
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment