# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Contributors:
# INITIAL AUTHORS - API and implementation and/or documentation
# :author : Francois Gallard
# OTHER AUTHORS - MACROSCOPIC CHANGES
"""Gradient approximation by complex step."""
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar
from numpy import bool_
from numpy import complex128
from numpy import dtype
from numpy import ndarray
from numpy import where
from numpy import zeros
from numpy.linalg import norm
from gemseo.core.parallel_execution.callable_parallel_execution import (
CallableParallelExecution,
)
from gemseo.utils.derivatives.approximation_modes import ApproximationMode
from gemseo.utils.derivatives.gradient_approximator import GradientApproximator
if TYPE_CHECKING:
from collections.abc import Sequence
[docs]
class ComplexStep(GradientApproximator):
r"""Complex step approximator, performing a second-order gradient calculation.
Enable a much lower step than real finite differences,
typically 1e-30,
since there is no cancellation error due to a difference calculation.
.. math::
\frac{df(x)}{dx} \approx Im\left( \frac{f(x+j*\\delta x)}{\\delta x} \right)
See
Martins, Joaquim RRA, Peter Sturdza, and Juan J. Alonso.
"The complex-step derivative approximation."
ACM Transactions on Mathematical Software (TOMS) 29.3 (2003): 245-262.
"""
_APPROXIMATION_MODE = ApproximationMode.COMPLEX_STEP
_DEFAULT_STEP: ClassVar[complex] = 1e-20
@GradientApproximator.step.setter
def step(self, value) -> None: # noqa:D102
if value.imag != 0:
self._step = value.imag
else:
self._step = value
[docs]
def f_gradient( # noqa:D102
self,
x_vect: ndarray,
step: complex | None = None,
x_indices: Sequence[int] | None = None,
**kwargs: Any,
) -> ndarray:
if norm(x_vect.imag) != 0.0:
msg = (
"Impossible to check the gradient at a complex "
"point using the complex step method."
)
raise ValueError(msg)
return super().f_gradient(x_vect, step=step, x_indices=x_indices, **kwargs)
def _compute_parallel_grad(
self,
input_values: ndarray,
n_perturbations: int,
input_perturbations: ndarray,
step: float,
**kwargs: Any,
) -> list[ndarray]:
self._function_kwargs = kwargs
functions = [self._wrap_function] * n_perturbations
parallel_execution = CallableParallelExecution(functions, **self._parallel_args)
perturbed_inputs: list[ndarray[Any, dtype[bool_]]] = [
input_values + input_perturbations[:, perturbation_index]
for perturbation_index in range(n_perturbations)
]
perturbed_outputs = parallel_execution.execute(perturbed_inputs)
return [
perturbed_outputs[perturbation_index].imag
/ input_perturbations[perturbation_index, perturbation_index].imag
for perturbation_index in range(n_perturbations)
]
def _compute_grad(
self,
input_values: ndarray,
n_perturbations: int,
input_perturbations: ndarray,
step: float,
**kwargs: Any,
) -> ndarray:
gradient = []
for perturbation_index in range(n_perturbations):
perturbated_input = (
input_values + input_perturbations[:, perturbation_index]
)
perturbated_output = self.f_pointer(perturbated_input, **kwargs)
gradient.append(
perturbated_output.imag
/ input_perturbations[perturbation_index, perturbation_index].imag
)
return gradient
def _generate_perturbations(
self,
input_values: ndarray,
input_indices: list[int],
step: float,
) -> tuple[ndarray, float | ndarray]:
input_dimension = len(input_values)
n_indices = len(input_indices)
input_perturbations = zeros((input_dimension, n_indices), dtype=complex128)
x_nnz = where(input_values == 0.0, 1.0, input_values)[input_indices]
input_perturbations[input_indices, range(n_indices)] = 1j * x_nnz * step
return input_perturbations, step