Source code for gemseo.disciplines.splitter

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""A discipline splitting an input variable."""

from __future__ import annotations

from typing import TYPE_CHECKING

from numpy import ndarray
from scipy.sparse import eye

from gemseo.core.discipline import MDODiscipline

if TYPE_CHECKING:
    from collections.abc import Iterable


[docs] class Splitter(MDODiscipline): """A discipline splitting an input variable. Several output variables containing slice of the input variable are extracted. Examples: >>> discipline = Splitter("alpha", {"beta": [0, 1], "delta": [2, 3], "gamma": 4}) >>> discipline.execute({"alpha": array([1.0, 2.0, 3.0, 4.0, 5.0])}) >>> delta = discipline.local_data["delta"] # delta = array([3.0, 4.0]) """ def __init__( self, input_name: str, output_names_to_input_indices: dict[str, Iterable[int] | int], ) -> None: """ Args: input_name: The name of the input to split. output_names_to_input_indices: The input indices associated with the output names. """ # noqa: D205, D212, D415 self.__input_name = input_name for output_name, input_indices in output_names_to_input_indices.items(): if not isinstance(input_indices, ndarray) and not isinstance( input_indices, list ): output_names_to_input_indices[output_name] = [input_indices] self.__slicing_structure = output_names_to_input_indices super().__init__() self.input_grammar.update_from_names([input_name]) self.output_grammar.update_from_names(output_names_to_input_indices.keys()) def _run(self) -> None: input_data = self.local_data[self.__input_name] for output_name, input_indices in self.__slicing_structure.items(): self.local_data[output_name] = input_data[input_indices] def _compute_jacobian( self, inputs: Iterable[str] | None = None, outputs: Iterable[str] | None = None, ) -> None: self._init_jacobian(init_type=self.InitJacobianType.SPARSE) identity = eye(self.local_data[self.__input_name].size, format="csr") for output_name, input_indices in self.__slicing_structure.items(): self.jac[output_name][self.__input_name] = identity[input_indices, :]