Skip to content

Commit

Permalink
Merge remote-tracking branch 'jczaja/jczaja/xpu-support'
Browse files Browse the repository at this point in the history
  • Loading branch information
Silv3S committed Feb 5, 2024
2 parents d2186c9 + 262ae16 commit e5f0603
Show file tree
Hide file tree
Showing 34 changed files with 2,503 additions and 25 deletions.
2 changes: 2 additions & 0 deletions mpi4jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
send,
sendrecv,
has_cuda_support,
has_sycl_support,
)

__all__ = [
Expand All @@ -36,4 +37,5 @@
"send",
"sendrecv",
"has_cuda_support",
"has_sycl_support",
]
2 changes: 1 addition & 1 deletion mpi4jax/_src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .collective_ops.send import send # noqa: F401, E402
from .collective_ops.sendrecv import sendrecv # noqa: F401, E402

from .utils import has_cuda_support # noqa: F401, E402
from .utils import has_cuda_support, has_sycl_support # noqa: F401, E402

# sanitize namespace
del jax_compat, xla_bridge, MPI, atexit, flush
51 changes: 50 additions & 1 deletion mpi4jax/_src/collective_ops/allgather.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
prefer_notoken,
)
from ..jax_compat import custom_call, token_type, ShapedArray
from ..decorators import translation_rule_cpu, translation_rule_gpu
from ..decorators import translation_rule_cpu, translation_rule_gpu, translation_rule_xpu
from ..validation import enforce_types
from ..comm import get_default_comm

Expand Down Expand Up @@ -127,6 +127,54 @@ def mpi_allgather_xla_encode_cpu(ctx, sendbuf, token, comm):
has_side_effect=True,
).results

@translation_rule_xpu
def mpi_allgather_xla_encode_xpu(ctx, sendbuf, token, comm):
from ..xla_bridge.mpi_xla_bridge_xpu import build_allgather_descriptor

comm = unpack_hashable(comm)

sendbuf_aval, *_ = ctx.avals_in
send_nptype = sendbuf_aval.dtype

send_type = ir.RankedTensorType(sendbuf.type)
send_dtype = send_type.element_type
send_dims = send_type.shape

# compute total number of elements in send array
send_nitems = _np.prod(send_dims, dtype=int)
send_dtype_handle = to_dtype_handle(send_nptype)

size = comm.Get_size()
out_shape = (size, *send_dims)

out_types = [
ir.RankedTensorType.get(out_shape, send_dtype),
*token_type(),
]

descriptor = build_allgather_descriptor(
send_nitems,
send_dtype_handle,
# we only support matching input and output arrays
send_nitems,
send_dtype_handle,
#
to_mpi_handle(comm),
)

operands = (sendbuf, token)

return custom_call(
b"mpi_allgather",
result_types=out_types,
operands=operands,
# layout matters here, because the first axis is special
operand_layouts=get_default_layouts(operands, order="c"),
result_layouts=get_default_layouts(out_types, order="c"),
backend_config=descriptor,
has_side_effect=True,
).results


@translation_rule_gpu
def mpi_allgather_xla_encode_gpu(ctx, sendbuf, token, comm):
Expand Down Expand Up @@ -194,3 +242,4 @@ def mpi_allgather_abstract_eval(x, token, comm):

mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cpu, platform="cpu")
mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_gpu, platform="cuda")
mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_xpu, platform="xpu")
51 changes: 50 additions & 1 deletion mpi4jax/_src/collective_ops/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
prefer_notoken,
)
from ..jax_compat import custom_call, token_type, ShapedArray
from ..decorators import translation_rule_cpu, translation_rule_gpu
from ..decorators import translation_rule_cpu, translation_rule_gpu, translation_rule_xpu
from ..validation import enforce_types
from ..comm import get_default_comm

Expand Down Expand Up @@ -170,6 +170,54 @@ def mpi_allreduce_xla_encode_gpu(ctx, x, token, op, comm, transpose):
backend_config=descriptor,
).results

@translation_rule_xpu
def mpi_allreduce_xla_encode_xpu(ctx, x, token, op, comm, transpose):
from ..xla_bridge.mpi_xla_bridge_xpu import build_allreduce_descriptor

op = unpack_hashable(op)
comm = unpack_hashable(comm)

if transpose:
assert op == _MPI.SUM
return [x, token]

x_aval, *_ = ctx.avals_in
x_nptype = x_aval.dtype

x_type = ir.RankedTensorType(x.type)
dtype = x_type.element_type
dims = x_type.shape

# compute total number of elements in array
nitems = _np.prod(dims, dtype=int)

out_types = [
ir.RankedTensorType.get(dims, dtype),
*token_type(),
]

operands = (
x,
token,
)

descriptor = build_allreduce_descriptor(
_np.intc(nitems),
to_mpi_handle(op),
to_mpi_handle(comm),
to_dtype_handle(x_nptype),
)

return custom_call(
b"mpi_allreduce",
result_types=out_types,
operands=operands,
operand_layouts=get_default_layouts(operands),
result_layouts=get_default_layouts(out_types),
has_side_effect=True,
backend_config=descriptor,
).results


# This function evaluates only the shapes during AST construction
def mpi_allreduce_abstract_eval(xs, token, op, comm, transpose):
Expand Down Expand Up @@ -230,3 +278,4 @@ def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm, transpose):
# assign to the primitive the correct encoder
mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cpu, platform="cpu")
mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_gpu, platform="cuda")
mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_xpu, platform="xpu")
52 changes: 51 additions & 1 deletion mpi4jax/_src/collective_ops/alltoall.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
prefer_notoken,
)
from ..jax_compat import custom_call, token_type, ShapedArray
from ..decorators import translation_rule_cpu, translation_rule_gpu
from ..decorators import translation_rule_cpu, translation_rule_gpu, translation_rule_xpu
from ..validation import enforce_types
from ..comm import get_default_comm

