# Source code for gemseo_umdo.use_cases.spring_mass_model.model

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""The |g|-free spring-mass model."""
from __future__ import annotations

from typing import Sequence

from numpy import arange
from numpy.typing import NDArray
from scipy.integrate import odeint

[docs]class SpringMassModel:
r"""The |g|-free spring-mass model :math:m\frac{d^2z(t)}{dt^2} = -kz(t) + mg.

This model computes the time displacement of an object attached to a spring
in function of the stiffness of the spring.

It computes also its maximum displacement.

The equations are

.. math::

m\frac{d^2z}{dt^2} = -kz + mg
"""

def __init__(
self,
mass: float = 1.5,
initial_state: tuple[float, float] = (0, 0),
initial_time: float = 0.0,
final_time: float = 10.0,
time_step: float = 0.1,
gravity: float = 9.8,
) -> None:
"""
Args:
mass: The mass of the object.
initial_state: The initial position and velocity of the object.
initial_time: The initial time.
final_time: The final time.
time_step: The time step.
gravity: The gravity acceleration.
"""  # noqa: D205 D212 D415
self.__mass = mass
self.__gravity = gravity
self.__initial_state = initial_state
self.__time = arange(initial_time, final_time, time_step)
self.__cost = 1.0 / time_step

@property
def cost(self) -> float:
"""The evaluation cost."""
return self.__cost

def __call__(self, stiffness: float = 2.25) -> tuple[NDArray[float], float]:
"""Compute the displacement of the object w.r.t. the stiffness of the spring.

Args:
stiffness: The stiffness of the spring.

Returns:
The displacement of the object at the different times,
as well as its maximum displacement.
"""
displacements = odeint(
self.__integration_func,
self.__initial_state,
self.__time,
args=(stiffness, self.__mass, self.__gravity),
)[:, 0]
return (displacements, max(displacements))

@staticmethod
def __integration_func(
state: Sequence[float], t: float, k: float, m: float, g: float
) -> list[float, float]:
"""Compute the derivative of the state (velocity/acceleration) at a given time.

Args:
state: The velocity and acceleration of the object.
t: The time.
k: The stiffness of the spring.
m: The mass of the object.
g: The gravity acceleration.

Returns:
The derivative of the velocity,
the derivative of the acceleration.
"""
return [state, -k * state / m + g]