Skip to content

Commit

Permalink
Merge branch 'main' into brendt/iso-tv-norm
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Dec 17, 2023
2 parents 8349bf3 + 4eb21c1 commit 5ea6b2b
Show file tree
Hide file tree
Showing 53 changed files with 1,120 additions and 609 deletions.
17 changes: 11 additions & 6 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,27 @@ Version 0.0.5 (unreleased)
``functional.IsotropicTVNorm`` and ``functional.ProximalAverage`` with
proximal operator approximations.
• New integrated Radon/X-ray transform ``linop.XRayTransform``.
• New operators ``operator.DiagonalStack`` and ``operator.VerticalStack``.
• Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and
``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes
to ``XRayTransform``.
• Rename ``AbelProjector`` to ``AbelTransform``.
• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.20.
• Rename some ``__init__`` parameters of ``linop.DiagonalStack`` and
``linop.VerticalStack``.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.21.
• Support ``flax`` versions up to 0.7.5.
• Use ``orbax`` for checkpointing ``flax`` models.



Version 0.0.4 (2023-08-03)
----------------------------

• Add new `Function` class for representing array-to-array mappings with more
• Add new ``Function`` class for representing array-to-array mappings with more
than one input.
• Add new methods and a function for computing Jacobian-vector products for
`Operator` objects.
``Operator`` objects.
• Add new proximal ADMM solvers.
• Add new ADMM subproblem solvers for problems involving a sum-of-convolutions
operator.
Expand All @@ -34,7 +39,7 @@ Version 0.0.4 (2023-08-03)
• Enable diagnostics for ML training loops.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.14.
• Change required packages and version numbers, including more recent version
for `flax`.
for ``flax``.
• Drop support for Python 3.7.
• Add support for 3D tomographic projection with the ASTRA Toolbox.

Expand All @@ -44,8 +49,8 @@ Version 0.0.3 (2022-09-21)
----------------------------

• Change required packages and version numbers, including more recent version
requirements for `numpy`, `scipy`, `svmbir`, and `ray`.
• Package `bm4d` removed from main requirements list due to issue #342.
requirements for ``numpy``, ``scipy``, ``svmbir``, and ``ray``.
• Package ``bm4d`` removed from main requirements list due to issue #342.
• Support ``jaxlib`` versions 0.3.0 to 0.3.15 and ``jax`` versions
0.3.0 to 0.3.17.
• Rename linear operators in ``radon_astra`` and ``radon_svmbir`` modules
Expand Down
4 changes: 2 additions & 2 deletions docs/source/include/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ in terms of the proximal operators of the :math:`f_i`
.. math::
\mathrm{prox}_f(\mb{x}, \lambda)
=
\begin{bmatrix}
\begin{pmatrix}
\mathrm{prox}_{f_1}(\mb{x}_1, \lambda) \\
\vdots \\
\mathrm{prox}_{f_N}(\mb{x}_N, \lambda) \\
\end{bmatrix} \;.
\end{pmatrix} \;.
Separable Functionals are implemented in the :class:`.SeparableFunctional` class. Separable functionals naturally accept :class:`.BlockArray` inputs and return the prox as a :class:`.BlockArray`.

Expand Down
2 changes: 2 additions & 0 deletions docs/source/team.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ Contributors
- `Saurav Maheshkar <https://github.com/SauravMaheshkar>`_ (Improvements to pre-commit configuration)
- `Yanpeng Yuan <https://github.com/yanpeng7>`_ (ASTRA interface improvements)
- `Li-Ta (Ollie) Lo <https://github.com/ollielo>`_ (ASTRA interface improvements)
- `Renat Sibgatulin <https://github.com/Sibgatulin>`_ (Docs corrections)
- `Salman Naqvi <https://github.com/shnaqvi>`_ (Contributions to approximate TV norm prox and proximal average implementation)
33 changes: 20 additions & 13 deletions examples/scripts/ct_astra_modl_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@
"depth": 10,
"num_filters": 64,
"block_depth": 4,
"cg_iter": 3,
"cg_iter_1": 3,
"cg_iter_2": 8,
}
# training configuration
train_conf: sflax.ConfigDict = {
Expand All @@ -132,6 +133,7 @@
"warmup_epochs": 0,
"log_every_steps": 40,
"log": True,
"checkpointing": True,
}


Expand Down Expand Up @@ -166,10 +168,11 @@
)

stats_object_ini = None
stats_object = None

checkpoint_files = []
for dirpath, dirnames, filenames in os.walk(workdir2):
checkpoint_files = [fn for fn in filenames if str.split(fn, "_")[0] == "checkpoint"]
checkpoint_files = [fn for fn in filenames]

if len(checkpoint_files) > 0:
model = sflax.MoDLNet(
Expand All @@ -178,11 +181,14 @@
channels=channels,
num_filters=model_conf["num_filters"],
block_depth=model_conf["block_depth"],
cg_iter=model_conf["cg_iter"],
cg_iter=model_conf["cg_iter_2"],
)

