Source code for gemseo.problems.mdo.sellar.utils

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Utils for the customizable Sellar MDO problem."""

from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING
from typing import Any

from numpy import array
from numpy import atleast_2d
from numpy import float64
from numpy import ndarray
from numpy import ones
from numpy import zeros

from gemseo.core.data_converters.json import JSONGrammarDataConverter
from gemseo.core.grammars.json_grammar import JSONGrammar
from gemseo.problems.mdo.sellar import WITH_2D_ARRAY
from gemseo.problems.mdo.sellar.variables import ALPHA
from gemseo.problems.mdo.sellar.variables import BETA
from gemseo.problems.mdo.sellar.variables import GAMMA
from gemseo.problems.mdo.sellar.variables import X_1
from gemseo.problems.mdo.sellar.variables import X_2
from gemseo.problems.mdo.sellar.variables import X_SHARED
from gemseo.problems.mdo.sellar.variables import Y_1
from gemseo.problems.mdo.sellar.variables import Y_2

if TYPE_CHECKING:
    from collections.abc import Iterable

    from gemseo.mda.base_mda import BaseMDA
    from gemseo.typing import RealArray


[docs] def get_initial_data(names: Iterable[str] = (), n: int = 1) -> dict[str, RealArray]: """Generate an initial solution for the MDO problem. Args: names: The names of the discipline inputs. n: The size of the local design variables and coupling variables Returns: The default values of the discipline inputs. """ inputs = { X_1: zeros(n), X_2: zeros(n), X_SHARED: array([1.0, 0.0], dtype=float64), Y_1: ones(n, dtype=float64), Y_2: ones(n, dtype=float64), ALPHA: array([3.16]), BETA: array([24.0]), GAMMA: array([0.2]), } if WITH_2D_ARRAY: # pragma: no cover inputs[X_SHARED] = atleast_2d(inputs[X_SHARED]) if not names: return inputs return {name: inputs[name] for name in names if name in inputs}
[docs] def get_y_opt(mda: BaseMDA) -> ndarray: """Return the optimal ``y`` array. Args: mda: The mda. Returns: The optimal ``y`` array. """ return array([ mda.io.data[Y_1][0].real, mda.io.data[Y_2][0].real, ])
[docs] class DataConverterFor2DArray(JSONGrammarDataConverter): """A data converter where ``x_shared`` is not a ndarray."""
[docs] def convert_value_to_array(self, name: str, value: Any) -> ndarray: # noqa: D102 # pragma: no cover if name == X_SHARED: return value[0] return super().convert_value_to_array(name, value)
[docs] def convert_array_to_value(self, name: str, array_: Any) -> Any: # noqa: D102 # pragma: no cover if name == X_SHARED: return array([array_]) return super().convert_array_to_value(name, array_)
[docs] @contextmanager def set_data_converter() -> None: """Set the data converter according to whether 2D array shall be used.""" if WITH_2D_ARRAY: # pragma: no cover JSONGrammar.DATA_CONVERTER_CLASS = DataConverterFor2DArray yield JSONGrammar.DATA_CONVERTER_CLASS = JSONGrammarDataConverter else: yield