Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposed changes to #541 #543

Merged
merged 53 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
a324841
Add type annotation
bwohlberg Jul 18, 2024
1ebf4e0
Remove jax distributed data generation option
bwohlberg Jul 18, 2024
a75ab62
Remove jax distributed data generation option
bwohlberg Jul 18, 2024
997f52a
Clean up
bwohlberg Jul 18, 2024
5f4552f
Clean up
bwohlberg Jul 18, 2024
a0b72ae
Extend docs
bwohlberg Jul 18, 2024
f77e212
Add additional test for exception state
bwohlberg Jul 18, 2024
5ab1d05
Tracer conversion error fix from Cristina
bwohlberg Jul 18, 2024
ecbfca5
Omitted import
bwohlberg Jul 18, 2024
28c828c
Clean up
bwohlberg Jul 18, 2024
dcd358c
Consistent phrasing
bwohlberg Jul 18, 2024
fbb4564
Merge branch 'cristina/issue535' into brendt/issue535_extended
bwohlberg Jul 19, 2024
8f286d0
Clean up some f-strings
bwohlberg Jul 22, 2024
5e85f8c
Add missing ray init
bwohlberg Jul 22, 2024
d69fdd2
Set dtype
bwohlberg Jul 22, 2024
3ef66a6
Merge branch 'cristina/issue535' into brendt/issue535_extended
bwohlberg Jul 22, 2024
0d97b3f
Fix indentation error
bwohlberg Jul 22, 2024
eec3242
Update module docstring
bwohlberg Jul 23, 2024
a7fa89f
Experimental solution to ray/jax failure
bwohlberg Jul 23, 2024
85ded0f
Bug fix
bwohlberg Jul 23, 2024
e7461f0
Improve docstring
bwohlberg Jul 23, 2024
5dac79f
Implement hack to resolve jax/ray conflict
bwohlberg Jul 23, 2024
25f318e
Debug attempt
bwohlberg Jul 23, 2024
9218e4d
Debug attempt
bwohlberg Jul 23, 2024
e73ae7d
Debug attempt
bwohlberg Jul 23, 2024
c9714e4
Debug attempt
bwohlberg Jul 23, 2024
e24ccdd
Debug attempt
bwohlberg Jul 23, 2024
9bbad64
Debug attempt
bwohlberg Jul 23, 2024
325fb9b
New solution attempt
bwohlberg Jul 23, 2024
d521aa3
Debug attempt
bwohlberg Jul 23, 2024
8eef347
Debug attempt
bwohlberg Jul 23, 2024
fa08d8c
Debug attempt
bwohlberg Jul 23, 2024
47b8067
Debug attempt
bwohlberg Jul 23, 2024
931c763
Debug attempt
bwohlberg Jul 23, 2024
a9cafff
Debug attempt
bwohlberg Jul 23, 2024
5f7001e
Debug attempt
bwohlberg Jul 23, 2024
644c189
Debug attempt
bwohlberg Jul 23, 2024
89b4772
Debug attempt
bwohlberg Jul 23, 2024
fdb8520
Debug attempt
bwohlberg Jul 23, 2024
978759e
Return to earlier approach
bwohlberg Jul 23, 2024
fc2315a
Extend comment
bwohlberg Jul 23, 2024
039a970
Clean up and improve function logic
bwohlberg Jul 23, 2024
9dca046
Address some problems
bwohlberg Jul 23, 2024
1fcd82d
Clean up
bwohlberg Jul 23, 2024
6cdf217
Rename function for consistency with related functions
bwohlberg Jul 23, 2024
f2acaf2
Merge branch 'main' into brendt/issue535_extended
bwohlberg Jul 23, 2024
32e7b01
Merge branch 'brendt/issue535_extended' into brendt/issue535_extended…
bwohlberg Jul 23, 2024
aa1467a
Bug fix
bwohlberg Jul 23, 2024
cc678fd
Clean up
bwohlberg Jul 23, 2024
d650f3a
Bug fix
bwohlberg Jul 23, 2024
0fe46a4
Address pylint complaint
bwohlberg Jul 23, 2024
41fa25e
Revert unworkable structure
bwohlberg Jul 24, 2024
2c96637
Merge branch 'cristina/issue535' into brendt/issue535_extended
bwohlberg Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Version 0.0.6 (unreleased)
``scico.flax.save_variables`` and ``scico.flax.load_variables``
respectively.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.30.
• Support ``flax`` versions between 0.8.0 and 0.8.3 (inclusive).
• Support ``flax`` versions 0.8.0 to 0.8.3.



