From 4e9ebede3fa227815da8dd342bfbdf1ee19df2f8 Mon Sep 17 00:00:00 2001 From: Joseph Capriotti Date: Thu, 10 Oct 2024 11:18:57 -0600 Subject: [PATCH] triangle tests --- pymatsolver/__init__.py | 3 ++- pymatsolver/solvers.py | 6 +++--- tests/test_Triangle.py | 16 ++++++++++++---- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/pymatsolver/__init__.py b/pymatsolver/__init__.py index a7d8e9d..9c9dfdb 100644 --- a/pymatsolver/__init__.py +++ b/pymatsolver/__init__.py @@ -24,6 +24,7 @@ .. autosummary:: :toctree: generated/ + Triangle Forward Backward @@ -60,7 +61,7 @@ } # Simple solvers -from .solvers import Diagonal, Forward, Backward +from .solvers import Diagonal, Triangle, Forward, Backward from .wrappers import WrapDirect from .wrappers import WrapIterative diff --git a/pymatsolver/solvers.py b/pymatsolver/solvers.py index 5c2ae3d..72accd8 100644 --- a/pymatsolver/solvers.py +++ b/pymatsolver/solvers.py @@ -455,7 +455,7 @@ def _solve_multiple(self, rhs): return rhs / self._diagonal[:, None] -class TriangularSolver(Base): +class Triangle(Base): """A solver for a diagonal matrix. Parameters @@ -512,7 +512,7 @@ def transpose(self): return trans -class Forward(TriangularSolver): +class Forward(Triangle): """A solver for a lower triangular matrix. Parameters @@ -538,7 +538,7 @@ def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, accur super().__init__(A, lower=True, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs) -class Backward(TriangularSolver): +class Backward(Triangle): """A solver for ann upper triangular matrix. Parameters diff --git a/tests/test_Triangle.py b/tests/test_Triangle.py index 8befb63..727ffa1 100644 --- a/tests/test_Triangle.py +++ b/tests/test_Triangle.py @@ -6,19 +6,27 @@ TOL = 1e-12 -@pytest.mark.parametrize("solver", [pymatsolver.Forward, pymatsolver.Backward]) -def test_solve(solver): +@pytest.mark.parametrize("solver", [pymatsolver.Triangle, pymatsolver.Forward, pymatsolver.Backward]) +@pytest.mark.parametrize("transpose", [True, False]) +def test_solve(solver, transpose): n = 50 nrhs = 20 A = sp.rand(n, n, 0.4) + sp.identity(n) sol = np.ones((n, nrhs)) if solver is pymatsolver.Backward: A = sp.triu(A) + lower = False else: A = sp.tril(A) - rhs = A @ sol + lower = True + + if transpose: + rhs = A.T @ sol + Ainv = solver(A, lower=lower).T + else: + rhs = A @ sol + Ainv = solver(A, lower=lower) - Ainv = solver(A) npt.assert_allclose(Ainv * rhs, sol, atol=TOL) npt.assert_allclose(Ainv * rhs[:, 0], sol[:, 0], atol=TOL)