Skip to content

Instantly share code, notes, and snippets.

@maweigert
Created April 1, 2025 09:37
Show Gist options
  • Save maweigert/0aa6f1193ad3246050316035cea0e581 to your computer and use it in GitHub Desktop.
Save maweigert/0aa6f1193ad3246050316035cea0e581 to your computer and use it in GitHub Desktop.
MLP decision boundary viz
"""
Plots the decision boundary and the loss as a function of training time/epochs
for perceptron (no hidden layer) or single hidden layer MLP for difefrent datasets ('moons', 'blobs', 'xor')
requires:
pip install torch torchvision torchaudio matplotlib tqdm seaborn moviepy scipy scikit-learn
run like:
python mlp_viz.py --dataset moons --hidden 128 --outdir out/
(set --hidden to 0 for no hidden layer aka perceptron)
"""
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
from sklearn.datasets import make_moons, make_blobs
from scipy.stats import qmc
from scipy.spatial import distance_matrix, KDTree
import networkx as nx
import argparse
import moviepy
import seaborn as sns
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
def fig_to_array_rgb(fig):
fig.canvas.draw()
width, height = fig.get_size_inches() * fig.get_dpi() # Get actual pixel dimensions
width, height = int(width), int(height)
img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8).reshape((height, width, 4))
return img[...,:3]
def hexagonal_grid(rows, r=1.0):
cols = 3*(rows//4)+2
dx = 2 * r # Horizontal distance between centers
dy = np.sqrt(3) * r # Vertical distance between centers
coords = []
for row in range(rows):
for col in range(cols if row % 2 == 0 else cols-1):
x = col * dx + (row % 2) * (dx / 2) # Offset every other row
y = row * dy
coords.append((x, y))
labels = np.arange(len(coords))%3
coords = np.array(coords)
coords -= coords.mean(0)
coords /= coords.max(0)
return coords, labels
def label_points(x:np.ndarray, k=5):
# Step 1: Construct k-NN Graph with distance-based weights
dist_mat = distance_matrix(x, x)
tree = KDTree(x)
G = nx.Graph()
for i, (dd, idx) in enumerate(zip(*tree.query(x,k=k))):
for d, j in zip(dd[1:], idx[1:]):
G.add_edge(i, j, weight=d)
curr_cut_size, partition = nx.approximation.one_exchange(G, seed=1)
labels = np.zeros(len(x), dtype=int)
labels[np.array(list(partition[0]))] = 1
return labels
def sample_points(n:int):
x = qmc.PoissonDisk(d=2, radius=1/1.3/n).random(n**2)
return x
# Define the MLP
class SimpleNet(nn.Module):
def __init__(self, n_hidden=256, activation='relu', num_classes:int=3):
super().__init__()
if activation == 'relu':
act = nn.ReLU
elif activation == 'gelu':
act = nn.GELU
elif activation == 'elu':
act = nn.ELU
elif activation == 'leaky_relu':
act = nn.LeakyReLU
elif activation == 'tanh':
act = nn.Tanh
else:
raise ValueError(f"Activation function {activation} not supported")
self.num_classes = num_classes
if n_hidden==0:
self.mlp = nn.Sequential(
nn.Linear(2, num_classes)
)
else:
self.mlp = nn.Sequential(
nn.Linear(2, n_hidden),
act(),
# nn.Linear(n_hidden, n_hidden),
# act(),
nn.Linear(n_hidden, num_classes)
)
def forward(self, x):
x = self.mlp(x)
return x
def decision_boundary(model, n_grid:int=100):
x_min, x_max = -.1,1.1
y_min, y_max = -.1,1.1
yy, xx = np.meshgrid(np.linspace(x_min, x_max, n_grid),
np.linspace(y_min, y_max, n_grid))
grid_points = torch.from_numpy(np.c_[yy.ravel(), xx.ravel()]).float().to(device)
grid_points.requires_grad = True
model.eval()
model.zero_grad()
logits = model(grid_points)
decision = logits[...,0]-logits[...,1]
jaco = torch.autograd.functional.jacobian(lambda x: model(x).mean(dim=0), grid_points)
grad = np.linalg.norm(jaco.detach().cpu().numpy(), axis=-1)
grad = grad.reshape((2,)+yy.shape)
decision = decision.detach().cpu().numpy().reshape(yy.shape)
return yy, xx, decision, grad
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", type=str, default='moons')
parser.add_argument("-n", "--num_epochs", type=int, default=800)
parser.add_argument("-p", "--num_points", type=int, default=100)
parser.add_argument("--hidden", type=int, default=256)
parser.add_argument("--activation", type=str, default='relu')
parser.add_argument("-o", "--outdir", type=str, default=None)
parser.add_argument("--lr", type=float, default=0.005)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# Set random seed for reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if args.dataset == 'moons':
X, Y = make_moons(n_samples=args.num_points, noise=0.2)
X = X-X.min(0, keepdims=True)
X /= X.max(0, keepdims=True)
elif args.dataset == 'blobs':
X, Y = make_blobs(n_samples=args.num_points, n_features=2, centers=2, cluster_std=0.5)
X = X-X.min(0, keepdims=True)
X /= X.max(0, keepdims=True)
elif args.dataset == 'xor':
X, Y = make_blobs(n_samples=args.num_points, n_features=2, centers=np.array([[1,0],[0,1],[0,0],[1,1]]), cluster_std=0.1)
Y = Y//2
X = X-X.min(0, keepdims=True)
X /= X.max(0, keepdims=True)
else:
raise ValueError(f"Dataset {args.dataset} not supported")
name = f'{args.dataset}_{args.activation}_{args.hidden}'
if args.outdir is not None:
args.outdir = Path(args.outdir)
args.outdir.mkdir(parents=True, exist_ok=True)
Xb = torch.from_numpy(X).float().unsqueeze(0).repeat(128,1,1)
Yb = torch.from_numpy(Y).long().unsqueeze(0).repeat(128,1)
Xb, Yb = Xb.to(device), Yb.to(device)
# Training setup
model = SimpleNet(args.hidden, args.activation, num_classes=2).to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f'Number of parameters: {num_params}')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
plt.ion()
fig, axs = plt.subplots(1,2, figsize=(9,3.6), num=1, clear=True)
# Training loop
tq = tqdm(range(args.num_epochs))
losses = []
imgs = []
for epoch in tq:
optimizer.zero_grad()
u = model(Xb)
u = u.permute(0, 2, 1)
loss = criterion(u, Yb)
if epoch>0:
loss.backward()
optimizer.step()
tq.set_postfix(loss=loss.item())
U = np.moveaxis(u[0].detach().cpu().numpy(), 0, -1)
losses.append(loss.item())
if epoch % 10 == 0:
yy, xx, dec, grad = decision_boundary(model, 300)
grad = grad[0]-grad[1]
grad_facet = (1234556*grad)%1
for i, ax in enumerate(axs):
ax.clear()
axs[0].pcolormesh(xx, yy, dec, cmap='coolwarm', shading='auto', clim=(-5,5))
bound = np.exp(-10*np.abs(dec))
bound = bound[...,None] * np.ones((1,1,4))
bound[:,:,:3] = 1-bound[:,:,:3]
axs[0].pcolormesh(xx, yy, bound, shading='auto')
axs[0].set_title(f'Logit difference (Epoch {epoch})', fontsize=8)
axs[1].plot(losses, color='C1')
axs[1].set_title(f'Loss', fontsize=8)
axs[1].set_xlim(0,args.num_epochs)
axs[1].set_ylim(0,1)
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('Loss')
sns.despine(ax=axs[1])
for i,ax in enumerate(axs[:1]):
kwargs = dict(alpha=1., marker='.') if i==0 else dict(alpha=.4, marker='.')
ax.scatter(*X[Y == 0].T[::-1], c='C3', **kwargs, label='Class 0')
ax.scatter(*X[Y == 1].T[::-1], c='C0', **kwargs, label='Class 1')
axs[0].legend(frameon=False, loc=(-.6,.5))
if epoch==0:
fig.tight_layout()
imgs.append(fig_to_array_rgb(fig).copy())
if args.outdir is not None:
clip = moviepy.ImageSequenceClip(list(imgs), fps=10)
clip.write_videofile(args.outdir / f'{name}.mp4', audio=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment