Skip to content

Commit

Permalink
Address some problems
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Jul 23, 2024
1 parent 039a970 commit 9dca046
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions examples/scripts/ct_astra_unet_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,20 @@

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax

import numpy as np

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_ct_data

"""
Prepare parallel processing. Set an arbitrary processor count (only
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)

Expand Down Expand Up @@ -196,7 +197,7 @@
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
jax.numpy.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
x=hist.Epoch,
ptyp="semilogy",
title="Loss function",
Expand All @@ -207,7 +208,7 @@
ax=ax[0],
)
plot.plot(
jax.numpy.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
x=hist.Epoch,
title="Metric",
xlbl="Epoch",
Expand Down

0 comments on commit 9dca046

Please sign in to comment.