Source code for gemseo.wrappers.filtering_discipline

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
#    INITIAL AUTHORS - initial API and implementation and/or initial
#                           documentation
#        :author: Matthias De Lozzo
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
from __future__ import annotations

from typing import Any
from typing import Iterable
from typing import Mapping

from gemseo.core.discipline import MDODiscipline


[docs]class FilteringDiscipline(MDODiscipline): """The FilteringDiscipline is a MDODiscipline wrapping another MDODiscipline, for a subset of inputs and outputs.""" def __init__( self, discipline: MDODiscipline, inputs_names: Iterable[str] | None = None, outputs_names: Iterable[str] | None = None, keep_in: bool = True, keep_out: bool = True, ) -> None: """ Args: discipline: The original discipline. inputs_names: The names of the inputs of interest. If ``None``, use all the inputs. outputs_names: The names of the outputs of interest. If ``None``, use all the outputs. keep_in: Whether to the inputs of interest. Otherwise, remove them. keep_out: Whether to the outputs of interest. Otherwise, remove them. """ self.discipline = discipline super().__init__(name=discipline.name) original_inputs_names = discipline.get_input_data_names() original_outputs_names = discipline.get_output_data_names() if not inputs_names: inputs_names = original_inputs_names elif not keep_in: inputs_names = list(set(original_inputs_names) - set(inputs_names)) if not outputs_names: outputs_names = original_outputs_names elif not keep_out: outputs_names = list(set(original_outputs_names) - set(outputs_names)) self.input_grammar.update(inputs_names) self.output_grammar.update(outputs_names) self.default_inputs = self.__filter_inputs(self.discipline.default_inputs) removed_inputs = set(original_inputs_names) - set(inputs_names) diff_inputs = set(self.discipline._differentiated_inputs) - removed_inputs self.add_differentiated_inputs(list(diff_inputs)) removed_outputs = set(original_outputs_names) - set(outputs_names) diff_outputs = set(self.discipline._differentiated_outputs) - removed_outputs self.add_differentiated_outputs(list(diff_outputs)) def _run(self) -> None: self.discipline.execute(self.get_input_data()) self.store_local_data(**self.__filter_inputs(self.discipline.local_data)) self.store_local_data(**self.__filter_outputs(self.discipline.local_data)) def _compute_jacobian( self, inputs: Iterable[str] | None = None, outputs: Iterable[str] | None = None, ) -> None: self.discipline._compute_jacobian(inputs, outputs) self._init_jacobian(inputs, outputs, with_zeros=True) jac = self.discipline.jac for output_name in self.get_output_data_names(): for input_name in self.get_input_data_names(): self.jac[output_name][input_name] = jac[output_name][input_name] def __filter_inputs(self, data: Mapping[str, Any]): """Filter a mapping by input names. Args: data: The original mapping. Returns: The mapping filtered by input names. """ return {name: data[name] for name in self.get_input_data_names()} def __filter_outputs(self, data): """Filter a mapping by output names. Args: data: The original mapping. Returns: The mapping filtered by output names. """ return {name: data[name] for name in self.get_output_data_names()}