Source code for gemseo.post.dataset.scatter_plot_matrix

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
#    INITIAL AUTHORS - initial API and implementation and/or initial
#                           documentation
#        :author: Matthias De Lozzo
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
r"""Draw a scatter matrix from a :class:`.Dataset`.

The :class:`.ScatterMatrix` class implements the scatter plot matrix,
which is a way to visualize :math:`n` samples of a
multi-dimensional vector

.. math::

   x=(x_1,x_2,\ldots,x_d)\in\mathbb{R}^d

in several 2D subplots where the (i,j) subplot represents the cloud
of points

.. math::

   \left(x_i^{(k)},x_j^{(k)}\right)_{1\leq k \leq n}

while the (i,i) subplot represents the empirical distribution of the samples

.. math::

   x_i^{(1)},\ldots,x_i^{(n)}

by means of an histogram or a kernel density estimator.

A variable name can be passed to the :meth:`.DatasetPlot.execute` method
by means of the ``classifier`` keyword in order to color the curves
according to the value of the variable name. This is useful when the data is
labeled.
"""

from __future__ import annotations

from collections.abc import Iterable
from collections.abc import Sequence
from typing import TYPE_CHECKING
from typing import Any
from typing import Union

if TYPE_CHECKING:
    from gemseo.datasets.dataset import Dataset

from gemseo.post.dataset._trend import Trend as _Trend
from gemseo.post.dataset._trend import TrendFunctionCreator
from gemseo.post.dataset.dataset_plot import DatasetPlot

ScatterMatrixOption = Union[bool, int, str, None, Sequence[str]]


[docs] class ScatterMatrix(DatasetPlot): """Scatter plot matrix.""" Trend = _Trend """The type of trend.""" def __init__( self, dataset: Dataset, variable_names: Iterable[str] = (), classifier: str = "", kde: bool = False, size: int = 25, marker: str = "o", plot_lower: bool = True, plot_upper: bool = True, trend: Trend | TrendFunctionCreator = Trend.NONE, **options: Any, ) -> None: """ Args: variable_names: The names of the variables to consider. If empty, consider all the variables of the dataset. classifier: The name of the variable to group data. If empty, do not group data. kde: The type of the distribution representation. If ``True``, plot kernel-density estimator on the diagonal. Otherwise, use histograms. size: The size of the points. marker: The marker for the points. plot_lower: Whether to plot the lower part. plot_upper: Whether to plot the upper part. trend: The trend function to be added on the scatter plots or a function creating a trend function from a set of *xy*-points. **options: The options of the underlying pandas scatter matrix. """ # noqa: D205, D212, D415 super().__init__( dataset, variable_names=variable_names, classifier=classifier, kde=kde, size=size, marker=marker, plot_lower=plot_lower, plot_upper=plot_upper, trend=trend, options=options, ) def _create_specific_data_from_dataset(self) -> tuple[tuple[str, str, int] | None]: """ Returns: The column of the dataset associated with the classifier if the classifier exists. Raises: ValueError: When the classifier does not exist. """ # noqa: D205, D212, D415 classifier = self._specific_settings.classifier if classifier and classifier not in self.dataset.variable_names: msg = ( f"{classifier} cannot be used as a classifier " f"because it is not a variable name; " f"available ones are: {self.dataset.variable_names}." ) raise ValueError(msg) if classifier: return (self._get_label(classifier)[1],) return (None,)