How does KNN algorithm work?

The letter “K” in KNN means the number of neighbors around the test sample. During prediction it searches for the nearest neighbors and takes their majority vote as the class predicted for the sample.

flowchart LR
	id1(For a sample) --> id2(Find its k nearest neighbor) --> id3(Take majority vote as sample class predicted for the sample)

Find neighbor: distance/similarity metric (some norms)

Minikowski distance

$$ D(X, Y)=\left(\sum_{i=1}^{n}\left|x_{i}-y_{i}\right|^{p}\right)^{\frac{1}{p}} $$

(Norm $l_0$)

Matthan distrance (Norm $l_1$, $p=1$)

$$ D(X, Y)=\left(\sum_{i=1}^{n}\left|x_{i}-y_{i}\right|\right) $$

Euclidean distance (Norm $l_2$, $p=2$)

$$ D(X, Y)=\left(\sum_{i=1}^{n}\left|x_{i}-y_{i}\right|^{2}\right)^{\frac{1}{2}} $$

L-infinity distance (Norm $l_\infty$)

$$ D(X, Y)=\text{max}\left|x_{i}-y_{i}\right| $$

L-negative-infinity distance (Norm $l_\infty$)

$$ D(X, Y)=\text{min}\left|x_{i}-y_{i}\right| $$

Feature scaling - Normalization

Why Normalization?

To avoid bias towards variables with higher magnitude

Standard score

$$ f_{i} \leftarrow \frac{f_{i}-\mu_{i}}{\sigma_{i}} $$

  • Represents the feature value in terms of $\sigma$ units from mean

  • Works well for populations that are normally distributed

Min-max feature scaling

$$ f_{i} \leftarrow \frac{f_{i}-f_{\min }}{f_{\max }-f_{\min }} $$

  • Set all feature values within [0,1] range

Three KNN algorithms: Brute force, Ball tree, and k-d tree

Brute force method

Training time complexity: $O(1)$

Training space complexity: $O(1)$

Prediction time complexity: $O(knd)$

​ each sample, calculate d times to get distance, then we have n sample, thus n*d, finally we have k samples need to be found, hence $knd$.

Prediction space complexity: $O(1)$

Training phase technically does not exist, since all computation is done during prediction, so we have O(1) for both time and space.

Prediction phase is, as method name suggest, a simple exhaustive search, which in pseudocode is:

Loop through all points k times:

1. Compute the distance between currently classifier sample and training points, remember the index of the element with the smallest distance (ignore previously selected points)
2. Add the class at found index to the counter 
Return the class with the most votes as a prediction

This is a nested loop structure, where the outer loop takes k steps and the inner loop takes n steps. 3rd point is $O(1$) and 4th is $O(\text{number of classes})$, so they are smaller. Additionally, we have to take into consideration the numer of dimensions d, more directions mean longer vectors to compute distances. Therefore, we have $O(n * k * d)$ time complexity.

As for space complexity, we need a small vector to count the votes for each class. It’s almost always very small and is fixed, so we can treat it as a O(1) space complexity.

Ball tree method

Training time complexity: $O(d * n * log(n))$

Training space complexity: $O(d * n)$

Prediction time complexity: $O(k * log(n))$

Prediction space complexity: $O(1)$

Ball tree algorithm takes another approach to dividing space where training points lie. In contrast to k-d trees, which divides space with median value “cuts”, ball tree groups points into “balls” organized into a tree structure. They go from the largest (root, with all points) to the smallest (leaves, with only a few or even 1 point). It allows fast nearest neighbor lookup because nearby neighbors are in the same or at least close “balls”.

During the training phase, we only need to construct the ball tree. There are a few algorithms for constructing the ball tree, but the one most similar to k-d tree (called “k-d construction algorithm” for that reason) is $O(d * n * log(n))$, the same as k-d tree.

Because of the tree building similarity, the complexities of the prediction phase are also the same as for k-d tree.

k-d tree method

Check a demonstration video here.

Training time complexity: $O(d * n * log(n))$

Training space complexity: $O(d * n)$

Prediction time complexity: $O(k * log(n))$

Prediction space complexity: $O(1)$

During the training phase, we have to construct the k-d tree. This data structure splits the k-dimensional space (here k means k dimensions of space, don’t confuse this with k as a number of nearest neighbors!) and allows faster search for nearest points, since we “know where to look” in that space. You may think of it like a generalization of BST for many dimensions. It “cuts” space with axis-aligned cuts, dividing points into groups in children nodes.