Expand Down
6 changes: 6 additions & 0 deletions examples/scripts/ct_astra_datagen_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
generated using filtered back projection (FBP).
"""

# isort: off
import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

from scico import plot
from scico.flax.examples import load_ct_data

Expand Down
6 changes: 6 additions & 0 deletions examples/scripts/ct_astra_modl_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,18 @@
reconstructed images.
"""

# isort: off
import os
from functools import partial
from time import time

import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand Down
15 changes: 10 additions & 5 deletions examples/scripts/ct_astra_odp_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,21 @@
term. The output of the final stage is the set of reconstructed images.
"""

# isort: off
import os
from functools import partial
from time import time

import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand All @@ -60,11 +69,7 @@
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform2D

"""
Prepare parallel processing. Set an arbitrary processor count (only
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)

Expand Down
21 changes: 14 additions & 7 deletions examples/scripts/ct_astra_unet_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,29 @@
by :cite:`jin-2017-unet`.
"""

# isort: off
import os
from time import time

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax

import numpy as np

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_ct_data

"""
Prepare parallel processing. Set an arbitrary processor count (only
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)

Expand Down Expand Up @@ -190,7 +197,7 @@
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
jax.numpy.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
x=hist.Epoch,
ptyp="semilogy",
title="Loss function",
Expand All @@ -201,7 +208,7 @@
ax=ax[0],
)
plot.plot(
jax.numpy.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
x=hist.Epoch,
title="Metric",
xlbl="Epoch",
Expand Down
11 changes: 9 additions & 2 deletions examples/scripts/deconv_datagen_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,17 @@
training neural network models for deconvolution (deblurring). Foam
phantoms from xdesign are used to generate the clean images.
"""

# isort: off
import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

from scico import plot
from scico.flax.examples import load_foam1_blur_data
from scico.flax.examples import load_blur_data

"""
Read data from cache or generate if not available.
Expand All @@ -29,7 +36,7 @@
nimg = train_nimg + test_nimg
output_size = 256 # image size

train_ds, test_ds = load_foam1_blur_data(
train_ds, test_ds = load_blur_data(
train_nimg,
test_nimg,
output_size,
Expand Down
19 changes: 12 additions & 7 deletions examples/scripts/deconv_modl_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,32 @@
images.
"""

# isort: off
import os
from functools import partial
from time import time

import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_foam1_blur_data
from scico.flax.examples import load_blur_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop import CircularConvolve

"""
Prepare parallel processing. Set an arbitrary processor count (only
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)

Expand All @@ -87,7 +92,7 @@
test_nimg = 64 # number of testing images
nimg = train_nimg + test_nimg

train_ds, test_ds = load_foam1_blur_data(
train_ds, test_ds = load_blur_data(
train_nimg,
test_nimg,
output_size,
Expand Down
19 changes: 12 additions & 7 deletions examples/scripts/deconv_odp_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,32 @@
set of deblurred images.
"""

# isort: off
import os
from functools import partial
from time import time

import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_foam1_blur_data
from scico.flax.examples import load_blur_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop import CircularConvolve

"""
Prepare parallel processing. Set an arbitrary processor count (only
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)

Expand All @@ -95,7 +100,7 @@
test_nimg = 64 # number of testing images
nimg = train_nimg + test_nimg

train_ds, test_ds = load_foam1_blur_data(
train_ds, test_ds = load_blur_data(
train_nimg,
test_nimg,
output_size,
Expand Down
10 changes: 5 additions & 5 deletions examples/scripts/denoise_dncnn_train_bsds.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
with additive Gaussian noise.
"""

# isort: off
import os
from time import time

import numpy as np

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand All @@ -26,11 +30,7 @@
from scico import metric, plot
from scico.flax.examples import load_image_data

"""
Prepare parallel processing. Set an arbitrary processor count (only
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)

Expand Down
4 changes: 2 additions & 2 deletions scico/flax/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
"""Data utility functions used by Flax example scripts."""

from .data_preprocessing import PaddedCircularConvolve, build_blur_kernel
from .examples import load_ct_data, load_foam1_blur_data, load_image_data
from .examples import load_blur_data, load_ct_data, load_image_data

__all__ = [
"load_ct_data",
"load_foam1_blur_data",
"load_blur_data",
"load_image_data",
"PaddedCircularConvolve",
"build_blur_kernel",
Expand Down
Loading
Loading