From e3f78b729531678650f7ad1e282be405958f5e96 Mon Sep 17 00:00:00 2001 From: Joseph Capriotti Date: Thu, 26 Sep 2024 15:43:09 -0600 Subject: [PATCH] update pardiso symmetric accuracy test and add more descriptive Error class --- pymatsolver/__init__.py | 2 ++ pymatsolver/solvers.py | 8 ++++++-- tests/test_Pardiso.py | 30 +++++++++--------------------- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/pymatsolver/__init__.py b/pymatsolver/__init__.py index a144b9f..a7d8e9d 100644 --- a/pymatsolver/__init__.py +++ b/pymatsolver/__init__.py @@ -73,6 +73,8 @@ from .direct import Solver from .direct import SolverLU +from .solvers import PymatsolverAccuracyError + BicgJacobi = BiCGJacobi # backwards compatibility try: diff --git a/pymatsolver/solvers.py b/pymatsolver/solvers.py index 890e508..8ee6e75 100644 --- a/pymatsolver/solvers.py +++ b/pymatsolver/solvers.py @@ -2,6 +2,10 @@ import warnings +class PymatsolverAccuracyError(Exception): + pass + + class Base(): _accuracy_tol = 1e-6 _check_accuracy = False @@ -57,7 +61,7 @@ def _transposeClass(self): def T(self): "The transpose operator for this class" if self._transposeClass is None: - raise Exception( + raise NotImplementedError( 'The transpose for the {} class is not possible.'.format( self.__name__ ) @@ -74,7 +78,7 @@ def _compute_accuracy(self, rhs, x): msg = 'Accuracy on solve is above tolerance: {0:e} > {1:e}'.format( nrm, self.accuracy_tol ) - raise Exception(msg) + raise PymatsolverAccuracyError(msg) def _solve(self, rhs): diff --git a/tests/test_Pardiso.py b/tests/test_Pardiso.py index e51761d..d8e5f79 100644 --- a/tests/test_Pardiso.py +++ b/tests/test_Pardiso.py @@ -84,27 +84,15 @@ def test_n_threads(test_mat_data): with pytest.raises(TypeError): Ainv.n_threads = "2" -# class TestPardisoNotSymmetric: -# -# @classmethod -# def setup_class(cls): -# cls.A = A -# cls.rhs = rhs -# cls.sol = sol -# -# def test(self): -# rhs = self.rhs -# sol = self.sol -# Ainv = pymatsolver.Pardiso(self.A, is_symmetric=True, check_accuracy=True) -# with pytest.raises(Exception): -# Ainv * rhs -# Ainv.clean() -# -# Ainv = pymatsolver.Pardiso(self.A) -# for i in range(3): -# assert np.linalg.norm(Ainv * rhs[:, i] - sol[:, i]) < TOL -# assert np.linalg.norm(Ainv * rhs - sol, np.inf) < TOL -# Ainv.clean() +def test_inacurrate_symmetry(test_mat_data): + A, rhs, sol = test_mat_data + # make A not symmetric + D = sp.diags(np.linspace(2, 3, A.shape[0])) + A = A @ D + Ainv = pymatsolver.Pardiso(A, is_symmetric=True, check_accuracy=True) + with pytest.raises(pymatsolver.PymatsolverAccuracyError): + Ainv * rhs + def test_pardiso_fdem():