k-nearest neighbors#

A KNNClassifier is a k-nearest neighbors model based on scikit-learn.

We want to classify the Iris dataset using a KNN classifier.

from __future__ import annotations

from numpy import array

from gemseo import configure_logger
from gemseo import create_benchmark_dataset
from gemseo.mlearning import create_classification_model

configure_logger()
<RootLogger root (INFO)>

Load Iris dataset#

iris = create_benchmark_dataset("IrisDataset", as_io=True)

Create the classification model#

Then, we build the k-NN classification model from the discipline cache and displays this model.

model = create_classification_model("KNNClassifier", data=iris)
model.learn()
model
KNNClassifier(input_names=(), n_neighbors=5, output_names=(), parameters={}, transformer={'inputs': <gemseo.mlearning.transformers.scaler.min_max_scaler.MinMaxScaler object at 0x704937e17c50>})
  • based on the scikit-learn library
  • built from 150 learning samples


Predict output#

Once it is built, we can use it for prediction.

input_value = {
    "sepal_length": array([4.5]),
    "sepal_width": array([3.0]),
    "petal_length": array([1.0]),
    "petal_width": array([0.2]),
}
output_value = model.predict(input_value)
output_value
{'specy': array([0])}

Total running time of the script: (0 minutes 0.026 seconds)

Gallery generated by Sphinx-Gallery