Source code for gemseo.wrappers.filtering_discipline

# -*- coding: utf-8 -*-
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# Contributors:
#    INITIAL AUTHORS - initial API and implementation and/or initial
#                           documentation
#        :author: Matthias De Lozzo
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
from gemseo.core.discipline import MDODiscipline


[docs]class FilteringDiscipline(MDODiscipline): """The FilteringDiscipline is a MDODiscipline wrapping another MDODiscipline, for a subset of inputs and outputs.""" def __init__( self, discipline, inputs_names=None, outputs_names=None, keep_in=True, keep_out=True, ): """Constructor. :param MDODiscipline discipline: discipline. :param list(str) inputs_names: list of inputs names. If None, use all inputs. Default: None. :param list(str) outputs_names: list of outputs names. If None, use all outputs. Default: None. :param bool keep_in: if True, keep the list of inputs names. Otherwise, remove them. :param bool keep_out: if True, keep the list of outputs names. Otherwise, remove them. """ self.discipline = discipline super(FilteringDiscipline, self).__init__(name=discipline.name) original_inputs_names = discipline.get_input_data_names() original_outputs_names = discipline.get_output_data_names() if inputs_names is not None: if not keep_in: inputs_names = list(set(original_inputs_names) - set(inputs_names)) else: inputs_names = original_inputs_names if outputs_names is not None: if not keep_out: outputs_names = list(set(original_outputs_names) - set(outputs_names)) else: outputs_names = original_outputs_names self.input_grammar.initialize_from_data_names(inputs_names) self.output_grammar.initialize_from_data_names(outputs_names) self.default_inputs = self.__filter_inputs(self.discipline.default_inputs) removed_inputs = set(original_inputs_names) - set(inputs_names) diff_inputs = set(self.discipline._differentiated_inputs) - removed_inputs self.add_differentiated_inputs(list(diff_inputs)) removed_outputs = set(original_outputs_names) - set(outputs_names) diff_outputs = set(self.discipline._differentiated_outputs) - removed_outputs self.add_differentiated_outputs(list(diff_outputs)) def _run(self): self.discipline.execute(self.get_input_data()) self.store_local_data(**self.__filter_inputs(self.discipline.local_data)) self.store_local_data(**self.__filter_outputs(self.discipline.local_data)) def _compute_jacobian(self, inputs=None, outputs=None): self.discipline._compute_jacobian(inputs, outputs) self._init_jacobian(inputs, outputs, with_zeros=True) jac = self.discipline.jac for output_name in self.get_output_data_names(): for input_name in self.get_input_data_names(): self.jac[output_name][input_name] = jac[output_name][input_name] @staticmethod def __filter(data, keys): """Filter a data dictionary by names. :param dict data: data dictionary. :param list(str) keys: list of dictionary keys. """ return {key: data[key] for key in keys} def __filter_inputs(self, data): """Filter a data dictionary by inputs names. :param dict data: data dictionary. """ return self.__filter(data, self.get_input_data_names()) def __filter_outputs(self, data): """Filter a data dictionary by outputs names. :param dict data: data dictionary. """ return self.__filter(data, self.get_output_data_names())