"""Various termination criteria for drivers."""

from __future__ import annotations

from abc import abstractmethod
from dataclasses import dataclass
from dataclasses import field
from typing import TYPE_CHECKING
from typing import Any
from typing import Final

from numpy import all as np_all
from numpy import allclose
from numpy import average
from numpy import bool_
from numpy import ndarray
from numpy.linalg import norm

from gemseo.algos.lagrange_multipliers import LagrangeMultipliers

    from gemseo.algos.optimization_problem import OptimizationProblem

[docs] class TerminationCriterion(Exception): # noqa: N818 """Stop driver for some reason."""
[docs] class FunctionIsNan(TerminationCriterion): # noqa: N818 """Stops driver when a function has NaN value or NaN Jacobian."""
[docs] class DesvarIsNan(TerminationCriterion): # noqa: N818 """Stops driver when the design variables are nan."""
[docs] class MaxIterReachedException(TerminationCriterion): # noqa: N818 """Exception raised when the maximum number of iterations is reached."""
[docs] class MaxTimeReached(TerminationCriterion): # noqa: N818 """Exception raised when the maximum execution time is reached."""
[docs] class FtolReached(TerminationCriterion): # noqa: N818 """Exception raised when the f_tol_rel or f_tol_abs criteria is reached."""
[docs] class XtolReached(TerminationCriterion): # noqa: N818 """Exception raised when the x_tol_rel or x_tol_abs criteria is reached."""
[docs] class KKTReached(TerminationCriterion): """A termination criterion based on the Karush-Kuhn-Tucker (KKT) residual norm."""
KKT_RESIDUAL_NORM: Final[str] = "KKT residual norm" """The name to store the KKT residual norm in a database."""
[docs] @dataclass class BaseToleranceTester: """The base class to test the tolerance with respect to a reference value. The reference value corresponds to the coordinate-wise average of the values associated to the last iterations. """ absolute: float = 0.0 """The absolute tolerance.""" relative: float = 0.0 """The relative tolerance.""" n_last_iterations: int = 3 """The number of last points to compute the reference.""" termination_criterion: TerminationCriterion = field(init=False) """The termination criterion."""
[docs] def check( self, problem: OptimizationProblem, raise_exception: bool = False, **kwargs: Any ) -> bool: """Check whether the tolerance criterion is met. Args: problem: The optimization problem to which the database is attached. raise_exception: Whether to raise an exception when the tolerance criterion is not met. **kwargs: The options of the tester. Returns: Whether the tolerance criterion is not met. Raises: TerminationCriterion: When the tolerance criterion is not met and ``raise_exception`` is ``True``. """ tolerance_criterion_is_reached = self._check(problem, **kwargs) if raise_exception and tolerance_criterion_is_reached: raise self.termination_criterion return tolerance_criterion_is_reached
@abstractmethod def _check(self, problem: OptimizationProblem, *args: Any, **kwargs: Any) -> bool: """Check whether the tolerance criterion is met. Args: problem: The optimization problem to which the database is attached. **kwargs: The options of the tester. Returns: Whether the tolerance criterion is not met. """
[docs] @dataclass class ObjectiveToleranceTester(BaseToleranceTester): """A tolerance tester for the objective.""" termination_criterion: TerminationCriterion = field(default=FtolReached, init=False) def _check(self, problem: OptimizationProblem, *args: Any, **kwargs: Any) -> bool: # noqa: D102 database = problem.database if len(database) < self.n_last_iterations: return False # Checks that there is at least one feasible point x_values = database.get_last_n_x_vect(self.n_last_iterations) if not any( problem.constraints.is_point_feasible(database[x_val]) for x_val in x_values ): return False obj_name = f_values = [ f_value for f_value in [ database.get_function_value(obj_name, x_val) for x_val in x_values ] if f_value is not None ] if len(f_values) < self.n_last_iterations: return False f_average = average(f_values) return np_all([ allclose(f_val, f_average, atol=self.absolute, rtol=self.relative) for f_val in f_values ])
[docs] @dataclass class DesignToleranceTester(BaseToleranceTester): """A tolerance tester for the design_vector.""" termination_criterion: TerminationCriterion = field(default=XtolReached, init=False) def _check(self, problem: OptimizationProblem, *args: Any, **kwargs: Any) -> bool: # noqa: D102 database = problem.database if len(database) < self.n_last_iterations: return False x_values = database.get_last_n_x_vect(self.n_last_iterations) # Checks that there is at least one feasible point if not any( problem.constraints.is_point_feasible(database[x_val]) for x_val in x_values ): return False x_average = average(x_values, axis=0) return np_all([ allclose(x_val, x_average, atol=self.absolute, rtol=self.relative) for x_val in x_values ])
[docs] @dataclass class KKTConditionsTester(BaseToleranceTester): """A tester for the Karush-Kuhn-Tucker (KKT) conditions.""" termination_criterion: TerminationCriterion = field(default=KKTReached, init=False) ineq_tolerance: float = 0.0 """The tolerance for the inequality constraints.""" kkt_norm: float = 0.0 """The reference KKT norm.""" def _check(self, problem: OptimizationProblem, input_vector: ndarray) -> bool: # noqa: D102 return kkt_residual_computation( problem, input_vector, self.ineq_tolerance ) <= max(self.absolute, self.relative * self.kkt_norm)
[docs] def kkt_residual_computation( opt_problem: OptimizationProblem, x_vect: ndarray, ineq_tolerance: float = 1e-4, ) -> float: """Compute the KKT residual norm. This implementation is inspired from Svanberg Matlab implementation of MMA algorithm see :cite:`svanberg1998method` Args: opt_problem: The optimization problem containing an optimization history. x_vect: The design point vector where the KKT conditions are tested. ineq_tolerance: The tolerance to consider a constraint as active. Returns: The KKT residual norm. """ res = opt_problem.database.get_function_value(KKT_RESIDUAL_NORM, x_vect) if res is not None: return res lagrange = LagrangeMultipliers(opt_problem) if opt_problem.constraints: lagrange.compute(x_vect, ineq_tolerance=ineq_tolerance) res = lagrange.kkt_residual + lagrange.constraint_violation, {KKT_RESIDUAL_NORM: res}) return res res = norm(lagrange.get_objective_jacobian(x_vect)), {KKT_RESIDUAL_NORM: res}) return res
[docs] def is_x_tol_reached( opt_problem: OptimizationProblem, x_tol_rel: float = 1e-6, x_tol_abs: float = 1e-6, n_x: int = 2, ) -> bool | bool_: """Tests if the tolerance on the design variables are reached. The coordinate wise average of the last n_x points are taken Then it is checked that all points are within the distance of the center with relative and absolute tolerances specified by the user. Parameters ---------- opt_problem: OptimizationProblem the optimization problem containing the iterations x_tol_rel: float relative tolerance x_tol_abs: float absolute tolerance n_x: int number of design vectors to account for """ tester = DesignToleranceTester( absolute=x_tol_abs, relative=x_tol_rel, n_last_iterations=n_x ) return tester.check(opt_problem)
[docs] def is_f_tol_reached( opt_problem: OptimizationProblem, f_tol_rel: float = 1e-6, f_tol_abs: float = 1e-6, n_x: int = 2, ) -> bool | bool_: """Tests if the tolerance on the objective function are reached. The average function value of the last n_x points are taken Then it is checked that all points are within the distance of the center with relative and absolute tolerances specified by the user. Parameters ---------- opt_problem: OptimizationProblem the optimization problem containing the iterations f_tol_rel: float relative tolerance f_tol_abs: float absolute tolerance n_x: int number of design vectors to account for """ tester = ObjectiveToleranceTester( absolute=f_tol_abs, relative=f_tol_rel, n_last_iterations=n_x ) return tester.check(opt_problem)
[docs] def is_kkt_residual_norm_reached( opt_problem: OptimizationProblem, x_vect: ndarray, kkt_abs_tol: float = 0.0, kkt_rel_tol: float = 0.0, ineq_tolerance: float = 1e-4, reference_residual: float = 1.0, ) -> bool: """Test if the KKT conditions are satisfied. Args: opt_problem: The optimization problem containing an optimization history. x_vect: The design point vector where the KKT conditions are tested. kkt_abs_tol: The absolute tolerance on the KKT condition residual. kkt_rel_tol: The relative tolerance on the KKT condition residual. ineq_tolerance: The tolerance to consider a constraint as active. reference_residual: The reference KKT condition residual. Returns: Whether the absolute or the relative KKT residual norm criterion is reached. """ tester = KKTConditionsTester( absolute=kkt_abs_tol, relative=kkt_rel_tol, ineq_tolerance=ineq_tolerance, kkt_norm=reference_residual, ) return tester.check(opt_problem, input_vector=x_vect)