Constructing the k-d tree is not a machine learning task itself, since it stems from computational geometry domain, so we won’t cover this in detail, only on conceptual level. The time complexity is usually $O(d * n * log(n))$, because insertion is $O(log(n))$ (similar to regular BST) and we have n points from the training dataset, each with d dimensions. I assume the efficient implementation of the data structure, i. e. it finds the optimal split point (median in the dimension) in O(n), which is possible with the median of medians algorithm. Space complexity is $O(d * n)$ — note that it depends on dimensionality d, which makes sense, since more dimensions correspond to more space divisions and larger trees (in addition to larger time complexity for the same reason).

As for the prediction phase, the k-d tree structure naturally supports “k nearest point neighbors query” operation, which is exactly what we need for kNN. The simple approach is to just query k times, removing the point found each time — since query takes $O(log(n))$, it is $O(k * log(n))$ in total. But since the k-d tree already cuts space during construction, after a single query we approximately know where to look — we can just search the “surroundings” around that point. Therefore, practical implementations of k-d tree support querying for whole k neighbors at one time and with complexity $O(sqrt(n) + k)$, which is much better for larger dimensionalities, which are very common in machine learning.

The above complexities are the average ones, assuming the balanced k-d tree. The O(log(n)) times assumed above may degrade up to $O(n)$ for unbalanced trees, but if the median is used during the tree construction, we should always get a tree with approximately $O(log(n))$ insertion/deletion/search complexity.

Choosing the method in practice

To summarize the complexities: brute force is the slowest in the big O notation, while both k-d tree and ball tree have the same lower complexity. How do we know which one to use then?

To get the answer, we have to look at both training and prediction times, that’s why I have provided both. The brute force algorithm has only one complexity, for prediction, $O(k * n)$. Other algorithms need to create the data structure first, so for training and prediction they get $O(d * n * log(n) + k * log(n))$, not taking into account the space complexity, which may also be important. Therefore, where the construction of the trees is frequent, the training phase may outweigh their advantage of faster nearest neighbor lookup.

Should we use k-d tree or ball tree? It depends on the data structure — relatively uniform or “well behaved” data will make better use of k-d tree, since the cuts of space will work well (near points will be close in the leaves after all cuts). For more clustered data the “balls” from the ball tree will reflect the structure better and therefore allow for faster nearest neighbor search. Fortunately, Scikit-learn supports “auto” option, which will automatically infer the best data structure from the data.

Let’s see this in practice on two case studies, which I’ve encountered in practice during my studies and job.

Case study 1: classification

The more “traditional” application of the kNN is the classification of data. It often has quite a lot of points, e. g. MNIST has 60k training images and 10k test images. Classification is done offline, which means we first do the train ing phase, then just use the results during prediction. Therefore, if we want to construct the data structure, we only need to do so once. For 10k test images, let’s compare the brute force (which calculates all distances every time) and k-d tree for 3 neighbors:

Brute force $(O(k * n))$: 3 * 10,000 = 30,000

k-d tree $ (O(k * log(n)))$: 3 * log(10,000) ~ 3 * 13 = 39

Comparison: 39 / 30,000 = 0.0013

As you can see, the performance gain is huge! The data structure method uses only a tiny fraction of the brute force time. For most datasets this method is a clear winner.

Case study 2: real-time smart monitoring

Machine Learning is commonly used for image recognition, often using neural networks. It’s very useful for real-time applications, where it’s often integrated with cameras, alarms etc. The problem with neural networks is that they often detect the same object 2 or more times — even the best architectures like YOLO have this problem. We can actually solve it with nearest neighbor search with a simple approach:

  • Calculate the center of each bounding box (rectangle)

  • For each rectangle, search for its nearest neighbor (1NN)

  • If points are closer than the selected threshold, merge them (they detect the same object)

The crucial part is searching for the closest center of another bounding box (point 2). Which algorithm should be used here? Typically we have only a few moving objects on camera, maybe up to 30–40. For such a small number, speedup from using data structures for faster lookup is negligible. Each frame is a separate image, so if we wanted to construct a k-d tree for example, we would have to do so for every frame, which may mean 30 times per second — a huge cost overall. Therefore, for such situations a simple brute force method works fastest and also has the smallest space requirement (which, with heavy neural networks or for embedded CPUs in cameras, may be important).

How to choose a proper k?

K-fold Cross Validation

Cross-validation is a statistical method used to estimate the skill of machine learning models.

  • Advantage: Can avoid overfitting and under-fitting

  • Disadvantage: K should be large enough

Reference

For complexity explanation

For three KNN algorithms