Skip to content

Commit

Permalink
Avoid mismatch with declared linop dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Oct 18, 2024
1 parent c4a7a49 commit 38ba3ae
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _back_project(
y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0
)

return HTy
return HTy.astype(jnp.float32)

@staticmethod
@partial(jax.jit, static_argnames=["nx"])
Expand Down
2 changes: 1 addition & 1 deletion scico/test/flax/test_inv.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def setup_method(self, method):
xt, key = random.randn((2 * self.bsize, self.N, self.N, self.chn), seed=4321)

self.nproj = 60 # number of projections
angles = np.linspace(0, np.pi, self.nproj) # evenly spaced projection angles
angles = np.linspace(0, np.pi, self.nproj, endpoint=False, dtype=np.float32)
self.opCT = XRayTransform2D(
input_shape=(self.N, self.N),
det_count=self.N,
Expand Down

0 comments on commit 38ba3ae

Please sign in to comment.