Skip to content

Commit

Permalink
Remove jax.interpreters.xla.register_collective_primitive.
Browse files Browse the repository at this point in the history
We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless.

PiperOrigin-RevId: 561150259
  • Loading branch information
hawkinsp authored and jax authors committed Aug 29, 2023
1 parent 289ccad commit 9390024
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 17 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ Remember to align the itemized text with the first line of an item within a list
* `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11.
Use the built-in `math.prod` instead.

* Internal deprecations:
* Internal deprecations/removals:
* The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype`
have been removed. Opaque dtypes have been renamed to Extended dtypes; use
`jnp.issubdtype(dtype, jax.dtypes.extended)` instead (available since jax v0.4.14).
* The utility `jax.interpreters.xla.register_collective_primitive` has been
removed. This utility did nothing useful in recent JAX releases and calls
to it can be safely removed.

## jaxlib 0.4.15

Expand Down
4 changes: 0 additions & 4 deletions jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,11 @@ def __call__(self, ctx: TranslationContext,
_backend_specific_translations: dict[str, dict[core.Primitive, TranslationRule]]
_backend_specific_translations = defaultdict(dict)

_collective_primitives: set[core.Primitive] = set()
initial_style_primitives: set[core.Primitive] = set()

def register_initial_style_primitive(prim: core.Primitive):
initial_style_primitives.add(prim)

def register_collective_primitive(prim: core.Primitive):
_collective_primitives.add(prim)

def register_translation(prim: core.Primitive, rule: TranslationRule, *,
platform: Optional[str] = None) -> None:
if platform is None:
Expand Down
11 changes: 0 additions & 11 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lib.mlir import ir
Expand Down Expand Up @@ -805,7 +804,6 @@ def broadcast_positional(ct, arg):
psum_p.multiple_results = True
psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum))
psum_p.def_abstract_eval(_allreduce_abstract_eval)
xla.register_collective_primitive(psum_p)
mlir.register_lowering(
psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum))
ad.deflinear2(psum_p, _psum_transpose_rule)
Expand Down Expand Up @@ -841,7 +839,6 @@ def pos_reduce(x):
pmax_p.multiple_results = True
pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max))
pmax_p.def_abstract_eval(_allreduce_abstract_eval)
xla.register_collective_primitive(pmax_p)
mlir.register_lowering(
pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max))
batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p)
Expand All @@ -854,7 +851,6 @@ def pos_reduce(x):
pmin_p.multiple_results = True
pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min))
pmin_p.def_abstract_eval(_allreduce_abstract_eval)
xla.register_collective_primitive(pmin_p)
mlir.register_lowering(
pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min))
batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p)
Expand Down Expand Up @@ -923,7 +919,6 @@ def _collective_batcher(prim, args, dims, **params):
ppermute_p = core.AxisPrimitive('ppermute')
ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
xla.register_collective_primitive(ppermute_p)
mlir.register_lowering(ppermute_p, _ppermute_lowering)
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p)
batching.axis_primitive_batchers[ppermute_p] = _ppermute_batcher
Expand Down Expand Up @@ -1074,7 +1069,6 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_

all_to_all_p = core.AxisPrimitive('all_to_all')
all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval)
xla.register_collective_primitive(all_to_all_p)
mlir.register_lowering(all_to_all_p, _all_to_all_lowering)
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher
Expand Down Expand Up @@ -1286,7 +1280,6 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
all_gather_p = core.AxisPrimitive('all_gather')
all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
all_gather_p.def_impl(_all_gather_impl)
xla.register_collective_primitive(all_gather_p)
mlir.register_lowering(all_gather_p, _all_gather_lowering)
ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
batching.primitive_batchers[all_gather_p] = _all_gather_batcher
Expand Down Expand Up @@ -1455,7 +1448,6 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in,
ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule)
batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher
batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective
xla.register_collective_primitive(reduce_scatter_p)
mlir.register_lowering(
reduce_scatter_p,
partial(_reduce_scatter_lowering, lax.add_p, psum))
Expand Down Expand Up @@ -1575,7 +1567,6 @@ def _axis_index_abstract_eval(*, axis_name):
return ShapedArray((), np.int32, named_shape={axis_name: frame.size})

axis_index_p = core.Primitive('axis_index')
xla.register_collective_primitive(axis_index_p)
mlir.register_lowering(axis_index_p, _axis_index_lowering)
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name')
Expand Down Expand Up @@ -1666,7 +1657,6 @@ def _pdot_lowering(x, y, *, axis_name, pos_contract, pos_batch, precision):
precision=precision, preferred_element_type=None)
return psum(local_out, axis_name) if axis_name is not None else local_out

xla.register_collective_primitive(pdot_p)
mlir.register_lowering(
pdot_p,
mlir.lower_fun(_pdot_lowering, multiple_results=False))
Expand Down Expand Up @@ -1756,7 +1746,6 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a
pgather_p = core.AxisPrimitive('pgather')
pgather_p.def_impl(_pgather_impl)
pgather_p.def_abstract_eval(_pgather_abstract_eval)
xla.register_collective_primitive(pgather_p)
mlir.register_lowering(pgather_p, _pgather_parallel_lowering)
# TODO: Transpose? That requires adding pscatter...
batching.primitive_batchers[pgather_p] = _pgather_batcher
Expand Down
1 change: 0 additions & 1 deletion jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
canonicalize_dtype as canonicalize_dtype,
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
pytype_aval_mappings as pytype_aval_mappings,
register_collective_primitive as register_collective_primitive,
register_translation as register_translation,
translations as translations,
xla_destructure as xla_destructure,
Expand Down

0 comments on commit 9390024

Please sign in to comment.