From 581f65a1115a0fe2795dbc1f628f05873f2f8766 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 10 Jan 2025 19:36:47 +0100 Subject: [PATCH] Adapt to Solve changes in Scipy 1.15 1. Use actual Solve Op to infer output dtype as CholSolve outputs a different dtype than basic Solve in Scipy==1.15 2. Tweaked test related to https://github.com/pymc-devs/pytensor/issues/1152 3. Tweak tolerage --- pytensor/tensor/slinalg.py | 7 ++++--- tests/tensor/test_blockwise.py | 2 +- tests/tensor/test_slinalg.py | 22 ++++++++++++++-------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 4904259d25..325567918a 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -259,9 +259,10 @@ def make_node(self, A, b): raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.") # Infer dtype by solving the most simple case with 1x1 matrices - o_dtype = scipy.linalg.solve( - np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype) - ).dtype + inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)] + out_arr = [[None]] + self.perform(None, inp_arr, out_arr) + o_dtype = out_arr[0][0].dtype x = tensor(dtype=o_dtype, shape=b.type.shape) return Apply(self, [A, b], [x]) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 8ce40d48ef..51862562ac 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -590,7 +590,7 @@ def core_scipy_fn(A, b): A_val_copy, b_val_copy ) np.testing.assert_allclose( - out, expected_out, atol=1e-5 if config.floatX == "float32" else 0 + out, expected_out, atol=1e-4 if config.floatX == "float32" else 0 ) # Confirm input was destroyed diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 3d4b6697b8..f46d771938 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -169,7 +169,12 @@ def test_eigvalsh_grad(): ) -class TestSolveBase(utt.InferShapeTester): +class TestSolveBase: + class SolveTest(SolveBase): + def perform(self, node, inputs, outputs): + A, b = inputs + outputs[0][0] = scipy.linalg.solve(A, b) + @pytest.mark.parametrize( "A_func, b_func, error_message", [ @@ -191,16 +196,16 @@ def test_make_node(self, A_func, b_func, error_message): with pytest.raises(ValueError, match=error_message): A = A_func() b = b_func() - SolveBase(b_ndim=2)(A, b) + self.SolveTest(b_ndim=2)(A, b) def test__repr__(self): np.random.default_rng(utt.fetch_seed()) A = matrix() b = matrix() - y = SolveBase(b_ndim=2)(A, b) + y = self.SolveTest(b_ndim=2)(A, b) assert ( y.__repr__() - == "SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0" + == "SolveTest{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0" ) @@ -239,8 +244,9 @@ def test_correctness(self): A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) A_val = np.dot(A_val.transpose(), A_val) - assert np.allclose( - scipy.linalg.solve(A_val, b_val), gen_solve_func(A_val, b_val) + np.testing.assert_allclose( + scipy.linalg.solve(A_val, b_val, assume_a="gen"), + gen_solve_func(A_val, b_val), ) A_undef = np.array( @@ -253,7 +259,7 @@ def test_correctness(self): ], dtype=config.floatX, ) - assert np.allclose( + np.testing.assert_allclose( scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val) ) @@ -450,7 +456,7 @@ def test_solve_dtype(self): fn = function([A, b], x) x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype)) - assert x.dtype == x_result.dtype + assert x.dtype == x_result.dtype, (A_dtype, b_dtype) def test_cho_solve():