diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index ea8d66c1473c..e71d7d4df8a5 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -106,17 +106,17 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: nonzero_indices = np.flatnonzero(local_slices) start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices)) subcube_indices.append(slice(start, end + 1)) - subcube_indices = tuple(subcube_indices) + subcube_indices_tuple = tuple(subcube_indices) # We only end up with all conditions being true if the local devices formed a # subcube of the full array. This is because we were biased towards taking a # "hull" spanned by the devices, and in case the local devices don't form a # subcube that hull will contain non-local devices. - if not is_local_device[subcube_indices].all(): + if not is_local_device[subcube_indices_tuple].all(): raise ValueError( "When passing host local inputs to pjit or xmap, devices " "connected to a single host must form a contiguous subcube of the " "global device mesh") - return Mesh(global_mesh.devices[subcube_indices], global_mesh.axis_names) + return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names) _mesh_object_dict = {} # type: ignore