# Copyright 2021 IRT Saint Exupéry,
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# Contributors:
#    INITIAL AUTHORS - API and implementation and/or documentation
#        :author: Francois Gallard
"""A k-means classification of the optimization history."""
from __future__ import division, unicode_literals

import logging
from typing import Optional, Tuple, Union

from numpy import array
from numpy import int as np_int
from sklearn import cluster
from sklearn.preprocessing import StandardScaler

from import OptPostProcessor
from gemseo.utils.py23_compat import Path

LOGGER = logging.getLogger(__name__)

[docs]class KMeans(OptPostProcessor): """The **KMeans** post processing performs a k-means clustering on optimization history. The default number of clusters is 5 and can be modified in option. The k-means construction depends on the :code:`MiniBatchKMeans` class of the :code:`cluster` module of the `scikit-learn library < sklearn.cluster.MiniBatchKMeans.html>`_ . """ def _run( self, save=True, # type: bool show=False, # type: bool file_path=None, # type: Optional[Path] directory_path=None, # type: Optional[Union[str,Path]] file_name=None, # type: Optional[str] file_extension=None, # type: Optional[str] fig_size=None, # type: Optional[Tuple[float, float]] n_clusters=5, # type: int ): # type: (...) -> None """ Args: n_clusters: The number of clusters. """ self.__build_clusters(n_clusters=n_clusters) def __build_clusters( self, n_clusters=5, # type: int ): # type: (...) -> None """Build the clusters. Args: n_clusters: The number of clusters. """ x_history = self.database.get_x_history() x_vars = array(x_history) x_vars_sc = StandardScaler().fit_transform(x_vars) # estimate bandwidth for mean shift algorithm = cluster.MiniBatchKMeans(n_clusters=n_clusters) # predict cluster memberships y_pred = algorithm.labels_.astype(np_int) for x_vars, y_vars in zip(x_history, y_pred):, {"KM_cluster": int(y_vars)}) self.out_data_dict[tuple(x_vars.real)] = y_vars