Source code for gemseo.core.mdofunctions.norm_db_function

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
#    INITIAL AUTHORS - API and implementation and/or documentation
#        :author: Francois Gallard
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
#        :author: Benoit Pauwels - Stacked data management
#               (e.g. iteration index)
#        :author: Gilberto Ruiz Jimenez
"""An MDOFunction subclass to support formulations."""

from __future__ import annotations

from typing import TYPE_CHECKING

from numpy import isnan

from gemseo.algos.database import Database
from gemseo.algos.stop_criteria import DesvarIsNan
from gemseo.algos.stop_criteria import FunctionIsNan
from gemseo.algos.stop_criteria import MaxIterReachedException
from gemseo.core.mdofunctions.mdo_function import MDOFunction

if TYPE_CHECKING:
    from gemseo.algos.opt_problem import OptimizationProblem
    from gemseo.typing import NumberArray


[docs] class NormDBFunction(MDOFunction): """An :class:`.MDOFunction` object to be evaluated from a database.""" def __init__( self, orig_func: MDOFunction, normalize: bool, is_observable: bool, optimization_problem: OptimizationProblem, ) -> None: """ Args: orig_func: The original function to be wrapped. normalize: If ``True``, then normalize the function's input vector. is_observable: If ``True``, new_iter_listeners are not called when function is called (avoid recursive call). optimization_problem: The optimization problem object that contains the function. """ # noqa: D205, D212, D415 self.__orig_func = orig_func self.__is_observable = is_observable self.__optimization_problem = optimization_problem # For performance design_space = self.__optimization_problem.design_space self.__unnormalize_vect = design_space.unnormalize_vect # self.__round_vect = design_space.round_vect self.__unnormalize_grad = design_space.unnormalize_grad self.__evaluate_orig_func = self.__orig_func.evaluate self.__jac_orig_func = orig_func.jac self.__is_max_iter_reached = self.__optimization_problem.is_max_iter_reached super().__init__( self._func_to_wrap, orig_func.name, jac=self._jac_to_wrap, f_type=orig_func.f_type, expr=orig_func.expr, input_names=orig_func.input_names, dim=orig_func.dim, output_names=orig_func.output_names, special_repr=orig_func.special_repr, original_name=orig_func.original_name, expects_normalized_inputs=normalize, ) def _func_to_wrap(self, x_vect: NumberArray) -> NumberArray: """Compute the function to be passed to the optimizer. Args: x_vect: The value of the design variables. Returns: The evaluation of the function for this value of the design variables. Raises: DesvarIsNan: If the design variables contain a NaN value. FunctionIsNan: If a function returns a NaN value when evaluated. MaxIterReachedException: If the maximum number of iterations has been reached. """ # TODO: Add a dedicated function check_has_nan(). if isnan(x_vect).any(): msg = f"Design Variables contain a NaN value: {x_vect}" raise DesvarIsNan(msg) normalize = self.expects_normalized_inputs if normalize: xn_vect = x_vect xu_vect = self.__unnormalize_vect(xn_vect) else: xu_vect = x_vect xn_vect = None # For performance, hash once, and reuse in get/store methods database = self.__optimization_problem.database hashed_xu = database.get_hashable_ndarray(xu_vect) # try to retrieve the evaluation value = database.get_function_value(self.name, hashed_xu) if value is None: if not database.get(hashed_xu) and self.__is_max_iter_reached(): raise MaxIterReachedException # if not evaluated yet, evaluate if normalize: value = self.__evaluate_orig_func(xn_vect) else: value = self.__evaluate_orig_func(xu_vect) if self.__optimization_problem.stop_if_nan and isnan(value).any(): msg = f"The function {self.name} is NaN for x={xu_vect}" raise FunctionIsNan(msg) # store (x, f(x)) in database database.store(hashed_xu, {self.name: value}) return value def _jac_to_wrap(self, x_vect: NumberArray) -> NumberArray: """Compute the gradient of the function to be passed to the optimizer. Args: x_vect: The value of the design variables. Returns: The evaluation of the gradient for this value of the design variables. Raises: FunctionIsNan: If the design variables contain a NaN value. If the evaluation of the jacobian results in a NaN value. """ # TODO: Add a dedicated function check_has_nan(). if isnan(x_vect).any(): msg = f"Design Variables contain a NaN value: {x_vect}" raise FunctionIsNan(msg) normalize = self.expects_normalized_inputs if normalize: xn_vect = x_vect xu_vect = self.__unnormalize_vect(xn_vect) else: xu_vect = x_vect xn_vect = None database = self.__optimization_problem.database design_space = self.__optimization_problem.design_space # try to retrieve the evaluation jac_u = database.get_function_value( Database.get_gradient_name(self.name), xu_vect ) if jac_u is None: if not database.get(xu_vect) and self.__is_max_iter_reached(): raise MaxIterReachedException # if not evaluated yet, evaluate if self.expects_normalized_inputs: jac_n = self.__jac_orig_func(xn_vect) jac_u = self.__unnormalize_grad(jac_n) else: jac_u = self.__jac_orig_func(xu_vect) jac_n = None if isnan(jac_u.data).any() and self.__optimization_problem.stop_if_nan: msg = f"Function {self.name}'s Jacobian is NaN for x={xu_vect}" raise FunctionIsNan(msg) func_name_to_value = {Database.get_gradient_name(self.name): jac_u} # store (x, j(x)) in database database.store(xu_vect, func_name_to_value) else: jac_n = design_space.normalize_grad(jac_u) if self.expects_normalized_inputs: return jac_n.real return jac_u.real