Skip to content

Commit

Permalink
Merge pull request jax-ml#17351 from jakevdp:mypy-fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561136741
  • Loading branch information
jax authors committed Aug 29, 2023
2 parents 5a578cb + f1fc2ad commit 289ccad
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,17 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
nonzero_indices = np.flatnonzero(local_slices)
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
subcube_indices.append(slice(start, end + 1))
subcube_indices = tuple(subcube_indices)
subcube_indices_tuple = tuple(subcube_indices)
# We only end up with all conditions being true if the local devices formed a
# subcube of the full array. This is because we were biased towards taking a
# "hull" spanned by the devices, and in case the local devices don't form a
# subcube that hull will contain non-local devices.
if not is_local_device[subcube_indices].all():
if not is_local_device[subcube_indices_tuple].all():
raise ValueError(
"When passing host local inputs to pjit or xmap, devices "
"connected to a single host must form a contiguous subcube of the "
"global device mesh")
return Mesh(global_mesh.devices[subcube_indices], global_mesh.axis_names)
return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names)


_mesh_object_dict = {} # type: ignore
Expand Down

0 comments on commit 289ccad

Please sign in to comment.