Source code for gemseo.algos.linear_solvers.linear_problem

# Copyright 2021 IRT Saint Exupéry,
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
#    INITIAL AUTHORS - API and implementation and/or documentation
#        :author: Francois Gallard
"""Linear equations problem."""

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

import matplotlib.pyplot as plt
from numpy.linalg import norm
from scipy.sparse.linalg import LinearOperator

from gemseo.algos.base_problem import BaseProblem

    from matplotlib.figure import Figure
    from numpy import ndarray

    from gemseo.utils.compatibility.scipy import ArrayType
    from gemseo.utils.compatibility.scipy import SparseArrayType

[docs] class LinearProblem(BaseProblem): r"""Representation of the linear equations' system :math:`Ax = b`. It also contains the solution, and some properties of the system such as the symmetry or positive definiteness. """ rhs: ndarray """The right-hand side of the equation.""" lhs: LinearOperator | SparseArrayType """The left-hand side of the equation. If ``None``, the problem can't be solved and the user has to set it after init. """ solution: ndarray """The current solution of the problem.""" is_converged: bool """If the solution is_converged.""" convergence_info: int | str """The information provided by the solver if convergence occurred or not.""" is_symmetric: bool """Whether the LHS is symmetric.""" is_positive_def: bool """Whether the LHS is positive definite.""" is_lhs_linear_operator: bool """Whether the LHS is symmetric.""" solver_options: dict[str, Any] """The options passed to the solver.""" solver_name: str """The solver name.""" residuals_history: list[float] """The convergence history of residuals.""" def __init__( self, lhs: ArrayType | LinearOperator, rhs: ndarray | None = None, solution: ndarray | None = None, is_symmetric: bool = False, is_positive_def: bool = False, is_converged: bool | None = None, ) -> None: """ Args: lhs: The left-hand side (matrix or linear operator) of the problem. rhs: The right-hand side (vector) of the problem. solution: The current solution. is_symmetric: Whether to assume that the LHS is symmetric. is_positive_def: Whether to assume that the LHS is positive definite. is_converged: Whether the solution is converged to the specified tolerance. If ``False``, the algorithm stopped before convergence. If ``None``, no run was performed. """ # noqa: D205, D212, D415 self.rhs = rhs self.lhs = lhs self.solution = solution self.is_converged = is_converged self.convergence_info = None self.is_symmetric = is_symmetric self.is_positive_def = is_positive_def if isinstance(lhs, LinearOperator): self.is_lhs_linear_operator = True else: self.is_lhs_linear_operator = False self.solver_options = None self.solver_name = None self.residuals_history = None
[docs] def compute_residuals( self, relative_residuals: bool = True, store: bool = False, current_x=None, ) -> ndarray: """Compute the L2 norm of the residuals of the problem. Args: relative_residuals: If ``True``, return norm(lhs.solution-rhs)/norm(rhs), else return norm(lhs.solution-rhs). store: Whether to store the residuals value in the residuals_history attribute. current_x: Compute the residuals associated with current_x, If ``None``, compute then from the solution attribute. Returns: The residuals value. Raises: ValueError: If :attr:`.solution` is ``None`` and ``current_x`` is ``None``. """ if self.rhs is None: raise ValueError("Missing RHS.") if current_x is None: current_x = self.solution if self.solution is None: raise ValueError("Missing solution.") res = norm( - self.rhs) if relative_residuals: res /= norm(self.rhs) if store: if self.residuals_history is None: self.residuals_history = [] self.residuals_history.append(res) return res
[docs] def plot_residuals(self) -> Figure: """Plot the residuals' convergence in log scale. Returns: The matplotlib figure. Raises: ValueError: When the residuals' history is empty. """ if self.residuals_history is None or len(self.residuals_history) == 0: raise ValueError( "Residuals history is empty. " " Use the 'store_residuals' option for the solver." ) fig = plt.figure(figsize=(11.0, 6.0)) plt.plot(self.residuals_history, color="black", lw=2) ax1 = fig.gca() ax1.set_yscale("log") ax1.set_title(f"Linear solver '{self.solver_name}' convergence") ax1.set_ylabel("Residuals norm (log)") ax1.set_xlabel("Iterations") return fig
[docs] def check(self) -> None: """Check the consistency of the dimensions of the LHS and RHS. Raises: ValueError: When the shapes are inconsistent. """ lhs_shape = self.lhs.shape rhs_shape = self.rhs.shape if ( (len(lhs_shape) != 2) or (lhs_shape[0] != rhs_shape[0]) or (len(rhs_shape) != 1 and rhs_shape[-1] != 1) ): raise ValueError( "Incompatible dimensions in linear system Ax=b," " A shape is %s and b shape is %s", self.lhs.shape, self.rhs.shape, )