diff --git a/pymatsolver/solvers.py b/pymatsolver/solvers.py index e4cba32..3c403b4 100644 --- a/pymatsolver/solvers.py +++ b/pymatsolver/solvers.py @@ -40,6 +40,9 @@ class Base(ABC): Extra keyword arguments. If there are any left here a warning will be raised. """ + __numpy_ufunc__ = True + __array_ufunc__ = None + def __init__( self, A, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs ): diff --git a/tests/test_Basic.py b/tests/test_Basic.py index faf287a..125b1a8 100644 --- a/tests/test_Basic.py +++ b/tests/test_Basic.py @@ -55,6 +55,9 @@ def test_basic_solve(): npt.assert_equal(Ainv @ rhs2d, rhs2d) npt.assert_equal(Ainv @ rhs3d, rhs3d) + npt.assert_equal(rhs @ Ainv, rhs) + npt.assert_equal(rhs.T * Ainv, rhs) + # use Diagonal solver as a concrete instance of the Base to test for some errors @@ -86,4 +89,16 @@ def test_errors_and_warnings(): IdentitySolver(np.full((4, 4), 1), check_rtol=0.0) with pytest.raises(ValueError, match="check_atol must.*"): - IdentitySolver(np.full((4, 4), 1), check_atol=-1.0) \ No newline at end of file + IdentitySolver(np.full((4, 4), 1), check_atol=-1.0) + + with pytest.raises(ValueError, match="Expected a vector of length.*"): + Ainv = IdentitySolver(np.eye(4, 4)) + Ainv @ np.ones(3) + + with pytest.raises(ValueError, match="Second to last dimension should be.*"): + Ainv = IdentitySolver(np.eye(4, 4)) + Ainv @ np.ones((3, 2)) + + with pytest.warns(FutureWarning, match="In Future pymatsolver v0.4.0, passing a vector.*"): + Ainv = IdentitySolver(np.eye(4, 4)) + Ainv @ np.ones((4, 1)) \ No newline at end of file