Created
June 1, 2018 21:25
-
-
Save blackheaven/16ca2e2d7f0d88e6801a63e5276e6d31 to your computer and use it in GitHub Desktop.
Coding Dojo (18-05-30) on K-means
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Kata | |
( kmeans | |
, randV | |
, prettifier | |
) where | |
import Data.List(groupBy, sortBy, minimumBy, unfoldr) | |
import Data.Function(on) | |
import System.Random(newStdGen, randomRs) | |
type Vector = (Float, Float) | |
type Cluster = [Vector] | |
type Distance = Vector -> Vector -> Float | |
type UnitOfWork = [Cluster] | |
kmeans :: Int -> [Vector] -> [Cluster] | |
kmeans nbClusters vectors = head (drop 7 (iterate round' (initVector nbClusters vectors))) | |
type Mean = Cluster -> Vector | |
means :: Mean | |
means cluster = meanCluster $ foldr1 accumulate cluster | |
where accumulate (accX, accY) (currentVectorX, currentVectorY) = (accX + currentVectorX, accY + currentVectorY) | |
meanCluster (accX, accY) = (accX / len, accY / len) | |
len = fromInteger $ toInteger $ length cluster | |
distance :: Distance | |
distance (vectorX1, vectorY1) (vectorX2, vectorY2) = sqrt $ (vectorX1-vectorX2)**2 + (vectorY1-vectorY2)**2 | |
initVector :: Int -> [Vector] -> UnitOfWork | |
initVector nbClusters vectors = groupCluster $ zip clusterIndexes vectors | |
where clusterIndexes = cycle $ [1..nbClusters] | |
type Round = UnitOfWork -> UnitOfWork | |
round' :: Round | |
round' previousRound = groupCluster $ map addNearest everyVectors | |
where clusterMeans = map means previousRound | |
everyVectors = concat previousRound | |
addNearest vector = (minimumBy (\min1 min2 -> compare (distance vector min1) (distance vector min2)) clusterMeans, vector) | |
groupCluster :: (Eq key, Ord key) => [(key, vector)] -> [[vector]] | |
groupCluster indexedVectors = map (map snd) $ groupBy ((==) `on` fst) $ sortBy (compare `on` fst) indexedVectors | |
randV :: Float -> Float -> IO [(Float, Float)] | |
randV ll lh = newStdGen >>= return . unfoldr (\(x:y:zs) -> Just ((x, y), zs)) . randomRs (ll, lh) | |
prettifier :: Cluster -> String | |
prettifier xs = unlines $ [replicate 15 '-'] ++ map show xs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment