Skip to content

Commit

Permalink
Debug attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Jul 23, 2024
1 parent 644c189 commit 89b4772
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions scico/flax/examples/ray_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray:
Returns:
Array of generated data.
"""
import os

os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
np.random.seed(seed)
saux = np.zeros((ndata, size, size, 1), dtype=np.float32)
for i in range(ndata):
Expand Down Expand Up @@ -163,11 +168,8 @@ def distributed_data_generation(
@ray.remote(num_gpus=0.001)
def data_gen(seed, size, ndata, imgf):
import os
import sys

os.environ["JAX_PLATFORMS"] = "cpu"
sys.modules.pop("jax")
sys.modules.pop("scico")
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
if "CUDA_VISIBLE_DEVICES" in os.environ:
Expand Down

0 comments on commit 89b4772

Please sign in to comment.