Source code for gemseo.utils.linear_solver

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# Contributors:
#    INITIAL AUTHORS - API and implementation and/or documentation
#        :author: Francois Gallard, Charlie Vanaret
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
"""
Linear solvers wrapper
**********************
"""
from __future__ import annotations

import logging

import numpy as np
import scipy.sparse.linalg as scipy_linalg
from scipy.sparse.base import issparse
from scipy.sparse.linalg import bicgstab
from scipy.sparse.linalg import cgs

LOGGER = logging.getLogger(__name__)


[docs]class LinearSolver: """Solve a linear system Ax=b.""" LGMRES = "lgmres" AVAILABLE_SOLVERS = {LGMRES: scipy_linalg.lgmres} def __init__(self): """Constructor.""" self.outer_v = [] # Used to store (v,Av) pairs for restart and multiple RHS @staticmethod def _check_linear_solver(linear_solver): """Check that linear solver is available. Args: linear_solver: The name of the linear solver to solve the linear problem. """ solver = LinearSolver.AVAILABLE_SOLVERS.get(linear_solver, None) if solver is None: raise AttributeError( "Invalid linear solver" + str(linear_solver) + " for scipy sparse: " ) return solver @staticmethod def _check_b(a_mat, b_vec): """Check the dimensions of the vector b and convert it to ndarray if sparse. For lgmres needs. Args: a_mat: The matrix A. b_vec: The vector b. Returns: The vector b with consistent dimensions. """ if len(b_vec.shape) == 2 and b_vec.shape[1] != 1: LOGGER.error( "Incompatible dimensions in linear system Ax=b, A " "shape is %s and b shape is %s", str(a_mat.shape), str(b_vec.shape), ) raise ValueError( "Second member of the linear system" + " must be a column vector" ) if issparse(b_vec): b_vec = b_vec.toarray() return b_vec.real
[docs] def solve(self, a_mat, b_vec, linear_solver="lgmres", **options): """Solve the linear system :math:`Ax=b`. Args: a_mat: The matrix :math:`A` of the system, can be a sparse matrix. b_vec: The second member :math:`b` of the system. linear_solver: The name of linear solver. **options: The options of the linear solver. Returns: The solution :math:`x` such that :math:`Ax=b`. """ scipy_linear_solver = LinearSolver._check_linear_solver(linear_solver) # check the dimensions of b b_vec = LinearSolver._check_b(a_mat, b_vec) # solve the system if "tol" not in options: options["tol"] = 1e-8 options["atol"] = options["tol"] if "maxiter" not in options: options["maxiter"] = 50 * len(b_vec) else: options["maxiter"] = min(options["maxiter"], 50 * len(b_vec)) sol, info = scipy_linear_solver( A=a_mat, b=b_vec, outer_v=self.outer_v, **options ) base_msg = "scipy linear solver algorithm stop info: " if info > 0: msg = "convergence to tolerance not achieved, number of iterations" total_msg = base_msg + msg LOGGER.warning(total_msg) total_msg = base_msg + "--- trying bicgstab method" LOGGER.warning(total_msg) sol, info = bicgstab( a_mat, b_vec, sol, maxiter=50 * len(b_vec), atol=options["atol"] ) diff = a_mat.dot(sol) - b_vec.T res = np.sqrt(np.sum(diff)) total_msg = f"{base_msg} --- --- residual = {res}" LOGGER.warning(total_msg) total_msg = f"{base_msg} --- --- info = {info}" LOGGER.warning(total_msg) if info < 0: total_msg = f"{base_msg} --- trying cgs method" LOGGER.warning(total_msg) sol, info = cgs( a_mat, b_vec, sol, maxiter=50 * len(b_vec), atol=options["atol"] ) diff = a_mat.dot(sol) - b_vec.T res = np.sqrt(np.sum(diff)) total_msg = f"{base_msg} --- --- residual = {res}" LOGGER.warning(total_msg) total_msg = f"{base_msg} --- --- info = {info}" LOGGER.warning(total_msg) elif info < 0: msg = "illegal input or breakdown" total_msg = base_msg + msg LOGGER.error(total_msg) return np.atleast_2d(sol).T