diff --git a/xla/python/sharding.cc b/xla/python/sharding.cc index d9d509cd95a5bc..0779618d5454b9 100644 --- a/xla/python/sharding.cc +++ b/xla/python/sharding.cc @@ -182,6 +182,9 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, parsed_pspec_(std::move(parsed_pspec)), manual_axes_(std::move(manual_axes)), logical_device_ids_(std::move(logical_device_ids)) { + if (spec_.is_none()) { + throw nb::type_error("Unexpected None passed as spec_ for NamedSharding."); + } nb::object idl = nb::object(mesh_.attr("_internal_device_list")); if (idl.is_none()) { internal_device_list_ = std::nullopt; diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index eaa783ccb4005a..430d3d8c35fedb 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 305 +_version = 306 # Version number for MLIR:Python components. mlir_api_version = 57