diff --git a/dynamax/utils/utils.py b/dynamax/utils/utils.py index b78d3421..e75d6f01 100644 --- a/dynamax/utils/utils.py +++ b/dynamax/utils/utils.py @@ -199,7 +199,7 @@ def find_permutation( return perm -def psd_solve(A, b, diagonal_boost=1e-6): +def psd_solve(A, b, diagonal_boost=1e-9): """A wrapper for coordinating the linalg solvers used in the library for psd matrices.""" A = symmetrize(A) + diagonal_boost * jnp.eye(A.shape[-1]) L, lower = cho_factor(A, lower=True)