From 4e5345fc8f295cc156ca431db0e056afecd69417 Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Sat, 29 Jul 2023 08:44:42 -0400 Subject: [PATCH] rewrote psd_solve to use cho_solve --- dynamax/utils/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dynamax/utils/utils.py b/dynamax/utils/utils.py index 018f4c1c..b78d3421 100644 --- a/dynamax/utils/utils.py +++ b/dynamax/utils/utils.py @@ -9,6 +9,7 @@ from jaxtyping import Array, Int from scipy.optimize import linear_sum_assignment from typing import Optional +from jax.scipy.linalg import cho_factor, cho_solve def has_tpu(): try: @@ -198,10 +199,12 @@ def find_permutation( return perm -def psd_solve(A,b): +def psd_solve(A, b, diagonal_boost=1e-6): """A wrapper for coordinating the linalg solvers used in the library for psd matrices.""" - A = A + 1e-6 - return jnp.linalg.solve(A,b) + A = symmetrize(A) + diagonal_boost * jnp.eye(A.shape[-1]) + L, lower = cho_factor(A, lower=True) + x = cho_solve((L, lower), b) + return x def symmetrize(A): """Symmetrize one or more matrices."""