# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Contributors:
# INITIAL AUTHORS - API and implementation and/or documentation
# :author: Francois Gallard
# :author: Isabelle Santos
# :author: Giulio Gargantini
"""A discipline for solving ordinary differential equations (ODEs)."""
from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING
from typing import Any
from typing import Final
from gemseo.algos.ode.factory import ODESolverLibraryFactory
from gemseo.algos.ode.ode_problem import ODEProblem
from gemseo.core.discipline.discipline import Discipline
from gemseo.disciplines.ode.ode_function import ODEFunction
from gemseo.utils.constants import READ_ONLY_EMPTY_DICT
from gemseo.utils.data_conversion import concatenate_dict_of_arrays_to_array
from collections.abc import Iterable
from gemseo.algos.ode.base_ode_solver_library import BaseODESolverLibrary
from gemseo.typing import RealArray
from gemseo.typing import StrKeyMapping
class ODEDiscipline(Discipline):
"""A discipline for solving Ordinary Differential Equations (ODE)."""
_rhs_discipline: Discipline
"""The discipline defining the RHS of the ODE."""
termination_event_disciplines: Iterable[Discipline]
"""The disciplines defining the stopping conditions."""
_ode_problem: ODEProblem
"""The ODE problem to be solved."""
_output_trajectory: bool
"""Whether to output both the state trajectories."""
__design_variables_names: Iterable[str]
"""The names of the design variables of the ODE."""
__final_state_names: tuple[str, ...]
"""The names of the variables at final time."""
__final_time_name: str
"""The name of the variable for the final time."""
__initial_state_names: tuple[str, ...]
"""The names of the variables for the initial conditions."""
__initial_time_name: str
"""The name of the variable for the initial time."""
__ode_solver: BaseODESolverLibrary
"""The ODE solver."""
__ode_solver_options: Mapping[str, Any]
"""The options of the ODE solver."""
__state_names: Iterable[str] | Mapping[str, str]
"""The names of the state variables, eventually bound to the
names of their time derivatives."""
__time_name: str
"""The name of the time variable."""
__trajectory_state_names: Iterable[str]
"""The names of the trajectories of the state variables."""
__TERMINATION_TIME: Final[str] = "termination_time"
"""The string constant for termination time."""
__TIMES: Final[str] = "times"
"""The string constant for times."""
def __init__(
rhs_discipline: Discipline,
times: RealArray,
time_name: str = "time",
state_names: Iterable[str] | Mapping[str, str] = (),
initial_state_names: Mapping[str, str] = READ_ONLY_EMPTY_DICT,
initial_time_name: str = "",
final_state_names: Mapping[str, str] = READ_ONLY_EMPTY_DICT,
final_time_name: str = "",
state_trajectory_names: Mapping[str, str] = READ_ONLY_EMPTY_DICT,
return_trajectories: bool = False,
name: str = "",
termination_event_disciplines: Iterable[Discipline] = (),
solve_at_algorithm_times: bool = False,
ode_solver_name: str = "RK45",
**ode_solver_settings: Any,
rhs_discipline: The discipline defining the right-hand side function
of the ODE.
times: Either the initial and final times
or the times of interest where the state must be stored,
including the initial and final times.
When only initial and final times are provided,
the times of interest are the instants chosen by the ODE solver
to compute the state trajectories.
time_name: The name of the time variable.
state_names: Either the names of the state variables,
passed as ``(state_name, ...)``,
or the names of the state variables
bound to the associated ``rhs_discipline`` outputs,
passed as ``{state_name: output_name, ...}``.
If empty, use all the ``rhs_discipline`` inputs.
initial_state_names: The names of the state variables
bound to the names of the variables denoting the initial conditions.
If empty,
use ``"state_initial"`` for a state variable named ``"state"``.
initial_time_name: The name of the variable for the initial time.
If empty, use ``f"initial_{time_name}"``.
final_state_names: The names of the state variables
bound to their names at final time.
If empty,
use ``"state_final"`` for a state variable named ``"state"``.
final_time_name: The name of the variable for the final time.
If empty, use ``f"final_{time_name}"``.
state_trajectory_names: The names of the state variables
bound to the names of their trajectories.
If empty,
use ``"state"`` for a state variable named ``"state"``.
return_trajectories: Whether to output
both the trajectories of the state variables
and their values at final time.
Otherwise, output only their values at final time.
termination_event_disciplines: The disciplines encoding termination events.
Each discipline must have the same inputs as ``rhs_discipline``
and only one output defined as an arrays of size 1
indicating the value of an event function.
The resolution of the ODE problem stops
when one of the event functions crosses the threshold 0.
If empty, the integration covers the entire time interval.
solve_at_algorithm_times: Whether to solve the ODE chosen by the algorithm.
ode_solver_name: The name of the ODE solver.
**ode_solver_settings: The settings of the ODE solver.
ValueError: If an expected state variable does not appear in
""" # noqa: D205, D212, D415
self._rhs_discipline = rhs_discipline
self._output_trajectory = (
return_trajectories or state_trajectory_names or len(times) > 2
self.termination_event_disciplines = termination_event_disciplines
# Define the names of the time variables and initial time variable.
self.__time_name = time_name
self.__initial_time_name = initial_time_name or f"initial_{self.__time_name}"
self.__final_time_name = final_time_name or f"final_{self.__time_name}"
# Define the names of the state variables and their time derivatives
if state_names:
if isinstance(state_names, Mapping):
self.__state_names = state_names.keys()
self.__state_dot_names = state_names.values()
self.__state_names = state_names
self.__state_dot_names = tuple(f"{state}_dot" for state in state_names)
self.__state_names = tuple(
for name in rhs_discipline.io.input_grammar.names
if name != time_name
self.__state_dot_names = tuple(
f"{state}_dot" for state in self.__state_names
missing_names = set(state_names) - set(rhs_discipline.default_input_data)
if missing_names:
msg = f"Missing default inputs in rhs_discipline for {missing_names}."
raise ValueError(msg)
excluded_names = [self.__time_name, *self.__state_names]
self.__design_variables_names = tuple(
for name in rhs_discipline.default_input_data
if name not in excluded_names
self.__initial_state_names = tuple(
initial_state_names.get(state_name, f"initial_{state_name}")
for state_name in state_names
self.__ode_solver = ODESolverLibraryFactory().create(ode_solver_name)
self.__ode_solver_options = ode_solver_settings
mapping_initial_state = {
initial_name: self.local_data.get(
initial_name, rhs_discipline.default_input_data[state_name]
for (initial_name, state_name) in zip(
self.__initial_state_names, self.__state_names
mapping_parameters = {
parameter: self.local_data.get(
parameter, rhs_discipline.default_input_data[parameter]
for parameter in self.__design_variables_names
mapping_inputs = {
self.__initial_time_name: self.local_data.get(
self.__initial_time_name, times[0]
self.__final_time_name: self.local_data.get(
self.__final_time_name, times[-1]
if self._output_trajectory:
mapping_inputs[self.__TIMES] = times
for termination_discipline in self.termination_event_disciplines:
# Define ODEProblem
initial_state = concatenate_dict_of_arrays_to_array(
event_functions = tuple(
ODEFunction(termination_discipline, state_names, time_name, terminal=True)
for termination_discipline in termination_event_disciplines
ode_func = ODEFunction(rhs_discipline, state_names, time_name)
self._ode_problem = ODEProblem(
# Define the names for the trajectories and final states.
self.__final_state_names = tuple(
final_state_names.get(state_name, f"final_{state_name}")
for state_name in state_names
if self._output_trajectory:
self.__trajectory_state_names = tuple(
state_trajectory_names.get(state_name, state_name)
for state_name in state_names
self.__trajectory_state_names = READ_ONLY_EMPTY_DICT
# Initialize inputs and outputs
self.default_input_data = mapping_inputs
output_names = [
if return_trajectories:
def _run(self, input_data: StrKeyMapping) -> StrKeyMapping | None:
mapping_parameters = {
k: self.local_data[k] for k in self.__design_variables_names
initial_time=self.local_data.get(self.__initial_time_name, None),
final_time=self.local_data.get(self.__final_time_name, None),
times=self.local_data.get(self.__TIMES, None),
for termination_discipline in self.termination_event_disciplines:
self._ode_problem.initial_state = concatenate_dict_of_arrays_to_array(
input_data, names=self.__initial_state_names
self.__ode_solver.execute(self._ode_problem, **self.__ode_solver_options)
result = self._ode_problem.result
if not result.algorithm_has_converged:
msg = (
f"ODE solver {result.algorithm_name} failed to converge. "
f"Message = {result.algorithm_termination_message}"
raise RuntimeError(msg)
output_data = dict(zip(self.__final_state_names, result.final_state))
if self._output_trajectory:
dict(zip(self.__trajectory_state_names, result.state_trajectories))
output_data[self.__TERMINATION_TIME] = result.termination_time
output_data[self.__TIMES] = result.times
return output_data