Implementing KNN in Python

K-nearest neighbors (KNN) is an algorithm which identifies the k nearest data points in a training sample to a new observation. Typically, nearest is defined by the Euclidian (or straight line) distance, however, other distance norms can be used.

Python is already home to several KNN implementations the most famous of which is the scikit-learn implementation. I still believe there is value in writing your own model implementations to learn more about how they work.

First lets break down what KNN is doing visually and then code up our own implementation. The visual below (built using D3.js) shows several points which are classified into the red and blue groups.

You can hover your mouse over this visual to develop an understanding of how the nearest three points impact the classification of the point.

We can identify the three (k = 3) closest points and determine of those, which classification is the most common. The most common classification becomes our predicted value.


A few notes before we jump into our own implementation. First, it is common to use an odd number for k when performing classification to avoid ties. Second, one downside of KNN when compared to other models is that KNN must be packaged with the training data to make predictions. This is different than linear regression which only requires the coefficients to be known at the time of prediction, for example.

Now let’s look at my implementation of KNN in Python. Only 8 lines of code (excluding function imports)! A safer version of this code may also include several assertion checks to ensure inputs are of the expected type and shape.

import numpy as np
import scipy as sci

def knn(new, train, labels, k=3, mode="c"):
    distances = np.sum((new - train) ** 2, axis=1)
    k_closest = distances.argsort()[:k]
    values = np.take(labels, k_closest)
    if mode == "c":
        return sci.stats.mode(values)[0][0]
    elif mode == "r":
        return np.mean(values)

Lets look at this function line by line. First, I define a function called knn which accepts a singular new observation called new, the training data called train with its associated labels (the correct prediction), and the mode which is either c for classification or r for regression.

def knn(new, train, labels, k=3, mode="c")

From there I compute how far each of the training points is from the new observation. To accurately compute the distances you would need to take the square root of this value. However, because we are only interested in the rank ordering of points, we can skip that step.

distances = np.sum((new - train) ** 2, axis=1)

Next I use argsort and take from numpy to rank order the indices by how close they are to the new observation. I use index slicing to grab the k nearest points. From there I use take to grab the values of the k closest indices from the label data.

k_closest = distances.argsort()[:k]
values = np.take(labels, k_closest)

Finally, I take the mode of the values for classification or the mean for regression. To predict over multiple observations I could pass the function into a list comprehension:

[knn(i, train, labels) for i in test]

This was a simple overview of KNN regression using basic numpy and scipy functions!