From ce3b9c72f84d4222c26b2f36b477487a97ab3457 Mon Sep 17 00:00:00 2001 From: Qi Yang Date: Fri, 12 Apr 2024 21:28:41 +0800 Subject: [PATCH] confirm status --- eigensolver/__init__.py | 6 + eigensolver/cpu_eig.py | 15 + eigensolver/jax_backend.py | 1 + eigensolver/jitted_functions.py | 7 +- .../.github/workflows/makefile.yml | 27 - jax_eigensolver/.gitignore | 160 -- jax_eigensolver/LICENSE | 201 --- jax_eigensolver/Makefile | 2 - jax_eigensolver/README.md | 2 - jax_eigensolver/__init__.py | 0 jax_eigensolver/eigensolver/__init__.py | 0 jax_eigensolver/eigensolver/jax_backend.py | 901 ---------- .../eigensolver/jitted_functions.py | 1506 ----------------- jax_eigensolver/test.py | 10 - jax_eigensolver/tests/test_jax_backend.py | 1243 -------------- .../tests/test_jitted_functions.py | 273 --- test.py | 16 +- 17 files changed, 37 insertions(+), 4333 deletions(-) create mode 100644 eigensolver/cpu_eig.py delete mode 100644 jax_eigensolver/.github/workflows/makefile.yml delete mode 100644 jax_eigensolver/.gitignore delete mode 100644 jax_eigensolver/LICENSE delete mode 100644 jax_eigensolver/Makefile delete mode 100644 jax_eigensolver/README.md delete mode 100644 jax_eigensolver/__init__.py delete mode 100644 jax_eigensolver/eigensolver/__init__.py delete mode 100644 jax_eigensolver/eigensolver/jax_backend.py delete mode 100644 jax_eigensolver/eigensolver/jitted_functions.py delete mode 100644 jax_eigensolver/test.py delete mode 100644 jax_eigensolver/tests/test_jax_backend.py delete mode 100644 jax_eigensolver/tests/test_jitted_functions.py diff --git a/eigensolver/__init__.py b/eigensolver/__init__.py index e69de29..488c7e9 100644 --- a/eigensolver/__init__.py +++ b/eigensolver/__init__.py @@ -0,0 +1,6 @@ +__all__ = ['eigs'] + +from .jitted_functions import * +from .jax_backend import * + +eigs = JaxBackend().eigs \ No newline at end of file diff --git a/eigensolver/cpu_eig.py b/eigensolver/cpu_eig.py new file mode 100644 index 0000000..0cad516 --- /dev/null +++ b/eigensolver/cpu_eig.py @@ -0,0 +1,15 @@ +import jax +import jax.numpy as jnp +import numpy as np + +__all__ = ["cpu_eig"] + +def cpu_eig_host(H): + res = np.linalg.eig(H) + print(res) + return res + +def cpu_eig(H): + result_shape = (jax.ShapeDtypeStruct(H.shape[0:1], H.dtype), + jax.ShapeDtypeStruct(H.shape, H.dtype)) + return jax.pure_callback(cpu_eig_host, result_shape, H) \ No newline at end of file diff --git a/eigensolver/jax_backend.py b/eigensolver/jax_backend.py index 63066aa..29ab227 100644 --- a/eigensolver/jax_backend.py +++ b/eigensolver/jax_backend.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +__all__ = ['JaxBackend'] from typing import Any, Optional, Tuple, Callable, List, Text, Type, Sequence from typing import Union diff --git a/eigensolver/jitted_functions.py b/eigensolver/jitted_functions.py index cfa47d6..e35b5ba 100644 --- a/eigensolver/jitted_functions.py +++ b/eigensolver/jitted_functions.py @@ -18,6 +18,7 @@ import types import numpy as np Tensor = Any +from .cpu_eig import cpu_eig def _iterative_classical_gram_schmidt(jax: types.ModuleType) -> Callable: @@ -642,7 +643,7 @@ def _check_eigvals_convergence_eig(jax): @functools.partial(jax.jit, static_argnums=(2, 3)) def check_eigvals_convergence(beta_m: float, Hm: jax.Array, tol: float, numeig: int) -> bool: - eigvals, eigvecs = jax.numpy.linalg.eig(Hm) + eigvals, eigvecs = cpu_eig(Hm) # TODO (mganahl) confirm that this is a valid matrix norm) Hm_norm = jax.numpy.linalg.norm(Hm) thresh = jax.numpy.maximum( @@ -793,7 +794,7 @@ def implicitly_restarted_arnoldi_method( def outer_loop(carry): Hm, Vm, fm, it, numits, ar_converged, _, _, = carry - evals, _ = jax.numpy.linalg.eig(Hm) + evals, _ = cpu_eig(Hm) shifts, _ = sort_fun(evals) # perform shifted QR iterations to compress arnoldi factorization # Note that ||fk|| typically decreases as one iterates the outer loop @@ -861,7 +862,7 @@ def cond_fun(carry): Hm = (numits > jax.numpy.arange(num_krylov_vecs))[:, None] * Hm * ( numits > jax.numpy.arange(num_krylov_vecs))[None, :] - eigvals, U = jax.numpy.linalg.eig(Hm) + eigvals, U = cpu_eig(Hm) inds = sort_fun(eigvals)[1][:numeig] vectors = get_vectors(Vm, U, inds, numeig) return eigvals[inds], [ diff --git a/jax_eigensolver/.github/workflows/makefile.yml b/jax_eigensolver/.github/workflows/makefile.yml deleted file mode 100644 index 8db022e..0000000 --- a/jax_eigensolver/.github/workflows/makefile.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: Makefile CI - -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] - -jobs: - build: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - - name: configure - run: ./configure - - - name: Install dependencies - run: make - - - name: Run check - run: make check - - - name: Run distcheck - run: make distcheck diff --git a/jax_eigensolver/.gitignore b/jax_eigensolver/.gitignore deleted file mode 100644 index 68bc17f..0000000 --- a/jax_eigensolver/.gitignore +++ /dev/null @@ -1,160 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ diff --git a/jax_eigensolver/LICENSE b/jax_eigensolver/LICENSE deleted file mode 100644 index 261eeb9..0000000 --- a/jax_eigensolver/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/jax_eigensolver/Makefile b/jax_eigensolver/Makefile deleted file mode 100644 index 4ec2e59..0000000 --- a/jax_eigensolver/Makefile +++ /dev/null @@ -1,2 +0,0 @@ -test: - python3 -m pytest \ No newline at end of file diff --git a/jax_eigensolver/README.md b/jax_eigensolver/README.md deleted file mode 100644 index cb9bab4..0000000 --- a/jax_eigensolver/README.md +++ /dev/null @@ -1,2 +0,0 @@ -# jax_eigensolver -Rescure useful code in TensorNetwork by google for further personal usage. diff --git a/jax_eigensolver/__init__.py b/jax_eigensolver/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/jax_eigensolver/eigensolver/__init__.py b/jax_eigensolver/eigensolver/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/jax_eigensolver/eigensolver/jax_backend.py b/jax_eigensolver/eigensolver/jax_backend.py deleted file mode 100644 index b4e02d5..0000000 --- a/jax_eigensolver/eigensolver/jax_backend.py +++ /dev/null @@ -1,901 +0,0 @@ -# Copyright 2019 The TensorNetwork Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Optional, Tuple, Callable, List, Text, Type, Sequence -from typing import Union -from tensornetwork.backends import abstract_backend -from tensornetwork.backends.numpy import decompositions -import varpu.jax_eigensolver.eigensolver.jitted_functions as jitted_functions -import numpy as np -from functools import partial -import warnings -import jax.numpy as jnp - -Tensor = Any -# pylint: disable=abstract-method - -_CACHED_MATVECS = {} -_CACHED_FUNCTIONS = {} - -class JaxBackend(abstract_backend.AbstractBackend): - """See abstract_backend.AbstractBackend for documentation.""" - - def __init__(self, dtype: Optional[np.dtype] = None, - precision: Optional[Text] = None) -> None: - # pylint: disable=global-variable-undefined - global libjax # Jax module - global jnp # jax.numpy module - global jsp # jax.scipy module - super().__init__() - try: - #pylint: disable=import-outside-toplevel - import jax - except ImportError as err: - raise ImportError("Jax not installed, please switch to a different " - "backend or install Jax.") from err - libjax = jax - jnp = libjax.numpy - jsp = libjax.scipy - self.name = "jax" - self._dtype = np.dtype(dtype) if dtype is not None else None - self.jax_precision = precision if precision is not None else libjax.lax.Precision.DEFAULT #pylint: disable=line-too-long - - def tensordot(self, a: Tensor, b: Tensor, - axes: Union[int, Sequence[Sequence[int]]]) -> Tensor: - return jnp.tensordot(a, b, axes, precision=self.jax_precision) - - def reshape(self, tensor: Tensor, shape: Tensor) -> Tensor: - return jnp.reshape(tensor, np.asarray(shape).astype(np.int32)) - - def transpose(self, tensor, perm=None) -> Tensor: - return jnp.transpose(tensor, perm) - - def shape_concat(self, values: Tensor, axis: int) -> Tensor: - return np.concatenate(values, axis) - - def slice(self, tensor: Tensor, start_indices: Tuple[int, ...], - slice_sizes: Tuple[int, ...]) -> Tensor: - if len(start_indices) != len(slice_sizes): - raise ValueError("Lengths of start_indices and slice_sizes must be" - "identical.") - return libjax.lax.dynamic_slice(tensor, start_indices, slice_sizes) - - def svd( - self, - tensor: Tensor, - pivot_axis: int = -1, - max_singular_values: Optional[int] = None, - max_truncation_error: Optional[float] = None, - relative: Optional[bool] = False - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - return decompositions.svd( - jnp, - tensor, - pivot_axis, - max_singular_values, - max_truncation_error, - relative=relative) - - def qr( - self, - tensor: Tensor, - pivot_axis: int = -1, - non_negative_diagonal: bool = False - ) -> Tuple[Tensor, Tensor]: - return decompositions.qr(jnp, tensor, pivot_axis, non_negative_diagonal) - - def rq( - self, - tensor: Tensor, - pivot_axis: int = -1, - non_negative_diagonal: bool = False - ) -> Tuple[Tensor, Tensor]: - return decompositions.rq(jnp, tensor, pivot_axis, non_negative_diagonal) - - - def shape_tensor(self, tensor: Tensor) -> Tensor: - return tensor.shape - - def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]: - return tensor.shape - - def sparse_shape(self, tensor: Tensor) -> Tuple[Optional[int], ...]: - return self.shape_tuple(tensor) - - def shape_prod(self, values: Tensor) -> Tensor: - return np.prod(values) - - def sqrt(self, tensor: Tensor) -> Tensor: - return jnp.sqrt(tensor) - - def convert_to_tensor(self, tensor: Tensor) -> Tensor: - if (not isinstance(tensor, (np.ndarray, jnp.ndarray)) - and not jnp.isscalar(tensor)): - raise TypeError(("Expected a `jnp.array`, `np.array` or scalar. " - f"Got {type(tensor)}")) - result = jnp.asarray(tensor) - return result - - def outer_product(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: - return jnp.tensordot(tensor1, tensor2, 0, - precision=self.jax_precision) - - def einsum(self, - expression: str, - *tensors: Tensor, - optimize: bool = True) -> Tensor: - return jnp.einsum(expression, *tensors, optimize=optimize, - precision=self.jax_precision) - - def norm(self, tensor: Tensor) -> Tensor: - return jnp.linalg.norm(tensor) - - def eye(self, - N, - dtype: Optional[np.dtype] = None, - M: Optional[int] = None) -> Tensor: - dtype = dtype if dtype is not None else jnp.float64 - return jnp.eye(N, M=M, dtype=dtype) - - def ones(self, - shape: Tuple[int, ...], - dtype: Optional[np.dtype] = None) -> Tensor: - dtype = dtype if dtype is not None else jnp.float64 - return jnp.ones(shape, dtype=dtype) - - def zeros(self, - shape: Tuple[int, ...], - dtype: Optional[np.dtype] = None) -> Tensor: - dtype = dtype if dtype is not None else jnp.float64 - return jnp.zeros(shape, dtype=dtype) - - def randn(self, - shape: Tuple[int, ...], - dtype: Optional[np.dtype] = None, - seed: Optional[int] = None) -> Tensor: - if not seed: - seed = np.random.randint(0, 2**63) - key = libjax.random.PRNGKey(seed) - - dtype = dtype if dtype is not None else np.dtype(np.float64) - - def cmplx_randn(complex_dtype, real_dtype): - real_dtype = np.dtype(real_dtype) - complex_dtype = np.dtype(complex_dtype) - - key_2 = libjax.random.PRNGKey(seed + 1) - - real_part = libjax.random.normal(key, shape, dtype=real_dtype) - complex_part = libjax.random.normal(key_2, shape, dtype=real_dtype) - unit = ( - np.complex64(1j) - if complex_dtype == np.dtype(np.complex64) else np.complex128(1j)) - return real_part + unit * complex_part - - if np.dtype(dtype) is np.dtype(jnp.complex128): - return cmplx_randn(dtype, jnp.float64) - if np.dtype(dtype) is np.dtype(jnp.complex64): - return cmplx_randn(dtype, jnp.float32) - - return libjax.random.normal(key, shape).astype(dtype) - - def random_uniform(self, - shape: Tuple[int, ...], - boundaries: Optional[Tuple[float, float]] = (0.0, 1.0), - dtype: Optional[np.dtype] = None, - seed: Optional[int] = None) -> Tensor: - if not seed: - seed = np.random.randint(0, 2**63) - key = libjax.random.PRNGKey(seed) - - dtype = dtype if dtype is not None else np.dtype(np.float64) - - def cmplx_random_uniform(complex_dtype, real_dtype): - real_dtype = np.dtype(real_dtype) - complex_dtype = np.dtype(complex_dtype) - - key_2 = libjax.random.PRNGKey(seed + 1) - - real_part = libjax.random.uniform( - key, - shape, - dtype=real_dtype, - minval=boundaries[0], - maxval=boundaries[1]) - complex_part = libjax.random.uniform( - key_2, - shape, - dtype=real_dtype, - minval=boundaries[0], - maxval=boundaries[1]) - unit = ( - np.complex64(1j) - if complex_dtype == np.dtype(np.complex64) else np.complex128(1j)) - return real_part + unit * complex_part - - if np.dtype(dtype) is np.dtype(jnp.complex128): - return cmplx_random_uniform(dtype, jnp.float64) - if np.dtype(dtype) is np.dtype(jnp.complex64): - return cmplx_random_uniform(dtype, jnp.float32) - - return libjax.random.uniform( - key, shape, minval=boundaries[0], maxval=boundaries[1]).astype(dtype) - - def eigs(self, #pylint: disable=arguments-differ - A: Callable, - args: Optional[List] = None, - initial_state: Optional[Tensor] = None, - shape: Optional[Tuple[int, ...]] = None, - dtype: Optional[Type[np.number]] = None, - num_krylov_vecs: int = 50, - numeig: int = 6, - tol: float = 1E-8, - which: Text = 'LR', - maxiter: int = 20) -> Tuple[Tensor, List]: - """ - Implicitly restarted Arnoldi method for finding the lowest - eigenvector-eigenvalue pairs of a linear operator `A`. - `A` is a function implementing the matrix-vector - product. - - WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered - at the first invocation of `eigs`, and on any subsequent calls - if the python `id` of `A` changes, even if the formal definition of `A` - stays the same. - Example: the following will jit once at the beginning, and then never again: - - ```python - import jax - import numpy as np - def A(H,x): - return jax.np.dot(H,x) - for n in range(100): - H = jax.np.array(np.random.rand(10,10)) - x = jax.np.array(np.random.rand(10,10)) - res = eigs(A, [H],x) #jitting is triggerd only at `n=0` - ``` - - The following code triggers jitting at every iteration, which - results in considerably reduced performance - - ```python - import jax - import numpy as np - for n in range(100): - def A(H,x): - return jax.np.dot(H,x) - H = jax.np.array(np.random.rand(10,10)) - x = jax.np.array(np.random.rand(10,10)) - res = eigs(A, [H],x) #jitting is triggerd at every step `n` - ``` - - Args: - A: A (sparse) implementation of a linear operator. - Call signature of `A` is `res = A(vector, *args)`, where `vector` - can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`. - args: A list of arguments to `A`. `A` will be called as - `res = A(initial_state, *args)`. - initial_state: An initial vector for the algorithm. If `None`, - a random initial `Tensor` is created using the `backend.randn` method - shape: The shape of the input-dimension of `A`. - dtype: The dtype of the input `A`. If no `initial_state` is provided, - a random initial state with shape `shape` and dtype `dtype` is created. - num_krylov_vecs: The number of iterations (number of krylov vectors). - numeig: The number of eigenvector-eigenvalue pairs to be computed. - tol: The desired precision of the eigenvalues. For the jax backend - this has currently no effect, and precision of eigenvalues is not - guaranteed. This feature may be added at a later point. To increase - precision the caller can either increase `maxiter` or `num_krylov_vecs`. - which: Flag for targetting different types of eigenvalues. Currently - supported are `which = 'LR'` (larges real part) and `which = 'LM'` - (larges magnitude). - maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes - equivalent to a simple Arnoldi method. - Returns: - (eigvals, eigvecs) - eigvals: A list of `numeig` eigenvalues - eigvecs: A list of `numeig` eigenvectors - """ - - if args is None: - args = [] - if which not in ('LR', 'LM'): - raise ValueError(f'which = {which} is currently not supported.') - - if numeig > num_krylov_vecs: - raise ValueError('`num_krylov_vecs` >= `numeig` required!') - - if initial_state is None: - if (shape is None) or (dtype is None): - raise ValueError("if no `initial_state` is passed, then `shape` and" - "`dtype` have to be provided") - initial_state = self.randn(shape, dtype) - - if not isinstance(initial_state, (jnp.ndarray, np.ndarray)): - raise TypeError("Expected a `jax.array`. Got {}".format( - type(initial_state))) - - if A not in _CACHED_MATVECS: - _CACHED_MATVECS[A] = libjax.tree_util.Partial(libjax.jit(A)) - - if "imp_arnoldi" not in _CACHED_FUNCTIONS: - imp_arnoldi = jitted_functions._implicitly_restarted_arnoldi(libjax) - _CACHED_FUNCTIONS["imp_arnoldi"] = imp_arnoldi - - eta, U, numits = _CACHED_FUNCTIONS["imp_arnoldi"](_CACHED_MATVECS[A], args, - initial_state, - num_krylov_vecs, numeig, - which, tol, maxiter, - self.jax_precision) - # if numeig > numits: - # warnings.warn( - # f"Arnoldi terminated early after numits = {numits}" - # f" < numeig = {numeig} steps. For this value of `numeig `" - # f"the routine will return spurious eigenvalues of value 0.0." - # f"Use a smaller value of numeig, or a smaller value for `tol`") - return eta, U - - def eigsh( - self, #pylint: disable=arguments-differ - A: Callable, - args: Optional[List] = None, - initial_state: Optional[Tensor] = None, - shape: Optional[Tuple[int, ...]] = None, - dtype: Optional[Type[np.number]] = None, - num_krylov_vecs: int = 50, - numeig: int = 6, - tol: float = 1E-8, - which: Text = 'SA', - maxiter: int = 20) -> Tuple[Tensor, List]: - """ - Implicitly restarted Lanczos method for finding the lowest - eigenvector-eigenvalue pairs of a symmetric (hermitian) linear operator `A`. - `A` is a function implementing the matrix-vector - product. - - WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered - at the first invocation of `eigsh`, and on any subsequent calls - if the python `id` of `A` changes, even if the formal definition of `A` - stays the same. - Example: the following will jit once at the beginning, and then never again: - - ```python - import jax - import numpy as np - def A(H,x): - return jax.np.dot(H,x) - for n in range(100): - H = jax.np.array(np.random.rand(10,10)) - x = jax.np.array(np.random.rand(10,10)) - res = eigsh(A, [H],x) #jitting is triggerd only at `n=0` - ``` - - The following code triggers jitting at every iteration, which - results in considerably reduced performance - - ```python - import jax - import numpy as np - for n in range(100): - def A(H,x): - return jax.np.dot(H,x) - H = jax.np.array(np.random.rand(10,10)) - x = jax.np.array(np.random.rand(10,10)) - res = eigsh(A, [H],x) #jitting is triggerd at every step `n` - ``` - - Args: - A: A (sparse) implementation of a linear operator. - Call signature of `A` is `res = A(vector, *args)`, where `vector` - can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`. - arsg: A list of arguments to `A`. `A` will be called as - `res = A(initial_state, *args)`. - initial_state: An initial vector for the algorithm. If `None`, - a random initial `Tensor` is created using the `backend.randn` method - shape: The shape of the input-dimension of `A`. - dtype: The dtype of the input `A`. If no `initial_state` is provided, - a random initial state with shape `shape` and dtype `dtype` is created. - num_krylov_vecs: The number of iterations (number of krylov vectors). - numeig: The number of eigenvector-eigenvalue pairs to be computed. - tol: The desired precision of the eigenvalues. For the jax backend - this has currently no effect, and precision of eigenvalues is not - guaranteed. This feature may be added at a later point. To increase - precision the caller can either increase `maxiter` or `num_krylov_vecs`. - which: Flag for targetting different types of eigenvalues. Currently - supported are `which = 'LR'` (larges real part) and `which = 'LM'` - (larges magnitude). - maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes - equivalent to a simple Arnoldi method. - Returns: - (eigvals, eigvecs) - eigvals: A list of `numeig` eigenvalues - eigvecs: A list of `numeig` eigenvectors - """ - - if args is None: - args = [] - if which not in ('SA', 'LA', 'LM'): - raise ValueError(f'which = {which} is currently not supported.') - - if numeig > num_krylov_vecs: - raise ValueError('`num_krylov_vecs` >= `numeig` required!') - - if initial_state is None: - if (shape is None) or (dtype is None): - raise ValueError("if no `initial_state` is passed, then `shape` and" - "`dtype` have to be provided") - initial_state = self.randn(shape, dtype) - - if not isinstance(initial_state, (jnp.ndarray, np.ndarray)): - raise TypeError("Expected a `jax.array`. Got {}".format( - type(initial_state))) - - if A not in _CACHED_MATVECS: - _CACHED_MATVECS[A] = libjax.tree_util.Partial(libjax.jit(A)) - - if "imp_lanczos" not in _CACHED_FUNCTIONS: - imp_lanczos = jitted_functions._implicitly_restarted_lanczos(libjax) - _CACHED_FUNCTIONS["imp_lanczos"] = imp_lanczos - - eta, U, numits = _CACHED_FUNCTIONS["imp_lanczos"](_CACHED_MATVECS[A], args, - initial_state, - num_krylov_vecs, numeig, - which, tol, maxiter, - self.jax_precision) - if numeig > numits: - warnings.warn( - f"Arnoldi terminated early after numits = {numits}" - f" < numeig = {numeig} steps. For this value of `numeig `" - f"the routine will return spurious eigenvalues of value 0.0." - f"Use a smaller value of numeig, or a smaller value for `tol`") - return eta, U - - def eigsh_lanczos( - self, - A: Callable, - args: Optional[List[Tensor]] = None, - initial_state: Optional[Tensor] = None, - shape: Optional[Tuple] = None, - dtype: Optional[Type[np.number]] = None, - num_krylov_vecs: int = 20, - numeig: int = 1, - tol: float = 1E-8, - delta: float = 1E-8, - ndiag: int = 10, - reorthogonalize: Optional[bool] = False) -> Tuple[Tensor, List]: - """ - Lanczos method for finding the lowest eigenvector-eigenvalue pairs - of a hermitian linear operator `A`. `A` is a function implementing - the matrix-vector product. - WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered - at the first invocation of `eigsh_lanczos`, and on any subsequent calls - if the python `id` of `A` changes, even if the formal definition of `A` - stays the same. - Example: the following will jit once at the beginning, and then never again: - - ```python - import jax - import numpy as np - def A(H,x): - return jax.np.dot(H,x) - for n in range(100): - H = jax.np.array(np.random.rand(10,10)) - x = jax.np.array(np.random.rand(10,10)) - res = eigsh_lanczos(A, [H],x) #jitting is triggerd only at `n=0` - ``` - - The following code triggers jitting at every iteration, which - results in considerably reduced performance - - ```python - import jax - import numpy as np - for n in range(100): - def A(H,x): - return jax.np.dot(H,x) - H = jax.np.array(np.random.rand(10,10)) - x = jax.np.array(np.random.rand(10,10)) - res = eigsh_lanczos(A, [H],x) #jitting is triggerd at every step `n` - ``` - - Args: - A: A (sparse) implementation of a linear operator. - Call signature of `A` is `res = A(vector, *args)`, where `vector` - can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`. - arsg: A list of arguments to `A`. `A` will be called as - `res = A(initial_state, *args)`. - initial_state: An initial vector for the Lanczos algorithm. If `None`, - a random initial `Tensor` is created using the `backend.randn` method - shape: The shape of the input-dimension of `A`. - dtype: The dtype of the input `A`. If no `initial_state` is provided, - a random initial state with shape `shape` and dtype `dtype` is created. - num_krylov_vecs: The number of iterations (number of krylov vectors). - numeig: The number of eigenvector-eigenvalue pairs to be computed. - If `numeig > 1`, `reorthogonalize` has to be `True`. - tol: The desired precision of the eigenvalues. For the jax backend - this has currently no effect, and precision of eigenvalues is not - guaranteed. This feature may be added at a later point. - To increase precision the caller can increase `num_krylov_vecs`. - delta: Stopping criterion for Lanczos iteration. - If a Krylov vector :math: `x_n` has an L2 norm - :math:`\\lVert x_n\\rVert < delta`, the iteration - is stopped. It means that an (approximate) invariant subspace has - been found. - ndiag: The tridiagonal Operator is diagonalized every `ndiag` iterations - to check convergence. This has currently no effect for the jax backend, - but may be added at a later point. - reorthogonalize: If `True`, Krylov vectors are kept orthogonal by - explicit orthogonalization (more costly than `reorthogonalize=False`) - Returns: - (eigvals, eigvecs) - eigvals: A jax-array containing `numeig` lowest eigenvalues - eigvecs: A list of `numeig` lowest eigenvectors - """ - if args is None: - args = [] - if num_krylov_vecs < numeig: - raise ValueError('`num_krylov_vecs` >= `numeig` required!') - - if numeig > 1 and not reorthogonalize: - raise ValueError( - "Got numeig = {} > 1 and `reorthogonalize = False`. " - "Use `reorthogonalize=True` for `numeig > 1`".format(numeig)) - if initial_state is None: - if (shape is None) or (dtype is None): - raise ValueError("if no `initial_state` is passed, then `shape` and" - "`dtype` have to be provided") - initial_state = self.randn(shape, dtype) - - if not isinstance(initial_state, (jnp.ndarray, np.ndarray)): - raise TypeError("Expected a `jax.array`. Got {}".format( - type(initial_state))) - if A not in _CACHED_MATVECS: - _CACHED_MATVECS[A] = libjax.tree_util.Partial(A) - if "eigsh_lanczos" not in _CACHED_FUNCTIONS: - eigsh_lanczos = jitted_functions._generate_jitted_eigsh_lanczos(libjax) - _CACHED_FUNCTIONS["eigsh_lanczos"] = eigsh_lanczos - eigsh_lanczos = _CACHED_FUNCTIONS["eigsh_lanczos"] - eta, U, numits = eigsh_lanczos(_CACHED_MATVECS[A], args, initial_state, - num_krylov_vecs, numeig, delta, - reorthogonalize, self.jax_precision) - if numeig > numits: - warnings.warn( - f"Lanczos terminated early after numits = {numits}" - f" < numeig = {numeig} steps. For this value of `numeig `" - f"the routine will return spurious eigenvalues of value 0.0." - f"Use a smaller value of numeig, or a smaller value for `tol`") - return eta, U - - def _gmres(self, - A_mv: Callable, - b: Tensor, - A_args: List, - A_kwargs: dict, - x0: Tensor, - tol: float, - atol: float, - num_krylov_vectors: int, - maxiter: int, - M: Optional[Callable] = None) -> Tuple[Tensor, int]: - """ GMRES solves the linear system A @ x = b for x given a vector `b` and - a general (not necessarily symmetric/Hermitian) linear operator `A`. - - As a Krylov method, GMRES does not require a concrete matrix representation - of the n by n `A`, but only a function - `vector1 = A_mv(vector0, *A_args, **A_kwargs)` - prescribing a one-to-one linear map from vector0 to vector1 (that is, - A must be square, and thus vector0 and vector1 the same size). If `A` is a - dense matrix, or if it is a symmetric/Hermitian operator, a different - linear solver will usually be preferable. - - GMRES works by first constructing the Krylov basis - K = (x0, A_mv@x0, A_mv@A_mv@x0, ..., (A_mv^num_krylov_vectors)@x_0) and then - solving a certain dense linear system K @ q0 = q1 from whose solution x can - be approximated. For `num_krylov_vectors = n` the solution is provably exact - in infinite precision, but the expense is cubic in `num_krylov_vectors` so - one is typically interested in the `num_krylov_vectors << n` case. - The solution can in this case be repeatedly - improved, to a point, by restarting the Arnoldi iterations each time - `num_krylov_vectors` is reached. Unfortunately the optimal parameter choices - balancing expense and accuracy are difficult to predict in advance, so - applying this function requires a degree of experimentation. - - In a tensor network code one is typically interested in A_mv implementing - some tensor contraction. This implementation thus allows `b` and `x0` to be - of whatever arbitrary, though identical, shape `b = A_mv(x0, ...)` expects. - Reshaping to and from a matrix problem is handled internally. - - The Jax backend version of GMRES uses a homemade implementation that, for - now, is suboptimal for num_krylov_vecs ~ b.size. - - For the same reason as described in eigsh_lancsoz, the function A_mv - should be Jittable (or already Jitted) and, if at all possible, defined - only once at the global scope. A new compilation will be triggered each - time an A_mv with a new function signature is passed in, even if the - 'new' function is identical to the old one (function identity is - undecidable). - - - Args: - A_mv : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and - `v` have the same shape. - b : The `b` in `A @ x = b`; it should be of the shape `A_mv` - operates on. - A_args : Positional arguments to `A_mv`, supplied to this interface - as a list. - Default: None. - A_kwargs : In the other backends, keyword arguments to `A_mv`, supplied - as a dictionary. However, the Jax backend does not support - A_mv accepting - keyword arguments since this causes problems with Jit. - Therefore, an error is thrown if A_kwargs is specified. - Default: None. - x0 : An optional guess solution. Zeros are used by default. - If `x0` is supplied, its shape and dtype must match those of - `b`, or an - error will be thrown. - Default: zeros. - tol, atol: Solution tolerance to achieve, - norm(residual) <= max(tol*norm(b), atol). - Default: tol=1E-05 - atol=tol - num_krylov_vectors - : Size of the Krylov space to build at each restart. - Expense is cubic in this parameter. - Default: 20. - maxiter : The Krylov space will be repeatedly rebuilt up to this many - times. Large values of this argument - should be used only with caution, since especially for nearly - symmetric matrices and small `num_krylov_vectors` convergence - might well freeze at a value significantly larger than `tol`. - Default: 1 - M : Inverse of the preconditioner of A; see the docstring for - `scipy.sparse.linalg.gmres`. This is unsupported in the Jax - backend, and NotImplementedError will be raised if it is - supplied. - Default: None. - - - Raises: - ValueError: -if `x0` is supplied but its shape differs from that of `b`. - -if num_krylov_vectors <= 0. - -if tol or atol was negative. - NotImplementedError: - If M is supplied. - - If A_kwargs is supplied. - TypeError: -if the dtype of `x0` and `b` are mismatching. - Returns: - x : The converged solution. It has the same shape as `b`. - info : 0 if convergence was achieved, the number of restarts otherwise. - """ - - if M is not None: - raise NotImplementedError("M is not supported by the Jax backend.") - if A_kwargs: - raise NotImplementedError("A_kwargs is not supported by the Jax backend.") - - - if A_mv not in _CACHED_MATVECS: - @libjax.tree_util.Partial - def matrix_matvec(x, *args): - x = x.reshape(b.shape) - result = A_mv(x, *args) - return result.ravel() - _CACHED_MATVECS[A_mv] = matrix_matvec - - if "gmres" not in _CACHED_FUNCTIONS: - _CACHED_FUNCTIONS["gmres"] = jitted_functions.gmres_wrapper(libjax) - gmres_m = _CACHED_FUNCTIONS["gmres"].gmres_m - x, _, n_iter, converged = gmres_m(_CACHED_MATVECS[A_mv], A_args, b.ravel(), - x0, tol, atol, num_krylov_vectors, - maxiter, self.jax_precision) - if converged: - info = 0 - else: - info = n_iter - x = self.reshape(x, b.shape) - return x, info - - def conj(self, tensor: Tensor) -> Tensor: - return jnp.conj(tensor) - - def eigh(self, matrix: Tensor) -> Tuple[Tensor, Tensor]: - return jnp.linalg.eigh(matrix) - - def addition(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: - return tensor1 + tensor2 - - def subtraction(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: - return tensor1 - tensor2 - - def multiply(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: - return tensor1 * tensor2 - - def divide(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: - return tensor1 / tensor2 - - def inv(self, matrix: Tensor) -> Tensor: - if len(matrix.shape) > 2: - raise ValueError("input to numpy backend method `inv` has shape {}." - " Only matrices are supported.".format(matrix.shape)) - return jnp.linalg.inv(matrix) - - def broadcast_right_multiplication(self, tensor1: Tensor, - tensor2: Tensor) -> Tensor: - if len(tensor2.shape) != 1: - raise ValueError("only order-1 tensors are allowed for `tensor2`," - " found `tensor2.shape = {}`".format(tensor2.shape)) - return tensor1 * tensor2 - - def broadcast_left_multiplication(self, tensor1: Tensor, - tensor2: Tensor) -> Tensor: - if len(tensor1.shape) != 1: - raise ValueError("only order-1 tensors are allowed for `tensor1`," - " found `tensor1.shape = {}`".format(tensor1.shape)) - - t1_broadcast_shape = self.shape_concat( - [self.shape_tensor(tensor1), [1] * (len(tensor2.shape) - 1)], axis=-1) - return tensor2 * self.reshape(tensor1, t1_broadcast_shape) - - def sin(self, tensor: Tensor) -> Tensor: - return jnp.sin(tensor) - - def cos(self, tensor: Tensor) -> Tensor: - return jnp.cos(tensor) - - def exp(self, tensor: Tensor) -> Tensor: - return jnp.exp(tensor) - - def log(self, tensor: Tensor) -> Tensor: - return jnp.log(tensor) - - def expm(self, matrix: Tensor) -> Tensor: - if len(matrix.shape) != 2: - raise ValueError("input to numpy backend method `expm` has shape {}." - " Only matrices are supported.".format(matrix.shape)) - if matrix.shape[0] != matrix.shape[1]: - raise ValueError("input to numpy backend method `expm` only supports" - " N*N matrix, {x}*{y} matrix is given".format( - x=matrix.shape[0], y=matrix.shape[1])) - # pylint: disable=no-member - return jsp.linalg.expm(matrix) - - def jit(self, fun: Callable, *args: List, **kwargs: dict) -> Callable: - return libjax.jit(fun, *args, **kwargs) - - def sum(self, - tensor: Tensor, - axis: Optional[Sequence[int]] = None, - keepdims: bool = False) -> Tensor: - return jnp.sum(tensor, axis=axis, keepdims=keepdims) - - def matmul(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: - if (tensor1.ndim <= 1) or (tensor2.ndim <= 1): - raise ValueError("inputs to `matmul` have to be tensors of order > 1,") - return jnp.matmul(tensor1, tensor2, precision=self.jax_precision) - - def diagonal(self, tensor: Tensor, offset: int = 0, axis1: int = -2, - axis2: int = -1) -> Tensor: - """Return specified diagonals. - - If tensor is 2-D, returns the diagonal of tensor with the given offset, - i.e., the collection of elements of the form a[i, i+offset]. - If a has more than two dimensions, then the axes specified by - axis1 and axis2 are used to determine the 2-D sub-array whose diagonal is - returned. The shape of the resulting array can be determined by removing - axis1 and axis2 and appending an index to the right equal to the size of the - resulting diagonals. - - This function only extracts diagonals. If you - wish to create diagonal matrices from vectors, use diagflat. - - Args: - tensor: A tensor. - offset: Offset of the diagonal from the main diagonal. - axis1, axis2: Axis to be used as the first/second axis of the 2D - sub-arrays from which the diagonals should be taken. - Defaults to second last/last axis. - Returns: - array_of_diagonals: A dim = min(1, tensor.ndim - 2) tensor storing - the batched diagonals. - """ - if axis1 == axis2: - raise ValueError("axis1, axis2 cannot be equal.") - return jnp.diagonal(tensor, offset=offset, axis1=axis1, axis2=axis2) - - def diagflat(self, tensor: Tensor, k: int = 0) -> Tensor: - """ Flattens tensor and creates a new matrix of zeros with its elements - on the k'th diagonal. - Args: - tensor: A tensor. - k : The diagonal upon which to place its elements. - Returns: - tensor: A new tensor with all zeros save the specified diagonal. - """ - return jnp.diag(jnp.ravel(tensor), k=k) - - def trace(self, tensor: Tensor, offset: int = 0, axis1: int = -2, - axis2: int = -1) -> Tensor: - """Return summed entries along diagonals. - - If tensor is 2-D, the sum is over the - diagonal of tensor with the given offset, - i.e., the collection of elements of the form a[i, i+offset]. - If a has more than two dimensions, then the axes specified by - axis1 and axis2 are used to determine the 2-D sub-array whose diagonal is - summed. - - Args: - tensor: A tensor. - offset: Offset of the diagonal from the main diagonal. - axis1, axis2: Axis to be used as the first/second axis of the 2D - sub-arrays from which the diagonals should be taken. - Defaults to second last/last axis. - Returns: - array_of_diagonals: The batched summed diagonals. - """ - if axis1 == axis2: - raise ValueError("axis1, axis2 cannot be equal.") - return jnp.trace(tensor, offset=offset, axis1=axis1, axis2=axis2) - - def abs(self, tensor: Tensor) -> Tensor: - """ - Returns the elementwise absolute value of tensor. - Args: - tensor: An input tensor. - Returns: - tensor: Its elementwise absolute value. - """ - return jnp.abs(tensor) - - def sign(self, tensor: Tensor) -> Tensor: - """ - Returns an elementwise tensor with entries - y[i] = 1, 0, -1 where tensor[i] > 0, == 0, and < 0 respectively. - - For complex input the behaviour of this function may depend on the backend. - The Jax backend version returns y[i] = x[i]/sqrt(x[i]^2). - - Args: - tensor: The input tensor. - """ - return jnp.sign(tensor) - - def item(self, tensor): - return tensor.item() - - def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: - """ - Returns the power of tensor a to the value of b. - In the case b is a tensor, then the power is by element - with a as the base and b as the exponent. - In the case b is a scalar, then the power of each value in a - is raised to the exponent of b. - - Args: - a: The tensor that contains the base. - b: The tensor that contains the exponent or a single scalar. - """ - return jnp.power(a, b) - - def eps(self, dtype: Type[np.number]) -> float: - """ - Return machine epsilon for given `dtype` - - Args: - dtype: A dtype. - - Returns: - float: Machine epsilon. - """ - return jnp.finfo(dtype).eps \ No newline at end of file diff --git a/jax_eigensolver/eigensolver/jitted_functions.py b/jax_eigensolver/eigensolver/jitted_functions.py deleted file mode 100644 index cfa47d6..0000000 --- a/jax_eigensolver/eigensolver/jitted_functions.py +++ /dev/null @@ -1,1506 +0,0 @@ -# Copyright 2019 The TensorNetwork Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -from typing import List, Any, Tuple, Callable, Sequence, Text -import collections -import types -import numpy as np -Tensor = Any - -def _iterative_classical_gram_schmidt(jax: types.ModuleType) -> Callable: - - JaxPrecisionType = type(jax.lax.Precision.DEFAULT) - def iterative_classical_gram_schmidt( - vector: jax.Array, - krylov_vectors: jax.Array, - precision: JaxPrecisionType, - iterations: int = 2, - ) -> jax.Array: - """ - Orthogonalize `vector` to all rows of `krylov_vectors`. - - Args: - vector: Initial vector. - krylov_vectors: Matrix of krylov vectors, each row is treated as a - vector. - iterations: Number of iterations. - - Returns: - jax.Array: The orthogonalized vector. - """ - i1 = list(range(1, len(krylov_vectors.shape))) - i2 = list(range(len(vector.shape))) - - vec = vector - overlaps = 0 - for _ in range(iterations): - ov = jax.numpy.tensordot( - krylov_vectors.conj(), vec, (i1, i2), precision=precision) - vec = vec - jax.numpy.tensordot( - ov, krylov_vectors, ([0], [0]), precision=precision) - overlaps = overlaps + ov - return vec, overlaps - - return iterative_classical_gram_schmidt - - -def _generate_jitted_eigsh_lanczos(jax: types.ModuleType) -> Callable: - """ - Helper function to generate jitted lanczos function used - in JaxBackend.eigsh_lanczos. The function `jax_lanczos` - returned by this higher-order function has the following - call signature: - ``` - eigenvalues, eigenvectors = jax_lanczos(matvec:Callable, - arguments: List[Tensor], - init: Tensor, - ncv: int, - neig: int, - landelta: float, - reortho: bool) - ``` - `matvec`: A callable implementing the matrix-vector product of a - linear operator. `arguments`: Arguments to `matvec` additional to - an input vector. `matvec` will be called as `matvec(init, *args)`. - `init`: An initial input vector to `matvec`. - `ncv`: Number of krylov iterations (i.e. dimension of the Krylov space). - `neig`: Number of eigenvalue-eigenvector pairs to be computed. - `landelta`: Convergence parameter: if the norm of the current Lanczos vector - - `reortho`: If `True`, reorthogonalize all krylov vectors at each step. - This should be used if `neig>1`. - - Args: - jax: The `jax` module. - Returns: - Callable: A jitted function that does a lanczos iteration. - - """ - JaxPrecisionType = type(jax.lax.Precision.DEFAULT) - - @functools.partial(jax.jit, static_argnums=(3, 4, 5, 6, 7)) - def jax_lanczos(matvec: Callable, arguments: List, init: jax.Array, - ncv: int, neig: int, landelta: float, reortho: bool, - precision: JaxPrecisionType) -> Tuple[jax.Array, List]: - """ - Lanczos iteration for symmeric eigenvalue problems. If reortho = False, - the Krylov basis is constructed without explicit re-orthogonalization. - In infinite precision, all Krylov vectors would be orthogonal. Due to - finite precision arithmetic, orthogonality is usually quickly lost. - For reortho=True, the Krylov basis is explicitly reorthogonalized. - - Args: - matvec: A callable implementing the matrix-vector product of a - linear operator. - arguments: Arguments to `matvec` additional to an input vector. - `matvec` will be called as `matvec(init, *args)`. - init: An initial input vector to `matvec`. - ncv: Number of krylov iterations (i.e. dimension of the Krylov space). - neig: Number of eigenvalue-eigenvector pairs to be computed. - landelta: Convergence parameter: if the norm of the current Lanczos vector - falls below `landelta`, iteration is stopped. - reortho: If `True`, reorthogonalize all krylov vectors at each step. - This should be used if `neig>1`. - precision: jax.lax.Precision type used in jax.numpy.vdot - - Returns: - jax.Array: Eigenvalues - List: Eigenvectors - int: Number of iterations - """ - shape = init.shape - dtype = init.dtype - iterative_classical_gram_schmidt = _iterative_classical_gram_schmidt(jax) - mask_slice = (slice(ncv + 2), ) + (None,) * len(shape) - def scalar_product(a, b): - i1 = list(range(len(a.shape))) - i2 = list(range(len(b.shape))) - return jax.numpy.tensordot(a.conj(), b, (i1, i2), precision=precision) - - def norm(a): - return jax.numpy.sqrt(scalar_product(a, a)) - - def body_lanczos(vals): - krylov_vectors, alphas, betas, i = vals - previous_vector = krylov_vectors[i] - def body_while(vals): - pv, kv, _ = vals - pv = iterative_classical_gram_schmidt( - pv, (i > jax.numpy.arange(ncv + 2))[mask_slice] * kv, precision)[0] - return [pv, kv, False] - - def cond_while(vals): - return vals[2] - - previous_vector, krylov_vectors, _ = jax.lax.while_loop( - cond_while, body_while, - [previous_vector, krylov_vectors, reortho]) - - beta = norm(previous_vector) - normalized_vector = previous_vector / beta - Av = matvec(normalized_vector, *arguments) - alpha = scalar_product(normalized_vector, Av) - alphas = alphas.at[i - 1].set(alpha) - betas = betas.at[i].set(beta) - - def while_next(vals): - Av, _ = vals - res = Av - normalized_vector * alpha - krylov_vectors[i - 1] * beta - return [res, False] - - def cond_next(vals): - return vals[1] - - next_vector, _ = jax.lax.while_loop( - cond_next, while_next, - [Av, jax.numpy.logical_not(reortho)]) - next_vector = jax.numpy.reshape(next_vector, shape) - - krylov_vectors = krylov_vectors.at[i].set(normalized_vector) - krylov_vectors = krylov_vectors.at[i + 1].set(next_vector) - - return [krylov_vectors, alphas, betas, i + 1] - - def cond_fun(vals): - betas, i = vals[-2], vals[-1] - norm = betas[i - 1] - return jax.lax.cond(i <= ncv, lambda x: x[0] > x[1], lambda x: False, - [norm, landelta]) - - # note: ncv + 2 because the first vector is all zeros, and the - # last is the unnormalized residual. - krylov_vecs = jax.numpy.zeros((ncv + 2,) + shape, dtype=dtype) - # NOTE (mganahl): initial vector is normalized inside the loop - krylov_vecs = krylov_vecs.at[1].set(init) - - # betas are the upper and lower diagonal elements - # of the projected linear operator - # the first two beta-values can be discarded - # set betas[0] to 1.0 for initialization of loop - # betas[2] is set to the norm of the initial vector. - betas = jax.numpy.zeros(ncv + 1, dtype=dtype) - betas = betas.at[0].set(1.0) - # diagonal elements of the projected linear operator - alphas = jax.numpy.zeros(ncv, dtype=dtype) - initvals = [krylov_vecs, alphas, betas, 1] - krylov_vecs, alphas, betas, numits = jax.lax.while_loop( - cond_fun, body_lanczos, initvals) - # FIXME (mganahl): if the while_loop stopps early at iteration i, alphas - # and betas are 0.0 at positions n >= i - 1. eigh will then wrongly give - # degenerate eigenvalues 0.0. JAX does currently not support - # dynamic slicing with variable slice sizes, so these beta values - # can't be truncated. Thus, if numeig >= i - 1, jitted_lanczos returns - # a set of spurious eigen vectors and eigen values. - # If algebraically small EVs are desired, one can initialize `alphas` with - # large positive values, thus pushing the spurious eigenvalues further - # away from the desired ones (similar for algebraically large EVs) - - #FIXME: replace with eigh_banded once JAX supports it - A_tridiag = jax.numpy.diag(alphas) + jax.numpy.diag( - betas[2:], 1) + jax.numpy.diag(jax.numpy.conj(betas[2:]), -1) - eigvals, U = jax.numpy.linalg.eigh(A_tridiag) - eigvals = eigvals.astype(dtype) - - # expand eigenvectors in krylov basis - def body_vector(i, vals): - krv, unitary, vectors = vals - dim = unitary.shape[1] - n, m = jax.numpy.divmod(i, dim) - vectors = vectors.at[n, :].set(vectors[n, :] + krv[m + 1] * unitary[m, n]) - return [krv, unitary, vectors] - - _vectors = jax.numpy.zeros((neig,) + shape, dtype=dtype) - _, _, vectors = jax.lax.fori_loop(0, neig * (krylov_vecs.shape[0] - 1), - body_vector, - [krylov_vecs, U, _vectors]) - - return jax.numpy.array(eigvals[0:neig]), [ - vectors[n] / norm(vectors[n]) for n in range(neig) - ], numits - - return jax_lanczos - - -def _generate_lanczos_factorization(jax: types.ModuleType) -> Callable: - """ - Helper function to generate a jitteed function that - computes a lanczos factoriazation of a linear operator. - Returns: - Callable: A jitted function that does a lanczos factorization. - - """ - JaxPrecisionType = type(jax.lax.Precision.DEFAULT) - - @functools.partial(jax.jit, static_argnums=(6, 7, 8, 9)) - def _lanczos_fact( - matvec: Callable, args: List, v0: jax.Array, - Vm: jax.Array, alphas: jax.Array, betas: jax.Array, - start: int, num_krylov_vecs: int, tol: float, precision: JaxPrecisionType - ): - """ - Compute an m-step lanczos factorization of `matvec`, with - m <=`num_krylov_vecs`. The factorization will - do at most `num_krylov_vecs` steps, and terminate early - if an invariat subspace is encountered. The returned arrays - `alphas`, `betas` and `Vm` will satisfy the Lanczos recurrence relation - ``` - matrix @ Vm - Vm @ Hm - fm * em = 0 - ``` - with `matrix` the matrix representation of `matvec`, - `Hm = jnp.diag(alphas) + jnp.diag(betas, -1) + jnp.diag(betas.conj(), 1)` - `fm=residual * norm`, and `em` a cartesian basis vector of shape - `(1, kv.shape[1])` with `em[0, -1] == 1` and 0 elsewhere. - - Note that the caller is responsible for dtype consistency between - the inputs, i.e. dtypes between all input arrays have to match. - - Args: - matvec: The matrix vector product. - args: List of arguments to `matvec`. - v0: Initial state to `matvec`. - Vm: An array for storing the krylov vectors. The individual - vectors are stored as columns. - The shape of `krylov_vecs` has to be - (num_krylov_vecs + 1, np.ravel(v0).shape[0]). - alphas: An array for storing the diagonal elements of the reduced - operator. - betas: An array for storing the lower diagonal elements of the - reduced operator. - start: Integer denoting the start position where the first - produced krylov_vector should be inserted into `Vm` - num_krylov_vecs: Number of krylov iterations, should be identical to - `Vm.shape[0] + 1` - tol: Convergence parameter. Iteration is terminated if the norm of a - krylov-vector falls below `tol`. - - Returns: - jax.Array: An array of shape - `(num_krylov_vecs, np.prod(initial_state.shape))` of krylov vectors. - jax.Array: The diagonal elements of the tridiagonal reduced - operator ("alphas") - jax.Array: The lower-diagonal elements of the tridiagonal reduced - operator ("betas") - jax.Array: The unnormalized residual of the Lanczos process. - float: The norm of the residual. - int: The number of performed iterations. - bool: if `True`: iteration hit an invariant subspace. - if `False`: iteration terminated without encountering - an invariant subspace. - """ - - shape = v0.shape - iterative_classical_gram_schmidt = _iterative_classical_gram_schmidt(jax) - Z = jax.numpy.linalg.norm(v0) - #only normalize if norm > tol, else return zero vector - v = jax.lax.cond(Z > tol, lambda x: v0 / Z, lambda x: v0 * 0.0, None) - Vm = Vm.at[start, :].set(jax.numpy.ravel(v)) - betas = jax.lax.cond( - start > 0, - lambda x: betas.at[start - 1].set(Z), - lambda x: betas, start) - # body of the arnoldi iteration - def body(vals): - Vm, alphas, betas, previous_vector, _, i = vals - Av = matvec(previous_vector, *args) - Av, overlaps = iterative_classical_gram_schmidt( - Av.ravel(), - (i >= jax.numpy.arange(Vm.shape[0]))[:, None] * Vm, precision) - alphas = alphas.at[i].set(overlaps[i]) - norm = jax.numpy.linalg.norm(Av) - Av = jax.numpy.reshape(Av, shape) - # only normalize if norm is larger than threshold, - # otherwise return zero vector - Av = jax.lax.cond(norm > tol, lambda x: Av/norm, lambda x: Av * 0.0, None) - Vm, betas = jax.lax.cond( - i < num_krylov_vecs - 1, - lambda x: (Vm.at[i + 1, :].set(Av.ravel()), betas.at[i].set(norm)), - lambda x: (Vm, betas), - None) - - return [Vm, alphas, betas, Av, norm, i + 1] - - def cond_fun(vals): - # Continue loop while iteration < num_krylov_vecs and norm > tol - norm, iteration = vals[4], vals[5] - counter_done = (iteration >= num_krylov_vecs) - norm_not_too_small = norm > tol - continue_iteration = jax.lax.cond(counter_done, lambda x: False, - lambda x: norm_not_too_small, None) - return continue_iteration - initial_values = [Vm, alphas, betas, v, Z, start] - final_values = jax.lax.while_loop(cond_fun, body, initial_values) - Vm, alphas, betas, residual, norm, it = final_values - return Vm, alphas, betas, residual, norm, it, norm < tol - - return _lanczos_fact - - -def _generate_arnoldi_factorization(jax: types.ModuleType) -> Callable: - """ - Helper function to create a jitted arnoldi factorization. - The function returns a function `_arnoldi_fact` which - performs an m-step arnoldi factorization. - - `_arnoldi_fact` computes an m-step arnoldi factorization - of an input callable `matvec`, with m = min(`it`,`num_krylov_vecs`). - `_arnoldi_fact` will do at most `num_krylov_vecs` steps. - `_arnoldi_fact` returns arrays `kv` and `H` which satisfy - the Arnoldi recurrence relation - ``` - matrix @ Vm - Vm @ Hm - fm * em = 0 - ``` - with `matrix` the matrix representation of `matvec` and - `Vm = jax.numpy.transpose(kv[:it, :])`, - `Hm = H[:it, :it]`, `fm = np.expand_dims(kv[it, :] * H[it, it - 1]`,1) - and `em` a kartesian basis vector of shape `(1, kv.shape[1])` - with `em[0, -1] == 1` and 0 elsewhere. - - Note that the caller is responsible for dtype consistency between - the inputs, i.e. dtypes between all input arrays have to match. - - Args: - matvec: The matrix vector product. This function has to be wrapped into - `jax.tree_util.Partial`. `matvec` will be called as `matvec(x, *args)` - for an input vector `x`. - args: List of arguments to `matvec`. - v0: Initial state to `matvec`. - Vm: An array for storing the krylov vectors. The individual - vectors are stored as columns. The shape of `krylov_vecs` has to be - (num_krylov_vecs + 1, np.ravel(v0).shape[0]). - H: Matrix of overlaps. The shape has to be - (num_krylov_vecs + 1,num_krylov_vecs + 1). - start: Integer denoting the start position where the first - produced krylov_vector should be inserted into `Vm` - num_krylov_vecs: Number of krylov iterations, should be identical to - `Vm.shape[0] + 1` - tol: Convergence parameter. Iteration is terminated if the norm of a - krylov-vector falls below `tol`. - - Returns: - kv: An array of krylov vectors - H: A matrix of overlaps - it: The number of performed iterations. - converged: Whether convergence was achieved. - - """ - JaxPrecisionType = type(jax.lax.Precision.DEFAULT) - iterative_classical_gram_schmidt = _iterative_classical_gram_schmidt(jax) - - @functools.partial(jax.jit, static_argnums=(5, 6, 7, 8)) - def _arnoldi_fact( - matvec: Callable, args: List, v0: jax.Array, - Vm: jax.Array, H: jax.Array, start: int, - num_krylov_vecs: int, tol: float, precision: JaxPrecisionType - ) -> Tuple[jax.Array, jax.Array, jax.Array, float, int, - bool]: - """ - Compute an m-step arnoldi factorization of `matvec`, with - m = min(`it`,`num_krylov_vecs`). The factorization will - do at most `num_krylov_vecs` steps. The returned arrays - `kv` and `H` will satisfy the Arnoldi recurrence relation - ``` - matrix @ Vm - Vm @ Hm - fm * em = 0 - ``` - with `matrix` the matrix representation of `matvec` and - `Vm = jax.numpy.transpose(kv[:it, :])`, - `Hm = H[:it, :it]`, `fm = np.expand_dims(kv[it, :] * H[it, it - 1]`,1) - and `em` a cartesian basis vector of shape `(1, kv.shape[1])` - with `em[0, -1] == 1` and 0 elsewhere. - - Note that the caller is responsible for dtype consistency between - the inputs, i.e. dtypes between all input arrays have to match. - - Args: - matvec: The matrix vector product. - args: List of arguments to `matvec`. - v0: Initial state to `matvec`. - Vm: An array for storing the krylov vectors. The individual - vectors are stored as columns. - The shape of `krylov_vecs` has to be - (num_krylov_vecs + 1, np.ravel(v0).shape[0]). - H: Matrix of overlaps. The shape has to be - (num_krylov_vecs + 1,num_krylov_vecs + 1). - start: Integer denoting the start position where the first - produced krylov_vector should be inserted into `Vm` - num_krylov_vecs: Number of krylov iterations, should be identical to - `Vm.shape[0] + 1` - tol: Convergence parameter. Iteration is terminated if the norm of a - krylov-vector falls below `tol`. - Returns: - jax.Array: An array of shape - `(num_krylov_vecs, np.prod(initial_state.shape))` of krylov vectors. - jax.Array: Upper Hessenberg matrix of shape - `(num_krylov_vecs, num_krylov_vecs`) of the Arnoldi processs. - jax.Array: The unnormalized residual of the Arnoldi process. - int: The norm of the residual. - int: The number of performed iterations. - bool: if `True`: iteration hit an invariant subspace. - if `False`: iteration terminated without encountering - an invariant subspace. - """ - - # Note (mganahl): currently unused, but is very convenient to have - # for further development and tests (it's usually more accurate than - # classical gs) - # Call signature: - #```python - # initial_vals = [Av.ravel(), Vm, i, H] - # Av, Vm, _, H = jax.lax.fori_loop( - # 0, i + 1, modified_gram_schmidt_step_arnoldi, initial_vals) - #``` - def modified_gram_schmidt_step_arnoldi(j, vals): #pylint: disable=unused-variable - """ - Single step of a modified gram-schmidt orthogonalization. - Substantially more accurate than classical gram schmidt - Args: - j: Integer value denoting the vector to be orthogonalized. - vals: A list of variables: - `vector`: The current vector to be orthogonalized - to all previous ones - `Vm`: jax.array of collected krylov vectors - `n`: integer denoting the column-position of the overlap - <`krylov_vector`|`vector`> within `H`. - Returns: - updated vals. - - """ - vector, krylov_vectors, n, H = vals - v = krylov_vectors[j, :] - h = jax.numpy.vdot(v, vector, precision=precision) - H = H.at[j, n].set(h) - vector = vector - h * v - return [vector, krylov_vectors, n, H] - - shape = v0.shape - Z = jax.numpy.linalg.norm(v0) - #only normalize if norm > tol, else return zero vector - v = jax.lax.cond(Z > tol, lambda x: v0 / Z, lambda x: v0 * 0.0, None) - Vm = Vm.at[start, :].set(jax.numpy.ravel(v)) - H = jax.lax.cond( - start > 0, - lambda x: H.at[x, x - 1].set(Z), - lambda x: H, start) - # body of the arnoldi iteration - def body(vals): - Vm, H, previous_vector, _, i = vals - Av = matvec(previous_vector, *args) - - Av, overlaps = iterative_classical_gram_schmidt( - Av.ravel(), - (i >= jax.numpy.arange(Vm.shape[0]))[:, None] * - Vm, precision) - H = H.at[:, i].set(overlaps) - norm = jax.numpy.linalg.norm(Av) - Av = jax.numpy.reshape(Av, shape) - - # only normalize if norm is larger than threshold, - # otherwise return zero vector - Av = jax.lax.cond(norm > tol, lambda x: Av/norm, lambda x: Av * 0.0, None) - Vm, H = jax.lax.cond( - i < num_krylov_vecs - 1, - lambda x: (Vm.at[i + 1, :].set(Av.ravel()), H.at[i + 1, i].set(norm)), #pylint: disable=line-too-long - lambda x: (x[0], x[1]), - (Vm, H, Av, i, norm)) - - return [Vm, H, Av, norm, i + 1] - - def cond_fun(vals): - # Continue loop while iteration < num_krylov_vecs and norm > tol - norm, iteration = vals[3], vals[4] - counter_done = (iteration >= num_krylov_vecs) - norm_not_too_small = norm > tol - continue_iteration = jax.lax.cond(counter_done, lambda x: False, - lambda x: norm_not_too_small, None) - return continue_iteration - - initial_values = [Vm, H, v, Z, start] - final_values = jax.lax.while_loop(cond_fun, body, initial_values) - Vm, H, residual, norm, it = final_values - return Vm, H, residual, norm, it, norm < tol - - return _arnoldi_fact - -# ###################################################### -# ####### NEW SORTING FUCTIONS INSERTED HERE ######### -# ###################################################### -def _LR_sort(jax): - @functools.partial(jax.jit, static_argnums=(0,)) - def sorter( - p: int, - evals: jax.Array) -> Tuple[jax.Array, jax.Array]: - inds = jax.numpy.argsort(jax.numpy.real(evals), kind='stable')[::-1] - shifts = evals[inds][-p:] - return shifts, inds - return sorter - -def _SA_sort(jax): - @functools.partial(jax.jit, static_argnums=(0,)) - def sorter( - p: int, - evals: jax.Array) -> Tuple[jax.Array, jax.Array]: - inds = jax.numpy.argsort(evals, kind='stable') - shifts = evals[inds][-p:] - return shifts, inds - return sorter - -def _LA_sort(jax): - @functools.partial(jax.jit, static_argnums=(0,)) - def sorter( - p: int, - evals: jax.Array) -> Tuple[jax.Array, jax.Array]: - inds = jax.numpy.argsort(evals, kind='stable')[::-1] - shifts = evals[inds][-p:] - return shifts, inds - return sorter - -def _LM_sort(jax): - @functools.partial(jax.jit, static_argnums=(0,)) - def sorter( - p: int, - evals: jax.Array) -> Tuple[jax.Array, jax.Array]: - inds = jax.numpy.argsort(jax.numpy.abs(evals), kind='stable')[::-1] - shifts = evals[inds][-p:] - return shifts, inds - return sorter - -# #################################################### -# #################################################### - -def _shifted_QR(jax): - @functools.partial(jax.jit, static_argnums=(4,)) - def shifted_QR( - Vm: jax.Array, Hm: jax.Array, fm: jax.Array, - shifts: jax.Array, - numeig: int) -> Tuple[jax.Array, jax.Array, jax.Array]: - # compress arnoldi factorization - q = jax.numpy.zeros(Hm.shape[0], dtype=Hm.dtype) - q = q.at[-1].set(1.0) - - def body(i, vals): - Vm, Hm, q = vals - shift = shifts[i] * jax.numpy.eye(Hm.shape[0], dtype=Hm.dtype) - Qj, R = jax.numpy.linalg.qr(Hm - shift) - Hm = R @ Qj + shift - Vm = Qj.T @ Vm - q = q @ Qj - return Vm, Hm, q - - Vm, Hm, q = jax.lax.fori_loop(0, shifts.shape[0], body, - (Vm, Hm, q)) - fk = Vm[numeig, :] * Hm[numeig, numeig - 1] + fm * q[numeig - 1] - return Vm, Hm, fk - return shifted_QR - -def _get_vectors(jax): - @functools.partial(jax.jit, static_argnums=(3,)) - def get_vectors(Vm: jax.Array, unitary: jax.Array, - inds: jax.Array, numeig: int) -> jax.Array: - - def body_vector(i, states): - dim = unitary.shape[1] - n, m = jax.numpy.divmod(i, dim) - states = states.at[n, :].set(states[n,:] + Vm[m, :] * unitary[m, inds[n]]) - return states - - state_vectors = jax.numpy.zeros([numeig, Vm.shape[1]], dtype=Vm.dtype) - state_vectors = jax.lax.fori_loop(0, numeig * Vm.shape[0], body_vector, - state_vectors) - state_norms = jax.numpy.linalg.norm(state_vectors, axis=1) - state_vectors = state_vectors / state_norms[:, None] - return state_vectors - - return get_vectors - -def _check_eigvals_convergence_eigh(jax): - @functools.partial(jax.jit, static_argnums=(3,)) - def check_eigvals_convergence(beta_m: float, Hm: jax.Array, - Hm_norm: float, - tol: float) -> bool: - eigvals, eigvecs = jax.numpy.linalg.eigh(Hm) - # TODO (mganahl) confirm that this is a valid matrix norm) - thresh = jax.numpy.maximum( - jax.numpy.finfo(eigvals.dtype).eps * Hm_norm, - jax.numpy.abs(eigvals) * tol) - vals = jax.numpy.abs(eigvecs[-1, :]) - return jax.numpy.all(beta_m * vals < thresh) - - return check_eigvals_convergence - -def _check_eigvals_convergence_eig(jax): - @functools.partial(jax.jit, static_argnums=(2, 3)) - def check_eigvals_convergence(beta_m: float, Hm: jax.Array, - tol: float, numeig: int) -> bool: - eigvals, eigvecs = jax.numpy.linalg.eig(Hm) - # TODO (mganahl) confirm that this is a valid matrix norm) - Hm_norm = jax.numpy.linalg.norm(Hm) - thresh = jax.numpy.maximum( - jax.numpy.finfo(eigvals.dtype).eps * Hm_norm, - jax.numpy.abs(eigvals[:numeig]) * tol) - vals = jax.numpy.abs(eigvecs[numeig - 1, :numeig]) - return jax.numpy.all(beta_m * vals < thresh) - - return check_eigvals_convergence - -def _implicitly_restarted_arnoldi(jax: types.ModuleType) -> Callable: - """ - Helper function to generate a jitted function to do an - implicitly restarted arnoldi factorization of `matvec`. The - returned routine finds the lowest `numeig` - eigenvector-eigenvalue pairs of `matvec` - by alternating between compression and re-expansion of an initial - `num_krylov_vecs`-step Arnoldi factorization. - - Note: The caller has to ensure that the dtype of the return value - of `matvec` matches the dtype of the initial state. Otherwise jax - will raise a TypeError. - - The function signature of the returned function is - Args: - matvec: A callable representing the linear operator. - args: Arguments to `matvec`. `matvec` is called with - `matvec(x, *args)` with `x` the input array on which - `matvec` should act. - initial_state: An starting vector for the iteration. - num_krylov_vecs: Number of krylov vectors of the arnoldi factorization. - numeig: The number of desired eigenvector-eigenvalue pairs. - which: Which eigenvalues to target. Currently supported: `which = 'LR'`. - tol: Convergence flag. If the norm of a krylov vector drops below `tol` - the iteration is terminated. - maxiter: Maximum number of (outer) iteration steps. - Returns: - eta, U: Two lists containing eigenvalues and eigenvectors. - - Args: - jax: The jax module. - Returns: - Callable: A function performing an implicitly restarted - Arnoldi factorization - """ - JaxPrecisionType = type(jax.lax.Precision.DEFAULT) - - arnoldi_fact = _generate_arnoldi_factorization(jax) - - - @functools.partial(jax.jit, static_argnums=(3, 4, 5, 6, 7, 8)) - def implicitly_restarted_arnoldi_method( - matvec: Callable, args: List, initial_state: jax.Array, - num_krylov_vecs: int, numeig: int, which: Text, tol: float, maxiter: int, - precision: JaxPrecisionType - ) -> Tuple[jax.Array, List[jax.Array], int]: - """ - Implicitly restarted arnoldi factorization of `matvec`. The routine - finds the lowest `numeig` eigenvector-eigenvalue pairs of `matvec` - by alternating between compression and re-expansion of an initial - `num_krylov_vecs`-step Arnoldi factorization. - - Note: The caller has to ensure that the dtype of the return value - of `matvec` matches the dtype of the initial state. Otherwise jax - will raise a TypeError. - - NOTE: Under certain circumstances, the routine can return spurious - eigenvalues 0.0: if the Arnoldi iteration terminated early - (after numits < num_krylov_vecs iterations) - and numeig > numits, then spurious 0.0 eigenvalues will be returned. - - Args: - matvec: A callable representing the linear operator. - args: Arguments to `matvec`. `matvec` is called with - `matvec(x, *args)` with `x` the input array on which - `matvec` should act. - initial_state: An starting vector for the iteration. - num_krylov_vecs: Number of krylov vectors of the arnoldi factorization. - numeig: The number of desired eigenvector-eigenvalue pairs. - which: Which eigenvalues to target. - Currently supported: `which = 'LR'` (largest real part). - tol: Convergence flag. If the norm of a krylov vector drops below `tol` - the iteration is terminated. - maxiter: Maximum number of (outer) iteration steps. - precision: jax.lax.Precision used within lax operations. - - Returns: - jax.Array: Eigenvalues - List: Eigenvectors - int: Number of inner krylov iterations of the last arnoldi - factorization. - """ - shape = initial_state.shape - dtype = initial_state.dtype - - dim = np.prod(shape).astype(np.int32) - num_expand = num_krylov_vecs - numeig - if not numeig <= num_krylov_vecs <= dim: - raise ValueError(f"num_krylov_vecs must be between numeig <=" - f" num_krylov_vecs <= dim, got " - f" numeig = {numeig}, num_krylov_vecs = " - f"{num_krylov_vecs}, dim = {dim}.") - if numeig > dim: - raise ValueError(f"number of requested eigenvalues numeig = {numeig} " - f"is larger than the dimension of the operator " - f"dim = {dim}") - - # initialize arrays - Vm = jax.numpy.zeros( - (num_krylov_vecs, jax.numpy.ravel(initial_state).shape[0]), dtype=dtype) - Hm = jax.numpy.zeros((num_krylov_vecs, num_krylov_vecs), dtype=dtype) - # perform initial arnoldi factorization - Vm, Hm, residual, norm, numits, ar_converged = arnoldi_fact( - matvec, args, initial_state, Vm, Hm, 0, num_krylov_vecs, tol, precision) - fm = residual.ravel() * norm - - # generate needed functions - shifted_QR = _shifted_QR(jax) - check_eigvals_convergence = _check_eigvals_convergence_eig(jax) - get_vectors = _get_vectors(jax) - - # sort_fun returns `num_expand` least relevant eigenvalues - # (those to be projected out) - if which == 'LR': - sort_fun = jax.tree_util.Partial(_LR_sort(jax), num_expand) - elif which == 'LM': - sort_fun = jax.tree_util.Partial(_LM_sort(jax), num_expand) - else: - raise ValueError(f"which = {which} not implemented") - - it = 1 # we already did one arnoldi factorization - if maxiter > 1: - # cast arrays to correct complex dtype - if Vm.dtype == np.float64: - dtype = np.complex128 - elif Vm.dtype == np.float32: - dtype = np.complex64 - elif Vm.dtype == np.complex128: - dtype = Vm.dtype - elif Vm.dtype == np.complex64: - dtype = Vm.dtype - else: - raise TypeError(f'dtype {Vm.dtype} not supported') - - Vm = Vm.astype(dtype) - Hm = Hm.astype(dtype) - fm = fm.astype(dtype) - - def outer_loop(carry): - Hm, Vm, fm, it, numits, ar_converged, _, _, = carry - evals, _ = jax.numpy.linalg.eig(Hm) - shifts, _ = sort_fun(evals) - # perform shifted QR iterations to compress arnoldi factorization - # Note that ||fk|| typically decreases as one iterates the outer loop - # indicating that iram converges. - # ||fk|| = \beta_m in reference above - Vk, Hk, fk = shifted_QR(Vm, Hm, fm, shifts, numeig) - # reset matrices - beta_k = jax.numpy.linalg.norm(fk) - converged = check_eigvals_convergence(beta_k, Hk, tol, numeig) - Vk = Vk.at[numeig:, :].set(0.0) - Hk = Hk.at[numeig:, :].set(0.0) - Hk = Hk.at[:, numeig:].set(0.0) - def do_arnoldi(vals): - Vk, Hk, fk, _, _, _, _ = vals - # restart - Vm, Hm, residual, norm, numits, ar_converged = arnoldi_fact( - matvec, args, jax.numpy.reshape(fk, shape), Vk, Hk, numeig, - num_krylov_vecs, tol, precision) - fm = residual.ravel() * norm - return [Vm, Hm, fm, norm, numits, ar_converged, False] - - def cond_arnoldi(vals): - return vals[6] - - res = jax.lax.while_loop(cond_arnoldi, do_arnoldi, [ - Vk, Hk, fk, - jax.numpy.linalg.norm(fk), numeig, False, - jax.numpy.logical_not(converged) - ]) - - Vm, Hm, fm, norm, numits, ar_converged = res[0:6] - out_vars = [ - Hm, Vm, fm, it + 1, numits, ar_converged, converged, norm - ] - return out_vars - - def cond_fun(carry): - it, ar_converged, converged = carry[3], carry[5], carry[ - 6] - return jax.lax.cond( - it < maxiter, lambda x: x, lambda x: False, - jax.numpy.logical_not(jax.numpy.logical_or(converged, ar_converged))) - - converged = False - carry = [Hm, Vm, fm, it, numits, ar_converged, converged, norm] - res = jax.lax.while_loop(cond_fun, outer_loop, carry) - Hm, Vm = res[0], res[1] - numits, converged = res[4], res[6] - # if `ar_converged` then `norm`is below convergence threshold - # set it to 0.0 in this case to prevent `jnp.linalg.eig` from finding a - # spurious eigenvalue of order `norm`. - Hm = Hm.at[numits, numits - 1].set( - jax.lax.cond(converged, lambda x: Hm.dtype.type(0.0), lambda x: x, - Hm[numits, numits - 1])) - - # if the Arnoldi-factorization stopped early (after `numit` iterations) - # before exhausting the allowed size of the Krylov subspace, - # (i.e. `numit` < 'num_krylov_vecs'), set elements - # at positions m, n with m, n >= `numit` to 0.0. - - # FIXME (mganahl): under certain circumstances, the routine can still - # return spurious 0 eigenvalues: if arnoldi terminated early - # (after numits < num_krylov_vecs iterations) - # and numeig > numits, then spurious 0.0 eigenvalues will be returned - - Hm = (numits > jax.numpy.arange(num_krylov_vecs))[:, None] * Hm * ( - numits > jax.numpy.arange(num_krylov_vecs))[None, :] - eigvals, U = jax.numpy.linalg.eig(Hm) - inds = sort_fun(eigvals)[1][:numeig] - vectors = get_vectors(Vm, U, inds, numeig) - return eigvals[inds], [ - jax.numpy.reshape(vectors[n, :], shape) - for n in range(numeig) - ], numits - - return implicitly_restarted_arnoldi_method - - -def _implicitly_restarted_lanczos(jax: types.ModuleType) -> Callable: - """ - Helper function to generate a jitted function to do an - implicitly restarted lanczos factorization of `matvec`. The - returned routine finds the lowest `numeig` - eigenvector-eigenvalue pairs of `matvec` - by alternating between compression and re-expansion of an initial - `num_krylov_vecs`-step Lanczos factorization. - - Note: The caller has to ensure that the dtype of the return value - of `matvec` matches the dtype of the initial state. Otherwise jax - will raise a TypeError. - - The function signature of the returned function is - Args: - matvec: A callable representing the linear operator. - args: Arguments to `matvec`. `matvec` is called with - `matvec(x, *args)` with `x` the input array on which - `matvec` should act. - initial_state: An starting vector for the iteration. - num_krylov_vecs: Number of krylov vectors of the lanczos factorization. - numeig: The number of desired eigenvector-eigenvalue pairs. - which: Which eigenvalues to target. Currently supported: `which = 'LR'` - or `which = 'SR'`. - tol: Convergence flag. If the norm of a krylov vector drops below `tol` - the iteration is terminated. - maxiter: Maximum number of (outer) iteration steps. - Returns: - eta, U: Two lists containing eigenvalues and eigenvectors. - - Args: - jax: The jax module. - Returns: - Callable: A function performing an implicitly restarted - Lanczos factorization - """ - JaxPrecisionType = type(jax.lax.Precision.DEFAULT) - lanczos_fact = _generate_lanczos_factorization(jax) - - @functools.partial(jax.jit, static_argnums=(3, 4, 5, 6, 7, 8)) - def implicitly_restarted_lanczos_method( - matvec: Callable, args: List, initial_state: jax.Array, - num_krylov_vecs: int, numeig: int, which: Text, tol: float, maxiter: int, - precision: JaxPrecisionType - ) -> Tuple[jax.Array, List[jax.Array], int]: - """ - Implicitly restarted lanczos factorization of `matvec`. The routine - finds the lowest `numeig` eigenvector-eigenvalue pairs of `matvec` - by alternating between compression and re-expansion of an initial - `num_krylov_vecs`-step Lanczos factorization. - - Note: The caller has to ensure that the dtype of the return value - of `matvec` matches the dtype of the initial state. Otherwise jax - will raise a TypeError. - - NOTE: Under certain circumstances, the routine can return spurious - eigenvalues 0.0: if the Lanczos iteration terminated early - (after numits < num_krylov_vecs iterations) - and numeig > numits, then spurious 0.0 eigenvalues will be returned. - - References: - http://emis.impa.br/EMIS/journals/ETNA/vol.2.1994/pp1-21.dir/pp1-21.pdf - http://people.inf.ethz.ch/arbenz/ewp/Lnotes/chapter11.pdf - - Args: - matvec: A callable representing the linear operator. - args: Arguments to `matvec`. `matvec` is called with - `matvec(x, *args)` with `x` the input array on which - `matvec` should act. - initial_state: An starting vector for the iteration. - num_krylov_vecs: Number of krylov vectors of the lanczos factorization. - numeig: The number of desired eigenvector-eigenvalue pairs. - which: Which eigenvalues to target. - Currently supported: `which = 'LR'` (largest real part). - tol: Convergence flag. If the norm of a krylov vector drops below `tol` - the iteration is terminated. - maxiter: Maximum number of (outer) iteration steps. - precision: jax.lax.Precision used within lax operations. - - Returns: - jax.Array: Eigenvalues - List: Eigenvectors - int: Number of inner krylov iterations of the last lanczos - factorization. - """ - shape = initial_state.shape - dtype = initial_state.dtype - - dim = np.prod(shape).astype(np.int32) - num_expand = num_krylov_vecs - numeig - #note: the second part of the cond is for testing purposes - if num_krylov_vecs <= numeig < dim: - raise ValueError(f"num_krylov_vecs must be between numeig <" - f" num_krylov_vecs <= dim = {dim}," - f" num_krylov_vecs = {num_krylov_vecs}") - if numeig > dim: - raise ValueError(f"number of requested eigenvalues numeig = {numeig} " - f"is larger than the dimension of the operator " - f"dim = {dim}") - - # initialize arrays - Vm = jax.numpy.zeros( - (num_krylov_vecs, jax.numpy.ravel(initial_state).shape[0]), dtype=dtype) - alphas = jax.numpy.zeros(num_krylov_vecs, dtype=dtype) - betas = jax.numpy.zeros(num_krylov_vecs - 1, dtype=dtype) - - # perform initial lanczos factorization - Vm, alphas, betas, residual, norm, numits, ar_converged = lanczos_fact( - matvec, args, initial_state, Vm, alphas, betas, 0, num_krylov_vecs, tol, - precision) - fm = residual.ravel() * norm - # generate needed functions - shifted_QR = _shifted_QR(jax) - check_eigvals_convergence = _check_eigvals_convergence_eigh(jax) - get_vectors = _get_vectors(jax) - - # sort_fun returns `num_expand` least relevant eigenvalues - # (those to be projected out) - if which == 'LA': - sort_fun = jax.tree_util.Partial(_LA_sort(jax), num_expand) - elif which == 'SA': - sort_fun = jax.tree_util.Partial(_SA_sort(jax), num_expand) - elif which == 'LM': - sort_fun = jax.tree_util.Partial(_LM_sort(jax), num_expand) - else: - raise ValueError(f"which = {which} not implemented") - - it = 1 # we already did one lanczos factorization - def outer_loop(carry): - alphas, betas, Vm, fm, it, numits, ar_converged, _, _, = carry - # pack into alphas and betas into tridiagonal matrix - Hm = jax.numpy.diag(alphas) + jax.numpy.diag(betas, -1) + jax.numpy.diag( - betas.conj(), 1) - evals, _ = jax.numpy.linalg.eigh(Hm) - shifts, _ = sort_fun(evals) - # perform shifted QR iterations to compress lanczos factorization - # Note that ||fk|| typically decreases as one iterates the outer loop - # indicating that iram converges. - # ||fk|| = \beta_m in reference above - Vk, Hk, fk = shifted_QR(Vm, Hm, fm, shifts, numeig) - # extract new alphas and betas - alphas = jax.numpy.diag(Hk) - betas = jax.numpy.diag(Hk, -1) - alphas = alphas.at[numeig:].set(0.0) - betas = betas.at[numeig-1:].set(0.0) - - beta_k = jax.numpy.linalg.norm(fk) - Hktest = Hk[:numeig, :numeig] - matnorm = jax.numpy.linalg.norm(Hktest) - converged = check_eigvals_convergence(beta_k, Hktest, matnorm, tol) - - - def do_lanczos(vals): - Vk, alphas, betas, fk, _, _, _, _ = vals - # restart - Vm, alphas, betas, residual, norm, numits, ar_converged = lanczos_fact( - matvec, args, jax.numpy.reshape(fk, shape), Vk, alphas, betas, - numeig, num_krylov_vecs, tol, precision) - fm = residual.ravel() * norm - return [Vm, alphas, betas, fm, norm, numits, ar_converged, False] - - def cond_lanczos(vals): - return vals[7] - - res = jax.lax.while_loop(cond_lanczos, do_lanczos, [ - Vk, alphas, betas, fk, - jax.numpy.linalg.norm(fk), numeig, False, - jax.numpy.logical_not(converged) - ]) - - Vm, alphas, betas, fm, norm, numits, ar_converged = res[0:7] - - out_vars = [ - alphas, betas, Vm, fm, it + 1, numits, ar_converged, converged, norm - ] - return out_vars - - def cond_fun(carry): - it, ar_converged, converged = carry[4], carry[6], carry[7] - return jax.lax.cond( - it < maxiter, lambda x: x, lambda x: False, - jax.numpy.logical_not(jax.numpy.logical_or(converged, ar_converged))) - - converged = False - carry = [alphas, betas, Vm, fm, it, numits, ar_converged, converged, norm] - res = jax.lax.while_loop(cond_fun, outer_loop, carry) - alphas, betas, Vm = res[0], res[1], res[2] - numits, ar_converged, converged = res[5], res[6], res[7] - Hm = jax.numpy.diag(alphas) + jax.numpy.diag(betas, -1) + jax.numpy.diag( - betas.conj(), 1) - # FIXME (mganahl): under certain circumstances, the routine can still - # return spurious 0 eigenvalues: if lanczos terminated early - # (after numits < num_krylov_vecs iterations) - # and numeig > numits, then spurious 0.0 eigenvalues will be returned - Hm = (numits > jax.numpy.arange(num_krylov_vecs))[:, None] * Hm * ( - numits > jax.numpy.arange(num_krylov_vecs))[None, :] - - eigvals, U = jax.numpy.linalg.eigh(Hm) - inds = sort_fun(eigvals)[1][:numeig] - vectors = get_vectors(Vm, U, inds, numeig) - return eigvals[inds], [ - jax.numpy.reshape(vectors[n, :], shape) for n in range(numeig) - ], numits - - return implicitly_restarted_lanczos_method - - -def gmres_wrapper(jax: types.ModuleType): - """ - Allows Jax (the module) to be passed in as an argument rather than imported, - since doing the latter breaks the build. In addition, instantiates certain - of the enclosed functions as concrete objects within a Dict, allowing them to - be cached. This avoids spurious recompilations that would otherwise be - triggered by attempts to pass callables into Jitted functions. - - The important function here is functions["gmres_m"], which implements - GMRES. The other functions are exposed only for testing. - - Args: - ---- - jax: The imported Jax module. - - Returns: - ------- - functions: A namedtuple of functions: - functions.gmres_m = gmres_m - functions.gmres_residual = gmres_residual - functions.gmres_krylov = gmres_krylov - functions.gs_step = _gs_step - functions.kth_arnoldi_step = kth_arnoldi_step - functions.givens_rotation = givens_rotation - """ - jnp = jax.numpy - JaxPrecisionType = type(jax.lax.Precision.DEFAULT) - def gmres_m( - A_mv: Callable, A_args: Sequence, b: jax.Array, x0: jax.Array, - tol: float, atol: float, num_krylov_vectors: int, maxiter: int, - precision: JaxPrecisionType) -> Tuple[jax.Array, float, int, bool]: - """ - Solve A x = b for x using the m-restarted GMRES method. This is - intended to be called via jax_backend.gmres. - - Given a linear mapping with (n x n) matrix representation - A = A_mv(*A_args) gmres_m solves - Ax = b (1) - where x and b are length-n vectors, using the method of - Generalized Minimum RESiduals with M iterations per restart (GMRES_M). - - Args: - A_mv: A function v0 = A_mv(v, *A_args) where v0 and v have the same shape. - A_args: A list of positional arguments to A_mv. - b: The b in A @ x = b. - x0: Initial guess solution. - tol, atol: Solution tolerance to achieve, - norm(residual) <= max(tol * norm(b), atol). - tol is also used to set the threshold at which the Arnoldi factorization - terminates. - num_krylov_vectors: Size of the Krylov space to build at each restart. - maxiter: The Krylov space will be repeatedly rebuilt up to this many - times. - Returns: - x: The approximate solution. - beta: Norm of the residual at termination. - n_iter: Number of iterations at termination. - converged: Whether the desired tolerance was achieved. - """ - num_krylov_vectors = min(num_krylov_vectors, b.size) - x = x0 - b_norm = jnp.linalg.norm(b) - tol = max(tol * b_norm, atol) - for n_iter in range(maxiter): - done, beta, x = gmres(A_mv, A_args, b, x, num_krylov_vectors, x0, tol, - b_norm, precision) - if done: - break - return x, beta, n_iter, done - - def gmres(A_mv: Callable, A_args: Sequence, b: jax.Array, - x: jax.Array, num_krylov_vectors: int, x0: jax.Array, - tol: float, b_norm: float, - precision: JaxPrecisionType) -> Tuple[bool, float, jax.Array]: - """ - A single restart of GMRES. - - Args: - A_mv: A function `v0 = A_mv(v, *A_args)` where `v0` and - `v` have the same shape. - A_args: A list of positional arguments to A_mv. - b: The `b` in `A @ x = b`. - x: Initial guess solution. - tol: Solution tolerance to achieve, - num_krylov_vectors : Size of the Krylov space to build. - Returns: - done: Whether convergence was achieved. - beta: Magnitude of residual (i.e. the error estimate). - x: The approximate solution. - """ - r, beta = gmres_residual(A_mv, A_args, b, x) - k, V, R, beta_vec = gmres_krylov(A_mv, A_args, num_krylov_vectors, - x0, r, beta, tol, b_norm, precision) - x = gmres_update(k, V, R, beta_vec, x0) - done = k < num_krylov_vectors - 1 - return done, beta, x - - @jax.jit - def gmres_residual(A_mv: Callable, A_args: Sequence, b: jax.Array, - x: jax.Array) -> Tuple[jax.Array, float]: - """ - Computes the residual vector r and its norm, beta, which is minimized by - GMRES. - - Args: - A_mv: A function v0 = A_mv(v, *A_args) where v0 and - v have the same shape. - A_args: A list of positional arguments to A_mv. - b: The b in A @ x = b. - x: Initial guess solution. - Returns: - r: The residual vector. - beta: Its magnitude. - """ - r = b - A_mv(x, *A_args) - beta = jnp.linalg.norm(r) - return r, beta - - def gmres_update(k: int, V: jax.Array, R: jax.Array, - beta_vec: jax.Array, - x0: jax.Array) -> jax.Array: - """ - Updates the solution in response to the information computed by the - main GMRES loop. - - Args: - k: The final iteration which was reached by GMRES before convergence. - V: The Arnoldi matrix of Krylov vectors. - R: The R factor in H = QR where H is the Arnoldi overlap matrix. - beta_vec: Stores the Givens factors used to map H into QR. - x0: The initial guess solution. - Returns: - x: The updated solution. - """ - q = min(k, R.shape[1]) - y = jax.scipy.linalg.solve_triangular(R[:q, :q], beta_vec[:q]) - x = x0 + V[:, :q] @ y - return x - - @functools.partial(jax.jit, static_argnums=(2, 8)) - def gmres_krylov( - A_mv: Callable, A_args: Sequence, n_kry: int, x0: jax.Array, - r: jax.Array, beta: float, tol: float, b_norm: float, - precision: JaxPrecisionType - ) -> Tuple[int, jax.Array, jax.Array, jax.Array]: - """ - Builds the Arnoldi decomposition of (A, v), where v is the normalized - residual of the current solution estimate. The decomposition is - returned as V, R, where V is the usual matrix of Krylov vectors and - R is the upper triangular matrix in H = QR, with H the usual matrix - of overlaps. - - Args: - A_mv: A function `v0 = A_mv(v, *A_args)` where `v0` and - `v` have the same shape. - A_args: A list of positional arguments to A_mv. - n_kry: Size of the Krylov space to build; this is called - num_krylov_vectors in higher level code. - x0: Guess solution. - r: Residual vector. - beta: Magnitude of r. - tol: Solution tolerance to achieve. - b_norm: Magnitude of b in Ax = b. - Returns: - k: Counts the number of iterations before convergence. - V: The Arnoldi matrix of Krylov vectors. - R: From H = QR where H is the Arnoldi matrix of overlaps. - beta_vec: Stores Q implicitly as Givens factors. - """ - n = r.size - err = beta - v = r / beta - - # These will store the Givens rotations used to update the QR decompositions - # of the Arnoldi matrices. - # cos : givens[0, :] - # sine: givens[1, :] - givens = jnp.zeros((2, n_kry), dtype=x0.dtype) - beta_vec = jnp.zeros((n_kry + 1), dtype=x0.dtype) - beta_vec = beta_vec.at[0].set(beta) - V = jnp.zeros((n, n_kry + 1), dtype=x0.dtype) - V = V.at[:, 0].set(v) - R = jnp.zeros((n_kry + 1, n_kry), dtype=x0.dtype) - - # The variable data for the carry call. Each iteration modifies these - # values and feeds the results to the next iteration. - k = 0 - gmres_variables = (k, V, R, beta_vec, err, # < The actual output we need. - givens) # < Modified between iterations. - gmres_constants = (tol, A_mv, A_args, b_norm, n_kry) - gmres_carry = (gmres_variables, gmres_constants) - # The 'x' input for the carry call. Each iteration will receive an ascending - # loop index (from the jnp.arange) along with the constant data - # in gmres_constants. - - def gmres_krylov_work(gmres_carry: GmresCarryType) -> GmresCarryType: - """ - Performs a single iteration of gmres_krylov. See that function for a more - detailed description. - - Args: - gmres_carry: The gmres_carry from gmres_krylov. - Returns: - gmres_carry: The updated gmres_carry. - """ - gmres_variables, gmres_constants = gmres_carry - k, V, R, beta_vec, err, givens = gmres_variables - tol, A_mv, A_args, b_norm, _ = gmres_constants - - V, H = kth_arnoldi_step(k, A_mv, A_args, V, R, tol, precision) - R_col, givens = apply_givens_rotation(H[:, k], givens, k) - R = R.at[:, k].set(R_col[:]) - - # Update the residual vector. - cs, sn = givens[:, k] * beta_vec[k] - beta_vec = beta_vec.at[k].set(cs) - beta_vec = beta_vec.at[k + 1].set(sn) - err = jnp.abs(sn) / b_norm - gmres_variables = (k + 1, V, R, beta_vec, err, givens) - return (gmres_variables, gmres_constants) - - def gmres_krylov_loop_condition(gmres_carry: GmresCarryType) -> bool: - """ - This function dictates whether the main GMRES while loop will proceed. - It is equivalent to: - if k < n_kry and err > tol: - return True - else: - return False - where k, n_kry, err, and tol are unpacked from gmres_carry. - - Args: - gmres_carry: The gmres_carry from gmres_krylov. - Returns: - (bool): Whether to continue iterating. - """ - gmres_constants, gmres_variables = gmres_carry - tol = gmres_constants[0] - k = gmres_variables[0] - err = gmres_variables[4] - n_kry = gmres_constants[4] - - def is_iterating(k, n_kry): - return k < n_kry - - def not_converged(args): - err, tol = args - return err >= tol - return jax.lax.cond(is_iterating(k, n_kry), # Predicate. - not_converged, # Called if True. - lambda x: False, # Called if False. - (err, tol)) # Arguments to calls. - - gmres_carry = jax.lax.while_loop(gmres_krylov_loop_condition, - gmres_krylov_work, - gmres_carry) - gmres_variables, gmres_constants = gmres_carry - k, V, R, beta_vec, err, givens = gmres_variables - return (k, V, R, beta_vec) - - VarType = Tuple[int, jax.Array, jax.Array, jax.Array, - float, jax.Array] - ConstType = Tuple[float, Callable, Sequence, jax.Array, int] - GmresCarryType = Tuple[VarType, ConstType] - - - @functools.partial(jax.jit, static_argnums=(6,)) - def kth_arnoldi_step( - k: int, A_mv: Callable, A_args: Sequence, V: jax.Array, - H: jax.Array, tol: float, - precision: JaxPrecisionType) -> Tuple[jax.Array, jax.Array]: - """ - Performs the kth iteration of the Arnoldi reduction procedure. - Args: - k: The current iteration. - A_mv, A_args: A function A_mv(v, *A_args) performing a linear - transformation on v. - V: A matrix of size (n, K + 1), K > k such that each column in - V[n, :k+1] stores a Krylov vector and V[:, k+1] is all zeroes. - H: A matrix of size (K, K), K > k with H[:, k] all zeroes. - Returns: - V, H: With their k'th columns respectively filled in by a new - orthogonalized Krylov vector and new overlaps. - """ - - def _gs_step( - r: jax.Array, - v_i: jax.Array) -> Tuple[jax.Array, jax.Array]: - """ - Performs one iteration of the stabilized Gram-Schmidt procedure, with - r to be orthonormalized against {v} = {v_0, v_1, ...}. - - Args: - r: The new vector which is not in the initially orthonormal set. - v_i: The i'th vector in that set. - Returns: - r_i: The updated r which is now orthonormal with v_i. - h_i: The overlap of r with v_i. - """ - h_i = jnp.vdot(v_i, r, precision=precision) - r_i = r - h_i * v_i - return r_i, h_i - - v = A_mv(V[:, k], *A_args) - v_new, H_k = jax.lax.scan(_gs_step, init=v, xs=V.T) - v_norm = jnp.linalg.norm(v_new) - r_new = v_new / v_norm - # Normalize v unless it is the zero vector. - r_new = jax.lax.cond(v_norm > tol, - lambda x: x[0] / x[1], - lambda x: 0.*x[0], - (v_new, v_norm) - ) - H = H.at[:,k].set(H_k) - H = H.at[k+1,k].set(v_norm) - V = V.at[:,k+1].set(r_new) - return V, H - -#################################################################### -# GIVENS ROTATIONS -#################################################################### - @jax.jit - def apply_rotations(H_col: jax.Array, givens: jax.Array, - k: int) -> jax.Array: - """ - Successively applies each of the rotations stored in givens to H_col. - - Args: - H_col : The vector to be rotated. - givens: 2 x K, K > k matrix of rotation factors. - k : Iteration number. - Returns: - H_col : The rotated vector. - """ - rotation_carry = (H_col, 0, k, givens) - - def loop_condition(carry): - i = carry[1] - k = carry[2] - return jax.lax.cond(i < k, lambda x: True, lambda x: False, 0) - - def apply_ith_rotation(carry): - H_col, i, k, givens = carry - cs = givens[0, i] - sn = givens[1, i] - H_i = cs * H_col[i] - sn * H_col[i + 1] - H_ip1 = sn * H_col[i] + cs * H_col[i + 1] - H_col = H_col.at[i].set(H_i) - H_col = H_col.at[i + 1].set(H_ip1) - return (H_col, i + 1, k, givens) - - rotation_carry = jax.lax.while_loop(loop_condition, - apply_ith_rotation, - rotation_carry) - H_col = rotation_carry[0] - return H_col - - @jax.jit - def apply_givens_rotation(H_col: jax.Array, givens: jax.Array, - k: int) -> Tuple[jax.Array, jax.Array]: - """ - Applies the Givens rotations stored in the vectors cs and sn to the vector - H_col. Then constructs a new Givens rotation that eliminates H_col's - k'th element, yielding the corresponding column of the R in H's QR - decomposition. Returns the new column of R along with the new Givens - factors. - - Args: - H_col : The column of H to be rotated. - givens: A matrix representing the cosine and sine factors of the - previous GMRES Givens rotations, in that order - (i.e. givens[0, :] -> the cos factor). - k : Iteration number. - Returns: - R_col : The column of R obtained by transforming H_col. - givens_k: The new elements of givens that zeroed out the k+1'th element - of H_col. - """ - # This call successively applies each of the - # Givens rotations stored in givens[:, :k] to H_col. - H_col = apply_rotations(H_col, givens, k) - - cs_k, sn_k = givens_rotation(H_col[k], H_col[k + 1]) - givens = givens.at[0, k].set(cs_k) - givens = givens.at[1, k].set(sn_k) - - r_k = cs_k * H_col[k] - sn_k * H_col[k + 1] - R_col = H_col.at[k].set(r_k) - R_col = R_col.at[k + 1].set(0.) - return R_col, givens - - @jax.jit - def givens_rotation(v1: float, v2: float) -> Tuple[float, float]: - """ - Given scalars v1 and v2, computes cs = cos(theta) and sn = sin(theta) - so that [cs -sn] @ [v1] = [r] - [sn cs] [v2] [0] - Args: - v1, v2: The scalars. - Returns: - cs, sn: The rotation factors. - """ - t = jnp.sqrt(v1**2 + v2**2) - cs = v1 / t - sn = -v2 / t - return cs, sn - - fnames = [ - "gmres_m", "gmres_residual", "gmres_krylov", - "kth_arnoldi_step", "givens_rotation" - ] - functions = [ - gmres_m, gmres_residual, gmres_krylov, kth_arnoldi_step, - givens_rotation - ] - - class Functions: - - def __init__(self, fun_dict): - self.dict = fun_dict - - def __getattr__(self, name): - return self.dict[name] - - return Functions(dict(zip(fnames, functions))) \ No newline at end of file diff --git a/jax_eigensolver/test.py b/jax_eigensolver/test.py deleted file mode 100644 index 5397320..0000000 --- a/jax_eigensolver/test.py +++ /dev/null @@ -1,10 +0,0 @@ -import jax - - -if __name__ == "__main__": - backend = JaxBackend() - m = 100 - A = jax.random.normal(jax.random.PRNGKey(42),(m,m)) - b = jax.random.normal(jax.random.PRNGKey(41),(m,)) - def mapA(x): return A@x - backend.eigs(mapA,initial_state = b) \ No newline at end of file diff --git a/jax_eigensolver/tests/test_jax_backend.py b/jax_eigensolver/tests/test_jax_backend.py deleted file mode 100644 index 537f473..0000000 --- a/jax_eigensolver/tests/test_jax_backend.py +++ /dev/null @@ -1,1243 +0,0 @@ -# Copyright 2019 The TensorNetwork Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf -import numpy as np -import scipy as sp -import jax -import jax.numpy as jnp -import pytest -from eigensolver import jax_backend -import jax.config as config -from eigensolver import jitted_functions -# pylint: disable=no-member -config.update("jax_enable_x64", True) -np_randn_dtypes = [np.float32, np.float16, np.float64] -np_dtypes = np_randn_dtypes + [np.complex64, np.complex128] -np_not_half = [np.float32, np.float64, np.complex64, np.complex128] - - -def test_tensordot(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(2 * np.ones((2, 3, 4))) - b = backend.convert_to_tensor(np.ones((2, 3, 4))) - actual = backend.tensordot(a, b, ((1, 2), (1, 2))) - expected = np.array([[24.0, 24.0], [24.0, 24.0]]) - np.testing.assert_allclose(expected, actual) - - -def test_tensordot_int(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(2 * np.ones((3, 3, 3))) - b = backend.convert_to_tensor(np.ones((3, 3, 3))) - actual = backend.tensordot(a, b, 1) - expected = jax.numpy.tensordot(a, b, 1) - np.testing.assert_allclose(expected, actual) - - -def test_reshape(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(np.ones((2, 3, 4))) - actual = backend.shape_tuple(backend.reshape(a, np.array((6, 4, 1)))) - assert actual == (6, 4, 1) - - -def test_transpose(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor( - np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]])) - actual = backend.transpose(a, [2, 0, 1]) - expected = np.array([[[1.0, 3.0], [5.0, 7.0]], [[2.0, 4.0], [6.0, 8.0]]]) - np.testing.assert_allclose(expected, actual) - - -def test_transpose_noperm(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor( - np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]])) - actual = backend.transpose(a) # [2, 1, 0] - actual = backend.transpose(actual, perm=[0, 2, 1]) - expected = np.array([[[1.0, 3.0], [5.0, 7.0]], [[2.0, 4.0], [6.0, 8.0]]]) - np.testing.assert_allclose(expected, actual) - - -def test_shape_concat(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(2 * np.ones((1, 3, 1))) - b = backend.convert_to_tensor(np.ones((1, 2, 1))) - expected = backend.shape_concat((a, b), axis=1) - actual = np.array([[[2.0], [2.0], [2.0], [1.0], [1.0]]]) - np.testing.assert_allclose(expected, actual) - - -def test_slice(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor( - np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])) - actual = backend.slice(a, (1, 1), (2, 2)) - expected = np.array([[5., 6.], [8., 9.]]) - np.testing.assert_allclose(expected, actual) - - -def test_slice_raises_error(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor( - np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])) - with pytest.raises(ValueError): - backend.slice(a, (1, 1), (2, 2, 2)) - - -def test_shape_tensor(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(np.ones([2, 3, 4])) - assert isinstance(backend.shape_tensor(a), tuple) - actual = backend.shape_tensor(a) - expected = np.array([2, 3, 4]) - np.testing.assert_allclose(expected, actual) - - -def test_shape_tuple(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(np.ones([2, 3, 4])) - actual = backend.shape_tuple(a) - assert actual == (2, 3, 4) - - -def test_shape_prod(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(2 * np.ones([1, 2, 3, 4])) - actual = np.array(backend.shape_prod(a)) - assert actual == 2**24 - - -def test_sqrt(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(np.array([4., 9.])) - actual = backend.sqrt(a) - expected = np.array([2, 3]) - np.testing.assert_allclose(expected, actual) - - -def test_convert_to_tensor(): - backend = jax_backend.JaxBackend() - array = np.ones((2, 3, 4)) - actual = backend.convert_to_tensor(array) - expected = jax.jit(lambda x: x)(array) - assert isinstance(actual, type(expected)) - np.testing.assert_allclose(expected, actual) - -def test_outer_product(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(2 * np.ones((2, 1))) - b = backend.convert_to_tensor(np.ones((1, 2, 2))) - actual = backend.outer_product(a, b) - expected = np.array([[[[[2.0, 2.0], [2.0, 2.0]]]], [[[[2.0, 2.0], [2.0, - 2.0]]]]]) - np.testing.assert_allclose(expected, actual) - - -def test_einsum(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(2 * np.ones((2, 1))) - b = backend.convert_to_tensor(np.ones((1, 2, 2))) - actual = backend.einsum('ij,jil->l', a, b) - expected = np.array([4.0, 4.0]) - np.testing.assert_allclose(expected, actual) - - -def test_convert_bad_test(): - backend = jax_backend.JaxBackend() - with pytest.raises(TypeError, match="Expected"): - backend.convert_to_tensor(tf.ones((2, 2))) - - -def test_norm(): - backend = jax_backend.JaxBackend() - a = backend.convert_to_tensor(np.ones((2, 2))) - assert backend.norm(a) == 2 - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_eye(dtype): - backend = jax_backend.JaxBackend() - a = backend.eye(N=4, M=5, dtype=dtype) - np.testing.assert_allclose(np.eye(N=4, M=5, dtype=dtype), a) - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_ones(dtype): - backend = jax_backend.JaxBackend() - a = backend.ones((4, 4), dtype=dtype) - np.testing.assert_allclose(np.ones((4, 4), dtype=dtype), a) - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_zeros(dtype): - backend = jax_backend.JaxBackend() - a = backend.zeros((4, 4), dtype=dtype) - np.testing.assert_allclose(np.zeros((4, 4), dtype=dtype), a) - - -@pytest.mark.parametrize("dtype", np_randn_dtypes) -def test_randn(dtype): - backend = jax_backend.JaxBackend() - a = backend.randn((4, 4), dtype=dtype) - assert a.shape == (4, 4) - - -@pytest.mark.parametrize("dtype", np_randn_dtypes) -def test_random_uniform(dtype): - backend = jax_backend.JaxBackend() - a = backend.random_uniform((4, 4), dtype=dtype) - assert a.shape == (4, 4) - - -@pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) -def test_randn_non_zero_imag(dtype): - backend = jax_backend.JaxBackend() - a = backend.randn((4, 4), dtype=dtype) - assert np.linalg.norm(np.imag(a)) != 0.0 - - -@pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) -def test_random_uniform_non_zero_imag(dtype): - backend = jax_backend.JaxBackend() - a = backend.random_uniform((4, 4), dtype=dtype) - assert np.linalg.norm(np.imag(a)) != 0.0 - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_eye_dtype(dtype): - backend = jax_backend.JaxBackend() - a = backend.eye(N=4, M=4, dtype=dtype) - assert a.dtype == dtype - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_ones_dtype(dtype): - backend = jax_backend.JaxBackend() - a = backend.ones((4, 4), dtype=dtype) - assert a.dtype == dtype - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_zeros_dtype(dtype): - backend = jax_backend.JaxBackend() - a = backend.zeros((4, 4), dtype=dtype) - assert a.dtype == dtype - - -@pytest.mark.parametrize("dtype", np_randn_dtypes) -def test_randn_dtype(dtype): - backend = jax_backend.JaxBackend() - a = backend.randn((4, 4), dtype=dtype) - assert a.dtype == dtype - - -@pytest.mark.parametrize("dtype", np_randn_dtypes) -def test_random_uniform_dtype(dtype): - backend = jax_backend.JaxBackend() - a = backend.random_uniform((4, 4), dtype=dtype) - assert a.dtype == dtype - - -@pytest.mark.parametrize("dtype", np_randn_dtypes) -def test_randn_seed(dtype): - backend = jax_backend.JaxBackend() - a = backend.randn((4, 4), seed=10, dtype=dtype) - b = backend.randn((4, 4), seed=10, dtype=dtype) - np.testing.assert_allclose(a, b) - - -@pytest.mark.parametrize("dtype", np_randn_dtypes) -def test_random_uniform_seed(dtype): - backend = jax_backend.JaxBackend() - a = backend.random_uniform((4, 4), seed=10, dtype=dtype) - b = backend.random_uniform((4, 4), seed=10, dtype=dtype) - np.testing.assert_allclose(a, b) - - -@pytest.mark.parametrize("dtype", np_randn_dtypes) -def test_random_uniform_boundaries(dtype): - lb = 1.2 - ub = 4.8 - backend = jax_backend.JaxBackend() - a = backend.random_uniform((4, 4), seed=10, dtype=dtype) - b = backend.random_uniform((4, 4), (lb, ub), seed=10, dtype=dtype) - assert ((a >= 0).all() and (a <= 1).all() and (b >= lb).all() and - (b <= ub).all()) - - -def test_random_uniform_behavior(): - seed = 10 - key = jax.random.PRNGKey(seed) - backend = jax_backend.JaxBackend() - a = backend.random_uniform((4, 4), seed=seed) - b = jax.random.uniform(key, (4, 4)) - np.testing.assert_allclose(a, b) - - -def test_conj(): - backend = jax_backend.JaxBackend() - real = np.random.rand(2, 2, 2) - imag = np.random.rand(2, 2, 2) - a = backend.convert_to_tensor(real + 1j * imag) - actual = backend.conj(a) - expected = real - 1j * imag - np.testing.assert_allclose(expected, actual) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -def test_eigsh_valid_init_operator_with_shape(dtype): - backend = jax_backend.JaxBackend() - D = 16 - np.random.seed(10) - init = backend.randn((D,), dtype=dtype, seed=10) - tmp = backend.randn((D, D), dtype=dtype, seed=10) - H = tmp + backend.transpose(backend.conj(tmp), (1, 0)) - - def mv(x, H): - return jax.numpy.dot(H, x) - - eta1, U1 = backend.eigsh_lanczos(mv, [H], init) - eta2, U2 = np.linalg.eigh(H) - v2 = U2[:, 0] - v2 = v2 / sum(v2) - v1 = np.reshape(U1[0], (D)) - v1 = v1 / sum(v1) - np.testing.assert_allclose(eta1[0], min(eta2)) - np.testing.assert_allclose(v1, v2) - - -def test_eigsh_small_number_krylov_vectors(): - backend = jax_backend.JaxBackend() - init = np.array([1, 1], dtype=np.float64) - H = np.array([[1, 2], [2, 4]], dtype=np.float64) - - def mv(x, H): - return jax.numpy.dot(H, x) - - eta1, _ = backend.eigsh_lanczos(mv, [H], init, numeig=1, num_krylov_vecs=2) - np.testing.assert_almost_equal(eta1, [0]) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -def test_eigsh_lanczos_1(dtype): - backend = jax_backend.JaxBackend() - D = 16 - np.random.seed(10) - init = backend.randn((D,), dtype=dtype, seed=10) - tmp = backend.randn((D, D), dtype=dtype, seed=10) - H = tmp + backend.transpose(backend.conj(tmp), (1, 0)) - - def mv(x, H): - return jax.numpy.dot(H, x) - - eta1, U1 = backend.eigsh_lanczos(mv, [H], init) - eta2, U2 = np.linalg.eigh(H) - v2 = U2[:, 0] - v2 = v2 / sum(v2) - v1 = np.reshape(U1[0], (D)) - v1 = v1 / sum(v1) - np.testing.assert_allclose(eta1[0], min(eta2)) - np.testing.assert_allclose(v1, v2) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -def test_eigsh_lanczos_2(dtype): - backend = jax_backend.JaxBackend() - D = 16 - np.random.seed(10) - tmp = backend.randn((D, D), dtype=dtype, seed=10) - H = tmp + backend.transpose(backend.conj(tmp), (1, 0)) - - def mv(x, H): - return jax.numpy.dot(H, x) - - eta1, U1 = backend.eigsh_lanczos(mv, [H], shape=(D,), dtype=dtype) - eta2, U2 = np.linalg.eigh(H) - v2 = U2[:, 0] - v2 = v2 / sum(v2) - v1 = np.reshape(U1[0], (D)) - v1 = v1 / sum(v1) - np.testing.assert_allclose(eta1[0], min(eta2)) - np.testing.assert_allclose(v1, v2) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -@pytest.mark.parametrize("numeig", [1, 2, 3, 4]) -def test_eigsh_lanczos_reorthogonalize(dtype, numeig): - backend = jax_backend.JaxBackend() - D = 24 - np.random.seed(10) - tmp = backend.randn((D, D), dtype=dtype, seed=10) - H = tmp + backend.transpose(backend.conj(tmp), (1, 0)) - - def mv(x, H): - return jax.numpy.dot(H, x) - - eta1, U1 = backend.eigsh_lanczos( - mv, [H], - shape=(D,), - dtype=dtype, - numeig=numeig, - num_krylov_vecs=D, - reorthogonalize=True, - ndiag=1, - tol=1E-12, - delta=1E-12) - eta2, U2 = np.linalg.eigh(H) - - np.testing.assert_allclose(eta1[0:numeig], eta2[0:numeig]) - for n in range(numeig): - v2 = U2[:, n] - v2 /= np.sum(v2) #fix phases - v1 = np.reshape(U1[n], (D)) - v1 /= np.sum(v1) - - np.testing.assert_allclose(v1, v2, rtol=1E-5, atol=1E-5) - - -def test_eigsh_lanczos_raises(): - backend = jax_backend.JaxBackend() - with pytest.raises( - ValueError, match='`num_krylov_vecs` >= `numeig` required!'): - backend.eigsh_lanczos(lambda x: x, numeig=10, num_krylov_vecs=9) - with pytest.raises( - ValueError, - match="Got numeig = 2 > 1 and `reorthogonalize = False`. " - "Use `reorthogonalize=True` for `numeig > 1`"): - backend.eigsh_lanczos(lambda x: x, numeig=2, reorthogonalize=False) - with pytest.raises( - ValueError, - match="if no `initial_state` is passed, then `shape` and" - "`dtype` have to be provided"): - backend.eigsh_lanczos(lambda x: x, shape=(10,), dtype=None) - with pytest.raises( - ValueError, - match="if no `initial_state` is passed, then `shape` and" - "`dtype` have to be provided"): - backend.eigsh_lanczos(lambda x: x, shape=None, dtype=np.float64) - with pytest.raises( - ValueError, - match="if no `initial_state` is passed, then `shape` and" - "`dtype` have to be provided"): - backend.eigsh_lanczos(lambda x: x) - with pytest.raises( - TypeError, match="Expected a `jax.array`. Got "): - backend.eigsh_lanczos(lambda x: x, initial_state=[1, 2, 3]) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -def test_broadcast_right_multiplication(dtype): - backend = jax_backend.JaxBackend() - tensor1 = backend.randn((2, 3), dtype=dtype, seed=10) - tensor2 = backend.randn((3,), dtype=dtype, seed=10) - out = backend.broadcast_right_multiplication(tensor1, tensor2) - np.testing.assert_allclose(out, np.array(tensor1) * np.array(tensor2)) - - -def test_broadcast_right_multiplication_raises(): - backend = jax_backend.JaxBackend() - tensor1 = backend.randn((2, 3)) - tensor2 = backend.randn((3, 3)) - with pytest.raises(ValueError): - backend.broadcast_right_multiplication(tensor1, tensor2) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -def test_broadcast_left_multiplication(dtype): - backend = jax_backend.JaxBackend() - tensor1 = backend.randn((3,), dtype=dtype, seed=10) - tensor2 = backend.randn((3, 4, 2), dtype=dtype, seed=10) - out = backend.broadcast_left_multiplication(tensor1, tensor2) - np.testing.assert_allclose(out, np.reshape(tensor1, (3, 1, 1)) * tensor2) - - -def test_broadcast_left_multiplication_raises(): - dtype = np.float64 - backend = jax_backend.JaxBackend() - tensor1 = backend.randn((3, 3), dtype=dtype, seed=10) - tensor2 = backend.randn((2, 4, 3), dtype=dtype, seed=10) - with pytest.raises(ValueError): - backend.broadcast_left_multiplication(tensor1, tensor2) - - -def test_sparse_shape(): - dtype = np.float64 - backend = jax_backend.JaxBackend() - tensor = backend.randn((2, 3, 4), dtype=dtype, seed=10) - np.testing.assert_allclose(backend.sparse_shape(tensor), tensor.shape) - - -@pytest.mark.parametrize("dtype,method", [(np.float64, "sin"), - (np.complex128, "sin"), - (np.float64, "cos"), - (np.complex128, "cos"), - (np.float64, "exp"), - (np.complex128, "exp"), - (np.float64, "log"), - (np.complex128, "log")]) -def test_elementwise_ops(dtype, method): - backend = jax_backend.JaxBackend() - tensor = backend.randn((4, 3, 2), dtype=dtype, seed=10) - if method == "log": - tensor = np.abs(tensor) - tensor1 = getattr(backend, method)(tensor) - tensor2 = getattr(np, method)(tensor) - np.testing.assert_almost_equal(tensor1, tensor2) - - -@pytest.mark.parametrize("dtype,method", [(np.float64, "expm"), - (np.complex128, "expm")]) -def test_matrix_ops(dtype, method): - backend = jax_backend.JaxBackend() - matrix = backend.randn((4, 4), dtype=dtype, seed=10) - if method == "expm": - matrix1 = backend.expm(matrix) - matrix2 = jax.scipy.linalg.expm(matrix) - np.testing.assert_almost_equal(np.array(matrix1), np.array(matrix2)) - - -@pytest.mark.parametrize("dtype,method", [(np.float64, "expm"), - (np.complex128, "expm")]) -def test_matrix_ops_raises(dtype, method): - backend = jax_backend.JaxBackend() - matrix = backend.randn((4, 4, 4), dtype=dtype, seed=10) - with pytest.raises(ValueError, match=r".*Only matrices.*"): - getattr(backend, method)(matrix) - matrix = backend.randn((4, 3), dtype=dtype, seed=10) - with pytest.raises(ValueError, match=r".*N\*N matrix.*"): - getattr(backend, method)(matrix) - - -def test_jit(): - backend = jax_backend.JaxBackend() - - def fun(x, A, y): - return jax.numpy.dot(x, jax.numpy.dot(A, y)) - - fun_jit = backend.jit(fun) - x = jax.numpy.array(np.random.rand(4)) - y = jax.numpy.array(np.random.rand(4)) - A = jax.numpy.array(np.random.rand(4, 4)) - res1 = fun(x, A, y) - res2 = fun_jit(x, A, y) - np.testing.assert_allclose(res1, res2) - - -def test_jit_args(): - backend = jax_backend.JaxBackend() - - def fun(x, A, y): - return jax.numpy.dot(x, jax.numpy.dot(A, y)) - - fun_jit = backend.jit(fun) - x = jax.numpy.array(np.random.rand(4)) - y = jax.numpy.array(np.random.rand(4)) - A = jax.numpy.array(np.random.rand(4, 4)) - - res1 = fun(x, A, y) - res2 = fun_jit(x, A, y) - res3 = fun_jit(x, y=y, A=A) - np.testing.assert_allclose(res1, res2) - np.testing.assert_allclose(res1, res3) - - -def compare_eigvals_and_eigvecs(U, - eta, - U_exact, - eta_exact, - rtol, - atol, - thresh=1E-8): - _, iy = np.nonzero(np.abs(eta[:, None] - eta_exact[None, :]) < thresh) - U_exact_perm = U_exact[:, iy] - U_exact_perm = U_exact_perm / np.expand_dims(np.sum(U_exact_perm, axis=0), 0) - U = U / np.expand_dims(np.sum(U, axis=0), 0) - np.testing.assert_allclose(U_exact_perm, U, atol=atol, rtol=rtol) - np.testing.assert_allclose(eta, eta_exact[iy], atol=atol, rtol=rtol) - - -############################################################## -# eigs and eigsh tests # -############################################################## -def generate_hermitian_matrix(be, dtype, D): - H = be.randn((D, D), dtype=dtype, seed=10) - H += H.T.conj() - return H - - -def generate_matrix(be, dtype, D): - return be.randn((D, D), dtype=dtype, seed=10) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -@pytest.mark.parametrize( - "solver, matrix_generator, exact_decomp, which", - [(jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LM"), - (jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LR"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "SA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LM")]) -def test_eigs_eigsh_all_eigvals_with_init(dtype, solver, matrix_generator, - exact_decomp, which): - backend = jax_backend.JaxBackend() - D = 16 - np.random.seed(10) - init = backend.randn((D,), dtype=dtype, seed=10) - H = matrix_generator(backend, dtype, D) - - def mv(x, H): - return jax.numpy.dot(H, x) - - eta, U = solver(mv, [H], init, numeig=D, num_krylov_vecs=D, which=which) - eta_exact, U_exact = exact_decomp(H) - - rtol = 1E-8 - atol = 1E-8 - compare_eigvals_and_eigvecs( - np.stack(U, axis=1), eta, U_exact, eta_exact, rtol, atol, thresh=1E-4) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -@pytest.mark.parametrize( - "solver, matrix_generator, exact_decomp, which", - [(jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LM"), - (jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LR"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "SA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LM")]) -def test_eigs_eigsh_all_eigvals_no_init(dtype, solver, matrix_generator, - exact_decomp, which): - backend = jax_backend.JaxBackend() - D = 16 - np.random.seed(10) - H = matrix_generator(backend, dtype, D) - - def mv(x, H): - return jax.numpy.dot(H, x) - - eta, U = solver( - mv, [H], - shape=(D,), - dtype=dtype, - numeig=D, - num_krylov_vecs=D, - which=which) - eta_exact, U_exact = exact_decomp(H) - rtol = 1E-8 - atol = 1E-8 - compare_eigvals_and_eigvecs( - np.stack(U, axis=1), eta, U_exact, eta_exact, rtol, atol, thresh=1E-4) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -@pytest.mark.parametrize( - "solver, matrix_generator, exact_decomp, which", - [(jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LM"), - (jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LR"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "SA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LM")]) -def test_eigs_eigsh_few_eigvals_with_init(dtype, solver, matrix_generator, - exact_decomp, which): - backend = jax_backend.JaxBackend() - D = 16 - np.random.seed(10) - init = backend.randn((D,), dtype=dtype, seed=10) - H = matrix_generator(backend, dtype, D) - - def mv(x, H): - return jax.numpy.dot(H, x) - - eta, U = solver( - mv, [H], init, numeig=4, num_krylov_vecs=16, maxiter=50, which=which) - eta_exact, U_exact = exact_decomp(H) - rtol = 1E-8 - atol = 1E-8 - compare_eigvals_and_eigvecs( - np.stack(U, axis=1), eta, U_exact, eta_exact, rtol, atol, thresh=1E-4) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -@pytest.mark.parametrize( - "solver, matrix_generator, exact_decomp, which", - [(jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LM"), - (jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LR"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "SA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LM")]) -def test_eigs_eigsh_few_eigvals_no_init(dtype, solver, matrix_generator, - exact_decomp, which): - backend = jax_backend.JaxBackend() - D = 16 - np.random.seed(10) - H = matrix_generator(backend, dtype, D) - - def mv(x, H): - return jax.numpy.dot(H, x) - - eta, U = solver( - mv, [H], - shape=(D,), - dtype=dtype, - numeig=4, - num_krylov_vecs=16, - which=which) - eta_exact, U_exact = exact_decomp(H) - rtol = 1E-8 - atol = 1E-8 - compare_eigvals_and_eigvecs( - np.stack(U, axis=1), eta, U_exact, eta_exact, rtol, atol, thresh=1E-4) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -@pytest.mark.parametrize( - "solver, matrix_generator, exact_decomp, which", - [(jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LM"), - (jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LR"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "SA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LM")]) -def test_eigs_eigsh_large_ncv_with_init(dtype, solver, matrix_generator, - exact_decomp, which): - backend = jax_backend.JaxBackend() - D = 100 - np.random.seed(10) - init = backend.randn((D,), dtype=dtype, seed=10) - H = matrix_generator(backend, dtype, D) - - def mv(x, H): - return jax.numpy.dot(H, x) - - eta, U = solver( - mv, [H], init, numeig=4, num_krylov_vecs=50, maxiter=50, which=which) - eta_exact, U_exact = exact_decomp(H) - rtol = 1E-8 - atol = 1E-8 - compare_eigvals_and_eigvecs( - np.stack(U, axis=1), eta, U_exact, eta_exact, rtol, atol, thresh=1E-4) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -@pytest.mark.parametrize( - "solver, matrix_generator, exact_decomp, which", - [(jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LM"), - (jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LR"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "SA"), - (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, - "LM")]) - -def test_eigs_eigsh_large_matrix_with_init(dtype, solver, matrix_generator, - exact_decomp, which): - backend = jax_backend.JaxBackend() - D = 1000 - np.random.seed(10) - init = backend.randn((D,), dtype=dtype, seed=10) - H = matrix_generator(backend, dtype, D) - - def mv(x, H): - return jax.numpy.dot(H, x, precision=jax.lax.Precision.HIGHEST) - - eta, U = solver( - mv, [H], - init, - numeig=4, - num_krylov_vecs=40, - maxiter=500, - which=which, - tol=1E-10) - eta_exact, U_exact = exact_decomp(H) - - thresh = { - np.complex64: 1E-3, - np.float32: 1E-3, - np.float64: 1E-4, - np.complex128: 1E-4 - } - rtol = 1E-8 - atol = 1E-8 - compare_eigvals_and_eigvecs( - np.stack(U, axis=1), - eta, - U_exact, - eta_exact, - rtol, - atol, - thresh=thresh[dtype]) - - -def get_ham_params(dtype, N, which): - if which == 'uniform': - hop = -jnp.ones(N - 1, dtype=dtype) - pot = jnp.ones(N, dtype=dtype) - if dtype in (np.complex128, np.complex64): - hop -= 1j * jnp.ones(N - 1, dtype) - elif which == 'rand': - hop = (-1) * jnp.array(np.random.rand(N - 1).astype(dtype) - 0.5) - pot = jnp.array(np.random.rand(N).astype(dtype)) - 0.5 - if dtype in (np.complex128, np.complex64): - hop -= 1j * jnp.array(np.random.rand(N - 1).astype(dtype) - 0.5) - return pot, hop - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -@pytest.mark.parametrize("param_type", ['uniform', 'rand']) -@pytest.mark.parametrize("N", [14]) -def test_eigsh_free_fermions(N, dtype, param_type): - """ - Find the lowest eigenvalues and eigenvectors - of a 1d free-fermion Hamiltonian on N sites. - The dimension of the hermitian matrix is - (2**N, 2**N). - """ - backend = jax_backend.JaxBackend(precision=jax.lax.Precision.HIGHEST) - np.random.seed(10) - pot, hop = get_ham_params(dtype, N, param_type) - P = jnp.diag(np.array([0, -1])).astype(dtype) - c = jnp.array([[0, 1], [0, 0]], dtype) - n = c.T @ c - eye = jnp.eye(2, dtype=dtype) - neye = jnp.kron(n, eye) - eyen = jnp.kron(eye, n) - ccT = jnp.kron(c @ P, c.T) - cTc = jnp.kron(c.T, c) - - @jax.jit - def matvec(vec): - x = vec.reshape((4, 2**(N - 2))) - out = jnp.zeros(x.shape, x.dtype) - t1 = neye * pot[0] + eyen * pot[1] / 2 - t2 = cTc * hop[0] - ccT * jnp.conj(hop[0]) - out += jnp.einsum('ij,ki -> kj', x, t1 + t2) - x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape((4, 2**(N - 2))) - out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape( - (4, 2**(N - 2))) - for site in range(1, N - 2): - t1 = neye * pot[site] / 2 + eyen * pot[site + 1] / 2 - t2 = cTc * hop[site] - ccT * jnp.conj(hop[site]) - out += jnp.einsum('ij,ki -> kj', x, t1 + t2) - x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape((4, 2**(N - 2))) - out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape( - (4, 2**(N - 2))) - t1 = neye * pot[N - 2] / 2 + eyen * pot[N - 1] - t2 = cTc * hop[N - 2] - ccT * jnp.conj(hop[N - 2]) - out += jnp.einsum('ij,ki -> kj', x, t1 + t2) - x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape((4, 2**(N - 2))) - out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape( - (4, 2**(N - 2))) - - x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(2**N) - out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(2**N) - return out.ravel() - - H = np.diag(pot) + np.diag(hop.conj(), 1) + np.diag(hop, -1) - single_particle_energies = np.linalg.eigh(H)[0] - - many_body_energies = [] - for n in range(2**N): - many_body_energies.append( - np.sum(single_particle_energies[np.nonzero( - np.array(list(bin(n)[2:]), dtype=int)[::-1])[0]])) - many_body_energies = np.sort(many_body_energies) - - init = jnp.array(np.random.randn(2**N)).astype(dtype) - init /= jnp.linalg.norm(init) - - ncv = 20 - numeig = 3 - which = 'SA' - tol = 1E-10 - maxiter = 30 - atol = 1E-8 - eta, _ = backend.eigsh( - A=matvec, - args=[], - initial_state=init, - num_krylov_vecs=ncv, - numeig=numeig, - which=which, - tol=tol, - maxiter=maxiter) - np.testing.assert_allclose( - eta, many_body_energies[:numeig], atol=atol, rtol=atol) - - -@pytest.mark.parametrize( - "solver, whichs", - [(jax_backend.JaxBackend().eigs, ["SM", "SR", "LI", "SI"]), - (jax_backend.JaxBackend().eigsh, ["SM", "BE"])]) -def test_eigs_eigsh_raises(solver, whichs): - with pytest.raises( - ValueError, match='`num_krylov_vecs` >= `numeig` required!'): - solver(lambda x: x, numeig=10, num_krylov_vecs=9) - - with pytest.raises( - ValueError, - match="if no `initial_state` is passed, then `shape` and" - "`dtype` have to be provided"): - solver(lambda x: x, shape=(10,), dtype=None) - with pytest.raises( - ValueError, - match="if no `initial_state` is passed, then `shape` and" - "`dtype` have to be provided"): - solver(lambda x: x, shape=None, dtype=np.float64) - with pytest.raises( - ValueError, - match="if no `initial_state` is passed, then `shape` and" - "`dtype` have to be provided"): - solver(lambda x: x) - with pytest.raises( - TypeError, match="Expected a `jax.array`. Got "): - solver(lambda x: x, initial_state=[1, 2, 3]) - for which in whichs: - with pytest.raises( - ValueError, match=f"which = {which}" - f" is currently not supported."): - solver(lambda x: x, which=which) - - -def test_eigs_dtype_raises(): - solver = jax_backend.JaxBackend().eigs - with pytest.raises(TypeError, match="dtype"): - solver(lambda x: x, shape=(10,), dtype=np.int32, - num_krylov_vecs=10) - -################################################################## -############# This test should just not crash ################ -################################################################## -@pytest.mark.parametrize("dtype", - [np.float64, np.complex128, np.float32, np.complex64]) -def test_eigs_bugfix(dtype): - backend = jax_backend.JaxBackend() - D = 200 - mat = jax.numpy.array(np.random.rand(D, D).astype(dtype)) - x = jax.numpy.array(np.random.rand(D).astype(dtype)) - - def matvec_jax(vector, matrix): - return matrix @ vector - - backend.eigs( - matvec_jax, [mat], - numeig=1, - initial_state=x, - which='LR', - maxiter=10, - num_krylov_vecs=100, - tol=0.0001) - -def test_sum(): - np.random.seed(10) - backend = jax_backend.JaxBackend() - tensor = np.random.rand(2, 3, 4) - a = backend.convert_to_tensor(tensor) - actual = backend.sum(a, axis=(1, 2)) - expected = np.sum(tensor, axis=(1, 2)) - np.testing.assert_allclose(expected, actual) - - actual = backend.sum(a, axis=(1, 2), keepdims=True) - expected = np.sum(a, axis=(1, 2), keepdims=True) - np.testing.assert_allclose(expected, actual) - - -def test_matmul(): - np.random.seed(10) - backend = jax_backend.JaxBackend() - t1 = np.random.rand(10, 2, 3) - t2 = np.random.rand(10, 3, 4) - a = backend.convert_to_tensor(t1) - b = backend.convert_to_tensor(t2) - actual = backend.matmul(a, b) - expected = np.matmul(t1, t2) - np.testing.assert_allclose(expected, actual) - t3 = np.random.rand(10) - t4 = np.random.rand(11) - c = backend.convert_to_tensor(t3) - d = backend.convert_to_tensor(t4) - with pytest.raises(ValueError, match="inputs to"): - backend.matmul(c, d) - -def test_gmres_raises(): - backend = jax_backend.JaxBackend() - dummy_mv = lambda x: x - N = 10 - - b = jax.numpy.zeros((N,)) - x0 = jax.numpy.zeros((N+1),) - diff = "If x0 is supplied, its shape" - with pytest.raises(ValueError, match=diff): # x0, b have different sizes - backend.gmres(dummy_mv, b, x0=x0) - - x0 = jax.numpy.zeros((N,), dtype=jax.numpy.float32) - b = jax.numpy.zeros((N,), dtype=jax.numpy.float64) - diff = (f"If x0 is supplied, its dtype, {x0.dtype}, must match b's" - f", {b.dtype}.") - with pytest.raises(TypeError, match=diff): # x0, b have different dtypes - backend.gmres(dummy_mv, b, x0=x0) - - x0 = jax.numpy.zeros((N,)) - b = jax.numpy.zeros((N,)).reshape(2, N//2) - diff = "If x0 is supplied, its shape" - with pytest.raises(ValueError, match=diff): # x0, b have different shapes - backend.gmres(dummy_mv, b, x0=x0) - - num_krylov_vectors = 0 - diff = (f"num_krylov_vectors must be positive, not" - f"{num_krylov_vectors}.") - with pytest.raises(ValueError, match=diff): # num_krylov_vectors <= 0 - backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors) - - tol = -1. - diff = (f"tol = {tol} must be positive.") - with pytest.raises(ValueError, match=diff): # tol < 0 - backend.gmres(dummy_mv, b, tol=tol) - - atol = -1 - diff = (f"atol = {atol} must be positive.") - with pytest.raises(ValueError, match=diff): # atol < 0 - backend.gmres(dummy_mv, b, atol=atol) - - M = lambda x: x - diff = "M is not supported by the Jax backend." - with pytest.raises(NotImplementedError, match=diff): - backend.gmres(dummy_mv, b, M=M) - - A_kwargs = {"bee": "honey"} - diff = "A_kwargs is not supported by the Jax backend." - with pytest.raises(NotImplementedError, match=diff): - backend.gmres(dummy_mv, b, A_kwargs=A_kwargs) - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_gmres_on_small_known_problem(dtype): - dummy = jax.numpy.zeros(1, dtype=dtype) - dtype = dummy.dtype - - backend = jax_backend.JaxBackend() - A = jax.numpy.array(([[1, 1], [3, -4]]), dtype=dtype) - b = jax.numpy.array([3, 2], dtype=dtype) - x0 = jax.numpy.ones(2, dtype=dtype) - n_kry = 2 - - def A_mv(x): - return A @ x - tol = 100*jax.numpy.finfo(dtype).eps - x, _ = backend.gmres(A_mv, b, x0=x0, num_krylov_vectors=n_kry, tol=tol) - solution = jax.numpy.array([2., 1.], dtype=dtype) - eps = jax.numpy.linalg.norm(jax.numpy.abs(solution) - jax.numpy.abs(x)) - assert eps < tol - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_gmres_with_args(dtype): - dummy = jax.numpy.zeros(1, dtype=dtype) - dtype = dummy.dtype - - backend = jax_backend.JaxBackend() - A = jax.numpy.zeros((2, 2), dtype=dtype) - B = jax.numpy.array(([[0, 1], [3, 0]]), dtype=dtype) - C = jax.numpy.array(([[1, 0], [0, -4]]), dtype=dtype) - b = jax.numpy.array([3, 2], dtype=dtype) - x0 = jax.numpy.ones(2, dtype=dtype) - n_kry = 2 - - def A_mv(x, B, C): - return (A + B + C) @ x - tol = 100*jax.numpy.finfo(dtype).eps - x, _ = backend.gmres(A_mv, b, A_args=[B, C], x0=x0, num_krylov_vectors=n_kry, - tol=tol) - solution = jax.numpy.array([2., 1.], dtype=dtype) - eps = jax.numpy.linalg.norm(jax.numpy.abs(solution) - jax.numpy.abs(x)) - assert eps < tol - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_gmres_on_larger_random_problem(dtype): - dummy = jax.numpy.zeros(1, dtype=dtype) - dtype = dummy.dtype - backend = jax_backend.JaxBackend() - matshape = (100, 100) - vecshape = (100,) - A = backend.randn(matshape, seed=10, dtype=dtype) - solution = backend.randn(vecshape, seed=10, dtype=dtype) - def A_mv(x): - return A @ x - b = A_mv(solution) - tol = b.size * jax.numpy.finfo(dtype).eps - x, _ = backend.gmres(A_mv, b, tol=tol, num_krylov_vectors=100) - err = jax.numpy.linalg.norm(jax.numpy.abs(x)-jax.numpy.abs(solution)) - rtol = tol*jax.numpy.linalg.norm(b) - atol = tol - assert err < max(rtol, atol) - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_gmres_not_matrix(dtype): - dummy = jax.numpy.zeros(1, dtype=dtype) - dtype = dummy.dtype - backend = jax_backend.JaxBackend() - matshape = (100, 100) - vecshape = (100,) - A = backend.randn(matshape, dtype=dtype, seed=10) - A = backend.reshape(A, (2, 50, 2, 50)) - solution = backend.randn(vecshape, dtype=dtype, seed=10) - solution = backend.reshape(solution, (2, 50)) - def A_mv(x): - return backend.einsum('ijkl,kl', A, x) - b = A_mv(solution) - tol = b.size * np.finfo(dtype).eps - x, _ = backend.gmres(A_mv, b, tol=tol, num_krylov_vectors=100) - err = jax.numpy.linalg.norm(jax.numpy.abs(x)-jax.numpy.abs(solution)) - rtol = tol*jax.numpy.linalg.norm(b) - atol = tol - assert err < max(rtol, atol) - - -@pytest.mark.parametrize("dtype", np_dtypes) -@pytest.mark.parametrize("offset", range(-2, 2)) -@pytest.mark.parametrize("axis1", range(0, 3)) -@pytest.mark.parametrize("axis2", range(0, 3)) -def test_diagonal(dtype, offset, axis1, axis2): - shape = (5, 5, 5, 5) - backend = jax_backend.JaxBackend() - array = backend.randn(shape, dtype=dtype, seed=10) - if axis1 == axis2: - with pytest.raises(ValueError): - actual = backend.diagonal(array, offset=offset, axis1=axis1, axis2=axis2) - else: - actual = backend.diagonal(array, offset=offset, axis1=axis1, axis2=axis2) - expected = jax.numpy.diagonal(array, offset=offset, axis1=axis1, - axis2=axis2) - np.testing.assert_allclose(actual, expected) - - -@pytest.mark.parametrize("dtype", np_dtypes) -@pytest.mark.parametrize("offset", range(-2, 2)) -def test_diagflat(dtype, offset): - shape = (5, 5, 5, 5) - backend = jax_backend.JaxBackend() - array = backend.randn(shape, dtype=dtype, seed=10) - actual = backend.diagflat(array, k=offset) - expected = jax.numpy.diag(jax.numpy.ravel(array), k=offset) - np.testing.assert_allclose(actual, expected) - - -@pytest.mark.parametrize("dtype", np_dtypes) -@pytest.mark.parametrize("offset", range(-2, 2)) -@pytest.mark.parametrize("axis1", range(0, 3)) -@pytest.mark.parametrize("axis2", range(0, 3)) -def test_trace(dtype, offset, axis1, axis2): - shape = (5, 5, 5, 5) - backend = jax_backend.JaxBackend() - array = backend.randn(shape, dtype=dtype, seed=10) - if axis1 == axis2: - with pytest.raises(ValueError): - actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2) - else: - actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2) - expected = jax.numpy.trace(array, offset=offset, axis1=axis1, axis2=axis2) - np.testing.assert_allclose(actual, expected) - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_abs(dtype): - shape = (4, 3, 2) - backend = jax_backend.JaxBackend() - tensor = backend.randn(shape, dtype=dtype, seed=10) - actual = backend.abs(tensor) - expected = jax.numpy.abs(tensor) - np.testing.assert_allclose(expected, actual) - - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_sign(dtype): - shape = (4, 3, 2) - backend = jax_backend.JaxBackend() - tensor = backend.randn(shape, dtype=dtype, seed=10) - actual = backend.sign(tensor) - expected = jax.numpy.sign(tensor) - np.testing.assert_allclose(expected, actual) - - -@pytest.mark.parametrize("pivot_axis", [-1, 1, 2]) -@pytest.mark.parametrize("dtype", np_dtypes) -def test_pivot(dtype, pivot_axis): - shape = (4, 3, 2, 8) - pivot_shape = (np.prod(shape[:pivot_axis]), np.prod(shape[pivot_axis:])) - backend = jax_backend.JaxBackend() - tensor = backend.randn(shape, dtype=dtype, seed=10) - expected = tensor.reshape(pivot_shape) - actual = backend.pivot(tensor, pivot_axis=pivot_axis) - np.testing.assert_allclose(expected, actual) - - -@pytest.mark.parametrize("dtype, atol", [(np.float32, 1E-6), - (np.float64, 1E-10), - (np.complex64, 1E-6), - (np.complex128, 1E-10)]) -def test_inv(dtype, atol): - shape = (10, 10) - backend = jax_backend.JaxBackend() - matrix = backend.randn(shape, dtype=dtype, seed=10) - inv = backend.inv(matrix) - np.testing.assert_allclose(inv @ matrix, np.eye(10), atol=atol) - np.testing.assert_allclose(matrix @ inv, np.eye(10), atol=atol) - tensor = backend.randn((10, 10, 10), dtype=dtype, seed=10) - with pytest.raises(ValueError, match="input to"): - backend.inv(tensor) - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_item(dtype): - backend = jax_backend.JaxBackend() - tensor = backend.randn((1,), dtype=dtype, seed=10) - assert backend.item(tensor) == tensor.item() - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_power(dtype): - shape = (4, 3, 2) - backend = jax_backend.JaxBackend() - base_tensor = backend.randn(shape, dtype=dtype, seed=10) - power_tensor = backend.randn(shape, dtype=dtype, seed=10) - actual = backend.power(base_tensor, power_tensor) - expected = jax.numpy.power(base_tensor, power_tensor) - np.testing.assert_allclose(expected, actual) - - power = np.random.rand(1)[0] - actual = backend.power(base_tensor, power) - expected = jax.numpy.power(base_tensor, power) - np.testing.assert_allclose(expected, actual) - -@pytest.mark.parametrize("dtype", np_dtypes) -def test_eps(dtype): - backend = jax_backend.JaxBackend() - assert backend.eps(dtype) == np.finfo(dtype).eps \ No newline at end of file diff --git a/jax_eigensolver/tests/test_jitted_functions.py b/jax_eigensolver/tests/test_jitted_functions.py deleted file mode 100644 index 0e78225..0000000 --- a/jax_eigensolver/tests/test_jitted_functions.py +++ /dev/null @@ -1,273 +0,0 @@ -# Copyright 2019 The TensorNetwork Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import jax -import pytest -from eigensolver import jitted_functions -jax.config.update('jax_enable_x64', True) - -jax_dtypes = [np.float32, np.float64, np.complex64, np.complex128] -precision = jax.lax.Precision.HIGHEST - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -@pytest.mark.parametrize("ncv", [10, 20, 30]) -def test_arnoldi_factorization(dtype, ncv): - np.random.seed(10) - D = 20 - mat = np.random.rand(D, D).astype(dtype) - x = np.random.rand(D).astype(dtype) - - @jax.tree_util.Partial - @jax.jit - def matvec(vector, matrix): - return matrix @ vector - - arnoldi = jitted_functions._generate_arnoldi_factorization(jax) - Vm = jax.numpy.zeros((ncv, D), dtype=dtype) - H = jax.numpy.zeros((ncv, ncv), dtype=dtype) - start = 0 - tol = 1E-5 - Vm, Hm, residual, norm, _, _ = arnoldi(matvec, [mat], x, Vm, H, start, ncv, - tol, precision) - fm = residual * norm - em = np.zeros((1, Vm.shape[0])) - em[0, -1] = 1 - #test arnoldi relation - np.testing.assert_almost_equal(mat @ Vm.T - Vm.T @ Hm - fm[:, None] * em, - np.zeros((D, ncv)).astype(dtype)) - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -def test_LR_sort(dtype): - np.random.seed(10) - x = np.random.rand(20).astype(dtype) - p = 10 - LR_sort = jitted_functions._LR_sort(jax) - actual_x, actual_inds = LR_sort(p, jax.numpy.array(np.real(x))) - exp_inds = np.argsort(x)[::-1] - exp_x = x[exp_inds][-p:] - np.testing.assert_allclose(exp_x, actual_x) - np.testing.assert_allclose(exp_inds, actual_inds) - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -def test_SA_sort(dtype): - np.random.seed(10) - x = np.random.rand(20).astype(dtype) - p = 10 - SA_sort = jitted_functions._SA_sort(jax) - actual_x, actual_inds = SA_sort(p, jax.numpy.array(np.real(x))) - exp_inds = np.argsort(x) - exp_x = x[exp_inds][-p:] - np.testing.assert_allclose(exp_x, actual_x) - np.testing.assert_allclose(exp_inds, actual_inds) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -def test_shifted_QR(dtype): - np.random.seed(10) - D = 20 - ncv = 10 - numeig = 4 - mat = np.random.rand(D, D).astype(dtype) - Ham = mat + mat.T.conj() - x = np.random.rand(D).astype(dtype) - - @jax.tree_util.Partial - @jax.jit - def matvec(vector, matrix): - return matrix @ vector - - lanczos = jitted_functions._generate_lanczos_factorization(jax) - shifted_QR = jitted_functions._shifted_QR(jax) - SA_sort = jitted_functions._SA_sort(jax) - - Vm = jax.numpy.zeros((ncv, D), dtype=dtype) - alphas = jax.numpy.zeros(ncv, dtype=dtype) - betas = jax.numpy.zeros(ncv - 1, dtype=dtype) - start = 0 - tol = 1E-5 - Vm, alphas, betas, residual, norm, _, _ = lanczos(matvec, [Ham], x, Vm, - alphas, betas, start, ncv, - tol, precision) - - Hm = jax.numpy.diag(alphas) + jax.numpy.diag(betas, -1) + jax.numpy.diag( - betas.conj(), 1) - fm = residual * norm - em = np.zeros((1, ncv)) - em[0, -1] = 1 - #test arnoldi relation - np.testing.assert_almost_equal(Ham @ Vm.T - Vm.T @ Hm - fm[:, None] * em, - np.zeros((D, ncv)).astype(dtype)) - - evals, _ = jax.numpy.linalg.eigh(Hm) - shifts, _ = SA_sort(numeig, evals) - Vk, Hk, fk = shifted_QR(Vm, Hm, fm, shifts, numeig) - - Vk = Vk.at[numeig:, :].set(0) - Hk = Hk.at[numeig:, :].set(0) - Hk = Hk.at[:, numeig:].set(0) - ek = np.zeros((1, ncv)) - ek[0, numeig - 1] = 1.0 - - np.testing.assert_almost_equal(Ham @ Vk.T - Vk.T @ Hk - fk[:, None] * ek, - np.zeros((D, ncv)).astype(dtype)) - - -@pytest.mark.parametrize("dtype", [np.float64, np.complex128]) -@pytest.mark.parametrize("ncv", [10, 20, 30]) -def test_lanczos_factorization(dtype, ncv): - np.random.seed(10) - D = 20 - mat = np.random.rand(D, D).astype(dtype) - Ham = mat + mat.T.conj() - x = np.random.rand(D).astype(dtype) - - @jax.tree_util.Partial - @jax.jit - def matvec(vector, matrix): - return matrix @ vector - - lanczos = jitted_functions._generate_lanczos_factorization(jax) - - Vm = jax.numpy.zeros((ncv, D), dtype=dtype) - alphas = jax.numpy.zeros(ncv, dtype=dtype) - betas = jax.numpy.zeros(ncv-1, dtype=dtype) - start = 0 - tol = 1E-5 - Vm, alphas, betas, residual, norm, _, _ = lanczos(matvec, [Ham], x, Vm, - alphas, betas, start, ncv, - tol, precision) - Hm = jax.numpy.diag(alphas) + jax.numpy.diag(betas, -1) + jax.numpy.diag( - betas.conj(), 1) - fm = residual * norm - em = np.zeros((1, Vm.shape[0])) - em[0, -1] = 1 - #test arnoldi relation - np.testing.assert_almost_equal(Ham @ Vm.T - Vm.T @ Hm - fm[:, None] * em, - np.zeros((D, ncv)).astype(dtype)) - - -@pytest.mark.parametrize("dtype", jax_dtypes) -def test_gmres_on_small_known_problem(dtype): - """ - GMRES produces the correct result on an analytically solved - linear system. - """ - dummy = jax.numpy.zeros(1, dtype=dtype) - dtype = dummy.dtype - gmres = jitted_functions.gmres_wrapper(jax) - - A = jax.numpy.array(([[1, 1], [3, -4]]), dtype=dtype) - b = jax.numpy.array([3, 2], dtype=dtype) - x0 = jax.numpy.ones(2, dtype=dtype) - n_kry = 2 - maxiter = 1 - - @jax.tree_util.Partial - def A_mv(x): - return A @ x - tol = A.size*jax.numpy.finfo(dtype).eps - x, _, _, _ = gmres.gmres_m(A_mv, [], b, x0, tol, tol, n_kry, maxiter, - precision) - solution = jax.numpy.array([2., 1.], dtype=dtype) - np.testing.assert_allclose(x, solution, atol=tol) - - -@pytest.mark.parametrize("dtype", jax_dtypes) -def test_gmres_krylov(dtype): - """ - gmres_krylov correctly builds the QR-decomposed Arnoldi decomposition. - This function assumes that gmres["kth_arnoldi_step (which is - independently tested) is correct. - """ - dummy = jax.numpy.zeros(1, dtype=dtype) - dtype = dummy.dtype - gmres = jitted_functions.gmres_wrapper(jax) - - n = 2 - n_kry = n - np.random.seed(10) - - @jax.tree_util.Partial - def A_mv(x): - return A @ x - A = jax.numpy.array(np.random.rand(n, n).astype(dtype)) - tol = A.size*jax.numpy.finfo(dtype).eps - x0 = jax.numpy.array(np.random.rand(n).astype(dtype)) - b = jax.numpy.array(np.random.rand(n), dtype=dtype) - r, beta = gmres.gmres_residual(A_mv, [], b, x0) - _, V, R, _ = gmres.gmres_krylov(A_mv, [], n_kry, x0, r, beta, - tol, jax.numpy.linalg.norm(b), - precision) - phases = jax.numpy.sign(jax.numpy.diagonal(R[:-1, :])) - R = phases.conj()[:, None] * R[:-1, :] - Vtest = np.zeros((n, n_kry + 1), dtype=x0.dtype) - Vtest[:, 0] = r/beta - Vtest = jax.numpy.array(Vtest) - Htest = jax.numpy.zeros((n_kry + 1, n_kry), dtype=x0.dtype) - for k in range(n_kry): - Vtest, Htest = gmres.kth_arnoldi_step(k, A_mv, [], Vtest, Htest, tol, - precision) - _, Rtest = jax.numpy.linalg.qr(Htest) - phases = jax.numpy.sign(jax.numpy.diagonal(Rtest)) - Rtest = phases.conj()[:, None] * Rtest - np.testing.assert_allclose(V, Vtest, atol=tol) - np.testing.assert_allclose(R, Rtest, atol=tol) - - -@pytest.mark.parametrize("dtype", jax_dtypes) -def test_gmres_arnoldi_step(dtype): - """ - The Arnoldi decomposition within GMRES is correct. - """ - gmres = jitted_functions.gmres_wrapper(jax) - dummy = jax.numpy.zeros(1, dtype=dtype) - dtype = dummy.dtype - n = 4 - n_kry = n - np.random.seed(10) - A = jax.numpy.array(np.random.rand(n, n).astype(dtype)) - x0 = jax.numpy.array(np.random.rand(n).astype(dtype)) - Q = np.zeros((n, n_kry + 1), dtype=x0.dtype) - Q[:, 0] = x0/jax.numpy.linalg.norm(x0) - Q = jax.numpy.array(Q) - H = jax.numpy.zeros((n_kry + 1, n_kry), dtype=x0.dtype) - tol = A.size*jax.numpy.finfo(dtype).eps - @jax.tree_util.Partial - def A_mv(x): - return A @ x - for k in range(n_kry): - Q, H = gmres.kth_arnoldi_step(k, A_mv, [], Q, H, tol, precision) - QAQ = Q[:, :n_kry].conj().T @ A @ Q[:, :n_kry] - np.testing.assert_allclose(H[:n_kry, :], QAQ, atol=tol) - - -@pytest.mark.parametrize("dtype", jax_dtypes) -def test_givens(dtype): - """ - gmres["givens_rotation produces the correct rotation factors. - """ - gmres = jitted_functions.gmres_wrapper(jax) - np.random.seed(10) - v = jax.numpy.array(np.random.rand(2).astype(dtype)) - cs, sn = gmres.givens_rotation(*v) - rot = np.zeros((2, 2), dtype=dtype) - rot[0, 0] = cs - rot[1, 1] = cs - rot[0, 1] = -sn - rot[1, 0] = sn - rot = jax.numpy.array(rot) - result = rot @ v - tol = 4*jax.numpy.finfo(dtype).eps - np.testing.assert_allclose(result[-1], 0., atol=tol) \ No newline at end of file diff --git a/test.py b/test.py index f08ea6e..b801f66 100644 --- a/test.py +++ b/test.py @@ -1,9 +1,15 @@ import jax +import jax.numpy as jnp +from eigensolver import eigs +from jax import config +config.update("jax_enable_x64", True) if __name__ == "__main__": - backend = JaxBackend() - m = 100 - A = jax.random.normal(jax.random.PRNGKey(42),(m,m)) - b = jax.random.normal(jax.random.PRNGKey(41),(m,)) + m = 10 + A = jax.random.uniform(jax.random.PRNGKey(42),(m,m)) + b = jax.random.uniform(jax.random.PRNGKey(41),(m,)) def mapA(x): return A@x - backend.eigs(mapA,initial_state = b) \ No newline at end of file + res = eigs(mapA, initial_state = b, numeig=1, num_krylov_vecs = 5) + print(res[0],res[1][0]) + + A @ res[1][0] / res[1][0] \ No newline at end of file