Created
April 1, 2025 09:37
-
-
Save maweigert/0aa6f1193ad3246050316035cea0e581 to your computer and use it in GitHub Desktop.
MLP decision boundary viz
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Plots the decision boundary and the loss as a function of training time/epochs | |
for perceptron (no hidden layer) or single hidden layer MLP for difefrent datasets ('moons', 'blobs', 'xor') | |
requires: | |
pip install torch torchvision torchaudio matplotlib tqdm seaborn moviepy scipy scikit-learn | |
run like: | |
python mlp_viz.py --dataset moons --hidden 128 --outdir out/ | |
(set --hidden to 0 for no hidden layer aka perceptron) | |
""" | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
from pathlib import Path | |
from sklearn.datasets import make_moons, make_blobs | |
from scipy.stats import qmc | |
from scipy.spatial import distance_matrix, KDTree | |
import networkx as nx | |
import argparse | |
import moviepy | |
import seaborn as sns | |
device = torch.device("cuda" if torch.cuda.is_available() else "mps") | |
def fig_to_array_rgb(fig): | |
fig.canvas.draw() | |
width, height = fig.get_size_inches() * fig.get_dpi() # Get actual pixel dimensions | |
width, height = int(width), int(height) | |
img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8).reshape((height, width, 4)) | |
return img[...,:3] | |
def hexagonal_grid(rows, r=1.0): | |
cols = 3*(rows//4)+2 | |
dx = 2 * r # Horizontal distance between centers | |
dy = np.sqrt(3) * r # Vertical distance between centers | |
coords = [] | |
for row in range(rows): | |
for col in range(cols if row % 2 == 0 else cols-1): | |
x = col * dx + (row % 2) * (dx / 2) # Offset every other row | |
y = row * dy | |
coords.append((x, y)) | |
labels = np.arange(len(coords))%3 | |
coords = np.array(coords) | |
coords -= coords.mean(0) | |
coords /= coords.max(0) | |
return coords, labels | |
def label_points(x:np.ndarray, k=5): | |
# Step 1: Construct k-NN Graph with distance-based weights | |
dist_mat = distance_matrix(x, x) | |
tree = KDTree(x) | |
G = nx.Graph() | |
for i, (dd, idx) in enumerate(zip(*tree.query(x,k=k))): | |
for d, j in zip(dd[1:], idx[1:]): | |
G.add_edge(i, j, weight=d) | |
curr_cut_size, partition = nx.approximation.one_exchange(G, seed=1) | |
labels = np.zeros(len(x), dtype=int) | |
labels[np.array(list(partition[0]))] = 1 | |
return labels | |
def sample_points(n:int): | |
x = qmc.PoissonDisk(d=2, radius=1/1.3/n).random(n**2) | |
return x | |
# Define the MLP | |
class SimpleNet(nn.Module): | |
def __init__(self, n_hidden=256, activation='relu', num_classes:int=3): | |
super().__init__() | |
if activation == 'relu': | |
act = nn.ReLU | |
elif activation == 'gelu': | |
act = nn.GELU | |
elif activation == 'elu': | |
act = nn.ELU | |
elif activation == 'leaky_relu': | |
act = nn.LeakyReLU | |
elif activation == 'tanh': | |
act = nn.Tanh | |
else: | |
raise ValueError(f"Activation function {activation} not supported") | |
self.num_classes = num_classes | |
if n_hidden==0: | |
self.mlp = nn.Sequential( | |
nn.Linear(2, num_classes) | |
) | |
else: | |
self.mlp = nn.Sequential( | |
nn.Linear(2, n_hidden), | |
act(), | |
# nn.Linear(n_hidden, n_hidden), | |
# act(), | |
nn.Linear(n_hidden, num_classes) | |
) | |
def forward(self, x): | |
x = self.mlp(x) | |
return x | |
def decision_boundary(model, n_grid:int=100): | |
x_min, x_max = -.1,1.1 | |
y_min, y_max = -.1,1.1 | |
yy, xx = np.meshgrid(np.linspace(x_min, x_max, n_grid), | |
np.linspace(y_min, y_max, n_grid)) | |
grid_points = torch.from_numpy(np.c_[yy.ravel(), xx.ravel()]).float().to(device) | |
grid_points.requires_grad = True | |
model.eval() | |
model.zero_grad() | |
logits = model(grid_points) | |
decision = logits[...,0]-logits[...,1] | |
jaco = torch.autograd.functional.jacobian(lambda x: model(x).mean(dim=0), grid_points) | |
grad = np.linalg.norm(jaco.detach().cpu().numpy(), axis=-1) | |
grad = grad.reshape((2,)+yy.shape) | |
decision = decision.detach().cpu().numpy().reshape(yy.shape) | |
return yy, xx, decision, grad | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-d", "--dataset", type=str, default='moons') | |
parser.add_argument("-n", "--num_epochs", type=int, default=800) | |
parser.add_argument("-p", "--num_points", type=int, default=100) | |
parser.add_argument("--hidden", type=int, default=256) | |
parser.add_argument("--activation", type=str, default='relu') | |
parser.add_argument("-o", "--outdir", type=str, default=None) | |
parser.add_argument("--lr", type=float, default=0.005) | |
parser.add_argument("--seed", type=int, default=42) | |
args = parser.parse_args() | |
# Set random seed for reproducibility | |
torch.manual_seed(args.seed) | |
np.random.seed(args.seed) | |
if args.dataset == 'moons': | |
X, Y = make_moons(n_samples=args.num_points, noise=0.2) | |
X = X-X.min(0, keepdims=True) | |
X /= X.max(0, keepdims=True) | |
elif args.dataset == 'blobs': | |
X, Y = make_blobs(n_samples=args.num_points, n_features=2, centers=2, cluster_std=0.5) | |
X = X-X.min(0, keepdims=True) | |
X /= X.max(0, keepdims=True) | |
elif args.dataset == 'xor': | |
X, Y = make_blobs(n_samples=args.num_points, n_features=2, centers=np.array([[1,0],[0,1],[0,0],[1,1]]), cluster_std=0.1) | |
Y = Y//2 | |
X = X-X.min(0, keepdims=True) | |
X /= X.max(0, keepdims=True) | |
else: | |
raise ValueError(f"Dataset {args.dataset} not supported") | |
name = f'{args.dataset}_{args.activation}_{args.hidden}' | |
if args.outdir is not None: | |
args.outdir = Path(args.outdir) | |
args.outdir.mkdir(parents=True, exist_ok=True) | |
Xb = torch.from_numpy(X).float().unsqueeze(0).repeat(128,1,1) | |
Yb = torch.from_numpy(Y).long().unsqueeze(0).repeat(128,1) | |
Xb, Yb = Xb.to(device), Yb.to(device) | |
# Training setup | |
model = SimpleNet(args.hidden, args.activation, num_classes=2).to(device) | |
num_params = sum(p.numel() for p in model.parameters()) | |
print(f'Number of parameters: {num_params}') | |
criterion = nn.CrossEntropyLoss() | |
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) | |
plt.ion() | |
fig, axs = plt.subplots(1,2, figsize=(9,3.6), num=1, clear=True) | |
# Training loop | |
tq = tqdm(range(args.num_epochs)) | |
losses = [] | |
imgs = [] | |
for epoch in tq: | |
optimizer.zero_grad() | |
u = model(Xb) | |
u = u.permute(0, 2, 1) | |
loss = criterion(u, Yb) | |
if epoch>0: | |
loss.backward() | |
optimizer.step() | |
tq.set_postfix(loss=loss.item()) | |
U = np.moveaxis(u[0].detach().cpu().numpy(), 0, -1) | |
losses.append(loss.item()) | |
if epoch % 10 == 0: | |
yy, xx, dec, grad = decision_boundary(model, 300) | |
grad = grad[0]-grad[1] | |
grad_facet = (1234556*grad)%1 | |
for i, ax in enumerate(axs): | |
ax.clear() | |
axs[0].pcolormesh(xx, yy, dec, cmap='coolwarm', shading='auto', clim=(-5,5)) | |
bound = np.exp(-10*np.abs(dec)) | |
bound = bound[...,None] * np.ones((1,1,4)) | |
bound[:,:,:3] = 1-bound[:,:,:3] | |
axs[0].pcolormesh(xx, yy, bound, shading='auto') | |
axs[0].set_title(f'Logit difference (Epoch {epoch})', fontsize=8) | |
axs[1].plot(losses, color='C1') | |
axs[1].set_title(f'Loss', fontsize=8) | |
axs[1].set_xlim(0,args.num_epochs) | |
axs[1].set_ylim(0,1) | |
axs[1].set_xlabel('Epoch') | |
axs[1].set_ylabel('Loss') | |
sns.despine(ax=axs[1]) | |
for i,ax in enumerate(axs[:1]): | |
kwargs = dict(alpha=1., marker='.') if i==0 else dict(alpha=.4, marker='.') | |
ax.scatter(*X[Y == 0].T[::-1], c='C3', **kwargs, label='Class 0') | |
ax.scatter(*X[Y == 1].T[::-1], c='C0', **kwargs, label='Class 1') | |
axs[0].legend(frameon=False, loc=(-.6,.5)) | |
if epoch==0: | |
fig.tight_layout() | |
imgs.append(fig_to_array_rgb(fig).copy()) | |
if args.outdir is not None: | |
clip = moviepy.ImageSequenceClip(list(imgs), fps=10) | |
clip.write_videofile(args.outdir / f'{name}.mp4', audio=False) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment