Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reject invalid None in jax.NamedSharding(spec=None). #21828

Merged
merged 1 commit into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions xla/python/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
parsed_pspec_(std::move(parsed_pspec)),
manual_axes_(std::move(manual_axes)),
logical_device_ids_(std::move(logical_device_ids)) {
if (spec_.is_none()) {
throw nb::type_error(
"Unexpected None passed as spec for NamedSharding. Did you mean P()?");
}
nb::object idl = nb::object(mesh_.attr("_internal_device_list"));
if (idl.is_none()) {
internal_device_list_ = std::nullopt;
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 308
_version = 309

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down
Loading