Source code for gemseo.utils.comparisons

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
# Matthias De Lozzo
# Antoine DECHAUME
"""Data comparison tools."""

from __future__ import annotations

from collections.abc import Mapping
from functools import partial

from numpy import allclose
from numpy import array_equal
from numpy import asarray
from numpy import ndarray
from numpy import str_

from gemseo.typing import SparseOrDenseRealArray
from gemseo.utils.compatibility.scipy import sparse_classes
from gemseo.utils.data_conversion import flatten_nested_dict

DataToCompare = (
    Mapping[str, SparseOrDenseRealArray]
    | Mapping[str, Mapping[str, SparseOrDenseRealArray]]
)


# TODO: add runtime optimization to detect,
# until some point in gemseo process (first algo iteration?),
# if sparse arrays are actually used and
# then allow this function to use a specialized implementation.
[docs] def compare_dict_of_arrays( dict_of_arrays: DataToCompare, other_dict_of_arrays: DataToCompare, tolerance: float = 0.0, nan_are_equal: bool = False, ) -> bool: """Check if two dictionaries of NumPy and/or SciPy sparse arrays are equal. The dictionaries can be nested, in which case they are flattened. If the tolerance is set, then arrays are considered equal if ``norm(dict_of_arrays[name] - other_dict_of_arrays[name]) /(1 + norm(other_dict_of_arrays[name])) <= tolerance``. Args: dict_of_arrays: A dictionary of NumPy arrays and/or SciPy sparse matrices. other_dict_of_arrays: Another dictionary of NumPy arrays and/or SciPy sparse matrices. tolerance: The relative tolerance. If 0.0, the array must be exactly equal. nan_are_equal: Whether to compare NaN's as equal. Returns: Whether the dictionaries are equal. """ # Flatten the dictionaries if nested if any(isinstance(value, Mapping) for value in dict_of_arrays.values()): dict_of_arrays = flatten_nested_dict(dict_of_arrays) if any(isinstance(value, Mapping) for value in other_dict_of_arrays.values()): other_dict_of_arrays = flatten_nested_dict(other_dict_of_arrays) # Check the keys if dict_of_arrays.keys() != other_dict_of_arrays.keys(): return False if tolerance: compare_arrays = partial( allclose, rtol=tolerance, atol=tolerance, equal_nan=nan_are_equal, ) else: compare_arrays = partial(array_equal, equal_nan=nan_are_equal) # Check the values for key, array_ in dict_of_arrays.items(): other_array = other_dict_of_arrays[key] if not isinstance(array_, (ndarray, sparse_classes)): array_ = asarray(array_) if not isinstance(other_array, (ndarray, sparse_classes)): other_array = asarray(other_array) if array_.shape != other_array.shape: return False array_is_dense = isinstance(array_, ndarray) other_array_is_dense = isinstance(other_array, ndarray) if array_is_dense or other_array_is_dense: if not array_is_dense: array_ = array_.toarray() if not other_array_is_dense: other_array = other_array.toarray() # Sparsity is kept only when both arrays are sparse else: array_ = array_.data.reshape(-1) other_array = other_array.data.reshape(-1) if array_.dtype.type is str_: if array_ != other_array: return False elif not compare_arrays(array_, other_array): return False return True