Source code for gemseo.post.kmeans

# -*- coding: utf-8 -*-
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# Contributors:
#    INITIAL AUTHORS - API and implementation and/or documentation
#        :author: Francois Gallard
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
"""
A k-means classification of the optimization history
****************************************************
"""
from __future__ import absolute_import, division, unicode_literals

from future import standard_library
from numpy import array
from numpy import int as np_int
from sklearn import cluster
from sklearn.preprocessing import StandardScaler

from gemseo.post.opt_post_processor import OptPostProcessor

standard_library.install_aliases()


from gemseo import LOGGER


[docs]class KMeans(OptPostProcessor): """ The **KMeans** post processing performs a k-means clustering on optimization history. The default number of clusters is 5 and can be modified in option. The k-means construction depends on the :code:`MiniBatchKMeans` class of the :code:`cluster` module of the `scikit-learn library <https://scikit-learn.org/stable/modules/generated/ sklearn.cluster.MiniBatchKMeans.html>`_ . """ def _run(self, n_clusters=5): # pylint: disable=W0221 """ Computes the clustering :param n_clusters: prescribed number of clusters """ self.__build_clusters(n_clusters) def __build_clusters(self, n_clusters=5): """ Builds the clusters :param n_clusters: prescribed number of clusters :type n_clusters: int """ x_history = self.database.get_x_history() x_vars = array(x_history) x_vars_sc = StandardScaler().fit_transform(x_vars) # estimate bandwidth for mean shift algorithm = cluster.MiniBatchKMeans(n_clusters=n_clusters) # predict cluster memberships algorithm.fit(x_vars_sc) y_pred = algorithm.labels_.astype(np_int) for x_vars, y_vars in zip(x_history, y_pred): self.database.store(x_vars, {"KM_cluster": int(y_vars)}) self.out_data_dict[tuple(x_vars.real)] = y_vars