# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
#    INITIAL AUTHORS - API and implementation and/or documentation
#        :author: Francois Gallard
#        :author: Damien Guenot
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
"""Plot the derivatives of the functions."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from matplotlib import pyplot
from numpy import arange
from numpy import atleast_2d
from numpy import ndarray
from numpy import where

from gemseo.post.opt_post_processor import OptPostProcessor
from gemseo.utils.string_tools import repr_variable

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Mapping

from matplotlib.figure import Figure

LOGGER = logging.getLogger(__name__)

[docs]
"""Derivatives of the objective and constraints at a given iteration."""

DEFAULT_FIG_SIZE = (10.0, 10.0)

def _plot(
self,
iteration: int | None = None,
) -> None:
"""
Args:
iteration: The iteration to plot the sensitivities.
Can use either positive or negative indexing,
e.g. 5 for the 5-th iteration
or -2 for the penultimate one.
If None, use the iteration of the optimum.
scale_gradients: If True, normalize each gradient
w.r.t. the design variables.
selected iteration if they were not computed by the algorithm.

.. warning::
Activating this option may add considerable computation time
depending on the cost of the gradient evaluation.
This option will not compute the gradients if the
:class:.OptimizationProblem instance was imported from an HDF5
file. This option requires an :class:.OptimizationProblem with a
"""  # noqa: D205, D212, D415
if iteration is None:
design_value = self.opt_problem.solution.x_opt
else:
design_value = self.opt_problem.database.get_x_vect(iteration)

fig = self.__generate_subplots(
self._get_design_variable_names(),
design_value,
design_value,
),
)

self,
design_value: ndarray,
) -> dict[str, ndarray]:
"""Return the gradients of all the output variable at a given design value.

Args:
design_value: The value of the design vector.
w.r.t. the design variables.
selected iteration if they were not computed by the algorithm.

.. warning::
Activating this option may add considerable computation time
depending on the cost of the gradient evaluation.
This option will not compute the gradients if the
:class:.OptimizationProblem instance was imported from an HDF5
file. This option requires an :class:.OptimizationProblem with a

Returns:
indexed by the names of the output,
e.g. "output_name" for a mono-dimensional output,
or "output_name_i" for the i-th component of a multidimensional output.
"""
try:
design_value,
no_db_no_norm=True,
eval_jac=True,
eval_observables=False,
normalize=False,
)
except NotImplementedError:
LOGGER.info(
"The missing gradients for an OptimizationProblem without "
"callable functions cannot be computed."
)

function_names = self.opt_problem.get_all_function_name()
for function_name in function_names:
else:
)
continue

)

def __generate_subplots(
self,
design_names: Iterable[str],
design_value: ndarray,
) -> Figure:
"""Generate the gradients subplots from the data.

Args:
design_names: The names of the design variables.
design_value: The reference value for x.
w.r.t. the design variables.

Returns:

Raises:
ValueError: If gradients is empty.
"""
msg = "No gradients to plot at current iteration."
raise ValueError(msg)

n_cols = 2

fig, axes = pyplot.subplots(
nrows=n_rows, ncols=n_cols, sharex=True, figsize=self.DEFAULT_FIG_SIZE
)

axes = atleast_2d(axes)
abscissa = arange(len(design_value))
if self._change_obj:

i = j = 0
font_size = 12
rotation = 90
axe = axes[i][j]
axe.bar(
abscissa,
align="center",
)
axe.grid()
axe.set_axisbelow(True)
axe.set_title(output_name)
axe.set_xticks(abscissa)
axe.set_xticklabels(design_names, fontsize=font_size, rotation=rotation)
# Update y labels spacing
vis_labels = [
label for label in axe.get_yticklabels() if label.get_visible() is True
]
pyplot.setp(vis_labels[::2], visible=False)
if j == n_cols - 1:
j = 0
i += 1
else:
j += 1

if j == n_cols - 1:
axe = axes[i][j]
axe.set_xticks(abscissa)
axe.set_xticklabels(design_names, fontsize=font_size, rotation=rotation)

title = (
"Derivatives of objective and constraints with respect to design variables"
)