Skip to content

Instantly share code, notes, and snippets.

@borgwang
Created July 21, 2017 18:25
Show Gist options
  • Save borgwang/46583b4773992156de67a3ab40511d34 to your computer and use it in GitHub Desktop.
Save borgwang/46583b4773992156de67a3ab40511d34 to your computer and use it in GitHub Desktop.
kmeans
import numpy as np
import matplotlib.pyplot as plt
data = [[1.658985, 4.285136], [-3.453687, 3.424321], [4.838138, -1.151539],
[-5.37971, -3.362104], [0.972564, 2.924086], [-3.567919, 1.531611],
[0.450614, -3.302219], [-3.487105, -1.724432], [2.668759, 1.594842],
[-3.15648, 3.191137], [3.165506, -3.999838], [-2.786837, -3.099354],
[4.208187, 2.984927], [-2.123337, 2.943366], [0.704199, -0.479481],
[-0.392370, -3.963704], [2.831667, 1.574018], [-0.790153, 3.343144],
[2.943496, -3.357075], [-3.195883, -2.283926], [2.336445, 2.875106],
[-1.786345, 2.554248], [2.190101, -1.906020], [-3.403367, -2.778288],
[1.778124, 3.880832], [-1.688346, 2.230267], [2.592976, -2.054368],
[-4.00725, -3.207066], [2.257734, 3.387564], [-2.679011, 0.785119],
[0.939512, -4.023563], [-3.674424, -2.261084], [2.046259, 2.735279],
[-3.189470, 1.780269], [4.372646, -0.822248], [-2.579316, -3.497576],
[1.889034, 5.190400], [-0.798747, 2.185588], [2.836520, -2.658556],
[-3.837877, -3.253815], [2.096701, 3.886007], [-2.709034, 2.923887],
[3.367037, -3.184789], [-2.121479, -4.232586], [2.329546, 3.179764],
[-3.284816, 3.273099], [3.091414, -3.815232], [-3.762093, -2.432191],
[3.542056, 2.778832], [-1.736822, 4.241041], [2.127073, -2.983680],
[-4.323818, -3.938116], [3.792121, 5.135768], [-4.786473, 3.358547],
[2.624081, -3.260715], [-4.009299, -2.978115], [2.493525, 1.963710],
[-2.513661, 2.642162], [1.864375, -3.176309], [-3.171184, -3.572452],
[2.894220, 2.489128], [-2.562539, 2.884438], [3.491078, -3.947487],
[-2.565729, -2.012114], [3.332948, 3.983102], [-1.616805, 3.573188],
[2.280615, -2.559444], [-2.651229, -3.103198], [2.321395, 3.154987],
[-1.685703, 2.939697], [3.031012, -3.620252], [-4.599622, -2.185829],
[4.196223, 1.126677], [-2.133863, 3.093686], [4.668892, -2.562705],
[-2.793241, -2.149706], [2.884105, 3.043438], [-2.967647, 2.848696],
[4.479332, -1.764772], [-4.905566, -2.911070]]
def cal_dist(d1, d2):
return ((d1[0] - d2[:,0])**2 + (d1[1] - d2[:,1])**2) ** 0.5
def re_classify(centers, data):
res = np.zeros(len(data))
for i, d in enumerate(data):
dists = cal_dist(d, centers)
r = np.argmin(dists)
res[i] = r
return res
def re_center(k, classfication, data):
centers = []
for i in range(k):
centers.append(np.mean(data[np.where(classfication == i)[0]], 0))
return np.array(centers)
data = np.array(data)
data_size = len(data)
k = 4
center_idxs = np.random.choice(range(data_size), size=k)
centers = data[center_idxs]
while True:
cls_res = re_classify(centers, data)
new_centers = re_center(k, cls_res, data)
if np.mean(new_centers - centers) < 1e-4:
break
centers = new_centers
class1 = data[np.where(cls_res == 0)[0]]
class2 = data[np.where(cls_res == 1)[0]]
class3 = data[np.where(cls_res == 2)[0]]
class4 = data[np.where(cls_res == 3)[0]]
plt.scatter(class1[:,0], class1[:,1], c='r')
plt.scatter(class2[:,0], class2[:,1], c='g')
plt.scatter(class3[:,0], class3[:,1], c='b')
plt.scatter(class4[:,0], class4[:,1], c='y')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment