Source code for gemseo.algos.ode.scipy_ode.scipy_ode

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
#    INITIAL AUTHORS - API and implementation and/or documentation
#        :author: Isabelle Santos
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
"""Wrappers for SciPy's ODE solvers.

ODE stands for ordinary differential equation.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar
from typing import Final

from numpy import newaxis
from scipy.integrate import solve_ivp

from gemseo.algos.ode.base_ode_solver_library import BaseODESolverLibrary
from gemseo.algos.ode.base_ode_solver_library import ODESolverDescription
from gemseo.algos.ode.base_ode_solver_settings import BaseODESolverSettings
from gemseo.algos.ode.scipy_ode.settings.base_scipy_ode_jac_settings import (
    BaseScipyODESolverJacSettings,
)
from gemseo.algos.ode.scipy_ode.settings.bdf import BDF_Settings
from gemseo.algos.ode.scipy_ode.settings.dop853 import DOP853_Settings
from gemseo.algos.ode.scipy_ode.settings.lsoda import LSODA_Settings
from gemseo.algos.ode.scipy_ode.settings.radau import Radau_Settings
from gemseo.algos.ode.scipy_ode.settings.rk23 import RK23_Settings
from gemseo.algos.ode.scipy_ode.settings.rk45 import RK45_Settings

if TYPE_CHECKING:
    from gemseo.algos.ode.ode_problem import ODEProblem
    from gemseo.algos.ode.ode_result import ODEResult

LOGGER = logging.getLogger(__name__)


[docs] class ScipyODESolverDescription(ODESolverDescription): """The description of SciPy ODE solvers.""" library_name: str = "SciPy ODE" """The name of the wrapped library."""
[docs] class ScipyODEAlgos(BaseODESolverLibrary): """Wrapper for SciPy's ODE solvers. ODE stands for ordinary differential equation. """ __DOC: Final[str] = "https://docs.scipy.org/doc/scipy/reference/generated/" ALGORITHM_INFOS: ClassVar[dict[str, ScipyODESolverDescription]] = { "RK45": ScipyODESolverDescription( algorithm_name="RK45", internal_algorithm_name="RK45", description="Explicit Runge-Kutta method of order 5(4)", website=f"{__DOC}scipy.integrate.RK45.html", Settings=RK45_Settings, ), "RK23": ScipyODESolverDescription( algorithm_name="RK23", internal_algorithm_name="RK23", description="Explicit Runge-Kutta method of order 3(2)", website=f"{__DOC}scipy.integrate.RK23.html", Settings=RK23_Settings, ), "DOP853": ScipyODESolverDescription( algorithm_name="DOP853", internal_algorithm_name="DOP853", description="Explicit Runge-Kutta method of order 8", website=f"{__DOC}scipy.integrate.DOP853.html", Settings=DOP853_Settings, ), "Radau": ScipyODESolverDescription( algorithm_name="Radau", internal_algorithm_name="Radau", description="Implicit Runge-Kutta method of the Radau IIA type of order 5", website=f"{__DOC}scipy.integrate.Radau.html", Settings=Radau_Settings, ), "BDF": ScipyODESolverDescription( algorithm_name="BDF", internal_algorithm_name="BDF", description=( "Implicit multi-step variable-order (1 to 5) method based on a backward" " differentiation formula for the derivative approximation" ), website=f"{__DOC}scipy.integrate.BDF.html", Settings=BDF_Settings, ), "LSODA": ScipyODESolverDescription( algorithm_name="LSODA", internal_algorithm_name="LSODA", description="Adams/BDF method with automatic stiffness detection/switching", website=f"{__DOC}scipy.integrate.LSODA.html", Settings=LSODA_Settings, ), } def _run(self, problem: ODEProblem, **settings: Any) -> ODEResult: settings_ = self._filter_settings( settings, model_to_exclude=BaseODESolverSettings ) if issubclass( self.ALGORITHM_INFOS[self.algo_name].Settings, BaseScipyODESolverJacSettings ): settings_["jac"] = problem.jac_function_wrt_state if problem.compute_trajectory and not problem.solve_at_algorithm_times: settings_["t_eval"] = problem.evaluation_times solution = solve_ivp( fun=problem.rhs_function, y0=problem.initial_state, method=self._algo_name, t_span=problem.time_interval, events=problem.event_functions, **settings_, ) problem.result.algorithm_name = self._algo_name problem.result.algorithm_settings = settings_ problem.result.algorithm_has_converged = solution.status >= 0 problem.result.algorithm_termination_message = solution.message problem.result.times = solution.t problem.result.n_func_evaluations = solution.nfev problem.result.n_jac_evaluations = solution.njev if not problem.result.algorithm_has_converged: LOGGER.warning(solution.message) index = None if solution.t_events: for i, t_event in enumerate(solution.t_events): if len(t_event): index = i break if problem.compute_trajectory: problem.result.state_trajectories = solution.y problem.result.terminal_event_index = index if problem.event_functions and index is not None: problem.result.termination_time = solution.t_events[index] problem.result.final_state = solution.y_events[index][0][:, newaxis] else: problem.result.termination_time = solution.t[-1:] problem.result.final_state = solution.y[:, -1][:, newaxis] def _get_result(self, problem: ODEProblem) -> ODEResult: return problem.result