In [None]:
%matplotlib inline


# Gaussian Mixtures

Load Iris dataset and create clusters.


In [None]:
from __future__ import division, unicode_literals

from numpy import array

from gemseo.api import configure_logger, load_dataset
from gemseo.core.dataset import Dataset
from gemseo.mlearning.api import create_clustering_model

configure_logger()

## Create dataset
We import the Iris benchmark dataset through the API.



In [None]:
iris = load_dataset("IrisDataset")

# Extract inputs as a new dataset
data = iris.get_data_by_group(iris.PARAMETER_GROUP)
variables = iris.get_names(iris.PARAMETER_GROUP)
print(variables)

dataset = Dataset("sepal_and_petal")
dataset.set_from_array(data, variables)

## Create clustering model
We know that there are three classes of Iris plants.
We will thus try to identify three clusters.



In [None]:
model = create_clustering_model("GaussianMixture", data=dataset, n_components=3)
model.learn()
print(model)

## Predict output
Once it is built, we can use it for prediction.



In [None]:
input_value = {
    "sepal_length": array([4.5]),
    "sepal_width": array([3.0]),
    "petal_length": array([1.0]),
    "petal_width": array([0.2]),
}
output_value = model.predict(input_value)
print(output_value)

## Plot clusters
Show cluster labels



In [None]:
dataset.add_variable(
    "gm_specy", model.labels.reshape((-1, 1)), group="labels", cache_as_input=False
)
dataset.plot("ScatterMatrix", kde=True, classifier="gm_specy")