Created
March 5, 2017 21:30
-
-
Save jmsword/f39fe298f33ab8fe2468e659570debc9 to your computer and use it in GitHub Desktop.
KNN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from random import randint | |
from sklearn.neighbors import NearestNeighbors | |
import math | |
import random | |
#Read in data | |
df = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None) | |
#Column names | |
cols = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'class'] | |
#Assign column names | |
df.columns = cols | |
#Scatter plot | |
x = df['sepal_length'] | |
y = df['sepal_width'] | |
plt.scatter(x, y) | |
plt.xlabel('Sepal Length') | |
plt.ylabel('Sepal Width') | |
plt.show() | |
#pick random point | |
random.seed() | |
pt = df.iloc[random.choice(df.index.tolist())] | |
pt['sepal_length'] | |
#determine distances from random point | |
def dist_from_pt(p): | |
return math.sqrt(((pt.sepal_length - p.sepal_length) ** 2) + ((pt.sepal_width - p.sepal_width) ** 2)) | |
#set distances as values in new column | |
df['dist_from_pt'] = df[['sepal_length', 'sepal_width']].apply(func=dist_from_pt, axis=1) | |
#sort values by distance column | |
df_sorted = df.sort_values(by='dist_from_pt', ascending=True) | |
#define knn function | |
def knn(k): | |
return df_sorted['class'][0:k].value_counts().index[0] | |
#print majority class based on number of neighbors inputted for k | |
print(knn(25)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment