Skip to content

Instantly share code, notes, and snippets.

@leconteur
Created March 3, 2017 01:41
Show Gist options
  • Save leconteur/3d6feeba2750807c959c7b4f42bdd902 to your computer and use it in GitHub Desktop.
Save leconteur/3d6feeba2750807c959c7b4f42bdd902 to your computer and use it in GitHub Desktop.
from sklearn.datasets import fetch_mldata
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
if __name__ == "__main__":
mnist = fetch_mldata('MNIST original')
X, y = mnist.data, mnist.target
#print(mnist.target_name)
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y)
plt.imshow(Xtrain[0].reshape((28, 28)))
plt.show()
pca = PCA(n_components=128, svd_solver="arpack")
pca.fit(Xtrain)
Xtrain = pca.transform(Xtrain)
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(Xtrain, ytrain)
Xtest = pca.transform(Xtest)
yhat_test = knn.predict(Xtest)
print(classification_report(ytest, yhat_test))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment