Skip to content

Commit

Permalink
confirm status
Browse files Browse the repository at this point in the history
  • Loading branch information
qiyang-ustc committed Apr 12, 2024
1 parent 0ca9055 commit ce3b9c7
Show file tree
Hide file tree
Showing 17 changed files with 37 additions and 4,333 deletions.
6 changes: 6 additions & 0 deletions eigensolver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
__all__ = ['eigs']

from .jitted_functions import *
from .jax_backend import *

eigs = JaxBackend().eigs
15 changes: 15 additions & 0 deletions eigensolver/cpu_eig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import jax
import jax.numpy as jnp
import numpy as np

__all__ = ["cpu_eig"]

def cpu_eig_host(H):
res = np.linalg.eig(H)
print(res)
return res

def cpu_eig(H):
result_shape = (jax.ShapeDtypeStruct(H.shape[0:1], H.dtype),
jax.ShapeDtypeStruct(H.shape, H.dtype))
return jax.pure_callback(cpu_eig_host, result_shape, H)
1 change: 1 addition & 0 deletions eigensolver/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['JaxBackend']

from typing import Any, Optional, Tuple, Callable, List, Text, Type, Sequence
from typing import Union
Expand Down
7 changes: 4 additions & 3 deletions eigensolver/jitted_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import types
import numpy as np
Tensor = Any
from .cpu_eig import cpu_eig

def _iterative_classical_gram_schmidt(jax: types.ModuleType) -> Callable:

Expand Down Expand Up @@ -642,7 +643,7 @@ def _check_eigvals_convergence_eig(jax):
@functools.partial(jax.jit, static_argnums=(2, 3))
def check_eigvals_convergence(beta_m: float, Hm: jax.Array,
tol: float, numeig: int) -> bool:
eigvals, eigvecs = jax.numpy.linalg.eig(Hm)
eigvals, eigvecs = cpu_eig(Hm)
# TODO (mganahl) confirm that this is a valid matrix norm)
Hm_norm = jax.numpy.linalg.norm(Hm)
thresh = jax.numpy.maximum(
Expand Down Expand Up @@ -793,7 +794,7 @@ def implicitly_restarted_arnoldi_method(

def outer_loop(carry):
Hm, Vm, fm, it, numits, ar_converged, _, _, = carry
evals, _ = jax.numpy.linalg.eig(Hm)
evals, _ = cpu_eig(Hm)
shifts, _ = sort_fun(evals)
# perform shifted QR iterations to compress arnoldi factorization
# Note that ||fk|| typically decreases as one iterates the outer loop
Expand Down Expand Up @@ -861,7 +862,7 @@ def cond_fun(carry):

Hm = (numits > jax.numpy.arange(num_krylov_vecs))[:, None] * Hm * (
numits > jax.numpy.arange(num_krylov_vecs))[None, :]
eigvals, U = jax.numpy.linalg.eig(Hm)
eigvals, U = cpu_eig(Hm)
inds = sort_fun(eigvals)[1][:numeig]
vectors = get_vectors(Vm, U, inds, numeig)
return eigvals[inds], [
Expand Down
27 changes: 0 additions & 27 deletions jax_eigensolver/.github/workflows/makefile.yml

This file was deleted.

160 changes: 0 additions & 160 deletions jax_eigensolver/.gitignore

This file was deleted.

Loading

0 comments on commit ce3b9c7

Please sign in to comment.