# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Contributors:
# INITIAL AUTHORS - initial API and implementation and/or initial
# documentation
# :author: Damien Guenot
# :author: Francois Gallard, refactoring
# OTHER AUTHORS - MACROSCOPIC CHANGES
"""Optimization library wrappers base class."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Any
from typing import Final
import numpy
from gemseo.algos._unsuitability_reason import _UnsuitabilityReason
from gemseo.algos.driver_library import DriverDescription
from gemseo.algos.driver_library import DriverLibrary
from gemseo.algos.first_order_stop_criteria import KKTReached
from gemseo.algos.first_order_stop_criteria import is_kkt_residual_norm_reached
from gemseo.algos.first_order_stop_criteria import kkt_residual_computation
from gemseo.algos.opt_problem import OptimizationProblem
from gemseo.algos.stop_criteria import FtolReached
from gemseo.algos.stop_criteria import XtolReached
from gemseo.algos.stop_criteria import is_f_tol_reached
from gemseo.algos.stop_criteria import is_x_tol_reached
if TYPE_CHECKING:
from collections.abc import Mapping
from numpy import ndarray
from gemseo.core.mdofunctions.mdo_function import MDOFunction
[docs]
@dataclass
class OptimizationAlgorithmDescription(DriverDescription):
"""The description of an optimization algorithm."""
handle_equality_constraints: bool = False
"""Whether the optimization algorithm handles equality constraints."""
handle_inequality_constraints: bool = False
"""Whether the optimization algorithm handles inequality constraints."""
handle_multiobjective: bool = False
"""Whether the optimization algorithm handles multiple objectives."""
positive_constraints: bool = False
"""Whether the optimization algorithm requires positive constraints."""
problem_type: OptimizationProblem.ProblemType = (
OptimizationProblem.ProblemType.NON_LINEAR
)
"""The type of problem (see :attr:`.OptimizationProblem.ProblemType`)."""
[docs]
class OptimizationLibrary(DriverLibrary):
"""Base optimization library defining a collection of optimization algorithms.
Typically used as:
#. Instantiate an :class:`.OptimizationLibrary`.
#. Select the algorithm with :attr:`.algo_name`.
#. Solve an :class:`.OptimizationProblem` with :meth:`.execute`.
Notes:
The missing current values
of the :class:`.DesignSpace` attached to the :class:`.OptimizationProblem`
are automatically initialized
with the method :meth:`.DesignSpace.initialize_missing_current_values`.
"""
MAX_ITER = "max_iter"
F_TOL_REL = "ftol_rel"
F_TOL_ABS = "ftol_abs"
X_TOL_REL = "xtol_rel"
X_TOL_ABS = "xtol_abs"
_KKT_TOL_ABS = "kkt_tol_abs"
_KKT_TOL_REL = "kkt_tol_rel"
STOP_CRIT_NX = "stop_crit_n_x"
# Maximum step for the line search
LS_STEP_SIZE_MAX = "max_ls_step_size"
# Maximum number of line search steps (per iteration).
LS_STEP_NB_MAX = "max_ls_step_nb"
MAX_FUN_EVAL = "max_fun_eval"
MAX_TIME = "max_time"
PG_TOL = "pg_tol"
SCALING_THRESHOLD: Final[str] = "scaling_threshold"
VERBOSE = "verbose"
__DEFAULT_FTOL_ABS: Final[float] = 0.0
"""The default absolute tolerance for the objective."""
__DEFAULT_FTOL_REL: Final[float] = 0.0
"""The default relative tolerance for the objective."""
__DEFAULT_XTOL_ABS: Final[float] = 0.0
"""The default absolute tolerance for the design variables."""
__DEFAULT_XTOL_REL: Final[float] = 0.0
"""The default relative tolerance for the design variables."""
__DEFAULT_KKT_ABS_TOL: Final[float] = 0.0
"""The default absolute tolerance for the Karush-Kuhn-Tucker (KKT) conditions."""
__DEFAULT_KKT_REL_TOL: Final[float] = 0.0
"""The default relative tolerance for the Karush-Kuhn-Tucker (KKT) conditions."""
__DEFAULT_STOP_CRIT_N_X: Final[int] = 3
"""The default minimum number of iterations to assess tolerance."""
def __init__(self) -> None: # noqa:D107
super().__init__()
self._ftol_abs = self.__DEFAULT_FTOL_ABS
self._ftol_rel = self.__DEFAULT_FTOL_REL
self._xtol_abs = self.__DEFAULT_XTOL_ABS
self._xtol_rel = self.__DEFAULT_XTOL_REL
self.__kkt_abs_tol = self.__DEFAULT_KKT_ABS_TOL
self.__kkt_rel_tol = self.__DEFAULT_KKT_REL_TOL
self.__ref_kkt_norm = None
self._stop_crit_n_x = self.__DEFAULT_STOP_CRIT_N_X
def __algorithm_handles(self, algo_name: str, eq_constraint: bool):
"""Check if the algorithm handles equality or inequality constraints.
Args:
algo_name: The name of the algorithm.
eq_constraint: Whether the constraints are equality ones.
Returns:
Whether the algorithm handles the passed type of constraints.
"""
if algo_name not in self.descriptions:
msg = f"Algorithm {algo_name} not in library {self.__class__.__name__}."
raise KeyError(msg)
if eq_constraint:
return self.descriptions[algo_name].handle_equality_constraints
return self.descriptions[algo_name].handle_inequality_constraints
# TODO: API: rename to algorithm_handles_equality_constraints
[docs]
def algorithm_handles_eqcstr(self, algo_name: str) -> bool:
"""Check if an algorithm handles equality constraints.
Args:
algo_name: The name of the algorithm.
Returns:
Whether the algorithm handles equality constraints.
"""
return self.__algorithm_handles(algo_name, True)
# TODO: API: rename to algorithm_handles_inequality_constraints
[docs]
def algorithm_handles_ineqcstr(self, algo_name: str) -> bool:
"""Check if an algorithm handles inequality constraints.
Args:
algo_name: The name of the algorithm.
Returns:
Whether the algorithm handles inequality constraints.
"""
return self.__algorithm_handles(algo_name, False)
# TODO: API: rename to is_algo_requires_positive_constraints
[docs]
def is_algo_requires_positive_cstr(self, algo_name: str) -> bool:
"""Check if an algorithm requires positive constraints.
Args:
algo_name: The name of the algorithm.
Returns:
Whether the algorithm requires positive constraints.
"""
return self.descriptions[algo_name].positive_constraints
def _check_constraints_handling(
self, algo_name: str, problem: OptimizationProblem
) -> None:
"""Check if problem and algorithm are consistent for constraints handling."""
if problem.has_eq_constraints() and not self.algorithm_handles_eqcstr(
algo_name
):
raise ValueError(
"Requested optimization algorithm "
"%s can not handle equality constraints." % algo_name
)
if problem.has_ineq_constraints() and not self.algorithm_handles_ineqcstr(
algo_name
):
raise ValueError(
"Requested optimization algorithm "
"%s can not handle inequality constraints." % algo_name
)
[docs]
def get_right_sign_constraints(self):
"""Transform the problem constraints into their opposite sign counterpart.
This is done if the algorithm requires positive constraints.
"""
if self.problem.has_ineq_constraints() and self.is_algo_requires_positive_cstr(
self.algo_name
):
return [-cstr for cstr in self.problem.constraints]
return self.problem.constraints
def _pre_run(
self, problem: OptimizationProblem, algo_name: str, **options: Any
) -> None:
"""To be overridden by subclasses.
Specific method to be executed just before _run method call.
The missing current values of the :class:`.DesignSpace` are initialized
with the method :meth:`.DesignSpace.initialize_missing_current_values`.
Args:
problem: The optimization problem.
algo_name: The name of the algorithm.
**options: The options of the algorithm,
see the associated JSON file.
"""
super()._pre_run(problem, algo_name, **options)
self._check_constraints_handling(algo_name, problem)
if self.MAX_ITER in options:
max_iter = options[self.MAX_ITER]
elif (
self.MAX_ITER in self.OPTIONS_MAP
and self.OPTIONS_MAP[self.MAX_ITER] in options
):
max_iter = options[self.OPTIONS_MAP[self.MAX_ITER]]
else:
msg = "Could not determine the maximum number of iterations."
raise ValueError(msg)
self._ftol_rel = options.get(self.F_TOL_REL, self.__DEFAULT_FTOL_REL)
self._ftol_abs = options.get(self.F_TOL_ABS, self.__DEFAULT_FTOL_ABS)
self._xtol_rel = options.get(self.X_TOL_REL, self.__DEFAULT_XTOL_REL)
self._xtol_abs = options.get(self.X_TOL_ABS, self.__DEFAULT_XTOL_ABS)
self.__ineq_tolerance = options.get(self.INEQ_TOLERANCE, problem.ineq_tolerance)
self._stop_crit_n_x = options.get(
self.STOP_CRIT_NX, self.__DEFAULT_STOP_CRIT_N_X
)
self.__kkt_abs_tol = options.get(self._KKT_TOL_ABS, None)
self.__kkt_rel_tol = options.get(self._KKT_TOL_REL, None)
self.init_iter_observer(max_iter)
require_gradient = self.descriptions[self.algo_name].require_gradient
if (
self.__kkt_abs_tol is not None or self.__kkt_rel_tol is not None
) and require_gradient:
problem.add_callback(
self._check_kkt_from_database, each_new_iter=False, each_store=True
)
problem.design_space.initialize_missing_current_values()
if problem.differentiation_method == self.DifferentiationMethod.COMPLEX_STEP:
problem.design_space.to_complex()
# First, evaluate all functions at x_0. Some algorithms don't do this
function_values, _ = self.problem.evaluate_functions(
eval_jac=require_gradient,
eval_obj=True,
eval_observables=False,
normalize=options.get(
self.NORMALIZE_DESIGN_SPACE_OPTION, self._NORMALIZE_DS
),
)
scaling_threshold = options.get(self.SCALING_THRESHOLD)
if scaling_threshold is not None:
self.problem.objective = self.__scale(
self.problem.objective,
function_values[self.problem.objective.name],
scaling_threshold,
)
self.problem.constraints = [
self.__scale(
constraint, function_values[constraint.name], scaling_threshold
)
for constraint in self.problem.constraints
]
self.problem.observables = [
self.__scale(
observable, function_values[observable.name], scaling_threshold
)
for observable in self.problem.observables
]
@classmethod
def _get_unsuitability_reason(
cls,
algorithm_description: OptimizationAlgorithmDescription,
problem: OptimizationProblem,
) -> _UnsuitabilityReason:
reason = super()._get_unsuitability_reason(algorithm_description, problem)
if reason:
return reason
if (
problem.has_eq_constraints()
and not algorithm_description.handle_equality_constraints
):
return _UnsuitabilityReason.EQUALITY_CONSTRAINTS
if (
problem.has_ineq_constraints()
and not algorithm_description.handle_inequality_constraints
):
return _UnsuitabilityReason.INEQUALITY_CONSTRAINTS
if (
problem.pb_type == problem.ProblemType.NON_LINEAR
and algorithm_description.problem_type == problem.ProblemType.LINEAR
):
return _UnsuitabilityReason.NON_LINEAR_PROBLEM
return reason
[docs]
def new_iteration_callback(self, x_vect: ndarray) -> None:
"""Verify the design variable and objective value stopping criteria.
Raises:
FtolReached: If the defined relative or absolute function
tolerance is reached.
XtolReached: If the defined relative or absolute x tolerance
is reached.
"""
# First check if the max_iter is reached and update the progress bar
super().new_iteration_callback(x_vect)
if is_f_tol_reached(
self.problem, self._ftol_rel, self._ftol_abs, self._stop_crit_n_x
):
raise FtolReached
if is_x_tol_reached(
self.problem, self._xtol_rel, self._xtol_abs, self._stop_crit_n_x
):
raise XtolReached
def _check_kkt_from_database(self, x_vect: ndarray) -> None:
"""Verify, if required, KKT norm stopping criterion at each database storage.
Raises:
KKTReached: If the absolute tolerance on the KKT residual is reached.
"""
check_kkt = True
function_names = [
self.problem.get_objective_name(),
*self.problem.get_constraint_names(),
]
database = self.problem.database
for function_name in function_names:
if (
database.get_function_value(
database.get_gradient_name(function_name), x_vect
)
is None
) or (database.get_function_value(function_name, x_vect) is None):
check_kkt = False
break
if check_kkt and (self.__ref_kkt_norm is None):
self.__ref_kkt_norm = kkt_residual_computation(
self.problem, x_vect, self.__ineq_tolerance
)
if check_kkt and is_kkt_residual_norm_reached(
self.problem,
x_vect,
kkt_abs_tol=self.__kkt_abs_tol,
kkt_rel_tol=self.__kkt_rel_tol,
ineq_tolerance=self.__ineq_tolerance,
reference_residual=self.__ref_kkt_norm,
):
raise KKTReached
@staticmethod
def __scale(
function: MDOFunction,
function_value: Mapping[str, ndarray],
scaling_threshold: float,
) -> MDOFunction:
"""Scale a function based on its value on the current design values.
Args:
function: The function.
function_value: The function value of reference for scaling.
scaling_threshold: The threshold on the reference function value
that triggers scaling.
Returns:
The scaled function.
"""
reference_values = numpy.absolute(function_value)
threshold_reached = reference_values > scaling_threshold
if not threshold_reached.any():
return function
scaled_function = function / numpy.where(
threshold_reached, reference_values, 1.0
)
# Use same function name for consistency with name used in database
scaled_function.name = function.name
return scaled_function