"""
Example of k-NN classifier with K-fold CV for error, and Confusion Matrix

"""

import numpy as np
import matplotlib.pyplot as plt

from sklearn.neighbors import KNeighborsClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# import some data to play with
# The iris data has 150 points in R^4.  There are 3 classes (1, 2, 3)

iris=datasets.load_iris()
X=iris.data
y=iris.target

# Split the data into a training set and a test set- Default is 25%
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0, test_size=0.3)

# Use K-fold CV to get accuracy
# Initialize the classifier
knn=KNeighborsClassifier(n_neighbors=3)
# Approximate the error using 5-fold CV
scores=cross_val_score(knn,X_train,y_train,cv=5,scoring='accuracy')

# Now vary the number of nearest neighbors from 1 to 7 to see how error changes.
#  We'll keep the error check at 5-fold CV
k_range = range(1, 7)
k_scores = []
# use iteration to caclulator different k in models, then return the average accuracy based on the cross validation
for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    scores = cross_val_score(knn, X_train, y_train, cv=5, scoring='accuracy')
    k_scores.append(scores.mean())
# plot to see clearly
plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel('Cross-Validated Accuracy')
plt.show()

# Looks like k=5 is a good number;  Use the confusion matrix to display results.
classifier = KNeighborsClassifier(n_neighbors=5).fit(X_train, y_train)

# Put together and display the confusion matrix
y_pred=classifier.predict(X_test)
cm=confusion_matrix(y_test,y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap=plt.cm.Blues)
plt.title='Confusion Matrix'
plt.show()