Plot the KNN Model in Python

KNN which is also known as the K-nearest Neighbour algorithm is a supervised machine learning algorithm that is used for classification purposes. The model is known as a lazy model as there is no specific formula on which the model trains. Each time we run the model, it starts calculating the distance again and again. Sometimes, it becomes very hard to see how actually the KNN algorithm classifies the dataset. In such cases, it is always a good idea to plot the KNN model to see the actual classification. Today, we are going to play with the dataset and visualize the KNN-trained model in Python using various modules like Matplotlib.

Check how to do hyperparameter tuning of the KNN model.

KNN Model in Python

KNN can be used as supervised or unsupervised in some cases. But it is known for classifying the output values. It uses different distance formulae to find the distance between the training datasets and the input data and then classify the data point in either of the groups based on the nearest neighbors. The reason for calling the KNN model a lazy model is that it finds the distance from all the training data points to the incoming data each time we run the model. So, basically, there is no saved trained model. It actually trained each time we run it.

The parameter K is very important in the KNN model as it helps to decide how many neighbors should the model classify the incoming data points. If we decided the K value to be 3, then the model will compare the incoming distance with the 3 smallest ones and classify based on majority voting. Usually, it is recommended to have the K value an odd number so that there will not be any conflict the in the voting.

Importing the Dataset

I assume at some point in your life, you have come across the iris dataset which is a very popular dataset. In this section, we will be using the iris dataset in order to plot the KNN model. The Iris dataset is a dataset that contains information about different flowers. We will use this dataset to create a model that will help us to classify a flower based on input values.

# importing the modules
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier

# Load the Iris dataset
iris = datasets.load_iris()
X = iris.data[:, :2]  
y = iris.target

Here the x and y are the input and the output values respectively. You can use various plots and preprocessing steps if you want to, in order to understand the dataset but we will move toward the training of the KNN model.

Training the KNN Model

As we know the KNN takes one mandatory parameter which is the value of K. For our model, we will assign 3 for the value of K. You can decide about it based on your own dataset.

# Create and fit the KNN model
k = 3 
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(X, y)

As you can see, our model has been trained on the dataset. We assume that you already have a trained model and want to plot the KNN model visually to see the classification process.

Plot the KNN Model

Now, it is time to plot the KNN model that we have just trained. Before plotting, we need to generate a mesh grid for the decision boundary and then we will plot our model over that decision boundary.

# Generate a mesh grid to plot the decision boundary
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
h = 0.02 
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

# Predict the class for each mesh point
Z = knn.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

Once the decision boundary is ready, we can then plot the KNN model on it.

# Plot the decision boundary and the data points
plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, Z, alpha=0.8)
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k')
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title(f'KNN (k={k}) Decision Boundary')
plt.show()
plot the knn model

As you can see the trained model has been plotted. The three different colors show the boundary for the classification region.

Plot the KNN Model With Circular Region

We can also plot the model and the incoming data point with different colors or circular regions. We will use the same model and the grid, we will just change the predicted class and the data point region:

# Generate a mesh grid to plot the decision boundary
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
h = 0.02
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

# Predict the class for each mesh point
Z = knn.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# Plot the decision boundary and the data points
plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, Z, alpha=0.8)

# Plot the data points
scatter = plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k')

# Plot the neighbors for an incoming data point
incoming_point = np.array([[6.0, 3.0]])  # Example incoming data point
neighbors = knn.kneighbors(incoming_point, n_neighbors=k, return_distance=False)
neighbor_points = X[neighbors.flatten()]
plt.scatter(neighbor_points[:, 0], neighbor_points[:, 1], s=200, facecolors='none', edgecolors='black')

# Plot the incoming data point
plt.scatter(incoming_point[:, 0], incoming_point[:, 1], s=200, c='red', marker='x')

# Add legend
legend_elements = scatter.legend_elements()[0]
legend_elements.append(plt.Line2D([0], [0], marker='x', color='red', lw=0, markersize=8, label='Incoming Point'))
plt.legend(legend_elements, iris.target_names + ['Incoming Point'])

plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title(f'KNN (k={k}) Decision Boundary with Incoming Point and Neighbors')
plt.show()
plot knn model

As shown the data points in the circular regions. These are the data points that are close to the incoming point. The incoming data point is represented by the red cross.

Summary

KNN algorithm is a well-known classification algorithm that is very popular for small datasets. It uses the simple logic of the distance formula and classifies the input dataset. In this post, we came across how to plot the KNN model in Python using different approaches.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top