diff --git a/conftest.py b/conftest.py index 4c8b48902..c95848a4c 100644 --- a/conftest.py +++ b/conftest.py @@ -14,6 +14,8 @@ have_ray = True ray.init(num_cpus=1) # call required to be here: see ray-project/ray#44087 +import jax.numpy as jnp + import scico.numpy as snp @@ -24,7 +26,8 @@ def pytest_sessionstart(session): def pytest_sessionfinish(session, exitstatus): """Clean up after end of test session.""" - ray.shutdown() + if have_ray: + ray.shutdown() @pytest.fixture(autouse=True) @@ -33,7 +36,8 @@ def add_modules(doctest_namespace): Necessary because `np` is used in doc strings for jax functions (e.g. `linear_transpose`) that get pulled into `scico/__init__.py`. - Also allow `snp` to be used without explicitly importing. + Also allow `snp` and `jnp` to be used without explicitly importing. """ doctest_namespace["np"] = np doctest_namespace["snp"] = snp + doctest_namespace["jnp"] = jnp diff --git a/requirements.txt b/requirements.txt index 227948c20..1946b4295 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,8 +4,8 @@ scipy>=1.6.0 imageio>=2.17 tifffile matplotlib -jaxlib>=0.4.3,<=0.4.26 -jax>=0.4.3,<=0.4.26 +jaxlib>=0.4.3,<=0.4.28 +jax>=0.4.3,<=0.4.28 orbax-checkpoint<=0.5.7 -flax>=0.8.0,<=0.8.2 +flax>=0.8.0,<=0.8.3 pyabel>=0.9.0 diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py index 7a270e1f4..4e59cca5f 100644 --- a/scico/linop/xray/astra.py +++ b/scico/linop/xray/astra.py @@ -362,10 +362,12 @@ def angle_to_vector(det_spacing: Tuple[float, float], angles: np.ndarray) -> np. def _ensure_writeable(x): """Ensure that `x.flags.writeable` is ``True``, copying if needed.""" - - if not x.flags.writeable: - try: - x.setflags(write=True) - except ValueError: - x = x.copy() + if hasattr(x, "flags"): # x is a numpy array + if not x.flags.writeable: + try: + x.setflags(write=True) + except ValueError: + x = x.copy() + else: # x is a jax array (which is immutable) + x = np.array(x) return x