diff --git a/CHANGELOG.md b/CHANGELOG.md index d0a49ef6e67a..af1abff7ffb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,10 +54,13 @@ Remember to align the itemized text with the first line of an item within a list * `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11. Use the built-in `math.prod` instead. -* Internal deprecations: +* Internal deprecations/removals: * The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype` have been removed. Opaque dtypes have been renamed to Extended dtypes; use `jnp.issubdtype(dtype, jax.dtypes.extended)` instead (available since jax v0.4.14). + * The utility `jax.interpreters.xla.register_collective_primitive` has been + removed. This utility did nothing useful in recent JAX releases and calls + to it can be safely removed. ## jaxlib 0.4.15 diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 6cd693259c96..22d256edc7da 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -266,15 +266,11 @@ def __call__(self, ctx: TranslationContext, _backend_specific_translations: dict[str, dict[core.Primitive, TranslationRule]] _backend_specific_translations = defaultdict(dict) -_collective_primitives: set[core.Primitive] = set() initial_style_primitives: set[core.Primitive] = set() def register_initial_style_primitive(prim: core.Primitive): initial_style_primitives.add(prim) -def register_collective_primitive(prim: core.Primitive): - _collective_primitives.add(prim) - def register_translation(prim: core.Primitive, rule: TranslationRule, *, platform: Optional[str] = None) -> None: if platform is None: diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 1464a1d65c3a..392065b231f6 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -35,7 +35,6 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import pxla -from jax._src.interpreters import xla from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lib.mlir import ir @@ -805,7 +804,6 @@ def broadcast_positional(ct, arg): psum_p.multiple_results = True psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum)) psum_p.def_abstract_eval(_allreduce_abstract_eval) -xla.register_collective_primitive(psum_p) mlir.register_lowering( psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum)) ad.deflinear2(psum_p, _psum_transpose_rule) @@ -841,7 +839,6 @@ def pos_reduce(x): pmax_p.multiple_results = True pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max)) pmax_p.def_abstract_eval(_allreduce_abstract_eval) -xla.register_collective_primitive(pmax_p) mlir.register_lowering( pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max)) batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p) @@ -854,7 +851,6 @@ def pos_reduce(x): pmin_p.multiple_results = True pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min)) pmin_p.def_abstract_eval(_allreduce_abstract_eval) -xla.register_collective_primitive(pmin_p) mlir.register_lowering( pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min)) batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p) @@ -923,7 +919,6 @@ def _collective_batcher(prim, args, dims, **params): ppermute_p = core.AxisPrimitive('ppermute') ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) ad.deflinear2(ppermute_p, _ppermute_transpose_rule) -xla.register_collective_primitive(ppermute_p) mlir.register_lowering(ppermute_p, _ppermute_lowering) batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p) batching.axis_primitive_batchers[ppermute_p] = _ppermute_batcher @@ -1074,7 +1069,6 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_ all_to_all_p = core.AxisPrimitive('all_to_all') all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval) -xla.register_collective_primitive(all_to_all_p) mlir.register_lowering(all_to_all_p, _all_to_all_lowering) ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule) batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher @@ -1286,7 +1280,6 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, all_gather_p = core.AxisPrimitive('all_gather') all_gather_p.def_abstract_eval(_all_gather_abstract_eval) all_gather_p.def_impl(_all_gather_impl) -xla.register_collective_primitive(all_gather_p) mlir.register_lowering(all_gather_p, _all_gather_lowering) ad.deflinear2(all_gather_p, _all_gather_transpose_rule) batching.primitive_batchers[all_gather_p] = _all_gather_batcher @@ -1455,7 +1448,6 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in, ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule) batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective -xla.register_collective_primitive(reduce_scatter_p) mlir.register_lowering( reduce_scatter_p, partial(_reduce_scatter_lowering, lax.add_p, psum)) @@ -1575,7 +1567,6 @@ def _axis_index_abstract_eval(*, axis_name): return ShapedArray((), np.int32, named_shape={axis_name: frame.size}) axis_index_p = core.Primitive('axis_index') -xla.register_collective_primitive(axis_index_p) mlir.register_lowering(axis_index_p, _axis_index_lowering) axis_index_p.def_abstract_eval(_axis_index_abstract_eval) core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name') @@ -1666,7 +1657,6 @@ def _pdot_lowering(x, y, *, axis_name, pos_contract, pos_batch, precision): precision=precision, preferred_element_type=None) return psum(local_out, axis_name) if axis_name is not None else local_out -xla.register_collective_primitive(pdot_p) mlir.register_lowering( pdot_p, mlir.lower_fun(_pdot_lowering, multiple_results=False)) @@ -1756,7 +1746,6 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a pgather_p = core.AxisPrimitive('pgather') pgather_p.def_impl(_pgather_impl) pgather_p.def_abstract_eval(_pgather_abstract_eval) -xla.register_collective_primitive(pgather_p) mlir.register_lowering(pgather_p, _pgather_parallel_lowering) # TODO: Transpose? That requires adding pscatter... batching.primitive_batchers[pgather_p] = _pgather_batcher diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index d6a1b314108b..697ab75f57f3 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -20,7 +20,6 @@ canonicalize_dtype as canonicalize_dtype, canonicalize_dtype_handlers as canonicalize_dtype_handlers, pytype_aval_mappings as pytype_aval_mappings, - register_collective_primitive as register_collective_primitive, register_translation as register_translation, translations as translations, xla_destructure as xla_destructure,