Last active
October 30, 2025 08:58
-
-
Save drscotthawley/d23564cf9cfa6fa922233a6eff8b4675 to your computer and use it in GitHub Desktop.
Code & utils re. a convolutional VAE with residual connections
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import matplotlib.pyplot as plt | |
| import torch.nn.functional as F | |
| import wandb | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, channels, use_skip=True, use_bn=True, act=nn.GELU): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=not use_bn) | |
| self.bn1 = nn.BatchNorm2d(channels) if use_bn else nn.Identity() | |
| self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=not use_bn) | |
| self.bn2 = nn.BatchNorm2d(channels) if use_bn else nn.Identity() | |
| self.use_skip, self.act = use_skip, act | |
| def forward(self, x): | |
| if self.use_skip: x0 = x | |
| out = self.act()(self.bn1(self.conv1(x))) | |
| out = F.dropout(out, 0.4, training=self.training) | |
| out = self.bn2(self.conv2(out)) | |
| if self.use_skip: out = out + x0 | |
| return self.act()(out) | |
| class ResNetVAEEncoder(nn.Module): | |
| """this makes a 1D vector of length latent_dim""" | |
| def __init__(self, in_channels, latent_dim=3, base_channels=32, blocks_per_level=4, use_skips=True, use_bn=True, act=nn.GELU): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_channels, base_channels, 3, padding=1, bias=not use_bn) | |
| self.bn1 = nn.BatchNorm2d(base_channels) if use_bn else nn.Identity() | |
| channels = [base_channels, base_channels*2, base_channels*4] | |
| self.levels = nn.ModuleList([nn.ModuleList([ResidualBlock(ch, use_skips, use_bn, act=act) for _ in range(blocks_per_level)]) for ch in channels]) | |
| self.transitions = nn.ModuleList([nn.Conv2d(channels[i], channels[i+1], 1, bias=not use_bn) for i in range(len(channels)-1) ]) | |
| self.global_avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.fc = nn.Linear(base_channels*4, 2*latent_dim) | |
| self.act = act | |
| def forward(self, x): | |
| x = self.act()(self.bn1(self.conv1(x))) | |
| for i in range(len(self.levels)): | |
| if i > 0: # shrink down | |
| x = F.avg_pool2d(x, 2) | |
| x = self.transitions[i-1](x) | |
| for block in self.levels[i]: | |
| x = block(x) | |
| print("encoder: x.shape = ",x.shape) | |
| x = self.global_avg_pool(x) | |
| x = self.fc(x.flatten(start_dim=1)) | |
| mean, logvar = x.chunk(2, dim=1) # mean and log variance | |
| return mean, logvar | |
| class ResNetVAEDecoder(nn.Module): | |
| """this is just the mirror image of ResNetVAEEnoder""" | |
| def __init__(self, out_channels, latent_dim=3, base_channels=32, blocks_per_level=4, use_skips=True, use_bn=True, act=nn.GELU): | |
| super().__init__() | |
| channels = [base_channels, base_channels*2, base_channels*4][::-1] # reversed from encoder | |
| self.channels = channels | |
| self.start_dim = 7 # starting spatial dimension | |
| self.fc = nn.Linear(latent_dim, channels[0] * self.start_dim * self.start_dim) # 128 * 16 # starting size | |
| self.levels = nn.ModuleList([nn.ModuleList([ResidualBlock(ch, use_skips, use_bn, act=act) for _ in range(blocks_per_level)]) for ch in channels]) | |
| self.transitions = nn.ModuleList([ nn.Conv2d(channels[i], channels[i+1], 1, bias=not use_bn) for i in range(len(channels)-1)]) | |
| self.final_conv = nn.Conv2d(base_channels, out_channels, 3, padding=1) | |
| self.act = act | |
| def forward(self, z): | |
| x = self.fc(z).view(-1, self.channels[0], self.start_dim, self.start_dim) # project to spatial | |
| for i in range(len(self.levels)): | |
| for block in self.levels[i]: | |
| x = block(x) | |
| if i < len(self.levels) - 1: # not last level | |
| x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) | |
| x = self.transitions[i](x) | |
| return self.final_conv(x) | |
| class ResNetVAEEncoderSpatial(nn.Module): | |
| "this shrinks down to a wee image for its latents, e.g. for MNIST: 1x28x28 -> 1x7x7 for two downsampling operations" | |
| def __init__(self, in_channels, latent_channels=1, base_channels=32, blocks_per_level=4, use_skips=True, use_bn=True, act=nn.GELU): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_channels, base_channels, 3, padding=1, bias=not use_bn) | |
| self.bn1 = nn.BatchNorm2d(base_channels) if use_bn else nn.Identity() | |
| channels = [base_channels, base_channels*2, base_channels*4] | |
| self.levels = nn.ModuleList([nn.ModuleList([ResidualBlock(ch, use_skips, use_bn, act=act) for _ in range(blocks_per_level)]) for ch in channels]) | |
| self.transitions = nn.ModuleList([nn.Conv2d(channels[i], channels[i+1], 1, bias=not use_bn) for i in range(len(channels)-1) ]) | |
| self.global_avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.channel_proj = nn.Conv2d(in_channels=channels[-1], out_channels=2*latent_channels, kernel_size=1 ) # 1x1 conv | |
| self.act = act | |
| def forward(self, x): | |
| x = self.act()(self.bn1(self.conv1(x))) | |
| for i in range(len(self.levels)): | |
| if i > 0: # shrink down | |
| x = F.avg_pool2d(x, 2) | |
| x = self.transitions[i-1](x) | |
| for block in self.levels[i]: | |
| x = block(x) | |
| x = self.channel_proj(x) | |
| mean, logvar = x.chunk(2, dim=1) # mean and log variance | |
| return mean, logvar | |
| class ResNetVAEDecoderSpatial(nn.Module): | |
| """this is just the mirror image of ResNetVAEEnoderSpatial""" | |
| def __init__(self, out_channels, latent_channels=1, base_channels=32, blocks_per_level=4, use_skips=True, use_bn=True, act=nn.GELU): | |
| super().__init__() | |
| channels = [base_channels, base_channels*2, base_channels*4][::-1] # reversed from encoder | |
| self.channels = channels | |
| self.channel_proj = nn.Conv2d(in_channels=latent_channels, out_channels=channels[0], kernel_size=1 ) # 1x1 conv | |
| self.levels = nn.ModuleList([nn.ModuleList([ResidualBlock(ch, use_skips, use_bn, act=act) for _ in range(blocks_per_level)]) for ch in channels]) | |
| self.transitions = nn.ModuleList([ nn.Conv2d(channels[i], channels[i+1], 1, bias=not use_bn) for i in range(len(channels)-1)]) | |
| self.final_conv = nn.Conv2d(base_channels, out_channels, 3, padding=1) | |
| self.act = act | |
| def forward(self, z): | |
| x = self.channel_proj(z) | |
| for i in range(len(self.levels)): | |
| for block in self.levels[i]: | |
| x = block(x) | |
| if i < len(self.levels) - 1: # not last level | |
| x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) | |
| x = self.transitions[i](x) | |
| return self.final_conv(x) | |
| class ResNetVAE(nn.Module): | |
| """Main VAE class""" | |
| def __init__(self, | |
| data_channels=1, # 1 channel for MNIST, 3 for CFAR10, etc. | |
| latent_dim=3, # dimensionality of the latent space. bigger=less compression, better reconstruction | |
| act = nn.GELU, | |
| spatial=True, | |
| ): | |
| super().__init__() | |
| if spatial: | |
| self.encoder = ResNetVAEEncoderSpatial(data_channels, latent_channels=1, act=act) | |
| self.decoder = ResNetVAEDecoderSpatial(data_channels, latent_channels=1, act=act) | |
| else: | |
| self.encoder = ResNetVAEEncoder(data_channels, latent_dim=latent_dim, act=act) | |
| self.decoder = ResNetVAEDecoder(data_channels, latent_dim=latent_dim, act=act) | |
| def forward(self, x): | |
| mu, log_var = self.encoder(x) | |
| z = torch.cat([mu, log_var], dim=1) # this is unnecessary/redundant but our other Lesson code expects z | |
| z_hat = mu + torch.exp(0.5 * log_var) * torch.randn_like(mu) | |
| x_hat = self.decoder(z_hat) | |
| return z, x_hat, mu, log_var, z_hat | |
| @torch.no_grad() | |
| def test_inference(model, test_ds, idx=None, return_fig=False): | |
| device = next(model.parameters()).device | |
| model.eval() | |
| if idx is None: idx = torch.randint(len(test_ds), (1,))[0] | |
| if isinstance(idx, int): idx = [idx] | |
| elif isinstance(idx, range): idx = list(idx) | |
| x_batch = torch.stack([test_ds[i][0] for i in idx]).to(device) # images | |
| y_batch = torch.tensor([test_ds[i][1] for i in idx]).to(device) # labels | |
| result = model.forward(x_batch) | |
| z, recon, mu, log_var, z_hat= result[:2] | |
| recon = torch.sigmoid(recon.view(len(idx), 28, 28)) | |
| fig, axs = plt.subplots(2, len(idx), figsize=(3*len(idx), 4)) | |
| if len(idx) == 1: axs = axs.reshape(2, 1) | |
| for i in range(len(idx)): | |
| axs[0,i].imshow(x_batch[i].view(28,28).cpu(), cmap='gray') | |
| axs[1,i].imshow(recon[i].cpu(), cmap='gray') | |
| if i == 0: | |
| axs[0,0].set_ylabel('Input', fontsize=12) | |
| axs[1,0].set_ylabel('Reconstruction', fontsize=12) | |
| model.train() | |
| if return_fig: return fig | |
| plt.show() | |
| @torch.no_grad() | |
| def test_inference_spatial(model, test_ds, idx=None, return_fig=False): | |
| device = next(model.parameters()).device | |
| model.eval() | |
| if idx is None: idx = torch.randint(len(test_ds), (1,))[0] | |
| if isinstance(idx, int): idx = [idx] | |
| elif isinstance(idx, range): idx = list(idx) | |
| x_batch = torch.stack([test_ds[i][0] for i in idx]).to(device) # images | |
| y_batch = torch.tensor([test_ds[i][1] for i in idx]).to(device) # labels | |
| result = model.forward(x_batch) | |
| z, recon, mu, log_var, z_hat = result[:5] | |
| recon = torch.sigmoid(recon.view(len(idx), 28, 28)) | |
| # Normalize mu for visualization | |
| mu=mu.squeeze() | |
| mu_flat = mu.view(len(idx),-1) | |
| mu_min = mu_flat.min(dim=1, keepdim=True)[0] | |
| mu_max = mu_flat.max(dim=1, keepdim=True)[0] | |
| mu_norm = ((mu_flat - mu_min) / (mu_max - mu_min + 1e-8)).view_as(mu) | |
| fig, axs = plt.subplots(3, len(idx), figsize=(3*len(idx), 6)) | |
| if len(idx) == 1: axs = axs.reshape(3, 1) | |
| for i in range(len(idx)): | |
| axs[0,i].imshow(x_batch[i].view(28,28).cpu(), cmap='gray') | |
| axs[1,i].imshow(mu_norm[i].cpu(), cmap='viridis') # middle row: mu | |
| axs[2,i].imshow(recon[i].cpu(), cmap='gray') | |
| if i == 0: | |
| axs[0,0].set_ylabel('Input', fontsize=12) | |
| axs[1,0].set_ylabel('Latent z_μ', fontsize=12) | |
| axs[2,0].set_ylabel('Reconstruction', fontsize=12) | |
| model.train() | |
| if return_fig: return fig | |
| plt.show() | |
| def log_example_images(model, test_ds, epoch, spatial=True): | |
| if wandb.run is None: return | |
| if spatial: | |
| fig = test_inference_spatial(model, test_ds, idx=range(5), return_fig=True) | |
| else: | |
| fig = test_inference(model, test_ds, idx=range(5), return_fig=True) | |
| wandb.log({"reconstructions": wandb.Image(fig), "epoch": epoch}) | |
| plt.close(fig) # if you forget to close it, you'll end up with many open figs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment