Skip to content

Commit

Permalink
triangle tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jcapriot committed Oct 10, 2024
1 parent 9305f05 commit 4e9ebed
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
3 changes: 2 additions & 1 deletion pymatsolver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
.. autosummary::
:toctree: generated/
Triangle
Forward
Backward
Expand Down Expand Up @@ -60,7 +61,7 @@
}

# Simple solvers
from .solvers import Diagonal, Forward, Backward
from .solvers import Diagonal, Triangle, Forward, Backward
from .wrappers import WrapDirect
from .wrappers import WrapIterative

Expand Down
6 changes: 3 additions & 3 deletions pymatsolver/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _solve_multiple(self, rhs):
return rhs / self._diagonal[:, None]


class TriangularSolver(Base):
class Triangle(Base):
"""A solver for a diagonal matrix.
Parameters
Expand Down Expand Up @@ -512,7 +512,7 @@ def transpose(self):
return trans


class Forward(TriangularSolver):
class Forward(Triangle):
"""A solver for a lower triangular matrix.
Parameters
Expand All @@ -538,7 +538,7 @@ def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, accur
super().__init__(A, lower=True, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)


class Backward(TriangularSolver):
class Backward(Triangle):
"""A solver for ann upper triangular matrix.
Parameters
Expand Down
16 changes: 12 additions & 4 deletions tests/test_Triangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,27 @@

TOL = 1e-12

@pytest.mark.parametrize("solver", [pymatsolver.Forward, pymatsolver.Backward])
def test_solve(solver):
@pytest.mark.parametrize("solver", [pymatsolver.Triangle, pymatsolver.Forward, pymatsolver.Backward])
@pytest.mark.parametrize("transpose", [True, False])
def test_solve(solver, transpose):
n = 50
nrhs = 20
A = sp.rand(n, n, 0.4) + sp.identity(n)
sol = np.ones((n, nrhs))
if solver is pymatsolver.Backward:
A = sp.triu(A)
lower = False
else:
A = sp.tril(A)
rhs = A @ sol
lower = True

if transpose:
rhs = A.T @ sol
Ainv = solver(A, lower=lower).T
else:
rhs = A @ sol
Ainv = solver(A, lower=lower)

Ainv = solver(A)
npt.assert_allclose(Ainv * rhs, sol, atol=TOL)
npt.assert_allclose(Ainv * rhs[:, 0], sol[:, 0], atol=TOL)

Expand Down

0 comments on commit 4e9ebed

Please sign in to comment.