gemseo / mlearning / classification

knn module

K-nearest neighbors classification model

The k-nearest neighbor classification algorithm is an approach to predict the output class of a new input point by selecting the majority class among the k nearest neighbors in a training set through voting. The algorithm may also predict the probabilties of belonging to each class by counting the number of occurences of the class withing the k nearest neighbors.

Let \((x_i)_{i=1,\cdots,n_{\text{samples}}}\in \mathbb{R}^{n_{\text{samples}}\times n_{\text{inputs}}}\) and \((y_i)_{i=1,\cdots,n_{\text{samples}}}\in \{1,\cdots,n_{\text{classes}}\}^{n_{\text{samples}}}\) denote the input and output training data respectively.

The procedure for predicting the class of a new input point \(x\in \mathbb{R}^{n_{\text{inputs}}}\) is the following:

Let \(i_1(x), \cdots, i_{n_{\text{samples}}}(x)\) be the indices of the input training points sorted by distance to the prediction point \(x\), i.e.

\[\|x-x_{i_1(x)}\| \leq \cdots \leq \|x-x_{i_{n_{\text{samples}}}(x)}\|.\]

The ordered indices may be formally determined through the inductive formula

\[i_p(x) = \underset{i\in I_p(x)}{\operatorname{argmin}}\|x-x_i\|,\quad p=1,\cdots,n_{\text{samples}}\]


\[\begin{split}I_1(x) = \{1,\cdots,n_{\text{samples}}\}\\ I_{p+1} = I_p(x)\setminus \{i_p(x)\},\quad p=1,\cdots,n_{\text{samples}}-1,\end{split}\]

that is

\[I_p(x) = \{1,\cdots,n_{\text{samples}}\}\setminus \{i_1(x),\cdots,i_{p-1}(x)\}.\]

Then, by denoting \(\operatorname{mode}(\cdot)\) the mode operator, i.e. the operator that extracts the element with the highest occurence, we may define the prediction operator as the mode of the set of output classes associated to the \(k\) first indices (classes of the \(k\)-nearest neighbors of \(x\)):

\[f(x) = \operatorname{mode}(y_{i_1(x)}, \cdots, y_{i_k(x)})\]

This concept is implemented through the KNNClassifier class which inherits from the MLClassificationAlgo class.


The classifier relies on the KNeighborsClassifier class of the scikit-learn library.

class gemseo.mlearning.classification.knn.KNNClassifier(data, transformer=None, input_names=None, output_names=None, n_neighbors=5, **parameters)[source]

Bases: gemseo.mlearning.classification.classification.MLClassificationAlgo

K nearest neighbors classification algorithm.


  • data (Dataset) – learning dataset.

  • transformer (dict(str)) – transformation strategy for data groups. If None, do not transform data. Default: None.

  • input_names (list(str)) – names of the input variables.

  • output_names (list(str)) – names of the output variables.

  • n_neighbors – number of neighbors.

  • parameters – other keyword arguments for sklearn KNN.

LIBRARY = 'scikit-learn'