Skip to content

Commit

Permalink
rewrote psd_solve to use cho_solve
Browse files Browse the repository at this point in the history
  • Loading branch information
calebweinreb committed Jul 29, 2023
1 parent d19e92a commit 4e5345f
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions dynamax/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jaxtyping import Array, Int
from scipy.optimize import linear_sum_assignment
from typing import Optional
from jax.scipy.linalg import cho_factor, cho_solve

def has_tpu():
try:
Expand Down Expand Up @@ -198,10 +199,12 @@ def find_permutation(
return perm


def psd_solve(A,b):
def psd_solve(A, b, diagonal_boost=1e-6):
"""A wrapper for coordinating the linalg solvers used in the library for psd matrices."""
A = A + 1e-6
return jnp.linalg.solve(A,b)
A = symmetrize(A) + diagonal_boost * jnp.eye(A.shape[-1])
L, lower = cho_factor(A, lower=True)
x = cho_solve((L, lower), b)
return x

def symmetrize(A):
"""Symmetrize one or more matrices."""
Expand Down

0 comments on commit 4e5345f

Please sign in to comment.