Skip to content

Instantly share code, notes, and snippets.

@larsr
Last active March 10, 2026 13:14
Show Gist options
  • Select an option

  • Save larsr/a80c59942c769129f819a2e0b5546ca3 to your computer and use it in GitHub Desktop.

Select an option

Save larsr/a80c59942c769129f819a2e0b5546ca3 to your computer and use it in GitHub Desktop.
Self Organizing Map SOM with digits example
#!/usr/bin/env -S uvx --with scikit-learn,matplotlib python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.preprocessing import scale
# --- minimal SOM ---
def train_som(data, gx, gy, n_iter=5000, lr=0.5, sigma=4.0):
data = np.array(data).astype(float)
T = n_iter
W = data[np.random.choice(len(data), (gx, gy))]
W /= np.linalg.norm(W, axis=-1, keepdims=True)
xx, yy = np.meshgrid(np.arange(gx), np.arange(gy), indexing='ij')
for t in range(T):
x = data[np.random.randint(len(data))]
s = sigma / (1 + t / (T / 2))
l = lr / (1 + t / (T / 2))
bx, by = np.unravel_index(np.linalg.norm(W - x, axis=-1).argmin(), (gx, gy))
W += l * np.exp(-((xx-bx)**2 + (yy-by)**2) / (2*s**2))[..., None] * (x - W)
return W, lambda x: np.unravel_index(np.linalg.norm(W - x, axis=-1).argmin(), (gx, gy))
# --- data ---
digits = datasets.load_digits(n_class=10)
data = scale(digits.data)
num = digits.target
# --- train ---
W, winner = train_som(data, 50, 50, n_iter=5000, lr=0.5, sigma=4.0)
# --- plot ---
plt.figure(figsize=(8, 8))
for x, t in zip(data, num):
w = winner(x)
plt.text(w[0]+.5, w[1]+.5, str(t),
color=plt.cm.rainbow(t / 10.),
fontdict={'weight': 'bold', 'size': 11})
plt.axis([0, 50, 0, 50])
plt.show()
@larsr
Copy link
Author

larsr commented Mar 10, 2026

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment