Source code for gemseo.disciplines.remapping

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
#     Matthias De Lozzo
"""A discipline whose inputs and outputs map to those of another."""
from __future__ import annotations

from typing import Dict
from typing import Iterable
from typing import Tuple
from typing import Union

from numpy import empty
from numpy import ndarray

from gemseo.core.discipline import MDODiscipline
from gemseo.utils.python_compatibility import singledispatchmethod

Data = Dict[str, ndarray]
Indices = Tuple[str, Union[int, Iterable[int]]]
NameMapping = Dict[str, Union[str, Indices]]


[docs]class RemappingDiscipline(MDODiscipline): """A discipline whose inputs and outputs map to those of another. An input or output name mapping looks like ``{"new_x": "x", "new_y": ("y", components)}`` where the variable ``"new_x"`` corresponds to the original variable ``"x"`` and the variable ``"new_y"`` corresponds to some ``components`` of the original variable ``"y"``. ``components`` can be an integer ``i`` (the ``i``-th component of ``y``), a sequence of integers ``[i, j, k]`` (the ``i``-th, ``j``-th and ``k``-th components of ``y``) or an iterable of integers ``range(i, j+1)`` (from the ``i``-th to the ``j``-th components of ``y``). """ _ATTR_TO_SERIALIZE = MDODiscipline._ATTR_TO_SERIALIZE + ( "_discipline", "_empty_original_input_data", "_input_mapping", "_output_mapping", ) def __init__( self, discipline: MDODiscipline, input_mapping: NameMapping, output_mapping: NameMapping, ) -> None: """.. # noqa: D205 D212 D415 Args: discipline: The original discipline. input_mapping: The input names to the original input names. output_mapping: The output names to the original output names. Raises: ValueError: When the original discipline has no default input values. """ if not discipline.default_inputs: raise ValueError("The original discipline has no default input values.") self._discipline = discipline self._empty_original_input_data = { k: empty(v.shape, dtype=v.dtype) for k, v in discipline.default_inputs.items() } self._input_mapping = self.__format_mapping(input_mapping) self._output_mapping = self.__format_mapping(output_mapping) super().__init__(name=self._discipline.name) self.input_grammar.update(self._input_mapping.keys()) self.output_grammar.update(self._output_mapping.keys()) self.default_inputs = self.__convert_from_origin( discipline.default_inputs, self._input_mapping ) @property def original_discipline(self) -> MDODiscipline: """The original discipline.""" return self._discipline @singledispatchmethod @staticmethod def __cast_mapping_value(value) -> slice | Iterable[int]: """Cast a value of a mapping. Args: value: The value to be casted. Returns: The casted value. Raises: ValueError: When the value is neither a string nor a tuple with a string as first component and an integer or iterable of integers as second one. """ raise ValueError( "The values of a name mapping should be either a str or a tuple[str, Any]." ) @staticmethod @__cast_mapping_value.register def _(value: str): # -> tuple[str, slice]: return value, slice(None) @staticmethod @__cast_mapping_value.register def _(value: tuple): # -> tuple[str, slice | Iterable[int]]: name, indices = value if isinstance(indices, int): return name, slice(indices, indices + 1) return value @classmethod def __format_mapping(cls, mapping: NameMapping) -> dict[str, slice | Iterable[int]]: """Format a mapping as ``{"current_name": ("original_name", components)}``. Args: mapping: The user mapping. Returns: The formatted mapping. """ return {k: cls.__cast_mapping_value(v) for k, v in mapping.items()} def _run(self) -> None: self._discipline.execute(self.__convert_to_origin(self.get_input_data())) self.local_data.update( self.__convert_from_origin( self._discipline.get_output_data(), self._output_mapping ) ) def _compute_jacobian( self, inputs: Iterable[str] | None = None, outputs: Iterable[str] | None = None, ) -> None: self._discipline._compute_jacobian(inputs=inputs, outputs=outputs) original_jac = self._discipline.jac self.jac = {} for new_o_name, (o_name, o_args) in self._output_mapping.items(): self.jac[new_o_name] = {} for new_i_name, (i_name, i_args) in self._input_mapping.items(): jac = original_jac[o_name][i_name] self.jac[new_o_name][new_i_name] = jac[o_args, i_args] @staticmethod def __convert_from_origin(original_data: Data, name_mapping: NameMapping) -> Data: """Convert original data to the current format. Args: original_data: The original data mapping the original names to the corresponding values. name_mapping: The current names mapping to the original ones. Returns: The current data mapping the current names to the corresponding values. """ return { new_name: original_data[original_name][args] for new_name, (original_name, args) in name_mapping.items() } def __convert_to_origin(self, input_data: Data) -> Data: """Convert current input data to the original format. Args: input_data: The current input data mapping the current input names to the corresponding values. Returns: The original input data mapping the original input names to the corresponding values. """ original_input_data = self._empty_original_input_data.copy() for new_name, value in input_data.items(): original_name, args = self._input_mapping[new_name] original_input_data[original_name][args] = value return original_input_data