# Source code for gemseo.utils.derivatives.complex_step

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
#    INITIAL AUTHORS - API and implementation and/or documentation
#       :author : Francois Gallard
#    OTHER AUTHORS   - MACROSCOPIC CHANGES

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar

from numpy import bool_
from numpy import complex128
from numpy import dtype
from numpy import ndarray
from numpy import where
from numpy import zeros
from numpy.linalg import norm

from gemseo.core.parallel_execution.callable_parallel_execution import (
CallableParallelExecution,
)
from gemseo.utils.derivatives.approximation_modes import ApproximationMode

if TYPE_CHECKING:
from collections.abc import Sequence

[docs]
r"""Complex step approximator, performing a second-order gradient calculation.

Enable a much lower step than real finite differences,
typically 1e-30,
since there is no cancellation error due to a difference calculation.

.. math::

\frac{df(x)}{dx} \approx Im\left( \frac{f(x+j*\\delta x)}{\\delta x} \right)

See
Martins, Joaquim RRA, Peter Sturdza, and Juan J. Alonso.
"The complex-step derivative approximation."
ACM Transactions on Mathematical Software (TOMS) 29.3 (2003): 245-262.
"""

_APPROXIMATION_MODE = ApproximationMode.COMPLEX_STEP

_DEFAULT_STEP: ClassVar[complex] = 1e-20

def step(self, value) -> None:  # noqa:D102
if value.imag != 0:
self._step = value.imag
else:
self._step = value

[docs]
self,
x_vect: ndarray,
step: complex | None = None,
x_indices: Sequence[int] | None = None,
**kwargs: Any,
) -> ndarray:
if norm(x_vect.imag) != 0.0:
msg = (
"Impossible to check the gradient at a complex "
"point using the complex step method."
)
raise ValueError(msg)

self,
input_values: ndarray,
n_perturbations: int,
input_perturbations: ndarray,
step: float,
**kwargs: Any,
) -> list[ndarray]:
self._function_kwargs = kwargs
functions = [self._wrap_function] * n_perturbations
parallel_execution = CallableParallelExecution(functions, **self._parallel_args)

perturbed_inputs: list[ndarray[Any, dtype[bool_]]] = [
input_values + input_perturbations[:, perturbation_index]
for perturbation_index in range(n_perturbations)
]
perturbed_outputs = parallel_execution.execute(perturbed_inputs)

return [
perturbed_outputs[perturbation_index].imag
/ input_perturbations[perturbation_index, perturbation_index].imag
for perturbation_index in range(n_perturbations)
]

self,
input_values: ndarray,
n_perturbations: int,
input_perturbations: ndarray,
step: float,
**kwargs: Any,
) -> ndarray:
for perturbation_index in range(n_perturbations):
perturbated_input = (
input_values + input_perturbations[:, perturbation_index]
)
perturbated_output = self.f_pointer(perturbated_input, **kwargs)
perturbated_output.imag
/ input_perturbations[perturbation_index, perturbation_index].imag
)

def _generate_perturbations(
self,
input_values: ndarray,
input_indices: list[int],
step: float,
) -> tuple[ndarray, float | ndarray]:
input_dimension = len(input_values)
n_indices = len(input_indices)
input_perturbations = zeros((input_dimension, n_indices), dtype=complex128)
x_nnz = where(input_values == 0.0, 1.0, input_values)[input_indices]
input_perturbations[input_indices, range(n_indices)] = 1j * x_nnz * step
return input_perturbations, step