From 34c0ff4fd41923603c48afa62993f6a5b8c31181 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 18 Dec 2023 14:22:31 -0700 Subject: [PATCH 1/5] Various changes (#487) * Improve script * Improve script * Typo fix * Clean up requirements files * Improve script * Debugging * Update GPU message * Bump max supported jaxlib/jax version * Update jax GPU installation instructions * Update GPU instructions link * Replace jax.experimental.host_callback.call with jax.pure_callback * Coding standards * Resolve test failures on GPU * Improve test level mechanism, apply to very slow svmbir tests * Tests seem to no longer hang on gpu * Set writable flag to avoid copy * Try again to fix the flags.writeable problem * Improve GPU support message --------- Co-authored-by: Brendt Wohlberg Co-authored-by: Michael-T-McCann --- CHANGES.rst | 2 +- conftest.py | 19 +++++- docs/source/install.rst | 24 ++++--- examples/examples_requirements.txt | 1 + examples/notebooks_requirements.txt | 2 +- misc/conda/make_conda_env.sh | 95 +++++++++++++++++----------- requirements.txt | 4 +- scico/denoiser.py | 18 ++++-- scico/flax/inverse.py | 10 +-- scico/linop/xray/astra.py | 35 +++++----- scico/linop/xray/svmbir.py | 15 ++--- scico/solver.py | 11 ++-- scico/test/linop/xray/test_svmbir.py | 7 +- 13 files changed, 146 insertions(+), 97 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index d3f1c9bf2..a1170b37f 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -17,7 +17,7 @@ Version 0.0.5 (unreleased) • Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``. • Rename some ``__init__`` parameters of ``linop.DiagonalStack`` and ``linop.VerticalStack``. -• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.21. +• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.23. • Support ``flax`` versions up to 0.7.5. • Use ``orbax`` for checkpointing ``flax`` models. diff --git a/conftest.py b/conftest.py index 109657c75..808858931 100644 --- a/conftest.py +++ b/conftest.py @@ -27,8 +27,25 @@ def pytest_addoption(parser, pluginmanager): Level definitions: 1 Critical tests only 2 Skip tests that do have a significant impact on coverage - 3 All tests + 3 All standard tests + 4 Run all tests, including those marked as slow to run """ parser.addoption( "--level", action="store", default=3, type=int, help="Set test level to be run" ) + + +def pytest_configure(config): + """Add marker description.""" + config.addinivalue_line("markers", "slow: mark test as slow to run") + + +def pytest_collection_modifyitems(config, items): + """Skip slow tests depending on selected testing level.""" + if config.getoption("--level") >= 4: + # don't skip tests at level 4 or higher + return + level_skip = pytest.mark.skip(reason="test not appropriate for selected level") + for item in items: + if "slow" in item.keywords: + item.add_marker(level_skip) diff --git a/docs/source/install.rst b/docs/source/install.rst index 8eec843e5..2c507303c 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -54,26 +54,32 @@ GPU Support The instructions above install a CPU-only version of SCICO. To install a version with GPU support: -1. Follow the CPU only instructions, above +1. Follow the CPU-only instructions, above 2. Install the version of jaxlib with GPU support, as described in the `JAX installation - instructions `_. + instructions `_. In the simplest case, the appropriate command is :: - pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + pip install --upgrade "jax[cuda11]" + for CUDA 11, or - but it may be necessary to explicitly specify the ``jaxlib`` - version if the most recent release is not yet supported by SCICO - (as specified in the ``requirements.txt`` file), or if using a - version of CUDA older than 11.4, or CuDNN older than 8.2, in which - case the command would be of the form :: + :: + + pip install --upgrade "jax[cuda12]" + + for CUDA 12, but it may be necessary to explicitly specify the + ``jaxlib`` version if the most recent release is not yet supported + by SCICO (as specified in the ``requirements.txt`` file), or if + using a version of CUDA older than 11.4, or CuDNN older than 8.2, + in which case the command would be of the form :: pip install --upgrade "jaxlib==0.4.2+cuda11.cudnn82" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - with appropriate substitution of ``jaxlib``, CUDA, and CuDNN version numbers. + with appropriate substitution of ``jaxlib``, CUDA, and CuDNN version + numbers. diff --git a/examples/examples_requirements.txt b/examples/examples_requirements.txt index ff7a1dcfb..125852f90 100644 --- a/examples/examples_requirements.txt +++ b/examples/examples_requirements.txt @@ -1,3 +1,4 @@ +-r ../requirements.txt astra-toolbox colour_demosaicing xdesign>=0.5.5 diff --git a/examples/notebooks_requirements.txt b/examples/notebooks_requirements.txt index 24d063d04..bcb9e03c9 100644 --- a/examples/notebooks_requirements.txt +++ b/examples/notebooks_requirements.txt @@ -1,7 +1,7 @@ +-r examples-requirements.txt nbformat nbconvert nb_conda_kernels psutil py2jn pypandoc -ray[tune] diff --git a/misc/conda/make_conda_env.sh b/misc/conda/make_conda_env.sh index 241f0f3a4..e424df832 100755 --- a/misc/conda/make_conda_env.sh +++ b/misc/conda/make_conda_env.sh @@ -25,6 +25,8 @@ REPOPATH=$(realpath $(dirname $0)) USAGE=$(cat <<-EOF Usage: $SCRIPT [-h] [-y] [-g] [-p python_version] [-e env_name] [-h] Display usage information + [-v] Verbose operation + [-t] Display actions that would be taken but do nothing [-y] Do not ask for confirmation [-p python_version] Specify Python version (e.g. 3.9) [-e env_name] Specify conda environment name @@ -32,6 +34,8 @@ EOF ) AGREE=no +VERBOSE=no +TEST=no PYVER="3.9" ENVNM=py$(echo $PYVER | sed -e 's/\.//g') @@ -46,13 +50,13 @@ EOF ) # Requirements that cannot be installed via conda (i.e. have to use pip) NOCONDA=$(cat <<-EOF -flax bm3d bm4d faculty-sphinx-theme py2jn colour_demosaicing ray[tune] +flax bm3d bm4d py2jn colour_demosaicing ray[tune,train] EOF ) OPTIND=1 -while getopts ":hyp:e:" opt; do +while getopts ":hvtyp:e:" opt; do case $opt in p|e) if [ -z "$OPTARG" ] || [ "${OPTARG:0:1}" = "-" ] ; then echo "Error: option -$opt requires an argument" >&2 @@ -61,6 +65,8 @@ while getopts ":hyp:e:" opt; do fi ;;& h) echo "$USAGE"; exit 0;; + t) VERBOSE=yes;TEST=yes;; + v) VERBOSE=yes;; y) AGREE=yes;; p) PYVER=$OPTARG;; e) ENVNM=$OPTARG;; @@ -125,31 +131,6 @@ JLVER=$($SED -n 's/^jaxlib>=.*<=\([0-9\.]*\).*/\1/p' \ JXVER=$($SED -n 's/^jax>=.*<=\([0-9\.]*\).*/\1/p' \ $REPOPATH/../../requirements.txt) -CONDAHOME=$(conda info --base) -ENVDIR=$CONDAHOME/envs/$ENVNM -if [ -d "$ENVDIR" ]; then - echo "Error: environment $ENVNM already exists" - exit 9 -fi - -if [ "$AGREE" == "no" ]; then - RSTR="Confirm creation of conda environment $ENVNM with Python $PYVER" - RSTR="$RSTR [y/N] " - read -r -p "$RSTR" CNFRM - if [ "$CNFRM" != 'y' ] && [ "$CNFRM" != 'Y' ]; then - echo "Cancelling environment creation" - exit 10 - fi -else - echo "Creating conda environment $ENVNM with Python $PYVER" -fi - -if [ "$AGREE" == "yes" ]; then - CONDA_FLAGS="-y" -else - CONDA_FLAGS="" -fi - # Construct merged list of all requirements if [ "$OS" == "Darwin" ]; then ALLREQUIRE=$(/usr/bin/mktemp -t condaenv) @@ -177,15 +158,56 @@ sort $ALLREQUIRE | uniq | $SED -E 's/(>|<|\|)/\\\1/g' \ | $SED -E 's/\#.*$//g' \ | $SED -E '/^-r.*|^jaxlib.*|^jax.*/d' > $FLTREQUIRE # Remove requirements that cannot be installed via conda +PIPREQ="" for nc in $NOCONDA; do # Escape [ and ] for use in regex - nc=$(echo $nc | sed -E 's/(\[|\])/\\\1/g') + nc=$(echo $nc | $SED -E 's/(\[|\])/\\\1/g') + # Add package to pip package list + PIPREQ="$PIPREQ "$(grep "$nc" $FLTREQUIRE | $SED 's/\\//g') # Remove package $nc from conda package list $SED -i "/^$nc.*\$/d" $FLTREQUIRE done # Get list of requirements to be installed via conda CONDAREQ=$(cat $FLTREQUIRE | xargs) +if [ "$VERBOSE" == "yes" ]; then + echo "Create python $PYVER environment $ENVNM in conda installation" + echo " $CONDAHOME" + echo "Packages to be installed via conda:" + echo " $CONDAREQ" | fmt -w 79 + echo "Packages to be installed via pip:" + echo " jaxlib==$JLVER jax==$JXVER $PIPREQ" | fmt -w 79 + if [ "$TEST" == "yes" ]; then + exit 0 + fi +fi + +CONDAHOME=$(conda info --base) +ENVDIR=$CONDAHOME/envs/$ENVNM +if [ -d "$ENVDIR" ]; then + echo "Error: environment $ENVNM already exists" + exit 9 +fi + +if [ "$AGREE" == "no" ]; then + RSTR="Confirm creation of conda environment $ENVNM with Python $PYVER" + RSTR="$RSTR [y/N] " + read -r -p "$RSTR" CNFRM + if [ "$CNFRM" != 'y' ] && [ "$CNFRM" != 'Y' ]; then + echo "Cancelling environment creation" + exit 10 + fi +else + echo "Creating conda environment $ENVNM with Python $PYVER" +fi + +if [ "$AGREE" == "yes" ]; then + CONDA_FLAGS="-y" +else + CONDA_FLAGS="" +fi + + # Update conda, create new environment, and activate it conda update $CONDA_FLAGS -n base conda conda create $CONDA_FLAGS -n $ENVNM python=$PYVER @@ -215,7 +237,7 @@ fi pip install --upgrade jaxlib==$JLVER jax==$JXVER # Install other packages that require installation via pip -pip install $NOCONDA +pip install $PIPREQ # Warn if libopenblas-dev not installed on debian/ubuntu if [ "$(which dpkg 2>/dev/null)" ]; then @@ -233,13 +255,12 @@ echo " conda activate $ENVNM" echo "The environment can be deactivated with the command" echo " conda deactivate" echo -echo "Jax installed without GPU support. To avoid warning messages," -echo "add the following to your .bashrc or .bash_aliases file" -echo " export JAX_PLATFORM_NAME=cpu" -echo "To include GPU support, reinstall the astra-toolbox conda" -echo "package on a host with GPUs, and see the instructions at" -echo " https://github.com/google/jax#pip-installation-gpu-cuda" -echo "for additional steps required after running this script and" -echo "activating the environment created by it." +echo "JAX installed without GPU support. To enable GPU support, install a" +echo "version of jaxlib with CUDA support following the instructions at" +echo " https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu" +echo "ASTRA Toolbox installed without GPU support if this script was" +echo "run on a host without CUDA drivers installed. To enable GPU support," +echo "uninstall and then reinstall the astra-toolbox conda package on a" +echo "host with CUDA drivers installed." exit 0 diff --git a/requirements.txt b/requirements.txt index 5ac3c6904..97f0263e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,8 @@ scipy>=1.6.0 tifffile imageio>=2.17 matplotlib -jaxlib>=0.4.3,<=0.4.21 -jax>=0.4.3,<=0.4.21 +jaxlib>=0.4.3,<=0.4.23 +jax>=0.4.3,<=0.4.23 flax>=0.6.1,<=0.7.5 svmbir>=0.3.3 pyabel>=0.9.0 diff --git a/scico/denoiser.py b/scico/denoiser.py index d3151c76f..574aa2596 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -12,7 +12,7 @@ import numpy as np -from jax.experimental import host_callback as hcb +import jax try: import bm3d as tubm3d @@ -84,7 +84,7 @@ def bm3d_eval(x: snp.Array, sigma: float): "BM3D requires two-dimensional or three dimensional inputs; got ndim = {x.ndim}." ) - # This check is also performed inside the BM3D call, but due to the host_callback, + # This check is also performed inside the BM3D call, but due to the callback, # no exception is raised and the program will crash with no traceback. # NOTE: if BM3D is extended to allow for different profiles, the block size must be # updated; this presumes 'np' profile (bs=8) @@ -103,7 +103,11 @@ def bm3d_eval(x: snp.Array, sigma: float): " the additional axes are singletons." ) - y = hcb.call(lambda args: bm3d_eval(*args).astype(x.dtype), (x, sigma), result_shape=x) + y = jax.pure_callback( + lambda args: bm3d_eval(*args).astype(x.dtype), + jax.ShapeDtypeStruct(x.shape, x.dtype), + (x, sigma), + ) # undo squeezing, if neccessary y = y.reshape(x_in_shape) @@ -145,7 +149,7 @@ def bm4d_eval(x: snp.Array, sigma: float): if isinstance(x.ndim, tuple) or x.ndim < 3: raise ValueError(f"BM4D requires three-dimensional inputs; got ndim = {x.ndim}.") - # This check is also performed inside the BM4D call, but due to the host_callback, + # This check is also performed inside the BM4D call, but due to the callback, # no exception is raised and the program will crash with no traceback. # NOTE: if BM4D is extended to allow for different profiles, the block size must be # updated; this presumes 'np' profile (bs=8) @@ -164,7 +168,11 @@ def bm4d_eval(x: snp.Array, sigma: float): " the additional axes are singletons." ) - y = hcb.call(lambda args: bm4d_eval(*args).astype(x.dtype), (x, sigma), result_shape=x) + y = jax.pure_callback( + lambda args: bm4d_eval(*args).astype(x.dtype), + jax.ShapeDtypeStruct(x.shape, x.dtype), + (x, sigma), + ) # undo squeezing, if neccessary y = y.reshape(x_in_shape) diff --git a/scico/flax/inverse.py b/scico/flax/inverse.py index db4a17372..96e34d087 100644 --- a/scico/flax/inverse.py +++ b/scico/flax/inverse.py @@ -120,11 +120,11 @@ def cg_solver(A: Callable, b: Array, x0: Array = None, maxiter: int = 50) -> Arr version constructed to be differentiable with the autograd functionality from jax. Therefore, (i) it uses :meth:`jax.lax.scan` to execute a fixed number of iterations and (ii) it assumes that the - linear operator may use :meth:`jax.experimental.host_callback`. Due - to the utilization of a while cycle, :meth:`scico.cg` is not - differentiable by jax and :meth:`jax.scipy.sparse.linalg.cg` does not - support functions using :meth:`jax.experimental.host_callback` - explaining why an additional conjugate gradient function is implemented. + linear operator may use :meth:`jax.pure_callback`. Due to the + utilization of a while cycle, :meth:`scico.cg` is not differentiable + by jax and :meth:`jax.scipy.sparse.linalg.cg` does not support + functions using :meth:`jax.pure_callback`, which is why an additional + conjugate gradient function has been implemented. Args: A: Function implementing linear operator :math:`A`, should be diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py index b877891ad..9f39d994e 100644 --- a/scico/linop/xray/astra.py +++ b/scico/linop/xray/astra.py @@ -22,7 +22,6 @@ import numpy as np import jax -import jax.experimental.host_callback as hcb try: import astra @@ -183,11 +182,10 @@ def __init__( ) def _proj(self, x: jax.Array) -> jax.Array: - # Applies the forward projector and generates a sinogram + # apply the forward projector and generate a sinogram def f(x): - if x.flags.writeable == False: - x.flags.writeable = True + x = ensure_writeable(x) if self.num_dims == 2: proj_id, result = astra.create_sino(x, self.proj_id) astra.data2d.delete(proj_id) @@ -196,15 +194,12 @@ def f(x): astra.data3d.delete(proj_id) return result - return hcb.call( - f, x, result_shape=jax.ShapeDtypeStruct(self.output_shape, self.output_dtype) - ) + return jax.pure_callback(f, jax.ShapeDtypeStruct(self.output_shape, self.output_dtype), x) def _bproj(self, y: jax.Array) -> jax.Array: - # applies backprojector + # apply backprojector def f(y): - if y.flags.writeable == False: - y.flags.writeable = True + y = ensure_writeable(y) if self.num_dims == 2: proj_id, result = astra.create_backprojection(y, self.proj_id) astra.data2d.delete(proj_id) @@ -215,7 +210,7 @@ def f(y): astra.data3d.delete(proj_id) return result - return hcb.call(f, y, result_shape=jax.ShapeDtypeStruct(self.input_shape, self.input_dtype)) + return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), y) def fbp(self, sino: jax.Array, filter_type: str = "Ram-Lak") -> jax.Array: """Filtered back projection (FBP) reconstruction. @@ -235,8 +230,7 @@ def fbp(self, sino: jax.Array, filter_type: str = "Ram-Lak") -> jax.Array: # Just use the CPU FBP alg for now; hitting memory issues with GPU one. def f(sino): - if sino.flags.writeable == False: - sino.flags.writeable = True + sino = ensure_writeable(sino) sino_id = astra.data2d.create("-sino", self.proj_geom, sino) # create memory for result @@ -262,6 +256,15 @@ def f(sino): astra.data2d.delete(sino_id) return out - return hcb.call( - f, sino, result_shape=jax.ShapeDtypeStruct(self.input_shape, self.input_dtype) - ) + return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), sino) + + +def ensure_writeable(x): + """Ensure that `x.flags.writeable` is ``True``, copying if needed.""" + + if not x.flags.writeable: + try: + x.setflags(write=True) + except ValueError: + x = x.copy() + return x diff --git a/scico/linop/xray/svmbir.py b/scico/linop/xray/svmbir.py index 8e757da84..1660b8d63 100644 --- a/scico/linop/xray/svmbir.py +++ b/scico/linop/xray/svmbir.py @@ -18,7 +18,6 @@ import numpy as np import jax -import jax.experimental.host_callback import scico.numpy as snp from scico.loss import Loss, SquaredL2Loss @@ -175,7 +174,6 @@ def __init__( self.delta_pixel = delta_pixel elif self.geometry == "parallel": - self.magnification = 1.0 if delta_pixel is None: self.delta_pixel = self.delta_channel @@ -232,8 +230,8 @@ def _proj( def _proj_hcb(self, x): x = x.reshape(self.svmbir_input_shape) - # host callback wrapper for _proj - y = jax.experimental.host_callback.call( + # callback wrapper for _proj + y = jax.pure_callback( lambda x: self._proj( x, self.angles, @@ -246,8 +244,8 @@ def _proj_hcb(self, x): delta_channel=self.delta_channel, delta_pixel=self.delta_pixel, ), + jax.ShapeDtypeStruct(self.svmbir_output_shape, self.output_dtype), x, - result_shape=jax.ShapeDtypeStruct(self.svmbir_output_shape, self.output_dtype), ) return y.reshape(self.output_shape) @@ -284,8 +282,8 @@ def _bproj( def _bproj_hcb(self, y): y = y.reshape(self.svmbir_output_shape) - # host callback wrapper for _bproj - x = jax.experimental.host_callback.call( + # callback wrapper for _bproj + x = jax.pure_callback( lambda y: self._bproj( y, self.angles, @@ -299,8 +297,8 @@ def _bproj_hcb(self, y): delta_channel=self.delta_channel, delta_pixel=self.delta_pixel, ), + jax.ShapeDtypeStruct(self.svmbir_input_shape, self.input_dtype), y, - result_shape=jax.ShapeDtypeStruct(self.svmbir_input_shape, self.input_dtype), ) return x.reshape(self.input_shape) @@ -389,7 +387,6 @@ def __init__( raise TypeError(f"Parameter W must be None or a linop.Diagonal, got {type(W)}.") def __call__(self, x: snp.Array) -> float: - if self.positivity and snp.sum(x < 0) > 0: return snp.inf else: diff --git a/scico/solver.py b/scico/solver.py index c897a55eb..02a76b487 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -53,7 +53,6 @@ import numpy as np import jax -import jax.experimental.host_callback as hcb import jax.numpy as jnp import jax.scipy.linalg as jsl @@ -239,14 +238,14 @@ def fun(x0): jac=jac, method=method, options=options, - ) # Returns OptimizeResult with x0 as ndarray + ) # Return OptimizeResult with x0 as ndarray return res.x.astype(x0_dtype) - # HCB call with side effects to get the OptimizeResult on the same device it was called - res.x = hcb.call( + # callback with side effects to get the OptimizeResult on the same device it was called + res.x = jax.pure_callback( fun, - arg=x0, - result_shape=x0, # From Jax-docs: This can be an object that has .shape and .dtype attributes + jax.ShapeDtypeStruct(x0.shape, x0_dtype), + x0, ) # un-vectorize the output array from spopt.minimize diff --git a/scico/test/linop/xray/test_svmbir.py b/scico/test/linop/xray/test_svmbir.py index 9674269c0..46077f3d5 100644 --- a/scico/test/linop/xray/test_svmbir.py +++ b/scico/test/linop/xray/test_svmbir.py @@ -26,8 +26,6 @@ BIG_INPUT = (32, 33, 50, 51, 125, 1.2) SMALL_INPUT = (4, 5, 7, 8, 16, 1.2) -device = jax.devices()[0] - def pytest_generate_tests(metafunc): param_ranges = { @@ -88,7 +86,6 @@ def make_A( delta_channel=None, delta_pixel=None, ): - angles = make_angles(num_angles) A = XRayTransform( im.shape, @@ -154,7 +151,7 @@ def test_adjoint( adjoint_test(A) -@pytest.mark.skipif(device.platform != "cpu", reason="test hangs on gpu") +@pytest.mark.slow def test_prox( is_3d, center_offset_small, @@ -185,7 +182,7 @@ def test_prox( prox_test(v, f, f.prox, alpha=0.25, rtol=5e-4) -@pytest.mark.skipif(device.platform != "cpu", reason="test hangs on gpu") +@pytest.mark.slow def test_prox_weights( is_3d, center_offset_small, From bdd68d9b1c8060f051c361ee212e57544d9ee094 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 18 Dec 2023 19:09:14 -0700 Subject: [PATCH 2/5] Release 0.0.5 (#488) * Remove XRayTransform warning (projection is accurate now) * Increase timeout value * Set release version and date --- CHANGES.rst | 2 +- examples/scriptcheck.sh | 2 +- scico/__init__.py | 2 +- scico/linop/xray/_xray.py | 5 ----- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index a1170b37f..59c359fc3 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,7 +3,7 @@ SCICO Release Notes =================== -Version 0.0.5 (unreleased) +Version 0.0.5 (2023-12-18) ---------------------------- • New functionals ``functional.AnisotropicTVNorm`` and diff --git a/examples/scriptcheck.sh b/examples/scriptcheck.sh index 459ba6c65..498668585 100755 --- a/examples/scriptcheck.sh +++ b/examples/scriptcheck.sh @@ -97,7 +97,7 @@ for f in $SCRIPTPATH/scripts/*.py; do sed -E -e "$re1$re2$re3$re4$re5$re6$re7$re8" $f > $g # Run temporary script and print status message. - if output=$(timeout 60s python $g 2>&1); then + if output=$(timeout 180s python $g 2>&1); then printf "%s\n" succeeded else printf "%s\n" FAILED diff --git a/scico/__init__.py b/scico/__init__.py index e75a6df94..a8e951b4f 100644 --- a/scico/__init__.py +++ b/scico/__init__.py @@ -8,7 +8,7 @@ solving the inverse problems that arise in scientific imaging applications. """ -__version__ = "0.0.5.dev0" +__version__ = "0.0.5" import logging import sys diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 7c1399daf..c93482563 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -26,11 +26,6 @@ class XRayTransform(LinearOperator): """X-ray transform operator. Wrap an X-ray projector object in a SCICO :class:`LinearOperator`. - **Warning:** Note that the only X-ray projector object currently - supported, :class:`.Parallel2dProjector`, is not a very accurate - approximation of the integral transform representing real projection - imaging, and may therefore not be suitable for real imaging - applications. """ def __init__(self, projector): From 2565d27831faacbe336a2e33ffc4bc74402784cd Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 18 Dec 2023 19:33:30 -0700 Subject: [PATCH 3/5] Post-release preparation for next release (#489) --- CHANGES.rst | 7 +++++++ scico/__init__.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index 59c359fc3..fc4563ee6 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,13 @@ SCICO Release Notes =================== +Version 0.0.6 (unreleased) +---------------------------- + +• No significant changes. + + + Version 0.0.5 (2023-12-18) ---------------------------- diff --git a/scico/__init__.py b/scico/__init__.py index a8e951b4f..84dea5c0c 100644 --- a/scico/__init__.py +++ b/scico/__init__.py @@ -8,7 +8,7 @@ solving the inverse problems that arise in scientific imaging applications. """ -__version__ = "0.0.5" +__version__ = "0.0.6.dev0" import logging import sys From 6fe11c7a369a11bc12c94963fd4adbbe5d88ef47 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 22 Dec 2023 10:10:13 -0700 Subject: [PATCH 4/5] Resolve some packaging issues (#490) * Resolve build warnings * Include tests in package * Include main requirements file * Move tifffile dependency to examples dependency list * Include orbax-checkpoint dependency * Add typing_extensions dependency * Docs improvement * Move svmbir dependency to examples dependency list * Update submodule --- .github/workflows/pytest_macos.yml | 4 +++- .github/workflows/pytest_ubuntu.yml | 3 ++- data | 2 +- dev_requirements.txt | 1 + examples/examples_requirements.txt | 4 +++- requirements.txt | 4 ++-- scico/flax/train/checkpoints.py | 19 +++++++++++-------- setup.py | 7 +++++-- 8 files changed, 28 insertions(+), 16 deletions(-) diff --git a/.github/workflows/pytest_macos.yml b/.github/workflows/pytest_macos.yml index 0b097029c..82cf74b95 100644 --- a/.github/workflows/pytest_macos.yml +++ b/.github/workflows/pytest_macos.yml @@ -59,11 +59,13 @@ jobs: pip install pytest-split pip install -r requirements.txt pip install -r dev_requirements.txt + mamba install -c conda-forge svmbir>=0.3.3 mamba install -c astra-toolbox astra-toolbox mamba install -c conda-forge pyyaml pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version + pip install bm3d>=4.0.0 pip install bm4d>=4.0.0 - pip install "ray[tune]>=2.0.0" + pip install "ray[tune]>=2.5.0" pip install hyperopt # Install package to be tested - name: Install package to be tested diff --git a/.github/workflows/pytest_ubuntu.yml b/.github/workflows/pytest_ubuntu.yml index db24830e5..f60c9d4ba 100644 --- a/.github/workflows/pytest_ubuntu.yml +++ b/.github/workflows/pytest_ubuntu.yml @@ -62,12 +62,13 @@ jobs: pip install pytest-split pip install -r requirements.txt pip install -r dev_requirements.txt + mamba install -c conda-forge svmbir>=0.3.3 mamba install -c astra-toolbox astra-toolbox mamba install -c conda-forge pyyaml pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version pip install bm3d>=4.0.0 pip install bm4d>=4.2.2 - pip install "ray[tune]>=2.0.0" + pip install "ray[tune]>=2.5.0" pip install hyperopt # Install package to be tested - name: Install package to be tested diff --git a/data b/data index a33ca716e..d5aff70f9 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit a33ca716e42ba7593d6120752053869fce8b1abb +Subproject commit d5aff70f95d33abf72e785fb945cc556db16cd12 diff --git a/dev_requirements.txt b/dev_requirements.txt index aa21af84b..213aa591e 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,3 +1,4 @@ +-r requirements.txt pylint pytest>=7.3.0 pytest-runner diff --git a/examples/examples_requirements.txt b/examples/examples_requirements.txt index 125852f90..15d0a2a1e 100644 --- a/examples/examples_requirements.txt +++ b/examples/examples_requirements.txt @@ -1,6 +1,8 @@ -r ../requirements.txt -astra-toolbox +tifffile colour_demosaicing +svmbir>=0.3.3 +astra-toolbox xdesign>=0.5.5 ray[tune,train]>=2.5.0 hyperopt diff --git a/requirements.txt b/requirements.txt index 97f0263e5..76722063a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ +typing_extensions numpy>=1.20.0 scipy>=1.6.0 -tifffile imageio>=2.17 matplotlib jaxlib>=0.4.3,<=0.4.23 jax>=0.4.3,<=0.4.23 +orbax-checkpoint flax>=0.6.1,<=0.7.5 -svmbir>=0.3.3 pyabel>=0.9.0 diff --git a/scico/flax/train/checkpoints.py b/scico/flax/train/checkpoints.py index 6ef233d80..dac4fd872 100644 --- a/scico/flax/train/checkpoints.py +++ b/scico/flax/train/checkpoints.py @@ -1,17 +1,19 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022 by SCICO Developers +# Copyright (C) 2022-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Utilities for checkpointing Flax models.""" + + from pathlib import Path from typing import Union import jax -import orbax +import orbax.checkpoint from flax.training import orbax_utils @@ -29,13 +31,14 @@ def checkpoint_restore( parameters. workdir: Checkpoint file or directory of checkpoints to restore from. - ok_no_ckpt: Flag to indicate if a checkpoint is expected. Default: - False, a checkpoint is expected and an error is generated. + ok_no_ckpt: Flag to indicate if a checkpoint is expected. If + ``False``, an error is generated if a checkpoint is not + found. Returns: - A restored Flax train state updated from checkpoint file is returned. - If no checkpoint files are present and checkpoints are not strictly - expected it returns the passed-in `state` unchanged. + A restored Flax train state updated from checkpoint file is + returned. If no checkpoint files are present and checkpoints are + not strictly expected it returns the passed-in `state` unchanged. Raises: FileNotFoundError: If a checkpoint is expected and is not found. @@ -68,7 +71,7 @@ def checkpoint_save(state: TrainState, config: ConfigDict, workdir: Union[str, P state: Flax train state which includes model and optimiser parameters. config: Python dictionary including model train configuration. - workdir: str or pathlib-like path to store checkpoint files in. + workdir: Path in which to store checkpoint files. """ if jax.process_index() == 0: orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() diff --git a/setup.py b/setup.py index 30e37e751..b7941b5ec 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ import site import sys -from setuptools import find_packages, setup +from setuptools import find_namespace_packages, setup # Import module scico._version without executing __init__.py spec = importlib.util.spec_from_file_location("_version", os.path.join("scico", "_version.py")) @@ -20,7 +20,10 @@ name = "scico" version = package_version() -packages = find_packages() +# Add argument exclude=["test", "test.*"] to exclude test subpackage +packages = find_namespace_packages(where="scico") +packages = [f"scico.{m}" for m in packages] + longdesc = """ SCICO is a Python package for solving the inverse problems that arise in scientific imaging applications. Its primary focus is providing methods for solving ill-posed inverse problems by using an appropriate prior model of the reconstruction space. SCICO includes a growing suite of operators, cost functionals, regularizers, and optimization routines that may be combined to solve a wide range of problems, and is designed so that it is easy to add new building blocks. SCICO is built on top of JAX, which provides features such as automatic gradient calculation and GPU acceleration. From 85c1a1cc089e93fba53740701373af334d5d9602 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 22 Dec 2023 11:05:46 -0700 Subject: [PATCH 5/5] Resolve packaging error (#491) * Resolve missing top-level scico package * Improve version number construction function --- scico/_version.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scico/_version.py b/scico/_version.py index 6ee1a4e2e..cdfd89aa6 100644 --- a/scico/_version.py +++ b/scico/_version.py @@ -96,7 +96,8 @@ def package_version(split: bool = False) -> Union[str, Tuple[str, str]]: # prag Package version string or tuple of strings. """ version = init_variable_assign_value("__version_") - if re.match(r"^[0-9\.]+$", version): # don't extend purely numeric version numbers + # don't extend purely numeric version numbers, possibly ending with post + if re.match(r"^[0-9\.]+(post[0-9]+)?$", version): git_hash = None else: git_hash = current_git_hash() diff --git a/setup.py b/setup.py index b7941b5ec..dd1345a00 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ version = package_version() # Add argument exclude=["test", "test.*"] to exclude test subpackage packages = find_namespace_packages(where="scico") -packages = [f"scico.{m}" for m in packages] +packages = ["scico"] + [f"scico.{m}" for m in packages] longdesc = """