Source code for gemseo.post.kmeans
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Contributors:
# INITIAL AUTHORS - API and implementation and/or documentation
# :author: Francois Gallard
# OTHER AUTHORS - MACROSCOPIC CHANGES
"""A k-means classification of the optimization history."""
from __future__ import annotations
from typing import TYPE_CHECKING
from numpy import array
from sklearn import cluster
from sklearn.preprocessing import StandardScaler
from gemseo.post.opt_post_processor import OptPostProcessor
if TYPE_CHECKING:
from pathlib import Path
from gemseo.utils.matplotlib_figure import FigSizeType
[docs]
class KMeans(OptPostProcessor):
"""Performs a k-means clustering on optimization history.
The default number of clusters is 5 and can be modified in option.
The k-means construction depends
on the ``MiniBatchKMeans`` class
of the ``cluster`` module of the
`scikit-learn library <https://scikit-learn.org/stable/modules/generated/
sklearn.cluster.MiniBatchKMeans.html>`_ .
"""
def _run(
self,
save: bool = True,
show: bool = False,
file_path: Path = "",
directory_path: str | Path = "",
file_name: str = "",
file_extension: str = "",
fig_size: FigSizeType = (),
n_clusters: int = 5,
) -> None:
"""
Args:
n_clusters: The number of clusters.
""" # noqa: D205, D212, D415
self.__build_clusters(n_clusters=n_clusters)
def __build_clusters(
self,
n_clusters: int = 5,
) -> None:
"""Build the clusters.
Args:
n_clusters: The number of clusters.
"""
x_history = self.database.get_x_vect_history()
x_vars = array(x_history)
x_vars_sc = StandardScaler().fit_transform(x_vars)
# estimate bandwidth for mean shift
algorithm = cluster.MiniBatchKMeans(n_clusters=n_clusters, n_init="auto")
# predict cluster memberships
algorithm.fit(x_vars_sc)
y_pred = algorithm.labels_.astype(int)
for x_vars, y_vars in zip(x_history, y_pred):
self.database.store(x_vars, {"KM_cluster": int(y_vars)})
self.materials_for_plotting[tuple(x_vars.real)] = y_vars