Source code for gemseo.post.dataset.scatter_plot_matrix

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
#    INITIAL AUTHORS - initial API and implementation and/or initial
#                           documentation
#        :author: Matthias De Lozzo
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
r"""Draw a scatter matrix from a :class:`.Dataset`.

The :class:`.ScatterMatrix` class implements the scatter plot matrix,
which is a way to visualize :math:`n` samples of a
multi-dimensional vector

.. math::

   x=(x_1,x_2,\ldots,x_d)\in\mathbb{R}^d

in several 2D subplots where the (i,j) subplot represents the cloud
of points

.. math::

   \left(x_i^{(k)},x_j^{(k)}\right)_{1\leq k \leq n}

while the (i,i) subplot represents the empirical distribution of the samples

.. math::

   x_i^{(1)},\ldots,x_i^{(n)}

by means of an histogram or a kernel density estimator.

A variable name can be passed to the :meth:`.DatasetPlot.execute` method
by means of the :code:`classifier` keyword in order to color the curves
according to the value of the variable name. This is useful when the data is
labeled.
"""
from __future__ import annotations

from typing import Sequence

from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from pandas.plotting import scatter_matrix

from gemseo.core.dataset import Dataset
from gemseo.post.dataset.dataset_plot import DatasetPlot


[docs]class ScatterMatrix(DatasetPlot): """Scatter plot matrix.""" def __init__( self, dataset: Dataset, variable_names: Sequence[str] | None = None, classifier: str | None = None, kde: bool = False, size: int = 25, marker: str = "o", plot_lower: bool = True, plot_upper: bool = True, ) -> None: """ Args: classifier: The name of the variable to build the cluster. kde: The type of the distribution representation. If True, plot kernel-density estimator on the diagonal. Otherwise, use histograms. size: The size of the points. marker: The marker for the points. plot_lower: Whether to plot the lower part. plot_upper: Whether to plot the upper part. """ super().__init__( dataset, variable_names=variable_names, classifier=classifier, kde=kde, size=size, marker=marker, plot_lower=plot_lower, plot_upper=plot_upper, ) def _plot( self, fig: None | Figure = None, axes: None | Axes = None, ) -> list[Figure]: variable_names = self._param.variable_names classifier = self._param.classifier kde = self._param.kde size = self._param.size marker = self._param.marker if variable_names is None: variable_names = self.dataset.variables if classifier is not None and classifier not in self.dataset.variables: raise ValueError( f"{classifier} cannot be used as a classifier " f"because it is not a variable name; " f"available ones are: {self.dataset.variables}." ) if kde: diagonal = "kde" else: diagonal = "hist" dataframe = self.dataset.export_to_dataframe(variable_names=variable_names) kwargs = {} if classifier is not None: palette = dict(enumerate("bgrcmyk")) groups = self.dataset.get_data_by_names([classifier], False)[:, 0:1] kwargs["color"] = [palette[group[0] % len(palette)] for group in groups] _, variable_name = self._get_label(classifier) dataframe = dataframe.drop(labels=variable_name, axis=1) dataframe.columns = self._get_variables_names(dataframe) fig, axes = self._get_figure_and_axes(fig, axes, self.fig_size) sub_axes = scatter_matrix( dataframe, diagonal=diagonal, s=size, marker=marker, figsize=self.fig_size, ax=axes, **kwargs, ) n_cols = sub_axes.shape[0] if not (self._param.plot_lower and self._param.plot_upper): for i in range(n_cols): for j in range(n_cols): sub_axes[i, j].get_xaxis().set_visible(False) sub_axes[i, j].get_yaxis().set_visible(False) if not self._param.plot_lower: for i in range(n_cols): for j in range(i): sub_axes[i, j].set_visible(False) for i in range(n_cols): sub_axes[i, i].get_xaxis().set_visible(True) sub_axes[i, i].get_yaxis().set_visible(True) if not self._param.plot_upper: for i in range(n_cols): for j in range(i + 1, n_cols): sub_axes[i, j].set_visible(False) for i in range(n_cols): sub_axes[-1, i].get_xaxis().set_visible(True) sub_axes[i, 0].get_yaxis().set_visible(True) plt.suptitle(self.title) return [fig]