Source code for gemseo.utils.testing
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Contributors:
# Matthias De Lozzo
# Antoine DECHAUME
from __future__ import annotations
from typing import Any
from typing import Mapping
from matplotlib.testing.decorators import image_comparison as mpl_image_comparison
from numpy import array_equal
from numpy import ndarray
from numpy.linalg import norm
from gemseo.utils.data_conversion import flatten_nested_dict
[docs]def compare_dict_of_arrays(
dict_of_arrays: Mapping[str, ndarray],
other_dict_of_arrays: Mapping[str, ndarray],
tolerance: float = 0.0,
) -> bool:
"""Check if two dictionaries of NumPy arrays are equal.
These dictionaries can be nested.
Args:
dict_of_arrays: A dictionary of NumPy arrays.
other_dict_of_arrays: Another dictionary of NumPy arrays.
tolerance: A relative tolerance.
The dictionaries are approximately equal
if for any key ``reference_name`` of ``reference_dict_of_arrays``,
``norm(dict_of_arrays[name]-reference_dict_of_arrays[name])
/(1+norm(reference_dict_of_arrays))<= cache_tol``
Returns:
Whether the dictionaries are equal.
"""
if any(isinstance(value, Mapping) for value in dict_of_arrays.values()):
dict_of_arrays = flatten_nested_dict(dict_of_arrays)
other_dict_of_arrays = flatten_nested_dict(other_dict_of_arrays)
for key, value in other_dict_of_arrays.items():
if key not in dict_of_arrays:
return False
if tolerance:
if norm(dict_of_arrays[key] - value) > tolerance * (1.0 + norm(value)):
return False
else:
if not array_equal(dict_of_arrays[key], value):
return False
return True
[docs]def image_comparison(*args: Any, **kwargs: Any) -> None:
"""Compare matplotlib images generated by the tests with reference ones.
This overloads :meth:`matplotlib.testing.decorators.image_comparison` by using
``"default"`` as ``style`` if missing. Use ``["png"]`` as ``extensions`` if missing.
"""
if "style" not in kwargs: # pragma: no cover
kwargs["style"] = "default"
if "extensions" not in kwargs:
kwargs["extensions"] = ["png"]
return mpl_image_comparison(*args, **kwargs)