Source code for gemseo.algos.linear_solvers.base_linear_solver_library
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Contributors:
# INITIAL AUTHORS - API and implementation and/or documentation
# :author: Francois Gallard
# OTHER AUTHORS - MACROSCOPIC CHANGES
"""Base class for libraries of linear solvers."""
from __future__ import annotations
import logging
import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from uuid import uuid4
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import LinearOperator
from scipy.sparse.linalg import spilu
from gemseo.algos._unsuitability_reason import _UnsuitabilityReason
from gemseo.algos.base_algorithm_library import AlgorithmDescription
from gemseo.algos.base_algorithm_library import BaseAlgorithmLibrary
from gemseo.algos.linear_solvers.base_linear_solver_settings import (
BaseLinearSolverSettings,
)
if TYPE_CHECKING:
from numpy import ndarray
from gemseo.algos.linear_solvers.linear_problem import LinearProblem
from gemseo.typing import SparseOrDenseRealArray
LOGGER = logging.getLogger(__name__)
[docs]
@dataclass
class LinearSolverDescription(AlgorithmDescription):
"""The description of a linear solver."""
lhs_must_be_symmetric: bool = False
"""Whether the left-hand side matrix must be symmetric."""
lhs_must_be_positive_definite: bool = False
"""Whether the left-hand side matrix must be positive definite."""
lhs_must_be_linear_operator: bool = False
"""Whether the left-hand side matrix must be a linear operator."""
Settings: type[BaseLinearSolverSettings] = BaseLinearSolverSettings
"""The linear solver libraries settings."""
[docs]
class BaseLinearSolverLibrary(BaseAlgorithmLibrary):
"""Base class for libraries of linear solvers."""
file_path: Path
"""The file path to save the linear problem after an execution."""
_problem: LinearProblem
"""The linear problem to solve."""
def __init__(self, algo_name: str) -> None: # noqa:D107
super().__init__(algo_name)
self.file_path = Path("linear_system.pck")
@staticmethod
def _build_ilu_preconditioner(
lhs: SparseOrDenseRealArray,
) -> LinearOperator:
"""Construct a preconditioner using an incomplete LU factorization.
Args:
lhs: The linear system matrix.
Returns:
The preconditioner operator.
"""
ilu = spilu(csc_matrix(lhs))
return LinearOperator(shape=lhs.shape, dtype=lhs.dtype, matvec=ilu.solve)
@classmethod
def _get_unsuitability_reason(
cls,
algorithm_description: LinearSolverDescription,
problem: LinearProblem,
) -> _UnsuitabilityReason:
reason = super()._get_unsuitability_reason(algorithm_description, problem)
if reason:
return reason
if not problem.is_symmetric and algorithm_description.lhs_must_be_symmetric:
return _UnsuitabilityReason.NOT_SYMMETRIC
if (
not problem.is_positive_def
and algorithm_description.lhs_must_be_positive_definite
):
return _UnsuitabilityReason.NOT_POSITIVE_DEFINITE
if (
problem.is_lhs_linear_operator
and not algorithm_description.lhs_must_be_linear_operator
):
return _UnsuitabilityReason.NOT_LINEAR_OPERATOR
return reason
def _pre_run(
self,
problem: LinearProblem,
**settings: Any,
) -> None:
problem.solver_name = self._algo_name
def _post_run(
self,
problem: LinearProblem,
result: ndarray,
**settings: Any,
) -> None:
if not problem.is_converged:
LOGGER.warning(
"The linear solver %s did not converge.", problem.solver_name
)
# If the save_when_fail option is True, save the LinearProblem to the disk when
# the system failed and print the file name in the warnings.
if settings["save_when_fail"] and not problem.is_converged:
file_path = Path(f"linear_system_{uuid4()}.pck")
with file_path.open("wb") as stream:
pickle.dump(problem, stream)
LOGGER.warning(
"Linear solver failed, saving problem to file: %s", file_path
)
self.file_path = file_path