From a32484172bb14be017d84d3c8089ed20562136c2 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jul 2024 19:02:05 -0600 Subject: [PATCH 01/48] Add type annotation --- scico/linop/xray/_xray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index adcfa23a..645f5ca4 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -106,7 +106,7 @@ def __init__( Pdx = np.stack((dx[0] * jnp.cos(angles), dx[1] * jnp.sin(angles))) Pdiag1 = np.abs(Pdx[0] + Pdx[1]) Pdiag2 = np.abs(Pdx[0] - Pdx[1]) - max_width = np.max(np.maximum(Pdiag1, Pdiag2)) + max_width: float = np.max(np.maximum(Pdiag1, Pdiag2)) if max_width > 1: warn( From 1ebf4e0fd95b0ba74b786c69770cf44c21f6b1f8 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jul 2024 19:04:42 -0600 Subject: [PATCH 02/48] Remove jax distributed data generation option --- scico/flax/examples/data_generation.py | 131 ++++++++----------------- scico/flax/examples/examples.py | 8 -- 2 files changed, 39 insertions(+), 100 deletions(-) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index a339435e..2dea9dbb 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -47,8 +47,6 @@ have_astra = False else: have_astra = True - -if have_astra: from scico.linop.xray.astra import XRayTransform2D @@ -98,11 +96,11 @@ def __init__( ) -def generate_foam2_images(seed: float, size: int, ndata: int) -> Array: - """Generate batch of foam2 structures. +def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: + """Generate batch of xdesign foam-like structures. - Generate batch of images with :class:`Foam2` structure - (foam-like material with two different attenuations). + Generate batch of images with `xdesign` foam-like structure, which + uses one attenuation. Args: seed: Seed for data generation. @@ -115,22 +113,20 @@ def generate_foam2_images(seed: float, size: int, ndata: int) -> Array: if not have_xdesign: raise RuntimeError("Package xdesign is required for use of this function.") - # np.random.seed(seed) - saux = jnp.zeros((ndata, size, size, 1)) + np.random.seed(seed) + saux = np.zeros((ndata, size, size, 1)) for i in range(ndata): - foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) - saux = saux.at[i, ..., 0].set(discrete_phantom(foam, size=size)) - # normalize - saux = saux / jnp.max(saux, axis=(1, 2), keepdims=True) + foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) + saux[i, ..., 0] = discrete_phantom(foam, size=size) return saux -def generate_foam1_images(seed: float, size: int, ndata: int) -> Array: - """Generate batch of xdesign foam-like structures. +def generate_foam2_images(seed: float, size: int, ndata: int) -> np.ndarray: + """Generate batch of foam2 structures. - Generate batch of images with `xdesign` foam-like structure, which - uses one attenuation. + Generate batch of images with :class:`Foam2` structure + (foam-like material with two different attenuations). Args: seed: Seed for data generation. @@ -143,11 +139,13 @@ def generate_foam1_images(seed: float, size: int, ndata: int) -> Array: if not have_xdesign: raise RuntimeError("Package xdesign is required for use of this function.") - # np.random.seed(seed) - saux = jnp.zeros((ndata, size, size, 1)) + np.random.seed(seed) + saux = np.zeros((ndata, size, size, 1)) for i in range(ndata): - foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) - saux = saux.at[i, ..., 0].set(discrete_phantom(foam, size=size)) + foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) + saux[i, ..., 0] = discrete_phantom(foam, size=size) + # normalize + saux /= np.max(saux, axis=(1, 2), keepdims=True) return saux @@ -191,8 +189,7 @@ def generate_ct_data( imgfunc: Callable = generate_foam2_images, seed: int = 1234, verbose: bool = False, - prefer_ray: bool = True, -) -> Tuple[Array, ...]: +) -> Tuple[Array, Array, Array]: """Generate batch of computed tomography (CT) data. Generate batch of CT data for training of machine learning network @@ -205,9 +202,6 @@ def generate_ct_data( imgfunc: Function for generating input images (e.g. foams). seed: Seed for data generation. verbose: Flag indicating whether to print status messages. - Default: ``False``. - prefer_ray: Use ray for distributed processing if available. - Default: ``True``. Returns: tuple: A tuple (img, sino, fbp) containing: @@ -220,14 +214,9 @@ def generate_ct_data( raise RuntimeError("Package astra is required for use of this function.") # Generate input data. - if have_ray and prefer_ray: - start_time = time() - img = ray_distributed_data_generation(imgfunc, size, nimg, seed) - time_dtgen = time() - start_time - else: - start_time = time() - img = distributed_data_generation(imgfunc, size, nimg, False) - time_dtgen = time() - start_time + start_time = time() + img = distributed_data_generation(imgfunc, size, nimg, seed) + time_dtgen = time() - start_time # Clip to [0,1] range. img = jnp.clip(img, 0, 1) @@ -236,7 +225,7 @@ def generate_ct_data( # Configure a CT projection operator to generate synthetic measurements. angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles gt_sh = (size, size) - detector_spacing = 1 + detector_spacing = 1.0 A = XRayTransform2D(gt_sh, size, detector_spacing, angles) # X-ray transform operator # Compute sinograms in parallel. @@ -284,8 +273,7 @@ def generate_blur_data( imgfunc: Callable, seed: int = 4321, verbose: bool = False, - prefer_ray: bool = True, -) -> Tuple[Array, ...]: +) -> Tuple[Array, Array]: """Generate batch of blurred data. Generate batch of blurred data for training of machine learning @@ -299,9 +287,6 @@ def generate_blur_data( imgfunc: Function to generate foams. seed: Seed for data generation. verbose: Flag indicating whether to print status messages. - Default: ``False``. - prefer_ray: Use ray for distributed processing if available. - Default: ``True``. Returns: tuple: A tuple (img, blurn) containing: @@ -309,14 +294,9 @@ def generate_blur_data( - **img** : Generated foam images. - **blurn** : Corresponding blurred and noisy images. """ - if have_ray and prefer_ray: - start_time = time() - img = ray_distributed_data_generation(imgfunc, size, nimg, seed) - time_dtgen = time() - start_time - else: - start_time = time() - img = distributed_data_generation(imgfunc, size, nimg, False) - time_dtgen = time() - start_time + start_time = time() + img = distributed_data_generation(imgfunc, size, nimg, seed) + time_dtgen = time() - start_time # Clip to [0,1] range. img = jnp.clip(img, 0, 1) @@ -356,44 +336,12 @@ def generate_blur_data( def distributed_data_generation( - imgenf: Callable, size: int, nimg: int, sharded: bool = True -) -> Array: - """Data generation distributed among processes using jax. - - Args: - imagenf: Function for batch-data generation. - size: Size of image to generate. - ndata: Number of images to generate. - sharded: Flag to indicate if data is to be returned as the - chunks generated by each process or consolidated. - Default: ``True``. - - Returns: - Array of generated data. - """ - nproc = jax.device_count() - seeds = jnp.arange(nproc) - if nproc > 1 and nimg % nproc > 0: - raise ValueError("Number of images to generate must be divisible by the number of devices") - - ndata_per_proc = int(nimg // nproc) - - idx = np.arange(nproc) - imgs = jax.vmap(imgenf, (0, None, None))(idx, size, ndata_per_proc) - - # imgs = jax.pmap(imgenf, static_broadcasted_argnums=(1, 2))(seeds, size, ndata_per_proc) - - if not sharded: - imgs = imgs.reshape((-1, size, size, 1)) - - return imgs - - -def ray_distributed_data_generation( imgenf: Callable, size: int, nimg: int, seedg: float = 123 -) -> Array: +) -> np.ndarray: """Data generation distributed among processes using ray. + *Warning:* + Args: imagenf: Function for batch-data generation. size: Size of image to generate. @@ -406,26 +354,25 @@ def ray_distributed_data_generation( if not have_ray: raise RuntimeError("Package ray is required for use of this function.") - @ray.remote - def data_gen(seed, size, ndata, imgf): - return imgf(seed, size, ndata) - - # Use half of available CPU resources. + # Use half of available CPU resources ar = ray.available_resources() if "CPU" not in ar: - warnings.warn("No CPU key in ray.available_resources() output") - nproc = max(int(ar.get("CPU", "1")) // 2, 1) - # nproc = max(int(ar["CPU"]) // 2, 1) + warnings.warn("No CPU key in ray.available_resources() output.") + nproc = max(int(ar.get("CPU", 1)) // 2, 1) if nproc > nimg: nproc = nimg if nproc > 1 and nimg % nproc > 0: raise ValueError( - f"Number of images to generate ({nimg}) " - f"must be divisible by the number of available devices ({nproc})" + f"Number of images to generate ({nimg}) must be divisible by " + f"the number of available devices ({nproc})." ) ndata_per_proc = int(nimg // nproc) + @ray.remote + def data_gen(seed, size, ndata, imgf): + return imgf(seed, size, ndata) + ray_return = ray.get( [data_gen.remote(seed + seedg, size, ndata_per_proc, imgenf) for seed in range(nproc)] ) diff --git a/scico/flax/examples/examples.py b/scico/flax/examples/examples.py index 0bf428a3..6af0129d 100644 --- a/scico/flax/examples/examples.py +++ b/scico/flax/examples/examples.py @@ -49,7 +49,6 @@ def load_ct_data( nproj: int, cache_path: Optional[str] = None, verbose: bool = False, - prefer_ray: bool = True, ) -> Tuple[CTDataSetDict, ...]: # pragma: no cover """ Load or generate CT data. @@ -77,8 +76,6 @@ def load_ct_data( Default: ``None``. verbose: Flag indicating whether to print status messages. Default: ``False``. - prefer_ray: Use ray for distributed processing if available. - Default: ``True``. Returns: tuple: A tuple (trdt, ttdt) containing: @@ -146,7 +143,6 @@ def load_ct_data( size, nproj, verbose=verbose, - prefer_ray=prefer_ray, ) # Separate training and testing partitions. trdt = {"img": img[:train_nimg], "sino": sino[:train_nimg], "fbp": fbp[:train_nimg]} @@ -186,7 +182,6 @@ def load_foam1_blur_data( noise_sigma: float, cache_path: Optional[str] = None, verbose: bool = False, - prefer_ray: bool = True, ) -> Tuple[DataSetDict, ...]: # pragma: no cover """Load or generate blurred data based on xdesign foam structures. @@ -214,8 +209,6 @@ def load_foam1_blur_data( Default: ``None``. verbose: Flag indicating whether to print status messages. Default: ``False``. - prefer_ray: Use ray for distributed processing if available. - Default: ``True``. Returns: tuple: A tuple (train_ds, test_ds) containing: @@ -297,7 +290,6 @@ def load_foam1_blur_data( noise_sigma, imgfunc=generate_foam1_images, verbose=verbose, - prefer_ray=prefer_ray, ) # Separate training and testing partitions. train_ds = {"image": blrn[:train_nimg], "label": img[:train_nimg]} From a75ab622016c9c1cf3a247783fa43ce919f32d5a Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jul 2024 19:05:04 -0600 Subject: [PATCH 03/48] Remove jax distributed data generation option --- scico/test/flax/test_examples_flax.py | 84 +++------------------------ 1 file changed, 8 insertions(+), 76 deletions(-) diff --git a/scico/test/flax/test_examples_flax.py b/scico/test/flax/test_examples_flax.py index e393732c..fd766342 100644 --- a/scico/test/flax/test_examples_flax.py +++ b/scico/test/flax/test_examples_flax.py @@ -3,8 +3,6 @@ import numpy as np -import jax - import pytest from scico import random @@ -15,7 +13,6 @@ have_astra, have_ray, have_xdesign, - ray_distributed_data_generation, ) from scico.flax.examples.data_preprocessing import ( CenterCrop, @@ -67,43 +64,17 @@ def fake_data_gen(seed, N, ndata): return dt -def test_distdatagen(): - N = 16 - nimg = 8 - dt = distributed_data_generation(fake_data_gen, N, nimg) - assert dt.ndim == 5 - assert dt.shape[0] * dt.shape[1] == nimg - assert dt.shape[2:] == (N, N, 1) - - -def test_distdatagen_flag(): - N = 16 - nimg = 8 - dt = distributed_data_generation(fake_data_gen, N, nimg, False) - assert dt.ndim == 4 - assert dt.shape == (nimg, N, N, 1) - - -@pytest.mark.skipif( - jax.device_count() == 1, reason="no processes for checking failure of distributed computing" -) -def test_distdatagen_exception(): - N = 16 - nimg = 15 - with pytest.raises(ValueError): - distributed_data_generation(fake_data_gen, N, nimg) - - @pytest.mark.skipif(not have_ray, reason="ray package not installed") -def test_ray_distdatagen(): +def test_distdatagen(): N = 16 nimg = 8 def random_data_gen(seed, N, ndata): - dt, key = random.randn((ndata, N, N, 1), seed=seed) + np.random.seed(seed) + dt = np.random.randn(ndata, N, N, 1) return dt - dt = ray_distributed_data_generation(random_data_gen, N, nimg) + dt = distributed_data_generation(random_data_gen, N, nimg) assert dt.ndim == 4 assert dt.shape == (nimg, N, N, 1) @@ -115,10 +86,9 @@ def test_ct_data_generation(): nproj = 45 def random_img_gen(seed, size, ndata): - key = jax.random.PRNGKey(seed) - key, subkey = jax.random.split(key) + np.random.seed(seed) shape = (ndata, size, size, 1) - return jax.random.normal(subkey, shape) + return np.random.randn(*shape) img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen) assert img.shape == (nimg, N, N, 1) @@ -126,24 +96,6 @@ def random_img_gen(seed, size, ndata): assert fbp.shape == (nimg, N, N, 1) -@pytest.mark.skipif(not have_astra, reason="astra package not installed") -def test_ct_data_generation_jax(): - N = 32 - nimg = 8 - nproj = 45 - - def random_img_gen(seed, size, ndata): - key = jax.random.PRNGKey(seed) - key, subkey = jax.random.split(key) - shape = (ndata, size, size, 1) - return jax.random.normal(subkey, shape) - - img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen, prefer_ray=False) - assert img.shape == (nimg, N, N, 1) - assert sino.shape == (nimg, nproj, N, 1) - assert fbp.shape == (nimg, N, N, 1) - - def test_blur_data_generation(): N = 32 nimg = 8 @@ -151,35 +103,15 @@ def test_blur_data_generation(): blur_kernel = np.ones((n, n)) / (n * n) def random_img_gen(seed, size, ndata): - key = jax.random.PRNGKey(seed) - key, subkey = jax.random.split(key) + np.random.seed(seed) shape = (ndata, size, size, 1) - return jax.random.normal(subkey, shape) + return np.random.randn(*shape) img, blurn = generate_blur_data(nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen) assert img.shape == (nimg, N, N, 1) assert blurn.shape == (nimg, N, N, 1) -def test_blur_data_generation_jax(): - N = 32 - nimg = 8 - n = 3 # convolution kernel size - blur_kernel = np.ones((n, n)) / (n * n) - - def random_img_gen(seed, size, ndata): - key = jax.random.PRNGKey(seed) - key, subkey = jax.random.split(key) - shape = (ndata, size, size, 1) - return jax.random.normal(subkey, shape) - - img, blurn = generate_blur_data( - nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen, prefer_ray=False - ) - assert img.shape == (nimg, N, N, 1) - assert blurn.shape == (nimg, N, N, 1) - - def test_rotation90(): N = 128 x, key = random.randn((N, N), seed=4321) From 997f52a66fd6d13e31fa3c60d2cb9994750b2ee2 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jul 2024 19:22:02 -0600 Subject: [PATCH 04/48] Clean up --- scico/flax/examples/data_generation.py | 2 +- scico/test/flax/test_examples_flax.py | 21 ++++++++------------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index 2dea9dbb..106e62ef 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -86,7 +86,7 @@ def __init__( attn2: Mass attenuation parameter for material 2. Default: 10. """ - super(Foam2, self).__init__(radius=0.5, material=SimpleMaterial(attn1)) + super().__init__(radius=0.5, material=SimpleMaterial(attn1)) if porosity < 0 or porosity > 1: raise ValueError("Porosity must be in the range [0,1).") self.sprinkle( diff --git a/scico/test/flax/test_examples_flax.py b/scico/test/flax/test_examples_flax.py index fd766342..c259526a 100644 --- a/scico/test/flax/test_examples_flax.py +++ b/scico/test/flax/test_examples_flax.py @@ -38,32 +38,27 @@ @pytest.mark.skipif(not have_xdesign, reason="xdesign package not installed") -def test_foam2_gen(): - seed = 4321 +def test_foam1_gen(): + seed = 4444 N = 32 ndata = 2 - from scico.flax.examples.data_generation import generate_foam2_images + from scico.flax.examples.data_generation import generate_foam1_images - dt = generate_foam2_images(seed, N, ndata) + dt = generate_foam1_images(seed, N, ndata) assert dt.shape == (ndata, N, N, 1) @pytest.mark.skipif(not have_xdesign, reason="xdesign package not installed") -def test_foam_gen(): - seed = 4444 +def test_foam2_gen(): + seed = 4321 N = 32 ndata = 2 - from scico.flax.examples.data_generation import generate_foam1_images + from scico.flax.examples.data_generation import generate_foam2_images - dt = generate_foam1_images(seed, N, ndata) + dt = generate_foam2_images(seed, N, ndata) assert dt.shape == (ndata, N, N, 1) -def fake_data_gen(seed, N, ndata): - dt, key = random.randn((ndata, N, N, 1), seed=seed) - return dt - - @pytest.mark.skipif(not have_ray, reason="ray package not installed") def test_distdatagen(): N = 16 From 5f4552f79e00ab320b483db4f14a39f95477381b Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jul 2024 19:42:38 -0600 Subject: [PATCH 05/48] Clean up --- scico/flax/examples/data_generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index 106e62ef..72652ff4 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -217,7 +217,7 @@ def generate_ct_data( start_time = time() img = distributed_data_generation(imgfunc, size, nimg, seed) time_dtgen = time() - start_time - # Clip to [0,1] range. + # clip to [0,1] range img = jnp.clip(img, 0, 1) nproc = jax.device_count() @@ -231,7 +231,7 @@ def generate_ct_data( # Compute sinograms in parallel. start_time = time() if nproc > 1: - # Shard array + # shard array imgshd = img.reshape((nproc, -1, size, size, 1)) sinoshd = batched_f(A, imgshd) sino = sinoshd.reshape((-1, nproj, size, 1)) From a0b72ae90b1a445853dd412bd81160c11432bc30 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jul 2024 19:47:14 -0600 Subject: [PATCH 06/48] Extend docs --- scico/flax/examples/data_generation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index 72652ff4..b06a287a 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -340,7 +340,11 @@ def distributed_data_generation( ) -> np.ndarray: """Data generation distributed among processes using ray. - *Warning:* + *Warning:* callable `imgenf` should not make use of any jax functions + to avoid the risk of errors when running with GPU devices, in which + case jax is initialized to expect the availability of GPUs, which are + then not available within the `ray.remote` function due to the absence + of any declared GPUs as a `num_gpus` parameter of `@ray.remote`. Args: imagenf: Function for batch-data generation. From f77e212f804565926ce0778cf9ddee3d22d85c22 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 18 Jul 2024 10:14:20 -0600 Subject: [PATCH 07/48] Add additional test for exception state --- scico/flax/examples/data_generation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index b06a287a..a1f93cbf 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -357,9 +357,12 @@ def distributed_data_generation( """ if not have_ray: raise RuntimeError("Package ray is required for use of this function.") + if not ray.is_initialized(): + raise RuntimeError("Ray must be initialized via ray.init() before using this function.") # Use half of available CPU resources ar = ray.available_resources() + print(ar) if "CPU" not in ar: warnings.warn("No CPU key in ray.available_resources() output.") nproc = max(int(ar.get("CPU", 1)) // 2, 1) From 5ab1d0557e7729c7e9a770cedbeae60f9fa6b8b4 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 18 Jul 2024 11:40:27 -0600 Subject: [PATCH 08/48] Tracer conversion error fix from Cristina --- scico/flax/examples/data_generation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index a1f93cbf..420e363d 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -178,7 +178,11 @@ def batched_f(f_: Callable, vr: Array) -> Array: evaluation preserves the batch axis. """ nproc = jax.device_count() - res = jax.pmap(lambda i: vector_f(f_, vr[i]))(jnp.arange(nproc)) + if vr.shape[0] != nproc: + vrr = vr.reshape((nproc, -1, *vr.shape[:1])) + else: + vrr = vr + res = jax.pmap(partial(vector_f, f_))(vrr) return res From ecbfca5b7699b7c6ef267c2eef82ee65edc43167 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 18 Jul 2024 11:42:22 -0600 Subject: [PATCH 09/48] Omitted import --- scico/flax/examples/data_generation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index 420e363d..74c93000 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -13,6 +13,7 @@ import os import warnings +from functools import partial from time import time from typing import Callable, List, Tuple, Union From 28c828c14e4bdec8f6d0d8b94b6494ed4e09582f Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 18 Jul 2024 11:44:58 -0600 Subject: [PATCH 10/48] Clean up --- scico/flax/examples/data_generation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index 74c93000..c1fd9417 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -367,7 +367,6 @@ def distributed_data_generation( # Use half of available CPU resources ar = ray.available_resources() - print(ar) if "CPU" not in ar: warnings.warn("No CPU key in ray.available_resources() output.") nproc = max(int(ar.get("CPU", 1)) // 2, 1) From dcd358ccd3b922199871e87d5567226b7e1f8a52 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 18 Jul 2024 12:39:21 -0600 Subject: [PATCH 11/48] Consistent phrasing --- CHANGES.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index d7793dd3..96b4185e 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -15,7 +15,7 @@ Version 0.0.6 (unreleased) ``scico.flax.save_variables`` and ``scico.flax.load_variables`` respectively. • Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.30. -• Support ``flax`` versions between 0.8.0 and 0.8.3 (inclusive). +• Support ``flax`` versions 0.8.0 to 0.8.3. From 8f286d0eb5d8960c02d68b41143ce6ec31ca531e Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 22 Jul 2024 12:17:32 -0600 Subject: [PATCH 12/48] Clean up some f-strings --- scico/flax/examples/examples.py | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/scico/flax/examples/examples.py b/scico/flax/examples/examples.py index 6af0129d..3b29e83f 100644 --- a/scico/flax/examples/examples.py +++ b/scico/flax/examples/examples.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -580,7 +580,7 @@ def print_input_path(path_display: str): # pragma: no cover Args: path_display: Path for loading data. """ - print(f"{'Data read from path':26s}{':':4s}{path_display}") + print(f"Data read from path: {path_display}") def print_output_path(path_display: str): # pragma: no cover @@ -589,7 +589,7 @@ def print_output_path(path_display: str): # pragma: no cover Args: path_display: Path for storing data. """ - print(f"{'Storing data in path':26s}{':':4s}{path_display}") + print(f"Storing data in path: {path_display}") def print_data_range(idstring: str, data: Array): # pragma: no cover @@ -599,11 +599,7 @@ def print_data_range(idstring: str, data: Array): # pragma: no cover idstring: Data descriptive string. data: Array to compute min and max. """ - print( - f"{'Data range --':10s}{idstring}{'--':5s}{':':5s}" - f"{'Min:':6s}{data.min():>5.2f}" - f"{', Max:':6s}{data.max():>5.2f}" - ) + print(f"Data range --{idstring}-- Min: {data.min():>5.2f} " f"Max: {data.max():>5.2f}") def print_data_size(idstring: str, size: int): # pragma: no cover @@ -613,7 +609,7 @@ def print_data_size(idstring: str, size: int): # pragma: no cover idstring: Data descriptive string. size: Integer representing size of a set. """ - print(f"{'Set --':3s}{idstring}{'--':12s}{':':4s}{'Size:':8s}{size}") + print(f"Set --{idstring}-- size: {size}") def print_info( @@ -648,9 +644,8 @@ def print_data_warning(idstring: str, requested: int, available: int): # pragma available: Size of data set available. """ print( - f"{'Not enough images sampled in ':10s}{idstring}" - f"{' file':6s}{'Requested :':14s}{requested}" - f"{' Available :':14s}{available}" + f"Not enough images sampled in {idstring} file. " + f"Requested: {requested} Available: {available}" ) @@ -669,10 +664,9 @@ def runtime_error_scalar( available: Parameter value available in data. """ raise RuntimeError( - f"{'Requested parameter --':15s}{type}{'-- :':7s}{requested}" - f"{' does not match parameter read from '}" - f"{idstring}{' file :':10s}{available}." - f"\nDelete cache and check data source." + f"Requested value of parameter --{type}-- does not match value " + f"read from {idstring} file. Requested: {requested} Available: " + f"{available}.\nDelete cache and check data source." ) @@ -689,8 +683,7 @@ def runtime_error_array(type: str, idstring: str, maxdiff: float): entries. """ raise RuntimeError( - f"{'Requested parameter --':15s}{type}{'--'}" - f"{' does not match parameter read from '}" - f"{idstring}{' file'}. Maximum array difference: {maxdiff:>5.3f}." - f"\nDelete cache and check data source." + f"Requested value of parameter --{type}-- does not match value " + f"read from {idstring} file. Maximum array difference: " + f"{maxdiff:>5.3f}.\nDelete cache and check data source." ) From 5e85f8cbf0002bb6bc4d93e9b54ded8edb667049 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 22 Jul 2024 12:18:14 -0600 Subject: [PATCH 13/48] Add missing ray init --- examples/scripts/ct_astra_datagen_foam2.py | 6 ++++++ examples/scripts/ct_astra_modl_train_foam2.py | 6 ++++++ examples/scripts/ct_astra_odp_train_foam2.py | 6 ++++++ examples/scripts/ct_astra_unet_train_foam2.py | 6 ++++++ examples/scripts/deconv_datagen_foam1.py | 7 +++++++ examples/scripts/deconv_modl_train_foam1.py | 6 ++++++ examples/scripts/deconv_odp_train_foam1.py | 6 ++++++ 7 files changed, 43 insertions(+) diff --git a/examples/scripts/ct_astra_datagen_foam2.py b/examples/scripts/ct_astra_datagen_foam2.py index 76c19da5..4e6fb97c 100644 --- a/examples/scripts/ct_astra_datagen_foam2.py +++ b/examples/scripts/ct_astra_datagen_foam2.py @@ -13,8 +13,14 @@ generated using filtered back projection (FBP). """ +# isort: off import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + from scico import plot from scico.flax.examples import load_ct_data diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_astra_modl_train_foam2.py index dc25ebb3..be218d5e 100644 --- a/examples/scripts/ct_astra_modl_train_foam2.py +++ b/examples/scripts/ct_astra_modl_train_foam2.py @@ -40,12 +40,18 @@ reconstructed images. """ +# isort: off import os from functools import partial from time import time import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable diff --git a/examples/scripts/ct_astra_odp_train_foam2.py b/examples/scripts/ct_astra_odp_train_foam2.py index 8c5d9ad6..42fc09cb 100644 --- a/examples/scripts/ct_astra_odp_train_foam2.py +++ b/examples/scripts/ct_astra_odp_train_foam2.py @@ -44,12 +44,18 @@ term. The output of the final stage is the set of reconstructed images. """ +# isort: off import os from functools import partial from time import time import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable diff --git a/examples/scripts/ct_astra_unet_train_foam2.py b/examples/scripts/ct_astra_unet_train_foam2.py index 72e82e81..ab651215 100644 --- a/examples/scripts/ct_astra_unet_train_foam2.py +++ b/examples/scripts/ct_astra_unet_train_foam2.py @@ -13,9 +13,15 @@ by :cite:`jin-2017-unet`. """ +# isort: off import os from time import time +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable diff --git a/examples/scripts/deconv_datagen_foam1.py b/examples/scripts/deconv_datagen_foam1.py index 99914e53..e0e6e492 100644 --- a/examples/scripts/deconv_datagen_foam1.py +++ b/examples/scripts/deconv_datagen_foam1.py @@ -12,8 +12,15 @@ training neural network models for deconvolution (deblurring). Foam phantoms from xdesign are used to generate the clean images. """ + +# isort: off import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + from scico import plot from scico.flax.examples import load_foam1_blur_data diff --git a/examples/scripts/deconv_modl_train_foam1.py b/examples/scripts/deconv_modl_train_foam1.py index 2916d7a1..dd310c41 100644 --- a/examples/scripts/deconv_modl_train_foam1.py +++ b/examples/scripts/deconv_modl_train_foam1.py @@ -41,12 +41,18 @@ images. """ +# isort: off import os from functools import partial from time import time import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable diff --git a/examples/scripts/deconv_odp_train_foam1.py b/examples/scripts/deconv_odp_train_foam1.py index ffe852e7..5cb19cbe 100644 --- a/examples/scripts/deconv_odp_train_foam1.py +++ b/examples/scripts/deconv_odp_train_foam1.py @@ -49,12 +49,18 @@ set of deblurred images. """ +# isort: off import os from functools import partial from time import time import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable From d69fdd28c4ba096fdd16fcbb73a049d14ce84830 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 22 Jul 2024 14:16:07 -0600 Subject: [PATCH 14/48] Set dtype --- scico/flax/examples/data_generation.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index c1fd9417..cba99b33 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -100,22 +100,22 @@ def __init__( def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: """Generate batch of xdesign foam-like structures. - Generate batch of images with `xdesign` foam-like structure, which - uses one attenuation. + Generate batch of images with `xdesign` foam-like structure, which + uses one attenuation. Args: - seed: Seed for data generation. - size: Size of image to generate. - ndata: Number of images to generate. + seed: Seed for data generation. + size: Size of image to generate. + ndata: Number of images to generate. - Returns: - Array of generated data. + Returns: + Array of generated data. """ if not have_xdesign: raise RuntimeError("Package xdesign is required for use of this function.") np.random.seed(seed) - saux = np.zeros((ndata, size, size, 1)) + saux = np.zeros((ndata, size, size, 1), dtype=np.float32) for i in range(ndata): foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) saux[i, ..., 0] = discrete_phantom(foam, size=size) @@ -141,7 +141,7 @@ def generate_foam2_images(seed: float, size: int, ndata: int) -> np.ndarray: raise RuntimeError("Package xdesign is required for use of this function.") np.random.seed(seed) - saux = np.zeros((ndata, size, size, 1)) + saux = np.zeros((ndata, size, size, 1), dtype=np.float32) for i in range(ndata): foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) saux[i, ..., 0] = discrete_phantom(foam, size=size) From 0d97b3f019b68d58d7325636a0307f3924eeb267 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 22 Jul 2024 17:14:44 -0600 Subject: [PATCH 15/48] Fix indentation error --- scico/flax/examples/data_generation.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index cba99b33..e1703ae4 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -100,16 +100,16 @@ def __init__( def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: """Generate batch of xdesign foam-like structures. - Generate batch of images with `xdesign` foam-like structure, which - uses one attenuation. + Generate batch of images with `xdesign` foam-like structure, which + uses one attenuation. Args: - seed: Seed for data generation. - size: Size of image to generate. - ndata: Number of images to generate. + seed: Seed for data generation. + size: Size of image to generate. + ndata: Number of images to generate. - Returns: - Array of generated data. + Returns: + Array of generated data. """ if not have_xdesign: raise RuntimeError("Package xdesign is required for use of this function.") From eec3242b43d9a49641ee658133453f71a7232340 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 22 Jul 2024 18:30:28 -0600 Subject: [PATCH 16/48] Update module docstring --- scico/flax/examples/data_generation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index e1703ae4..ef2f0a4f 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -7,8 +7,7 @@ """Functionality to generate training data for Flax example scripts. -Computation is distributed via ray (if available) or JAX or to reduce -processing time. +Computation is distributed via ray to reduce processing time. """ import os From a7fa89f771224a57698a6d5f5cd24b7298654c2d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 22 Jul 2024 18:41:21 -0600 Subject: [PATCH 17/48] Experimental solution to ray/jax failure --- scico/flax/examples/data_generation.py | 179 +++---------------------- scico/flax/examples/ray_functions.py | 164 ++++++++++++++++++++++ 2 files changed, 179 insertions(+), 164 deletions(-) create mode 100644 scico/flax/examples/ray_functions.py diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index ef2f0a4f..502e8008 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -11,29 +11,24 @@ """ import os -import warnings from functools import partial from time import time -from typing import Callable, List, Tuple, Union +from typing import Callable, Tuple import numpy as np -try: - import ray # noqa: F401 -except ImportError: - have_ray = False -else: - have_ray = True - try: import xdesign # noqa: F401 + + import ray # noqa: F401 except ImportError: - have_xdesign = False + have_ray_and_xdesign = False else: - have_xdesign = True - -if have_xdesign: - from xdesign import Foam, SimpleMaterial, UnitCircle, discrete_phantom + have_ray_and_xdesign = True + from .ray_functions import ( + generate_foam2_images, + distributed_data_generation, + ) import jax import jax.numpy as jnp @@ -54,102 +49,6 @@ os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" -if have_xdesign: - - class Foam2(UnitCircle): - """Foam-like material with two attenuations. - - Define functionality to generate phantom with structure similar - to foam with two different attenuation properties.""" - - def __init__( - self, - size_range: Union[float, List[float]] = [0.05, 0.01], - gap: float = 0, - porosity: float = 1, - attn1: float = 1.0, - attn2: float = 10.0, - ): - """Foam-like structure with two different attenuations. - Circles for material 1 are more sparse than for material 2 - by design. - - Args: - size_range: The radius, or range of radius, of the - circles to be added. Default: [0.05, 0.01]. - gap: Minimum distance between circle boundaries. - Default: 0. - porosity: Target porosity. Must be a value between - [0, 1]. Default: 1. - attn1: Mass attenuation parameter for material 1. - Default: 1. - attn2: Mass attenuation parameter for material 2. - Default: 10. - """ - super().__init__(radius=0.5, material=SimpleMaterial(attn1)) - if porosity < 0 or porosity > 1: - raise ValueError("Porosity must be in the range [0,1).") - self.sprinkle( - 300, size_range, gap, material=SimpleMaterial(attn2), max_density=porosity / 2.0 - ) + self.sprinkle( - 300, size_range, gap, material=SimpleMaterial(20), max_density=porosity - ) - - -def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: - """Generate batch of xdesign foam-like structures. - - Generate batch of images with `xdesign` foam-like structure, which - uses one attenuation. - - Args: - seed: Seed for data generation. - size: Size of image to generate. - ndata: Number of images to generate. - - Returns: - Array of generated data. - """ - if not have_xdesign: - raise RuntimeError("Package xdesign is required for use of this function.") - - np.random.seed(seed) - saux = np.zeros((ndata, size, size, 1), dtype=np.float32) - for i in range(ndata): - foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) - saux[i, ..., 0] = discrete_phantom(foam, size=size) - - return saux - - -def generate_foam2_images(seed: float, size: int, ndata: int) -> np.ndarray: - """Generate batch of foam2 structures. - - Generate batch of images with :class:`Foam2` structure - (foam-like material with two different attenuations). - - Args: - seed: Seed for data generation. - size: Size of image to generate. - ndata: Number of images to generate. - - Returns: - Array of generated data. - """ - if not have_xdesign: - raise RuntimeError("Package xdesign is required for use of this function.") - - np.random.seed(seed) - saux = np.zeros((ndata, size, size, 1), dtype=np.float32) - for i in range(ndata): - foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) - saux[i, ..., 0] = discrete_phantom(foam, size=size) - # normalize - saux /= np.max(saux, axis=(1, 2), keepdims=True) - - return saux - - def vector_f(f_: Callable, v: Array) -> Array: """Vectorize application of operator. @@ -214,8 +113,10 @@ def generate_ct_data( - **sino** : (:class:`jax.Array`): Corresponding sinograms. - **fbp** : (:class:`jax.Array`) Corresponding filtered back projections. """ - if not have_astra: - raise RuntimeError("Package astra is required for use of this function.") + if not have_ray_and_xdesign and have_astra: + raise RuntimeError( + "Packages ray, xdesign, and astra are required for use of this function." + ) # Generate input data. start_time = time() @@ -298,6 +199,8 @@ def generate_blur_data( - **img** : Generated foam images. - **blurn** : Corresponding blurred and noisy images. """ + if not have_ray_and_xdesign: + raise RuntimeError("Packages ray and xdesign are required for use of this function.") start_time = time() img = distributed_data_generation(imgfunc, size, nimg, seed) time_dtgen = time() - start_time @@ -337,55 +240,3 @@ def generate_blur_data( print(f"{'Blur generation':19s}{'time[s]:':10s}{time_blur:>7.2f}") return img, blurn - - -def distributed_data_generation( - imgenf: Callable, size: int, nimg: int, seedg: float = 123 -) -> np.ndarray: - """Data generation distributed among processes using ray. - - *Warning:* callable `imgenf` should not make use of any jax functions - to avoid the risk of errors when running with GPU devices, in which - case jax is initialized to expect the availability of GPUs, which are - then not available within the `ray.remote` function due to the absence - of any declared GPUs as a `num_gpus` parameter of `@ray.remote`. - - Args: - imagenf: Function for batch-data generation. - size: Size of image to generate. - ndata: Number of images to generate. - seedg: Base seed for data generation. Default: 123. - - Returns: - Array of generated data. - """ - if not have_ray: - raise RuntimeError("Package ray is required for use of this function.") - if not ray.is_initialized(): - raise RuntimeError("Ray must be initialized via ray.init() before using this function.") - - # Use half of available CPU resources - ar = ray.available_resources() - if "CPU" not in ar: - warnings.warn("No CPU key in ray.available_resources() output.") - nproc = max(int(ar.get("CPU", 1)) // 2, 1) - if nproc > nimg: - nproc = nimg - if nproc > 1 and nimg % nproc > 0: - raise ValueError( - f"Number of images to generate ({nimg}) must be divisible by " - f"the number of available devices ({nproc})." - ) - - ndata_per_proc = int(nimg // nproc) - - @ray.remote - def data_gen(seed, size, ndata, imgf): - return imgf(seed, size, ndata) - - ray_return = ray.get( - [data_gen.remote(seed + seedg, size, ndata_per_proc, imgenf) for seed in range(nproc)] - ) - imgs = np.vstack([t for t in ray_return]) - - return imgs diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py new file mode 100644 index 00000000..70fec12e --- /dev/null +++ b/scico/flax/examples/ray_functions.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2024 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Functionality to generate training data for Flax example scripts. + +Computation is distributed via ray (if available) or JAX or to reduce +processing time. +""" + +import warnings +from typing import Callable, List, Union + +import numpy as np + +try: + import ray # noqa: F401 +except ImportError: + raise RuntimeError("Package ray is required for use of this module.") + +try: + import xdesign # noqa: F401 +except ImportError: + raise RuntimeError("Package xdesign is required for use of this module.") +from xdesign import Foam, SimpleMaterial, UnitCircle, discrete_phantom + + +class Foam2(UnitCircle): + """Foam-like material with two attenuations. + + Define functionality to generate phantom with structure similar + to foam with two different attenuation properties.""" + + def __init__( + self, + size_range: Union[float, List[float]] = [0.05, 0.01], + gap: float = 0, + porosity: float = 1, + attn1: float = 1.0, + attn2: float = 10.0, + ): + """Foam-like structure with two different attenuations. + Circles for material 1 are more sparse than for material 2 + by design. + + Args: + size_range: The radius, or range of radius, of the + circles to be added. Default: [0.05, 0.01]. + gap: Minimum distance between circle boundaries. + Default: 0. + porosity: Target porosity. Must be a value between + [0, 1]. Default: 1. + attn1: Mass attenuation parameter for material 1. + Default: 1. + attn2: Mass attenuation parameter for material 2. + Default: 10. + """ + super().__init__(radius=0.5, material=SimpleMaterial(attn1)) + if porosity < 0 or porosity > 1: + raise ValueError("Porosity must be in the range [0,1).") + self.sprinkle( + 300, size_range, gap, material=SimpleMaterial(attn2), max_density=porosity / 2.0 + ) + self.sprinkle(300, size_range, gap, material=SimpleMaterial(20), max_density=porosity) + + +def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: + """Generate batch of xdesign foam-like structures. + + Generate batch of images with `xdesign` foam-like structure, which + uses one attenuation. + + Args: + seed: Seed for data generation. + size: Size of image to generate. + ndata: Number of images to generate. + + Returns: + Array of generated data. + """ + np.random.seed(seed) + saux = np.zeros((ndata, size, size, 1), dtype=np.float32) + for i in range(ndata): + foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) + saux[i, ..., 0] = discrete_phantom(foam, size=size) + + return saux + + +def generate_foam2_images(seed: float, size: int, ndata: int) -> np.ndarray: + """Generate batch of foam2 structures. + + Generate batch of images with :class:`Foam2` structure + (foam-like material with two different attenuations). + + Args: + seed: Seed for data generation. + size: Size of image to generate. + ndata: Number of images to generate. + + Returns: + Array of generated data. + """ + np.random.seed(seed) + saux = np.zeros((ndata, size, size, 1), dtype=np.float32) + for i in range(ndata): + foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) + saux[i, ..., 0] = discrete_phantom(foam, size=size) + # normalize + saux /= np.max(saux, axis=(1, 2), keepdims=True) + + return saux + + +def distributed_data_generation( + imgenf: Callable, size: int, nimg: int, seedg: float = 123 +) -> np.ndarray: + """Data generation distributed among processes using ray. + + *Warning:* callable `imgenf` should not make use of any jax functions + to avoid the risk of errors when running with GPU devices, in which + case jax is initialized to expect the availability of GPUs, which are + then not available within the `ray.remote` function due to the absence + of any declared GPUs as a `num_gpus` parameter of `@ray.remote`. + + Args: + imagenf: Function for batch-data generation. + size: Size of image to generate. + ndata: Number of images to generate. + seedg: Base seed for data generation. Default: 123. + + Returns: + Array of generated data. + """ + if not ray.is_initialized(): + raise RuntimeError("Ray must be initialized via ray.init() before calling this function.") + + # Use half of available CPU resources + ar = ray.available_resources() + if "CPU" not in ar: + warnings.warn("No CPU key in ray.available_resources() output.") + nproc = max(int(ar.get("CPU", 1)) // 2, 1) + if nproc > nimg: + nproc = nimg + if nproc > 1 and nimg % nproc > 0: + raise ValueError( + f"Number of images to generate ({nimg}) must be divisible by " + f"the number of available devices ({nproc})." + ) + + ndata_per_proc = int(nimg // nproc) + + @ray.remote + def data_gen(seed, size, ndata, imgf): + return imgf(seed, size, ndata) + + ray_return = ray.get( + [data_gen.remote(seed + seedg, size, ndata_per_proc, imgenf) for seed in range(nproc)] + ) + imgs = np.vstack([t for t in ray_return]) + + return imgs From 85ded0fc9b21d3d0bc0c5a444508b8f12bde33b0 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 22 Jul 2024 18:48:26 -0600 Subject: [PATCH 18/48] Bug fix --- scico/flax/examples/data_generation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index 502e8008..b86245b4 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -26,6 +26,7 @@ else: have_ray_and_xdesign = True from .ray_functions import ( + generate_foam1_images, # noqa generate_foam2_images, distributed_data_generation, ) From e7461f0b5949ca7090070b41bceef7af2ec3f807 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 10:56:01 -0600 Subject: [PATCH 19/48] Improve docstring --- scico/flax/examples/ray_functions.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 70fec12e..94399666 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -5,12 +5,13 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""Functionality to generate training data for Flax example scripts. +"""Generate training data for Flax example scripts using ray. -Computation is distributed via ray (if available) or JAX or to reduce -processing time. +Functions for generating xdesign foam phantoms and generation in parallel +using ray. """ + import warnings from typing import Callable, List, Union From 5dac79fd681879490bb0cf96fa86c0de593d4da5 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 11:44:54 -0600 Subject: [PATCH 20/48] Implement hack to resolve jax/ray conflict --- scico/flax/examples/ray_functions.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 94399666..1c80ad38 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -12,7 +12,6 @@ """ -import warnings from typing import Callable, List, Union import numpy as np @@ -140,8 +139,6 @@ def distributed_data_generation( # Use half of available CPU resources ar = ray.available_resources() - if "CPU" not in ar: - warnings.warn("No CPU key in ray.available_resources() output.") nproc = max(int(ar.get("CPU", 1)) // 2, 1) if nproc > nimg: nproc = nimg @@ -153,7 +150,11 @@ def distributed_data_generation( ndata_per_proc = int(nimg // nproc) - @ray.remote + # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that + # is expected to be quite brittle. + num_gpus = 0.0001 if "GPU" in ar else 0 + + @ray.remote(num_gpus=num_gpus) def data_gen(seed, size, ndata, imgf): return imgf(seed, size, ndata) From 25f318e83dd532937fa1b2da80a438aa9a357bf4 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 11:47:08 -0600 Subject: [PATCH 21/48] Debug attempt --- scico/flax/examples/ray_functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 1c80ad38..914b508c 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -152,7 +152,9 @@ def distributed_data_generation( # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that # is expected to be quite brittle. - num_gpus = 0.0001 if "GPU" in ar else 0 + # num_gpus = 0.0001 if "GPU" in ar else 0 + num_gpus = 1 if "GPU" in ar else 0 + print(num_gpus) @ray.remote(num_gpus=num_gpus) def data_gen(seed, size, ndata, imgf): From 9218e4ddb999c26872c1637549d158b93a232076 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 11:48:59 -0600 Subject: [PATCH 22/48] Debug attempt --- scico/flax/examples/ray_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 914b508c..3cc287f1 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -153,7 +153,7 @@ def distributed_data_generation( # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that # is expected to be quite brittle. # num_gpus = 0.0001 if "GPU" in ar else 0 - num_gpus = 1 if "GPU" in ar else 0 + num_gpus = 0.0001 if "GPU" in ar else 0 print(num_gpus) @ray.remote(num_gpus=num_gpus) From e73ae7d2993b4af2eeebf03343982c5d0ef6e4fc Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 11:51:59 -0600 Subject: [PATCH 23/48] Debug attempt --- scico/flax/examples/ray_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 3cc287f1..c6804060 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -129,7 +129,7 @@ def distributed_data_generation( imagenf: Function for batch-data generation. size: Size of image to generate. ndata: Number of images to generate. - seedg: Base seed for data generation. Default: 123. + seedg: Base seed for data generation. Returns: Array of generated data. @@ -153,7 +153,7 @@ def distributed_data_generation( # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that # is expected to be quite brittle. # num_gpus = 0.0001 if "GPU" in ar else 0 - num_gpus = 0.0001 if "GPU" in ar else 0 + num_gpus = 0.1 if "GPU" in ar else 0 print(num_gpus) @ray.remote(num_gpus=num_gpus) From c9714e4d5effa021f19fee813179364277dafe50 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 11:53:17 -0600 Subject: [PATCH 24/48] Debug attempt --- scico/flax/examples/ray_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index c6804060..bd6dcc0e 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -153,7 +153,7 @@ def distributed_data_generation( # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that # is expected to be quite brittle. # num_gpus = 0.0001 if "GPU" in ar else 0 - num_gpus = 0.1 if "GPU" in ar else 0 + num_gpus = 1 if "GPU" in ar else 0 print(num_gpus) @ray.remote(num_gpus=num_gpus) From e24ccdde71e4f0521a0d39da8272c0a31217e44c Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 11:55:33 -0600 Subject: [PATCH 25/48] Debug attempt --- scico/flax/examples/ray_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index bd6dcc0e..e88679e7 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -158,6 +158,9 @@ def distributed_data_generation( @ray.remote(num_gpus=num_gpus) def data_gen(seed, size, ndata, imgf): + import os + + print(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") return imgf(seed, size, ndata) ray_return = ray.get( From 9bbad64a72ea6e93cf6be2a2444fe6b695a85551 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 11:56:42 -0600 Subject: [PATCH 26/48] Debug attempt --- scico/flax/examples/ray_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index e88679e7..f2dc228e 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -153,7 +153,7 @@ def distributed_data_generation( # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that # is expected to be quite brittle. # num_gpus = 0.0001 if "GPU" in ar else 0 - num_gpus = 1 if "GPU" in ar else 0 + num_gpus = 0.1 if "GPU" in ar else 0 print(num_gpus) @ray.remote(num_gpus=num_gpus) From 325fb9b84036b9dfdd994a2cf13f678967cb8c70 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 12:22:44 -0600 Subject: [PATCH 27/48] New solution attempt --- scico/flax/examples/ray_functions.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index f2dc228e..54cda28d 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -11,7 +11,7 @@ using ray. """ - +import os from typing import Callable, List, Union import numpy as np @@ -150,17 +150,15 @@ def distributed_data_generation( ndata_per_proc = int(nimg // nproc) - # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that - # is expected to be quite brittle. - # num_gpus = 0.0001 if "GPU" in ar else 0 - num_gpus = 0.1 if "GPU" in ar else 0 - print(num_gpus) + # Attempt to avoid ray/jax conflicts. + if "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES" in os.environ: + ray_noset_cuda = os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] + else: + ray_noset_cuda = None + os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" - @ray.remote(num_gpus=num_gpus) + @ray.remote def data_gen(seed, size, ndata, imgf): - import os - - print(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") return imgf(seed, size, ndata) ray_return = ray.get( @@ -168,4 +166,7 @@ def data_gen(seed, size, ndata, imgf): ) imgs = np.vstack([t for t in ray_return]) + if ray_noset_cuda is not None: + os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = ray_noset_cuda + return imgs From d521aa3b66f865b083f4024c861a138e55e92deb Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 12:27:07 -0600 Subject: [PATCH 28/48] Debug attempt --- scico/flax/examples/ray_functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 54cda28d..039a81e7 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -159,6 +159,10 @@ def distributed_data_generation( @ray.remote def data_gen(seed, size, ndata, imgf): + import os + + if "CUDA_VISIBLE_DEVICES" in os.environ: + print(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") return imgf(seed, size, ndata) ray_return = ray.get( From 8eef347821a151f99612bf84dade8a09fe85fff3 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 12:29:53 -0600 Subject: [PATCH 29/48] Debug attempt --- scico/flax/examples/ray_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 039a81e7..d7c63bfe 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -163,6 +163,7 @@ def data_gen(seed, size, ndata, imgf): if "CUDA_VISIBLE_DEVICES" in os.environ: print(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") + del os.environ["CUDA_VISIBLE_DEVICES"] return imgf(seed, size, ndata) ray_return = ray.get( From fa08d8c5c3f3c3d15cd844d8350002e00b6b2bfc Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 12:37:50 -0600 Subject: [PATCH 30/48] Debug attempt --- scico/flax/examples/ray_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index d7c63bfe..0e45b5df 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -157,7 +157,7 @@ def distributed_data_generation( ray_noset_cuda = None os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" - @ray.remote + @ray.remote(num_gpus=0.001) def data_gen(seed, size, ndata, imgf): import os From 47b80671468657a6b07d139ec1674c88a8de1c62 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 13:48:07 -0600 Subject: [PATCH 31/48] Debug attempt --- scico/flax/examples/ray_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 0e45b5df..3e011580 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -156,6 +156,7 @@ def distributed_data_generation( else: ray_noset_cuda = None os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" + os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" @ray.remote(num_gpus=0.001) def data_gen(seed, size, ndata, imgf): From 931c763b91aead02dd6f3c04fdde0c5c18a441d3 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 13:49:14 -0600 Subject: [PATCH 32/48] Debug attempt --- scico/flax/examples/ray_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 3e011580..5bb450c1 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -162,6 +162,7 @@ def distributed_data_generation( def data_gen(seed, size, ndata, imgf): import os + os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" if "CUDA_VISIBLE_DEVICES" in os.environ: print(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") del os.environ["CUDA_VISIBLE_DEVICES"] From a9cafff7283536f4599841f7eeb1a2bf80b39bdc Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 13:54:06 -0600 Subject: [PATCH 33/48] Debug attempt --- scico/flax/examples/ray_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 5bb450c1..1d0ede6e 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -157,12 +157,14 @@ def distributed_data_generation( ray_noset_cuda = None os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" @ray.remote(num_gpus=0.001) def data_gen(seed, size, ndata, imgf): import os os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" if "CUDA_VISIBLE_DEVICES" in os.environ: print(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") del os.environ["CUDA_VISIBLE_DEVICES"] From 5f7001e766bd4ca4d1b3eeefd5724c4164ce4ef3 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 13:55:52 -0600 Subject: [PATCH 34/48] Debug attempt --- scico/flax/examples/ray_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 1d0ede6e..282446fb 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -158,11 +158,13 @@ def distributed_data_generation( os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" + os.environ["JAX_PLATFORMS"] = "cpu" @ray.remote(num_gpus=0.001) def data_gen(seed, size, ndata, imgf): import os + os.environ["JAX_PLATFORMS"] = "cpu" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" if "CUDA_VISIBLE_DEVICES" in os.environ: From 644c189278de2ec5009b36f7d4581fd9e812647c Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 14:00:07 -0600 Subject: [PATCH 35/48] Debug attempt --- scico/flax/examples/ray_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 282446fb..275bf1c4 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -163,8 +163,11 @@ def distributed_data_generation( @ray.remote(num_gpus=0.001) def data_gen(seed, size, ndata, imgf): import os + import sys os.environ["JAX_PLATFORMS"] = "cpu" + sys.modules.pop("jax") + sys.modules.pop("scico") os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" if "CUDA_VISIBLE_DEVICES" in os.environ: From 89b4772587ba83da0ca0abb2ef4b02ebd83c0fcb Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 14:02:46 -0600 Subject: [PATCH 36/48] Debug attempt --- scico/flax/examples/ray_functions.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 275bf1c4..68beed0d 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -80,6 +80,11 @@ def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: Returns: Array of generated data. """ + import os + + os.environ["JAX_PLATFORMS"] = "cpu" + os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" np.random.seed(seed) saux = np.zeros((ndata, size, size, 1), dtype=np.float32) for i in range(ndata): @@ -163,11 +168,8 @@ def distributed_data_generation( @ray.remote(num_gpus=0.001) def data_gen(seed, size, ndata, imgf): import os - import sys os.environ["JAX_PLATFORMS"] = "cpu" - sys.modules.pop("jax") - sys.modules.pop("scico") os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" if "CUDA_VISIBLE_DEVICES" in os.environ: From fdb8520d9fd856db294191f8964b2e5532e356ee Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 14:05:55 -0600 Subject: [PATCH 37/48] Debug attempt --- scico/flax/examples/ray_functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 68beed0d..032edf9b 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -81,7 +81,11 @@ def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: Array of generated data. """ import os + import sys + if "jax" in sys.modules: + print("jax loaded") + sys.modules.pop("jax") os.environ["JAX_PLATFORMS"] = "cpu" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" From 978759e5e72ef4a11e1a4f3ad2c21df1ff43cc04 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 14:20:32 -0600 Subject: [PATCH 38/48] Return to earlier approach --- scico/flax/examples/ray_functions.py | 37 ++++------------------------ 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 032edf9b..4535d725 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -11,7 +11,6 @@ using ray. """ -import os from typing import Callable, List, Union import numpy as np @@ -80,15 +79,6 @@ def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: Returns: Array of generated data. """ - import os - import sys - - if "jax" in sys.modules: - print("jax loaded") - sys.modules.pop("jax") - os.environ["JAX_PLATFORMS"] = "cpu" - os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" - os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" np.random.seed(seed) saux = np.zeros((ndata, size, size, 1), dtype=np.float32) for i in range(ndata): @@ -159,26 +149,12 @@ def distributed_data_generation( ndata_per_proc = int(nimg // nproc) - # Attempt to avoid ray/jax conflicts. - if "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES" in os.environ: - ray_noset_cuda = os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] - else: - ray_noset_cuda = None - os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" - os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" - os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" - os.environ["JAX_PLATFORMS"] = "cpu" - - @ray.remote(num_gpus=0.001) + # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that + # is expected to be rather brittle. + num_gpus = 1 if "GPU" in ar else 0 + + @ray.remote(num_gpus=num_gpus) def data_gen(seed, size, ndata, imgf): - import os - - os.environ["JAX_PLATFORMS"] = "cpu" - os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" - os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" - if "CUDA_VISIBLE_DEVICES" in os.environ: - print(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") - del os.environ["CUDA_VISIBLE_DEVICES"] return imgf(seed, size, ndata) ray_return = ray.get( @@ -186,7 +162,4 @@ def data_gen(seed, size, ndata, imgf): ) imgs = np.vstack([t for t in ray_return]) - if ray_noset_cuda is not None: - os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = ray_noset_cuda - return imgs From fc2315abf16e22f36efdbfd578e313c020d1e945 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 14:26:14 -0600 Subject: [PATCH 39/48] Extend comment --- scico/flax/examples/ray_functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index 4535d725..a90f402b 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -150,7 +150,9 @@ def distributed_data_generation( ndata_per_proc = int(nimg // nproc) # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that - # is expected to be rather brittle. + # can severely limit parallel execution (since ray will ensure that only + # as many actors as availble GPUs are created), and is expected to be rather + # brittle. num_gpus = 1 if "GPU" in ar else 0 @ray.remote(num_gpus=num_gpus) From 039a970c47e31f849fc5e715470e0b9219907d1c Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 14:54:05 -0600 Subject: [PATCH 40/48] Clean up and improve function logic --- scico/flax/examples/data_generation.py | 5 ----- scico/flax/examples/ray_functions.py | 25 ++++++++++++++----------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index b86245b4..422c078e 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -10,7 +10,6 @@ Computation is distributed via ray to reduce processing time. """ -import os from functools import partial from time import time from typing import Callable, Tuple @@ -46,10 +45,6 @@ from scico.linop.xray.astra import XRayTransform2D -# Arbitrary process count: only applies if GPU is not available. -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" - - def vector_f(f_: Callable, v: Array) -> Array: """Vectorize application of operator. diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/ray_functions.py index a90f402b..08665ee3 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/ray_functions.py @@ -139,22 +139,25 @@ def distributed_data_generation( # Use half of available CPU resources ar = ray.available_resources() nproc = max(int(ar.get("CPU", 1)) // 2, 1) + + # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that + # can severely limit parallel execution (since ray will ensure that only + # as many actors as available GPUs are created), and is expected to be + # rather brittle. + if "GPU" in ar: + num_gpus = 1 + nproc = min(nproc, int(ar.get("GPU"))) + else: + num_gpus = 0 + if nproc > nimg: nproc = nimg - if nproc > 1 and nimg % nproc > 0: - raise ValueError( - f"Number of images to generate ({nimg}) must be divisible by " - f"the number of available devices ({nproc})." - ) + if nimg % nproc > 0: + # Increase nimg to be a multiple of nproc if it isn't already + nimg = (nimg // nproc + 1) * nproc ndata_per_proc = int(nimg // nproc) - # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that - # can severely limit parallel execution (since ray will ensure that only - # as many actors as availble GPUs are created), and is expected to be rather - # brittle. - num_gpus = 1 if "GPU" in ar else 0 - @ray.remote(num_gpus=num_gpus) def data_gen(seed, size, ndata, imgf): return imgf(seed, size, ndata) From 9dca046628e08dfd1d96614f182d81559a8a92c6 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 15:31:20 -0600 Subject: [PATCH 41/48] Address some problems --- examples/scripts/ct_astra_unet_train_foam2.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/scripts/ct_astra_unet_train_foam2.py b/examples/scripts/ct_astra_unet_train_foam2.py index ab651215..bae623b3 100644 --- a/examples/scripts/ct_astra_unet_train_foam2.py +++ b/examples/scripts/ct_astra_unet_train_foam2.py @@ -22,19 +22,20 @@ ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax +import numpy as np + from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax from scico import metric, plot from scico.flax.examples import load_ct_data -""" -Prepare parallel processing. Set an arbitrary processor count (only -applies if GPU is not available). -""" -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + platform = jax.lib.xla_bridge.get_backend().platform print("Platform: ", platform) @@ -196,7 +197,7 @@ hist = stats_object.history(transpose=True) fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( - jax.numpy.vstack((hist.Train_Loss, hist.Eval_Loss)).T, + np.vstack((hist.Train_Loss, hist.Eval_Loss)).T, x=hist.Epoch, ptyp="semilogy", title="Loss function", @@ -207,7 +208,7 @@ ax=ax[0], ) plot.plot( - jax.numpy.vstack((hist.Train_SNR, hist.Eval_SNR)).T, + np.vstack((hist.Train_SNR, hist.Eval_SNR)).T, x=hist.Epoch, title="Metric", xlbl="Epoch", From 1fcd82daab2eebde939aed18be4a95021cdf4672 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 16:16:58 -0600 Subject: [PATCH 42/48] Clean up --- examples/scripts/ct_astra_odp_train_foam2.py | 9 ++++----- examples/scripts/deconv_modl_train_foam1.py | 9 ++++----- examples/scripts/deconv_odp_train_foam1.py | 9 ++++----- examples/scripts/denoise_dncnn_train_bsds.py | 10 +++++----- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/examples/scripts/ct_astra_odp_train_foam2.py b/examples/scripts/ct_astra_odp_train_foam2.py index 42fc09cb..4a8355e3 100644 --- a/examples/scripts/ct_astra_odp_train_foam2.py +++ b/examples/scripts/ct_astra_odp_train_foam2.py @@ -56,6 +56,9 @@ ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -66,11 +69,7 @@ from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop.xray.astra import XRayTransform2D -""" -Prepare parallel processing. Set an arbitrary processor count (only -applies if GPU is not available). -""" -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + platform = jax.lib.xla_bridge.get_backend().platform print("Platform: ", platform) diff --git a/examples/scripts/deconv_modl_train_foam1.py b/examples/scripts/deconv_modl_train_foam1.py index dd310c41..4f142071 100644 --- a/examples/scripts/deconv_modl_train_foam1.py +++ b/examples/scripts/deconv_modl_train_foam1.py @@ -53,6 +53,9 @@ ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -63,11 +66,7 @@ from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop import CircularConvolve -""" -Prepare parallel processing. Set an arbitrary processor count (only -applies if GPU is not available). -""" -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + platform = jax.lib.xla_bridge.get_backend().platform print("Platform: ", platform) diff --git a/examples/scripts/deconv_odp_train_foam1.py b/examples/scripts/deconv_odp_train_foam1.py index 5cb19cbe..f79d77e6 100644 --- a/examples/scripts/deconv_odp_train_foam1.py +++ b/examples/scripts/deconv_odp_train_foam1.py @@ -61,6 +61,9 @@ ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -71,11 +74,7 @@ from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop import CircularConvolve -""" -Prepare parallel processing. Set an arbitrary processor count (only -applies if GPU is not available). -""" -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + platform = jax.lib.xla_bridge.get_backend().platform print("Platform: ", platform) diff --git a/examples/scripts/denoise_dncnn_train_bsds.py b/examples/scripts/denoise_dncnn_train_bsds.py index a55df76d..ac9fcb75 100644 --- a/examples/scripts/denoise_dncnn_train_bsds.py +++ b/examples/scripts/denoise_dncnn_train_bsds.py @@ -13,11 +13,15 @@ with additive Gaussian noise. """ +# isort: off import os from time import time import numpy as np +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -26,11 +30,7 @@ from scico import metric, plot from scico.flax.examples import load_image_data -""" -Prepare parallel processing. Set an arbitrary processor count (only -applies if GPU is not available). -""" -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + platform = jax.lib.xla_bridge.get_backend().platform print("Platform: ", platform) From 6cdf217e20c0646c5da974c263a6b5e45229341d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 16:20:26 -0600 Subject: [PATCH 43/48] Rename function for consistency with related functions --- examples/scripts/deconv_datagen_foam1.py | 4 ++-- examples/scripts/deconv_modl_train_foam1.py | 4 ++-- examples/scripts/deconv_odp_train_foam1.py | 4 ++-- scico/flax/examples/__init__.py | 4 ++-- scico/flax/examples/examples.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/scripts/deconv_datagen_foam1.py b/examples/scripts/deconv_datagen_foam1.py index e0e6e492..f80bee5f 100644 --- a/examples/scripts/deconv_datagen_foam1.py +++ b/examples/scripts/deconv_datagen_foam1.py @@ -22,7 +22,7 @@ ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 from scico import plot -from scico.flax.examples import load_foam1_blur_data +from scico.flax.examples import load_blur_data """ Read data from cache or generate if not available. @@ -36,7 +36,7 @@ nimg = train_nimg + test_nimg output_size = 256 # image size -train_ds, test_ds = load_foam1_blur_data( +train_ds, test_ds = load_blur_data( train_nimg, test_nimg, output_size, diff --git a/examples/scripts/deconv_modl_train_foam1.py b/examples/scripts/deconv_modl_train_foam1.py index 4f142071..69b19d93 100644 --- a/examples/scripts/deconv_modl_train_foam1.py +++ b/examples/scripts/deconv_modl_train_foam1.py @@ -62,7 +62,7 @@ from scico import flax as sflax from scico import metric, plot -from scico.flax.examples import load_foam1_blur_data +from scico.flax.examples import load_blur_data from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop import CircularConvolve @@ -92,7 +92,7 @@ test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg -train_ds, test_ds = load_foam1_blur_data( +train_ds, test_ds = load_blur_data( train_nimg, test_nimg, output_size, diff --git a/examples/scripts/deconv_odp_train_foam1.py b/examples/scripts/deconv_odp_train_foam1.py index f79d77e6..9887fe89 100644 --- a/examples/scripts/deconv_odp_train_foam1.py +++ b/examples/scripts/deconv_odp_train_foam1.py @@ -70,7 +70,7 @@ from scico import flax as sflax from scico import metric, plot -from scico.flax.examples import load_foam1_blur_data +from scico.flax.examples import load_blur_data from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop import CircularConvolve @@ -100,7 +100,7 @@ test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg -train_ds, test_ds = load_foam1_blur_data( +train_ds, test_ds = load_blur_data( train_nimg, test_nimg, output_size, diff --git a/scico/flax/examples/__init__.py b/scico/flax/examples/__init__.py index 5a8d0d45..8ca7b182 100644 --- a/scico/flax/examples/__init__.py +++ b/scico/flax/examples/__init__.py @@ -8,11 +8,11 @@ """Data utility functions used by Flax example scripts.""" from .data_preprocessing import PaddedCircularConvolve, build_blur_kernel -from .examples import load_ct_data, load_foam1_blur_data, load_image_data +from .examples import load_blur_data, load_ct_data, load_image_data __all__ = [ "load_ct_data", - "load_foam1_blur_data", + "load_blur_data", "load_image_data", "PaddedCircularConvolve", "build_blur_kernel", diff --git a/scico/flax/examples/examples.py b/scico/flax/examples/examples.py index 3b29e83f..faa2daa0 100644 --- a/scico/flax/examples/examples.py +++ b/scico/flax/examples/examples.py @@ -174,7 +174,7 @@ def load_ct_data( return trdt, ttdt -def load_foam1_blur_data( +def load_blur_data( train_nimg: int, test_nimg: int, size: int, From aa1467a947cbfcbe8c5d56a4527ea5cab5a954fc Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 16:31:42 -0600 Subject: [PATCH 44/48] Bug fix --- scico/test/flax/test_examples_flax.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/scico/test/flax/test_examples_flax.py b/scico/test/flax/test_examples_flax.py index c259526a..10387440 100644 --- a/scico/test/flax/test_examples_flax.py +++ b/scico/test/flax/test_examples_flax.py @@ -11,8 +11,7 @@ generate_blur_data, generate_ct_data, have_astra, - have_ray, - have_xdesign, + have_ray_and_xdesign, ) from scico.flax.examples.data_preprocessing import ( CenterCrop, @@ -37,7 +36,7 @@ # These tests are for the scico.flax.examples module, NOT the example scripts -@pytest.mark.skipif(not have_xdesign, reason="xdesign package not installed") +@pytest.mark.skipif(not have_ray_and_xdesign, reason="ray or xdesign package not installed") def test_foam1_gen(): seed = 4444 N = 32 @@ -48,7 +47,7 @@ def test_foam1_gen(): assert dt.shape == (ndata, N, N, 1) -@pytest.mark.skipif(not have_xdesign, reason="xdesign package not installed") +@pytest.mark.skipif(not have_ray_and_xdesign, reason="ray or xdesign package not installed") def test_foam2_gen(): seed = 4321 N = 32 @@ -59,7 +58,7 @@ def test_foam2_gen(): assert dt.shape == (ndata, N, N, 1) -@pytest.mark.skipif(not have_ray, reason="ray package not installed") +@pytest.mark.skipif(not have_ray_and_xdesign, reason="ray or xdesign package not installed") def test_distdatagen(): N = 16 nimg = 8 From cc678fdd4b3c382a272c949c6afc781d111750b1 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 16:51:33 -0600 Subject: [PATCH 45/48] Clean up --- scico/flax/examples/data_generation.py | 4 ++-- scico/flax/examples/examples.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index 422c078e..8c5ff276 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -25,7 +25,7 @@ else: have_ray_and_xdesign = True from .ray_functions import ( - generate_foam1_images, # noqa + generate_foam1_images, generate_foam2_images, distributed_data_generation, ) @@ -171,7 +171,7 @@ def generate_blur_data( size: int, blur_kernel: Array, noise_sigma: float, - imgfunc: Callable, + imgfunc: Callable = generate_foam1_images, seed: int = 4321, verbose: bool = False, ) -> Tuple[Array, Array]: diff --git a/scico/flax/examples/examples.py b/scico/flax/examples/examples.py index faa2daa0..82c710aa 100644 --- a/scico/flax/examples/examples.py +++ b/scico/flax/examples/examples.py @@ -16,7 +16,7 @@ from scico.numpy import Array from scico.typing import Shape -from .data_generation import generate_blur_data, generate_ct_data, generate_foam1_images +from .data_generation import generate_blur_data, generate_ct_data from .data_preprocessing import ConfigImageSetDict, build_image_dataset, get_bsds_data from .typed_dict import CTDataSetDict @@ -288,7 +288,6 @@ def load_blur_data( size, blur_kernel, noise_sigma, - imgfunc=generate_foam1_images, verbose=verbose, ) # Separate training and testing partitions. From d650f3ac011e80f1920445842bf97c2c36186e51 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 17:47:21 -0600 Subject: [PATCH 46/48] Bug fix --- scico/flax/examples/data_generation.py | 85 +++++++++++++++++-- .../{ray_functions.py => xdesign_func.py} | 69 +-------------- scico/test/flax/test_examples_flax.py | 19 +++-- 3 files changed, 91 insertions(+), 82 deletions(-) rename scico/flax/examples/{ray_functions.py => xdesign_func.py} (60%) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index 8c5ff276..5a17e286 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -18,17 +18,25 @@ try: import xdesign # noqa: F401 +except ImportError: + have_xdesign = False + + def generate_foam1_images(): + raise RunTimeError("xdesign package required.") + + def generate_foam2_images(): + raise RunTimeError("xdesign package required.") +else: + have_xdesign = True + from .xdesign_func import generate_foam1_images, generate_foam2_images + +try: import ray # noqa: F401 except ImportError: - have_ray_and_xdesign = False + have_ray = False else: - have_ray_and_xdesign = True - from .ray_functions import ( - generate_foam1_images, - generate_foam2_images, - distributed_data_generation, - ) + have_ray = True import jax import jax.numpy as jnp @@ -109,7 +117,7 @@ def generate_ct_data( - **sino** : (:class:`jax.Array`): Corresponding sinograms. - **fbp** : (:class:`jax.Array`) Corresponding filtered back projections. """ - if not have_ray_and_xdesign and have_astra: + if not (have_ray and have_xdesign and have_astra): raise RuntimeError( "Packages ray, xdesign, and astra are required for use of this function." ) @@ -195,7 +203,7 @@ def generate_blur_data( - **img** : Generated foam images. - **blurn** : Corresponding blurred and noisy images. """ - if not have_ray_and_xdesign: + if not (have_ray and have_xdesign): raise RuntimeError("Packages ray and xdesign are required for use of this function.") start_time = time() img = distributed_data_generation(imgfunc, size, nimg, seed) @@ -236,3 +244,62 @@ def generate_blur_data( print(f"{'Blur generation':19s}{'time[s]:':10s}{time_blur:>7.2f}") return img, blurn + + +def distributed_data_generation( + imgenf: Callable, size: int, nimg: int, seedg: float = 123 +) -> np.ndarray: + """Data generation distributed among processes using ray. + + *Warning:* callable `imgenf` should not make use of any jax functions + to avoid the risk of errors when running with GPU devices, in which + case jax is initialized to expect the availability of GPUs, which are + then not available within the `ray.remote` function due to the absence + of any declared GPUs as a `num_gpus` parameter of `@ray.remote`. + + Args: + imagenf: Function for batch-data generation. + size: Size of image to generate. + ndata: Number of images to generate. + seedg: Base seed for data generation. + + Returns: + Array of generated data. + """ + if not have_ray: + raise RuntimeError("Package ray is required for use of this function.") + if not ray.is_initialized(): + raise RuntimeError("Ray must be initialized via ray.init() before calling this function.") + + # Use half of available CPU resources + ar = ray.available_resources() + nproc = max(int(ar.get("CPU", 1)) // 2, 1) + + # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that + # can severely limit parallel execution (since ray will ensure that only + # as many actors as available GPUs are created), and is expected to be + # rather brittle. + if "GPU" in ar: + num_gpus = 1 + nproc = min(nproc, int(ar.get("GPU"))) + else: + num_gpus = 0 + + if nproc > nimg: + nproc = nimg + if nimg % nproc > 0: + # Increase nimg to be a multiple of nproc if it isn't already + nimg = (nimg // nproc + 1) * nproc + + ndata_per_proc = int(nimg // nproc) + + @ray.remote(num_gpus=num_gpus) + def data_gen(seed, size, ndata, imgf): + return imgf(seed, size, ndata) + + ray_return = ray.get( + [data_gen.remote(seed + seedg, size, ndata_per_proc, imgenf) for seed in range(nproc)] + ) + imgs = np.vstack([t for t in ray_return]) + + return imgs diff --git a/scico/flax/examples/ray_functions.py b/scico/flax/examples/xdesign_func.py similarity index 60% rename from scico/flax/examples/ray_functions.py rename to scico/flax/examples/xdesign_func.py index 08665ee3..bbe41173 100644 --- a/scico/flax/examples/ray_functions.py +++ b/scico/flax/examples/xdesign_func.py @@ -5,21 +5,15 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""Generate training data for Flax example scripts using ray. +"""Generate training data for Flax example scripts using xdesign. -Functions for generating xdesign foam phantoms and generation in parallel -using ray. +Functions for generating xdesign foam phantoms. """ -from typing import Callable, List, Union +from typing import List, Union import numpy as np -try: - import ray # noqa: F401 -except ImportError: - raise RuntimeError("Package ray is required for use of this module.") - try: import xdesign # noqa: F401 except ImportError: @@ -111,60 +105,3 @@ def generate_foam2_images(seed: float, size: int, ndata: int) -> np.ndarray: saux /= np.max(saux, axis=(1, 2), keepdims=True) return saux - - -def distributed_data_generation( - imgenf: Callable, size: int, nimg: int, seedg: float = 123 -) -> np.ndarray: - """Data generation distributed among processes using ray. - - *Warning:* callable `imgenf` should not make use of any jax functions - to avoid the risk of errors when running with GPU devices, in which - case jax is initialized to expect the availability of GPUs, which are - then not available within the `ray.remote` function due to the absence - of any declared GPUs as a `num_gpus` parameter of `@ray.remote`. - - Args: - imagenf: Function for batch-data generation. - size: Size of image to generate. - ndata: Number of images to generate. - seedg: Base seed for data generation. - - Returns: - Array of generated data. - """ - if not ray.is_initialized(): - raise RuntimeError("Ray must be initialized via ray.init() before calling this function.") - - # Use half of available CPU resources - ar = ray.available_resources() - nproc = max(int(ar.get("CPU", 1)) // 2, 1) - - # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that - # can severely limit parallel execution (since ray will ensure that only - # as many actors as available GPUs are created), and is expected to be - # rather brittle. - if "GPU" in ar: - num_gpus = 1 - nproc = min(nproc, int(ar.get("GPU"))) - else: - num_gpus = 0 - - if nproc > nimg: - nproc = nimg - if nimg % nproc > 0: - # Increase nimg to be a multiple of nproc if it isn't already - nimg = (nimg // nproc + 1) * nproc - - ndata_per_proc = int(nimg // nproc) - - @ray.remote(num_gpus=num_gpus) - def data_gen(seed, size, ndata, imgf): - return imgf(seed, size, ndata) - - ray_return = ray.get( - [data_gen.remote(seed + seedg, size, ndata_per_proc, imgenf) for seed in range(nproc)] - ) - imgs = np.vstack([t for t in ray_return]) - - return imgs diff --git a/scico/test/flax/test_examples_flax.py b/scico/test/flax/test_examples_flax.py index 10387440..72c084dd 100644 --- a/scico/test/flax/test_examples_flax.py +++ b/scico/test/flax/test_examples_flax.py @@ -10,8 +10,11 @@ distributed_data_generation, generate_blur_data, generate_ct_data, + generate_foam1_images, + generate_foam2_images, have_astra, - have_ray_and_xdesign, + have_ray, + have_xdesign, ) from scico.flax.examples.data_preprocessing import ( CenterCrop, @@ -36,29 +39,27 @@ # These tests are for the scico.flax.examples module, NOT the example scripts -@pytest.mark.skipif(not have_ray_and_xdesign, reason="ray or xdesign package not installed") +@pytest.mark.skipif(not have_xdesign, reason="xdesign package not installed") def test_foam1_gen(): seed = 4444 N = 32 ndata = 2 - from scico.flax.examples.data_generation import generate_foam1_images dt = generate_foam1_images(seed, N, ndata) assert dt.shape == (ndata, N, N, 1) -@pytest.mark.skipif(not have_ray_and_xdesign, reason="ray or xdesign package not installed") +@pytest.mark.skipif(not have_xdesign, reason="xdesign package not installed") def test_foam2_gen(): seed = 4321 N = 32 ndata = 2 - from scico.flax.examples.data_generation import generate_foam2_images dt = generate_foam2_images(seed, N, ndata) assert dt.shape == (ndata, N, N, 1) -@pytest.mark.skipif(not have_ray_and_xdesign, reason="ray or xdesign package not installed") +@pytest.mark.skipif(not have_ray, reason="ray package not installed") def test_distdatagen(): N = 16 nimg = 8 @@ -73,7 +74,10 @@ def random_data_gen(seed, N, ndata): assert dt.shape == (nimg, N, N, 1) -@pytest.mark.skipif(not have_astra, reason="astra package not installed") +@pytest.mark.skipif( + not have_astra or not have_ray or not have_xdesign, + reason="astra, ray, or xdesign package not installed", +) def test_ct_data_generation(): N = 32 nimg = 8 @@ -90,6 +94,7 @@ def random_img_gen(seed, size, ndata): assert fbp.shape == (nimg, N, N, 1) +@pytest.mark.skipif(not have_ray or not have_xdesign, reason="ray or xdesign package not installed") def test_blur_data_generation(): N = 32 nimg = 8 From 0fe46a4ac20308260c57c15927d32410635b2f7d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 17:52:33 -0600 Subject: [PATCH 47/48] Address pylint complaint --- scico/flax/examples/data_generation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index 5a17e286..2d0c3685 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -20,6 +20,7 @@ import xdesign # noqa: F401 except ImportError: have_xdesign = False + # pylint: disable=missing-function-docstring def generate_foam1_images(): raise RunTimeError("xdesign package required.") @@ -27,6 +28,7 @@ def generate_foam1_images(): def generate_foam2_images(): raise RunTimeError("xdesign package required.") + # pylint: enable=missing-function-docstring else: have_xdesign = True from .xdesign_func import generate_foam1_images, generate_foam2_images From 41fa25eff6ddb8a70c233290583520faae295762 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 23 Jul 2024 18:07:54 -0600 Subject: [PATCH 48/48] Revert unworkable structure --- scico/flax/examples/data_generation.py | 105 +++++++++++++++++++++--- scico/flax/examples/xdesign_func.py | 107 ------------------------- 2 files changed, 96 insertions(+), 116 deletions(-) delete mode 100644 scico/flax/examples/xdesign_func.py diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index 2d0c3685..bc5df270 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -12,7 +12,7 @@ from functools import partial from time import time -from typing import Callable, Tuple +from typing import Callable, List, Tuple, Union import numpy as np @@ -20,18 +20,15 @@ import xdesign # noqa: F401 except ImportError: have_xdesign = False - # pylint: disable=missing-function-docstring - def generate_foam1_images(): - raise RunTimeError("xdesign package required.") + # pylint: disable=missing-class-docstring + class UnitCircle: + pass - def generate_foam2_images(): - raise RunTimeError("xdesign package required.") - - # pylint: enable=missing-function-docstring + # pylint: enable=missing-class-docstring else: have_xdesign = True - from .xdesign_func import generate_foam1_images, generate_foam2_images + from xdesign import Foam, SimpleMaterial, UnitCircle, discrete_phantom try: import ray # noqa: F401 @@ -55,6 +52,96 @@ def generate_foam2_images(): from scico.linop.xray.astra import XRayTransform2D +class Foam2(UnitCircle): + """Foam-like material with two attenuations. + + Define functionality to generate phantom with structure similar + to foam with two different attenuation properties.""" + + def __init__( + self, + size_range: Union[float, List[float]] = [0.05, 0.01], + gap: float = 0, + porosity: float = 1, + attn1: float = 1.0, + attn2: float = 10.0, + ): + """Foam-like structure with two different attenuations. + Circles for material 1 are more sparse than for material 2 + by design. + + Args: + size_range: The radius, or range of radius, of the + circles to be added. Default: [0.05, 0.01]. + gap: Minimum distance between circle boundaries. + Default: 0. + porosity: Target porosity. Must be a value between + [0, 1]. Default: 1. + attn1: Mass attenuation parameter for material 1. + Default: 1. + attn2: Mass attenuation parameter for material 2. + Default: 10. + """ + super().__init__(radius=0.5, material=SimpleMaterial(attn1)) + if porosity < 0 or porosity > 1: + raise ValueError("Porosity must be in the range [0,1).") + self.sprinkle( + 300, size_range, gap, material=SimpleMaterial(attn2), max_density=porosity / 2.0 + ) + self.sprinkle(300, size_range, gap, material=SimpleMaterial(20), max_density=porosity) + + +def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: + """Generate batch of xdesign foam-like structures. + + Generate batch of images with `xdesign` foam-like structure, which + uses one attenuation. + + Args: + seed: Seed for data generation. + size: Size of image to generate. + ndata: Number of images to generate. + + Returns: + Array of generated data. + """ + if not have_xdesign: + raise RuntimeError("Package xdesign is required for use of this module.") + np.random.seed(seed) + saux = np.zeros((ndata, size, size, 1), dtype=np.float32) + for i in range(ndata): + foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) + saux[i, ..., 0] = discrete_phantom(foam, size=size) + + return saux + + +def generate_foam2_images(seed: float, size: int, ndata: int) -> np.ndarray: + """Generate batch of foam2 structures. + + Generate batch of images with :class:`Foam2` structure + (foam-like material with two different attenuations). + + Args: + seed: Seed for data generation. + size: Size of image to generate. + ndata: Number of images to generate. + + Returns: + Array of generated data. + """ + if not have_xdesign: + raise RuntimeError("Package xdesign is required for use of this module.") + np.random.seed(seed) + saux = np.zeros((ndata, size, size, 1), dtype=np.float32) + for i in range(ndata): + foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) + saux[i, ..., 0] = discrete_phantom(foam, size=size) + # normalize + saux /= np.max(saux, axis=(1, 2), keepdims=True) + + return saux + + def vector_f(f_: Callable, v: Array) -> Array: """Vectorize application of operator. diff --git a/scico/flax/examples/xdesign_func.py b/scico/flax/examples/xdesign_func.py deleted file mode 100644 index bbe41173..00000000 --- a/scico/flax/examples/xdesign_func.py +++ /dev/null @@ -1,107 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2024 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Generate training data for Flax example scripts using xdesign. - -Functions for generating xdesign foam phantoms. -""" - -from typing import List, Union - -import numpy as np - -try: - import xdesign # noqa: F401 -except ImportError: - raise RuntimeError("Package xdesign is required for use of this module.") -from xdesign import Foam, SimpleMaterial, UnitCircle, discrete_phantom - - -class Foam2(UnitCircle): - """Foam-like material with two attenuations. - - Define functionality to generate phantom with structure similar - to foam with two different attenuation properties.""" - - def __init__( - self, - size_range: Union[float, List[float]] = [0.05, 0.01], - gap: float = 0, - porosity: float = 1, - attn1: float = 1.0, - attn2: float = 10.0, - ): - """Foam-like structure with two different attenuations. - Circles for material 1 are more sparse than for material 2 - by design. - - Args: - size_range: The radius, or range of radius, of the - circles to be added. Default: [0.05, 0.01]. - gap: Minimum distance between circle boundaries. - Default: 0. - porosity: Target porosity. Must be a value between - [0, 1]. Default: 1. - attn1: Mass attenuation parameter for material 1. - Default: 1. - attn2: Mass attenuation parameter for material 2. - Default: 10. - """ - super().__init__(radius=0.5, material=SimpleMaterial(attn1)) - if porosity < 0 or porosity > 1: - raise ValueError("Porosity must be in the range [0,1).") - self.sprinkle( - 300, size_range, gap, material=SimpleMaterial(attn2), max_density=porosity / 2.0 - ) + self.sprinkle(300, size_range, gap, material=SimpleMaterial(20), max_density=porosity) - - -def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: - """Generate batch of xdesign foam-like structures. - - Generate batch of images with `xdesign` foam-like structure, which - uses one attenuation. - - Args: - seed: Seed for data generation. - size: Size of image to generate. - ndata: Number of images to generate. - - Returns: - Array of generated data. - """ - np.random.seed(seed) - saux = np.zeros((ndata, size, size, 1), dtype=np.float32) - for i in range(ndata): - foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) - saux[i, ..., 0] = discrete_phantom(foam, size=size) - - return saux - - -def generate_foam2_images(seed: float, size: int, ndata: int) -> np.ndarray: - """Generate batch of foam2 structures. - - Generate batch of images with :class:`Foam2` structure - (foam-like material with two different attenuations). - - Args: - seed: Seed for data generation. - size: Size of image to generate. - ndata: Number of images to generate. - - Returns: - Array of generated data. - """ - np.random.seed(seed) - saux = np.zeros((ndata, size, size, 1), dtype=np.float32) - for i in range(ndata): - foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) - saux[i, ..., 0] = discrete_phantom(foam, size=size) - # normalize - saux /= np.max(saux, axis=(1, 2), keepdims=True) - - return saux