Expand Down Expand Up @@ -128,6 +128,55 @@ def mpi_alltoall_xla_encode_cpu(ctx, x, token, comm):
has_side_effect=True,
).results

@translation_rule_xpu
def mpi_alltoall_xla_encode_xpu(ctx, x, token, comm):
from ..xla_bridge.mpi_xla_bridge_xpu import build_alltoall_descriptor

comm = unpack_hashable(comm)

x_aval, *_ = ctx.avals_in
x_nptype = x_aval.dtype

x_type = ir.RankedTensorType(x.type)
dtype = x_type.element_type
dims = x_type.shape

# compute total number of elements in array
size = comm.Get_size()
assert dims[0] == size
nitems_per_proc = _np.prod(dims[1:], dtype=int)
dtype_handle = to_dtype_handle(x_nptype)

out_types = [
ir.RankedTensorType.get(dims, dtype),
*token_type(),
]

operands = (
x,
token,
)

descriptor = build_alltoall_descriptor(
nitems_per_proc,
dtype_handle,
# we only support matching input and output arrays
nitems_per_proc,
dtype_handle,
#
to_mpi_handle(comm),
)

return custom_call(
b"mpi_alltoall",
result_types=out_types,
operands=operands,
# force c order because first axis is special
operand_layouts=get_default_layouts(operands, order="c"),
result_layouts=get_default_layouts(out_types, order="c"),
has_side_effect=True,
backend_config=descriptor,
).results

@translation_rule_gpu
def mpi_alltoall_xla_encode_gpu(ctx, x, token, comm):
Expand Down Expand Up @@ -195,3 +244,4 @@ def mpi_alltoall_abstract_eval(xs, token, comm):
# assign to the primitive the correct encoder
mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cpu, platform="cpu")
mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_gpu, platform="cuda")
mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_xpu, platform="xpu")
24 changes: 23 additions & 1 deletion mpi4jax/_src/collective_ops/barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
prefer_notoken,
)
from ..jax_compat import custom_call, token_type
from ..decorators import translation_rule_cpu, translation_rule_gpu
from ..decorators import translation_rule_cpu, translation_rule_gpu, translation_rule_xpu
from ..validation import enforce_types
from ..comm import get_default_comm

Expand Down Expand Up @@ -88,6 +88,27 @@ def mpi_barrier_xla_encode_cpu(ctx, token, comm):
has_side_effect=True,
).results

@translation_rule_xpu
def mpi_barrier_xla_encode_xpu(ctx, token, comm):
from ..xla_bridge.mpi_xla_bridge_xpu import build_barrier_descriptor

comm = unpack_hashable(comm)

out_types = token_type()

operands = (token,)

descriptor = build_barrier_descriptor(to_mpi_handle(comm))

return custom_call(
b"mpi_barrier",
result_types=out_types,
operands=operands,
operand_layouts=get_default_layouts(operands),
result_layouts=get_default_layouts(out_types),
has_side_effect=True,
backend_config=descriptor,
).results

@translation_rule_gpu
def mpi_barrier_xla_encode_gpu(ctx, token, comm):
Expand Down Expand Up @@ -131,3 +152,4 @@ def mpi_barrier_batch_eval(in_args, batch_axes, comm):
# assign to the primitive the correct encoder
mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cpu, platform="cpu")
mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_gpu, platform="cuda")
mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_xpu, platform="xpu")
51 changes: 50 additions & 1 deletion mpi4jax/_src/collective_ops/bcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
prefer_notoken,
)
from ..jax_compat import custom_call, token_type, ShapedArray
from ..decorators import translation_rule_cpu, translation_rule_gpu
from ..decorators import translation_rule_cpu, translation_rule_gpu, translation_rule_xpu
from ..validation import enforce_types
from ..comm import get_default_comm

Expand Down Expand Up @@ -125,6 +125,54 @@ def mpi_bcast_xla_encode_cpu(ctx, x, token, root, comm):
has_side_effect=True,
).results

@translation_rule_xpu
def mpi_bcast_xla_encode_xpu(ctx, x, token, root, comm):
from ..xla_bridge.mpi_xla_bridge_xpu import build_bcast_descriptor

comm = unpack_hashable(comm)

x_aval, *_ = ctx.avals_in
x_nptype = x_aval.dtype

x_type = ir.RankedTensorType(x.type)
dtype = x_type.element_type
dims = x_type.shape

# compute total number of elements in array
nitems = _np.prod(dims, dtype=int)
dtype_handle = to_dtype_handle(x_nptype)

# output is not used on root, so prevent memory allocation
rank = comm.Get_rank()
if rank == root:
dims = (0,)

out_types = [
ir.RankedTensorType.get(dims, dtype),
*token_type(),
]

operands = (
x,
token,
)

descriptor = build_bcast_descriptor(
nitems,
root,
to_mpi_handle(comm),
dtype_handle,
)

return custom_call(
b"mpi_bcast",
result_types=out_types,
operands=operands,
operand_layouts=get_default_layouts(operands),
result_layouts=get_default_layouts(out_types),
has_side_effect=True,
backend_config=descriptor,
).results

@translation_rule_gpu
def mpi_bcast_xla_encode_gpu(ctx, x, token, root, comm):
Expand Down Expand Up @@ -199,3 +247,4 @@ def mpi_bcast_abstract_eval(xs, token, root, comm):
# assign to the primitive the correct encoder
mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cpu, platform="cpu")
mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_gpu, platform="cuda")
mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_xpu, platform="xpu")
Loading

0 comments on commit e5f0603

Please sign in to comment.