train_conf["workdir"] = workdir2
train_conf["post_lst"] = [lmbdapos]
# Parameters for 2nd stage
train_conf["workdir"] = workdir2
train_conf["opt_type"] = "ADAM"
train_conf["num_epochs"] = 150
# Construct training object
trainer = sflax.BasicFlaxTrainer(
train_conf,
Expand All @@ -203,7 +209,7 @@
channels=channels,
num_filters=model_conf["num_filters"],
block_depth=model_conf["block_depth"],
cg_iter=model_conf["cg_iter"],
cg_iter=model_conf["cg_iter_1"],
)
# First stage: initialization training loop.
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out")
Expand All @@ -230,8 +236,7 @@

# Second stage: depth iterations training loop.
model.depth = model_conf["depth"]
model.cg_iter = 8
train_conf["base_learning_rate"] = 1e-2
model.cg_iter = model_conf["cg_iter_2"]
train_conf["opt_type"] = "ADAM"
train_conf["num_epochs"] = 150
train_conf["workdir"] = workdir2
Expand Down Expand Up @@ -265,7 +270,7 @@


"""
Compare trained model in terms of reconstruction time
Evaluate trained model in terms of reconstruction time
and data fidelity.
"""
total_epochs = epochs_init + train_conf["num_epochs"]
Expand All @@ -281,7 +286,9 @@
f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}"
)

# Plot comparison
"""
Plot comparison.
"""
np.random.seed(123)
indx = np.random.randint(0, high=maxn)

Expand Down Expand Up @@ -311,10 +318,10 @@


"""
Plot convergence statistics. Statistics only generated if a training
cycle was done (i.e. not reading final epoch results from checkpoint).
Plot convergence statistics. Statistics are generated only if a training
cycle was done (i.e. if not reading final epoch results from checkpoint).
"""
if stats_object is not None:
if stats_object is not None and len(stats_object.iterations) > 0:
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
Expand All @@ -341,7 +348,7 @@
fig.show()

# Stats for initialization loop
if stats_object_ini is not None:
if stats_object_ini is not None and len(stats_object_ini.iterations) > 0:
hist = stats_object_ini.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
Expand Down
13 changes: 8 additions & 5 deletions examples/scripts/ct_astra_odp_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
"warmup_epochs": 0,
"log_every_steps": 160,
"log": True,
"checkpointing": True,
}


Expand Down Expand Up @@ -208,7 +209,7 @@


"""
Compare trained model in terms of reconstruction time and data fidelity.
Evaluate trained model in terms of reconstruction time and data fidelity.
"""
snr_eval = metric.snr(test_ds["label"][:maxn], output)
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
Expand All @@ -221,7 +222,9 @@
f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}"
)

# Plot comparison
"""
Plot comparison.
"""
np.random.seed(123)
indx = np.random.randint(0, high=maxn)

Expand Down Expand Up @@ -251,10 +254,10 @@


"""
Plot convergence statistics. Statistics only generated if a training
cycle was done (i.e. not reading final epoch results from checkpoint).
Plot convergence statistics. Statistics are generated only if a training
cycle was done (i.e. if not reading final epoch results from checkpoint).
"""
if stats_object is not None:
if stats_object is not None and len(stats_object.iterations) > 0:
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
Expand Down
27 changes: 17 additions & 10 deletions examples/scripts/ct_astra_unet_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
Read data from cache or generate if not available.
"""
N = 256 # phantom size
train_nimg = 536 # number of training images
test_nimg = 64 # number of testing images
train_nimg = 498 # number of training images
test_nimg = 32 # number of testing images
nimg = train_nimg + test_nimg
n_projection = 45 # CT views

Expand Down Expand Up @@ -83,6 +83,7 @@
"warmup_epochs": 0,
"log_every_steps": 1000,
"log": True,
"checkpointing": True,
}


Expand Down Expand Up @@ -123,18 +124,24 @@
"""
Evaluate on testing data.
"""
start_time = time()
del train_ds["image"]
del train_ds["label"]

fmap = sflax.FlaxMap(model, modvar)
output = fmap(test_ds["image"])
del model, modvar

maxn = test_nimg // 2
start_time = time()
output = fmap(test_ds["image"][:maxn])
time_eval = time() - start_time
output = jax.numpy.clip(output, a_min=0, a_max=1.0)


"""
Compare trained model in terms of reconstruction time and data fidelity.
Evaluate trained model in terms of reconstruction time and data fidelity.
"""
snr_eval = metric.snr(test_ds["label"], output)
psnr_eval = metric.psnr(test_ds["label"], output)
snr_eval = metric.snr(test_ds["label"][:maxn], output)
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'UNet training':15s}{'epochs:':2s}{train_conf['num_epochs']:>5d}"
f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}"
Expand Down Expand Up @@ -181,10 +188,10 @@


"""
Plot convergence statistics. Statistics only generated if a training
cycle was done (i.e. not reading final epoch results from checkpoint).
Plot convergence statistics. Statistics are generated only if a training
cycle was done (i.e. if not reading final epoch results from checkpoint).
"""
if stats_object is not None:
if stats_object is not None and len(stats_object.iterations) > 0:
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
Expand Down
Loading

0 comments on commit 5ea6b2b

Please sign in to comment.