Last active
March 10, 2026 13:14
-
-
Save larsr/a80c59942c769129f819a2e0b5546ca3 to your computer and use it in GitHub Desktop.
Self Organizing Map SOM with digits example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env -S uvx --with scikit-learn,matplotlib python | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn import datasets | |
| from sklearn.preprocessing import scale | |
| # --- minimal SOM --- | |
| def train_som(data, gx, gy, n_iter=5000, lr=0.5, sigma=4.0): | |
| data = np.array(data).astype(float) | |
| T = n_iter | |
| W = data[np.random.choice(len(data), (gx, gy))] | |
| W /= np.linalg.norm(W, axis=-1, keepdims=True) | |
| xx, yy = np.meshgrid(np.arange(gx), np.arange(gy), indexing='ij') | |
| for t in range(T): | |
| x = data[np.random.randint(len(data))] | |
| s = sigma / (1 + t / (T / 2)) | |
| l = lr / (1 + t / (T / 2)) | |
| bx, by = np.unravel_index(np.linalg.norm(W - x, axis=-1).argmin(), (gx, gy)) | |
| W += l * np.exp(-((xx-bx)**2 + (yy-by)**2) / (2*s**2))[..., None] * (x - W) | |
| return W, lambda x: np.unravel_index(np.linalg.norm(W - x, axis=-1).argmin(), (gx, gy)) | |
| # --- data --- | |
| digits = datasets.load_digits(n_class=10) | |
| data = scale(digits.data) | |
| num = digits.target | |
| # --- train --- | |
| W, winner = train_som(data, 50, 50, n_iter=5000, lr=0.5, sigma=4.0) | |
| # --- plot --- | |
| plt.figure(figsize=(8, 8)) | |
| for x, t in zip(data, num): | |
| w = winner(x) | |
| plt.text(w[0]+.5, w[1]+.5, str(t), | |
| color=plt.cm.rainbow(t / 10.), | |
| fontdict={'weight': 'bold', 'size': 11}) | |
| plt.axis([0, 50, 0, 50]) | |
| plt.show() |
Author
larsr
commented
Mar 10, 2026
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment