# Source code for gemseo.mlearning.regression.rbf

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
#    INITIAL AUTHORS - initial API and implementation and/or initial
#                         documentation
#        :author: Francois Gallard, Matthias De Lozzo
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
r"""The RBF network for regression.

The radial basis function surrogate discipline expresses the model output
as a weighted sum of kernel functions centered on the learning input data:

.. math::

y = w_1K(\|x-x_1\|;\epsilon) + w_2K(\|x-x_2\|;\epsilon) + \ldots
+ w_nK(\|x-x_n\|;\epsilon)

and the coefficients :math:(w_1, w_2, \ldots, w_n) are estimated
by least squares minimization.

Dependence
----------
The RBF model relies on the Rbf class of the
scipy library
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.Rbf.html>_.
"""

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Callable
from typing import ClassVar
from typing import Final
from typing import Union

from numpy import average
from numpy import exp
from numpy import finfo
from numpy import log
from numpy import newaxis
from numpy import sqrt
from numpy.linalg import norm
from scipy.interpolate import Rbf
from strenum import StrEnum

from gemseo.mlearning.core.supervised import SavedObjectType as _SavedObjectType
from gemseo.mlearning.regression.regression import BaseMLRegressionAlgo

if TYPE_CHECKING:
from collections.abc import Iterable

from gemseo.datasets.io_dataset import IODataset
from gemseo.mlearning.core.ml_algo import TransformerType
from gemseo.typing import RealArray

SavedObjectType = Union[_SavedObjectType, float, Callable]

[docs]
class RBFRegressor(BaseMLRegressionAlgo):
r"""Regression based on radial basis functions (RBFs).

This model relies on the SciPy class :class:scipy.interpolate.Rbf.
"""

der_function: Callable[[RealArray], RealArray]
"""The derivative of the radial basis function."""

y_average: RealArray
"""The mean of the learning output data."""

SHORT_ALGO_NAME: ClassVar[str] = "RBF"
LIBRARY: ClassVar[str] = "SciPy"

EUCLIDEAN: Final[str] = "euclidean"

[docs]
class Function(StrEnum):

GAUSSIAN = "gaussian"
LINEAR = "linear"
CUBIC = "cubic"
QUINTIC = "quintic"
THIN_PLATE = "thin_plate"

def __init__(
self,
data: IODataset,
transformer: TransformerType = BaseMLRegressionAlgo.IDENTITY,
input_names: Iterable[str] | None = None,
output_names: Iterable[str] | None = None,
function: Function | Callable[[float, float], float] = Function.MULTIQUADRIC,
der_function: Callable[[RealArray], RealArray] | None = None,
epsilon: float | None = None,
smooth: float = 0.0,
norm: str | Callable[[RealArray, RealArray], float] = "euclidean",
) -> None:
r"""
Args:
function: The radial basis function taking a radius :math:r as input,
representing a distance between two points.
If it is a string,
then it must be one of the following:

- "multiquadric" for :math:\sqrt{(r/\epsilon)^2 + 1},
- "inverse" for :math:1/\sqrt{(r/\epsilon)^2 + 1},
- "gaussian" for :math:\exp(-(r/\epsilon)^2),
- "linear" for :math:r,
- "cubic" for :math:r^3,
- "quintic" for :math:r^5,
- "thin_plate" for :math:r^2\log(r).

If it is a callable,
then it must take the two arguments self and r as inputs,
e.g. lambda self, r: sqrt((r/self.epsilon)**2 + 1)
The epsilon parameter will be available as self.epsilon.
Other keyword arguments passed in will be available as well.
der_function: The derivative of the radial basis function,
only to be provided if function is a callable
and if the use of the model with its derivative is required.
If None and if function is a callable,
an error will be raised.
If None and if function is a string,
the class will look for its internal implementation
and will raise an error if it is missing.
The der_function shall take three arguments
(input_data, norm_input_data, eps).
For an RBF of the form function(:math:r),
der_function(:math:x, :math:|x|, :math:\epsilon) shall
return :math:\epsilon^{-1} x/|x| f'(|x|/\epsilon).
If None, use the average distance between input data.
smooth: The degree of smoothness,
0 involving an interpolation of the learning points.
norm: The distance metric to be used,
either a distance function name known by SciPy
<https://docs.scipy.org/doc/scipy/reference/generated/
scipy.spatial.distance.cdist.html>_
or a function that computes the distance between two points.
"""  # noqa: D205 D212 D415
super().__init__(
data,
transformer=transformer,
input_names=input_names,
output_names=output_names,
function=function,
epsilon=epsilon,
smooth=smooth,
norm=norm,
)
self.y_average = 0.0
self.der_function = der_function

[docs]
class RBFDerivatives:
r"""Derivatives of functions used in :class:.RBFRegressor.

For an RBF of the form :math:f(r), :math:r scalar,
the derivative functions are defined by :math:d(f(r))/dx,
with :math:r=|x|/\epsilon. The functions are thus defined
by :math:df/dx = \epsilon^{-1} x/|x| f'(|x|/\epsilon).
This convention is chosen to avoid division by :math:|x| when
the terms may be cancelled out, as :math:f'(r) often has a term
in :math:r.
"""

TOL = finfo(float).eps

[docs]
@classmethod
cls,
input_data: RealArray,
norm_input_data: float,
eps: float,
) -> RealArray:
r"""Compute derivative of :math:f(r) = \sqrt{r^2 + 1} w.r.t. :math:x.

Args:
input_data: The 1D input data.
norm_input_data: The norm of the input variable.
eps: The correlation length.

Returns:
The derivative of the function.
"""
return input_data / eps**2 / sqrt((norm_input_data / eps) ** 2 + 1)

[docs]
@classmethod
cls,
input_data: RealArray,
norm_input_data: float,
eps: float,
) -> RealArray:
r"""Compute derivative of :math:f(r)=1/\sqrt{r^2 + 1} w.r.t. :math:x.

Args:
input_data: The 1D input data.
norm_input_data: The norm of the input variable.
eps: The correlation length.

Returns:
The derivative of the function.
"""
return -input_data / eps**2 / ((norm_input_data / eps) ** 2 + 1) ** 1.5

[docs]
@classmethod
def der_gaussian(
cls,
input_data: RealArray,
norm_input_data: float,
eps: float,
) -> RealArray:
r"""Compute derivative of :math:f(r)=\exp(-r^2) w.r.t. :math:x.

Args:
input_data: The 1D input data.
norm_input_data: The norm of the input variable.
eps: The correlation length.

Returns:
The derivative of the function.
"""
return -2 * input_data / eps**2 * exp(-((norm_input_data / eps) ** 2))

[docs]
@classmethod
def der_linear(
cls,
input_data: RealArray,
norm_input_data: float,
eps: float,
) -> RealArray:
"""Compute derivative of :math:f(r)=r w.r.t. :math:x.

If :math:x=0, return 0 (determined up to a tolerance).

Args:
input_data: The 1D input data.
norm_input_data: The norm of the input variable.
eps: The correlation length.

Returns:
The derivative of the function.
"""
return (
(norm_input_data > cls.TOL)
* input_data
/ eps
/ (norm_input_data + cls.TOL)
)

[docs]
@classmethod
def der_cubic(
cls,
input_data: RealArray,
norm_input_data: float,
eps: float,
) -> RealArray:
"""Compute derivative w.r.t. :math:x of the function :math:f(r) = r^3.

Args:
input_data: The 1D input data.
norm_input_data: The norm of the input variable.
eps: The correlation length.

Returns:
The derivative of the function.
"""
return 3 * norm_input_data * input_data / eps**3

[docs]
@classmethod
def der_quintic(
cls,
input_data: RealArray,
norm_input_data: float,
eps: float,
) -> RealArray:
"""Compute derivative w.r.t. :math:x of the function :math:f(r) = r^5.

Args:
input_data: The 1D input data.
norm_input_data : The norm of the input variable.
eps: The correlation length.

Returns:
The derivative of the function.
"""
return 5 * norm_input_data**3 * input_data / eps**5

[docs]
@classmethod
def der_thin_plate(
cls,
input_data: RealArray,
norm_input_data: float,
eps: float,
) -> RealArray:
r"""Compute derivative of :math:f(r) = r^2\log(r) w.r.t. :math:x.

If :math:x=0, return 0 (determined up to a tolerance).

Args:
input_data: The 1D input data.
norm_input_data: The norm of the input variable.
eps: The correlation length.

Returns:
The derivative of the function.
"""
return (
(norm_input_data > cls.TOL)
* input_data
/ eps**2
* (1 + 2 * log(norm_input_data / eps + cls.TOL))
)

def _fit(
self,
input_data: RealArray,
output_data: RealArray,
) -> None:
self.y_average = average(output_data, axis=0)
output_data -= self.y_average
args = [*list(input_data.T), output_data]
self.algo = Rbf(
*args,
mode="N-D",
function=self.parameters["function"],
epsilon=self.parameters["epsilon"],
smooth=self.parameters["smooth"],
norm=self.parameters["norm"],
)

def _predict(
self,
input_data: RealArray,
) -> RealArray:
return self.algo(*input_data.T).reshape((len(input_data), -1)) + self.y_average

def _predict_jacobian(
self,
input_data: RealArray,
) -> RealArray:
self._check_available_jacobian()
der_func = self.der_function or getattr(
self.RBFDerivatives, f"der_{self.function}"
)
#             predict_samples                        learn_samples
# Dimensions : ( n_samples , n_outputs , n_inputs , n_learn_samples )
# input_data : ( n_samples ,           , n_inputs ,                 )
# ref_points : (           ,           , n_inputs , n_learn_samples )
# nodes      : (           , n_outputs ,          , n_learn_samples )
# jacobians  : ( n_samples , n_outputs , n_inputs ,                 )
ref_points = self.algo.xi[newaxis, newaxis]
nodes = self.algo.nodes.T[newaxis, :, newaxis]
input_data = input_data[:, newaxis, :, newaxis]
diffs = input_data - ref_points
dists = norm(diffs, axis=2)[:, :, newaxis]
return (nodes * der_func(diffs, dists, eps=self.algo.epsilon)).sum(-1)

def _check_available_jacobian(self) -> None:
"""Check if the Jacobian is available for the given setup.

Raises:
NotImplementedError: Either if the Jacobian computation is not implemented
or if the derivative of the radial basis function is missing.
"""
if self.algo.norm != self.EUCLIDEAN:
msg = "Jacobian is only implemented for Euclidean norm."
raise NotImplementedError(msg)

if callable(self.function) and self.der_function is None:
msg = (
"No der_function is provided."
)
raise NotImplementedError(msg)

def _get_objects_to_save(self) -> dict[str, SavedObjectType]:
objects = super()._get_objects_to_save()
objects["y_average"] = self.y_average
objects["der_function"] = self.der_function
return objects

@property
def function(self) -> str:
"""The name of the kernel function.

The name is possibly different from self.parameters['function'], as it is mapped
(scipy). Examples: