Created
October 7, 2019 14:45
-
-
Save Nastaliss/12d3274944dfa6382905eb58e3adb223 to your computer and use it in GitHub Desktop.
This is an example of a simple implementation of the k means algorithm. Please tweak the constatns as you like :)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
from math import ceil, sqrt | |
from random import randint | |
import numpy as np | |
from matplotlib import pyplot as plt | |
class Point(object): | |
def __init__(self, x, y): | |
self.x = x | |
self.y = y | |
def __str__(self): | |
return 'Coordinates of this point: x: {}, y: {}'.format(self.x, self.y) | |
def __repr__(self): | |
return self.__str__() | |
def get_distance_from(self, point): | |
return sqrt(pow(point.x - self.x, 2) + pow(point.y - self.y, 2)) | |
def generate_random_points(points_count, min, max): | |
return [Point(randint(min, max), randint(min, max)) for _ in range(0, points_count)] | |
# ─── CODE ─────────────────────────────────────────────────────────────────────── | |
# CONSTANTS DEFINITIONS # | |
POINTS_COUNT = 100 | |
INITIAL_BARYCENTERS_COUNT = 3 | |
BARYCENTER_COLORS = ['#ff0000', '#00ff00', '#0000ff'] | |
MIN_X_Y = 0 | |
MAX_X_Y = 1000 | |
MAX_STEPS = 4 | |
# END OF CONSTANTS DEFINITIONS # | |
points = generate_random_points(POINTS_COUNT, MIN_X_Y, MAX_X_Y) | |
initial_barycenters = generate_random_points( | |
INITIAL_BARYCENTERS_COUNT, MIN_X_Y, MAX_X_Y) | |
barycenters = deepcopy(initial_barycenters) | |
clusters = [[] for cluster in range(INITIAL_BARYCENTERS_COUNT)] | |
# ─── K-MEANS ──────────────────────────────────────────────────────────────────── | |
for step in range(MAX_STEPS): | |
# ─── CLUSTERIZATION ───────────────────────────────────────────────────────────── | |
for point_index, point in enumerate(points): | |
closest_index = 0 | |
for barycenter_index, barycenter in enumerate(barycenters): | |
if point.get_distance_from(barycenter) < point.get_distance_from(barycenters[closest_index]): | |
closest_index = barycenter_index | |
clusters[closest_index].append(point) | |
# ─── DISPLAY ──────────────────────────────────────────────────────────────────── | |
plt.subplot(ceil(MAX_STEPS/2), 2, step+1) | |
plt.axis('equal') | |
plt.title('Step {}'.format(step)) | |
plt.scatter([point.x for point in barycenters], [ | |
point.y for point in barycenters], c=BARYCENTER_COLORS, edgecolors='#000000') | |
for cluster_index, cluster in enumerate(clusters): | |
plt.scatter([point.x for point in cluster], [ | |
point.y for point in cluster], c=BARYCENTER_COLORS[cluster_index]) | |
# ─── NEW BARYCENTERS CALCULATION ──────────────────────────────────────────────── | |
for cluster_index, cluster in enumerate(clusters): | |
barycenters[cluster_index] = Point(sum([point.x for point in cluster]) / len(cluster), | |
sum([point.y for point in cluster]) / len(cluster)) | |
# Here we go again | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment