diff --git a/src/mpi-base.jl b/src/mpi-base.jl index 2afd46d5a..4617c7a33 100644 --- a/src/mpi-base.jl +++ b/src/mpi-base.jl @@ -1,439 +1,378 @@ -typealias MpiDatatype Union(Float32, Float64, Complex64, Complex128, Char, - Int8, Uint8, Int16, Uint16, Int32, Uint32, Int64, - Uint64) - -const _mpi_datatype_map = { - Float32 => MPI_REAL4, - Float64 => MPI_REAL8, - Complex64 => MPI_COMPLEX, - Complex128 => MPI_DOUBLE_COMPLEX, - Char => MPI_WCHAR, - Int8 => MPI_INT8_T, - Uint8 => MPI_UINT8_T, - Int16 => MPI_INT16_T, - Uint16 => MPI_UINT16_T, - Int32 => MPI_INT32_T, - Uint32 => MPI_UINT32_T, - Int64 => MPI_INT64_T, - Uint64 => MPI_UINT64_T, -# The following doesn't not seem to work with my openmpi build -# Int128 => MPI_INTEGER16, -# Unfortunately mpich2 doesn't support 8-bits ints -# Bool => MPI_LOGICAL1, -} - -for (elty, elfn, elnl, elex) in - ((:Comm, :MPI_COMM_FREE, :MPI_COMM_NULL , None), - (:Request, :MPI_REQUEST_FREE, :MPI_REQUEST_NULL, Any ), - (:Operation, :MPI_OP_FREE, :MPI_OP_NULL , None)) - @eval begin - type ($elty) - fval::Int32 - extra::$elex - - function ($elty)(b::Int32) - a = new(b) - # TODO: Call MPI_Finalized in the free routine, then - # re-enable this - # finalizer(a, free) - a - end - end - - function ($elty)(b::Int32, extra) - a = ($elty)(b) - a.extra = extra - a - end - - convert(::Type{$elty}, x::Int32) = ($elty)(x) - convert(::Type{Int32}, x::($elty)) = Int32(x.fval) - - function isequal(a::($elty), b::($elty)) - isequal(a.fval, b.fval) - end - - function free(el::($elty)) - if el.fval != $elnl - ierr = Array(Int32, 1) - ccall($elfn, Void, (Ptr{Int32},Ptr{Int32},), &el.fval, ierr) - - if ierr[1] != MPI_SUCCESS - elfn_str = $string(elfn) - warn("$elfn_str: error $(ierr[1])") - end - end - end - end +typealias MPIDatatype Union(Char, + Int8, Uint8, Int16, Uint16, Int32, Uint32, Int64, + Uint64, + Float32, Float64, Complex64, Complex128) + +const datatypes = + {Char => MPI_WCHAR, + Int8 => MPI_INT8_T, + Uint8 => MPI_UINT8_T, + Int16 => MPI_INT16_T, + Uint16 => MPI_UINT16_T, + Int32 => MPI_INT32_T, + Uint32 => MPI_UINT32_T, + Int64 => MPI_INT64_T, + Uint64 => MPI_UINT64_T, + Float32 => MPI_REAL4, + Float64 => MPI_REAL8, + Complex64 => MPI_COMPLEX8, + Complex128 => MPI_COMPLEX16} + +type Comm + val::Int32 end - +const COMM_NULL = Comm(MPI_COMM_NULL) const COMM_SELF = Comm(MPI_COMM_SELF) const COMM_WORLD = Comm(MPI_COMM_WORLD) -const COMM_NULL = Comm(MPI_COMM_NULL) -const OP_NULL = Operation(MPI_OP_NULL) -const MAX = Operation(MPI_MAX ) -const MIN = Operation(MPI_MIN ) -const SUM = Operation(MPI_SUM ) -const PROD = Operation(MPI_PROD ) -const LAND = Operation(MPI_LAND ) -const BAND = Operation(MPI_BAND ) -const LOR = Operation(MPI_LOR ) -const BOR = Operation(MPI_BOR ) -const LXOR = Operation(MPI_LXOR ) -const BXOR = Operation(MPI_BXOR ) -##The following are not supported yet -#const MAXLOC = Operation(MPI_MAXLOC ) -#const MINLOC = Operation(MPI_MINLOC ) -#const REPLACE = Operation(MPI_REPLACE) +type Op + val::Int32 +end +const OP_NULL = Op(MPI_OP_NULL) +const BAND = Op(MPI_BAND) +const BOR = Op(MPI_BOR) +const BXOR = Op(MPI_BXOR) +const LAND = Op(MPI_LAND) +const LOR = Op(MPI_LOR) +const LXOR = Op(MPI_LXOR) +const MAX = Op(MPI_MAX) +const MIN = Op(MPI_MIN) +const PROD = Op(MPI_PROD) +const SUM = Op(MPI_SUM) + +type Request + val::Int32 + buffer +end +const REQUEST_NULL = Request(MPI_REQUEST_NULL, nothing) + +type Status + val::Array{Int32,1} + Status() = new(Array(Int32, MPI_STATUS_SIZE)) +end +Get_error(stat::Status) = stat.val[MPI_ERROR] +Get_source(stat::Status) = stat.val[MPI_SOURCE] +Get_tag(stat::Status) = stat.val[MPI_TAG] const ANY_SOURCE = MPI_ANY_SOURCE const ANY_TAG = MPI_ANY_TAG const TAG_UB = MPI_TAG_UB +const UNDEFINED = MPI_UNDEFINED -const STATUS_SIZE = MPI_STATUS_SIZE -const SOURCE = MPI_SOURCE -const TAG = MPI_TAG -const ERROR = MPI_ERROR -const REQUEST_NULL = Request(MPI_REQUEST_NULL) - -macro _mpi_error_check(ierr, fname) - # By default, MPI aborts if there is an error, so we skip the error check - # :($(ierr) == MPI_SUCCESS ? nothing : error("MPI error in:", $fname, - # "---", $ierr)) -end - -takebuf_array(s::IOStream) = - ccall(:jl_takebuf_array, Vector{Uint8}, (Ptr{Void},), s.ios) - -function _mpi_serialize(x) +function serialize(x) s = IOBuffer() - serialize(s, x) + Base.serialize(s, x) Base.takebuf_array(s) end -function _mpi_deserialize(x) - s = IOBuffer() - write(s, x) - seek(s, 0) - y = deserialize(s) - y -end - -for (fn, ff) in ((:init, :MPI_INIT), - (:finalize, :MPI_FINALIZE)) - @eval begin - function ($fn)() - ierr = Array(Int32, 1) - ccall($ff, Void, (Ptr{Int32},), ierr) - @_mpi_error_check ierr[1] $string(ff) - end - end +function deserialize(x) + s = IOBuffer(x) + Base.deserialize(s) end -function abort(c::Comm, errc::Integer) - ierr = Array(Int32, 1) - ccall(MPI_ABORT, Void, (Ptr{Int32},Ptr{Int32},Ptr{Int32},), - &c.fval, &errc, ierr) - @_mpi_error_check ierr[1] "MPI_ABORT" -end - -for (fn, ff) in ((:rank, :MPI_COMM_RANK), - (:size, :MPI_COMM_SIZE)) - @eval begin - function ($fn)(c::Comm) - ierr = Array(Int32, 1) - valu = Array(Int32, 1) - ccall($ff, Void, (Ptr{Int32}, Ptr{Int32}, Ptr{Int32},), - &c.fval, valu, ierr) - @_mpi_error_check ierr[1] $string(ff) - valu[1] - end - end -end +# Administrative functions -function barrier(c::Comm) - ierr = Array(Int32, 1) - ccall(MPI_BARRIER, Void, (Ptr{Int32},Ptr{Int32},), &c.fval, ierr) - @_mpi_error_check ierr[1] "MPI_BARRIER" +function Init() + ccall(MPI_INIT, Void, (Ptr{Int32},), &0) end -function Bcast!{T<:MpiDatatype}(A::Union(Ptr{T},Array{T}), count::Integer, - root::Integer, c::Comm) - ierr = Array(Int32, 1) - - ccall(MPI_BCAST, Void, - (Ptr{T}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32},), - A, &count, &_mpi_datatype_map[T], &root, &c.fval, ierr) - - @_mpi_error_check ierr[1] "MPI_BCAST" - A +function Finalize() + ccall(MPI_FINALIZE, Void, (Ptr{Int32},), &0) end -function Bcast!{T<:MpiDatatype}(A::Array{T}, root::Integer, c::Comm) - Bcast!(A, length(A), root, c) +function Abort(comm::Comm, errcode::Integer) + ccall(MPI_ABORT, Void, (Ptr{Int32},Ptr{Int32},Ptr{Int32}), + &comm.val, &errcode, &0) end -function bcast(A, root::Integer, c::Comm) - ierr = Array(Int32, 1) - len = Array(Int32, 1) - - if rank(c) == root - buf = _mpi_serialize(A) - len[1] = length(buf) - end - - ccall(MPI_BCAST, Void, - (Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, - Ptr{Int32},), - len, &sizeof(Int32), &MPI_BYTE, &root, &c.fval, ierr) - - @_mpi_error_check ierr[1] "MPI_BCAST" - - if rank(c) != root - buf = Array(Uint8, len[1]) - end - - ccall(MPI_BCAST, Void, - (Ptr{Uint8}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, - Ptr{Int32},), - buf, len, &MPI_BYTE, &root, &c.fval, ierr) - - @_mpi_error_check ierr[1] "MPI_BCAST" +function Initialized() + flag = Array(Int32, 1) + ccall(MPI_INITIALIZED, Void, (Ptr{Int32},Ptr{Int32}), flag, &0) + bool(flag[1]) +end - if rank(c) != root - _mpi_deserialize(buf) - else - A - end +function Finalized() + flag = Array(Int32, 1) + ccall(MPI_FINALIZED, Void, (Ptr{Int32},Ptr{Int32}), flag, &0) + bool(flag[1]) end -function Reduce{T<:MpiDatatype}(A::Union(Ptr{T},Array{T}), count::Integer, - op::Operation, root::Integer, c::Comm) - ierr = Array(Int32, 1) +function Comm_rank(comm::Comm) + rank = Array(Int32, 1) + ccall(MPI_COMM_RANK, Void, (Ptr{Int32}, Ptr{Int32}, Ptr{Int32}), + &comm.val, rank, &0) + rank[1] +end - if MPI.rank(c) == root - B = Array(T, count) - else - B = Array(T, 0) - end +function Comm_size(comm::Comm) + size = Array(Int32, 1) + ccall(MPI_COMM_SIZE, Void, (Ptr{Int32}, Ptr{Int32}, Ptr{Int32}), + &comm.val, size, &0) + size[1] +end - ccall(MPI_REDUCE, Void, - (Ptr{T}, Ptr{T}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, - Ptr{Int32}, Ptr{Int32}, Ptr{Int32},), - A, B, &count, &_mpi_datatype_map[T], &op.fval, &root, &c.fval, ierr) +# Point-to-point communication - @_mpi_error_check ierr[1] "MPI_REDUCE" +function Probe(src::Integer, tag::Integer, comm::Comm) + stat = Status() + ccall(MPI_PROBE, Void, + (Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}), + &src, &tag, &comm.val, stat.val, &0) + stat +end - if MPI.rank(c) != root - B = nothing +function Iprobe(src::Integer, tag::Integer, comm::Comm) + flag = Array(Int32, 1) + stat = Status() + ccall(MPI_IPROBE, Void, + (Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, + Ptr{Int32}), + &src, &tag, &comm.val, flag, stat.val, &0) + flag = bool(flag[1]) + if !flag + return (false, nothing) end - B + (true, stat) end -function Reduce{T<:MpiDatatype}(A::Array{T}, op::Operation, root::Integer, - c::Comm) - Reduce(A, length(A), op, root, c) +function Get_count{T<:MPIDatatype}(stat::Status, ::Type{T}) + count = Array(Int32, 1) + ccall(MPI_GET_COUNT, Void, (Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}), + stat.val, &datatypes[T], count, &0) + count[1] end -function Reduce{T<:MpiDatatype}(A::T, op::Operation, root::Integer, c::Comm) - A1 = T[A] - B1 = Reduce(A1, op, root, c) - if MPI.rank(c) == root - B1[1] - else - nothing - end +function Send{T<:MPIDatatype}(buf::Union(Ptr{T},Array{T}), count::Integer, + dest::Integer, tag::Integer, comm::Comm) + ccall(MPI_ISEND, Void, + (Ptr{T}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, + Ptr{Int32}), + buf, &count, &datatypes[T], &dest, &tag, &comm.val, &0) end -for (fnnm, ffnm) in ((:Isend!, :MPI_ISEND), (:Irecv!, :MPI_IRECV)) - @eval begin - function ($fnnm){T<:MpiDatatype}(A::Union(Ptr{T},Array{T}), - count::Integer, srcdest::Integer, - tag::Integer, c::Comm) - ierr = Array(Int32, 1) - req = Array(Int32, 1) - - ccall(($ffnm), Void, - (Ptr{T}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, - Ptr{Int32}, Ptr{Int32}, Ptr{Int32}), - A, &count, &_mpi_datatype_map[T], &srcdest, &tag, &c.fval, - req, ierr) - - @_mpi_error_check ierr[1] $string(ffnm) - Request(req[1]) - end - - function ($fnnm){T<:MpiDatatype}(A::Array{T}, srcdest::Integer, - tag::Integer, c::Comm) - ($fnnm)(A, length(A), srcdest, tag, c) - end - end +function Send{T<:MPIDatatype}(buf::Array{T}, dest::Integer, tag::Integer, + comm::Comm) + Send(buf, length(buf), dest, tag, comm) end -function send(A, dest::Integer, tag::Integer, c::Comm) - ierr = Array(Int32, 1) - - buf = _mpi_serialize(A) - - ccall(MPI_SEND, Void, - (Ptr{Uint8}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, - Ptr{Int32}, Ptr{Int32}), - buf, &length(buf), &MPI_BYTE, &dest, &tag, &c.fval, ierr) - - @_mpi_error_check ierr[1] "MPI_SEND" +#= +function Send{T<:MPIDatatype}(obj::T, dest::Integer, tag::Integer, comm::Comm) + buf = [obj] + Send(buf, dest, tag, comm) end +=# -function isend(A, dest::Integer, tag::Integer, c::Comm) - ierr = Array(Int32, 1) - req = Array(Int32, 1) - - buf = _mpi_serialize(A) +function send(obj, dest::Integer, tag::Integer, comm::Comm) + buf = serialize(obj) + Send(buf, dest, tag, comm) +end +function Isend{T<:MPIDatatype}(buf::Union(Ptr{T},Array{T}), count::Integer, + dest::Integer, tag::Integer, comm::Comm) + rval = Array(Int32, 1) ccall(MPI_ISEND, Void, - (Ptr{Uint8}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, - Ptr{Int32}, Ptr{Int32}), - buf, &length(buf), &MPI_BYTE, &dest, &tag, &c.fval, req, ierr) + (Ptr{T}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, + Ptr{Int32}, Ptr{Int32}), + buf, &count, &datatypes[T], &dest, &tag, &comm.val, rval, &0) + Request(rval[1], buf) +end - @_mpi_error_check ierr[1] "MPI_ISEND" - Request(req[1], buf) +function Isend{T<:MPIDatatype}(buf::Array{T}, dest::Integer, tag::Integer, + comm::Comm) + Isend(buf, length(buf), dest, tag, comm) end -function probe(source::Integer, tag::Integer, c::Comm) - ierr = Array(Int32, 1) - stat = Array(Int32, MPI_STATUS_SIZE) +#= +function Isend{T<:MPIDatatype}(obj::T, dest::Integer, tag::Integer, comm::Comm) + buf = [obj] + Isend(buf, dest, tag, comm) +end +=# - ccall(MPI_PROBE, Void, - (Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}), - &source, &tag, &c.fval, stat, ierr) +function isend(obj, dest::Integer, tag::Integer, comm::Comm) + buf = serialize(obj) + Isend(buf, dest, tag, comm) +end - @_mpi_error_check ierr[1] "MPI_PROBE" +function Recv!{T<:MPIDatatype}(buf::Union(Ptr{T},Array{T}), count::Integer, + src::Integer, tag::Integer, comm::Comm) + stat = Status() + ccall(MPI_RECV, Void, + (Ptr{T}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, + Ptr{Int32}, Ptr{Int32}), + buf, &count, &datatypes[T], &src, &tag, &comm.val, stat.val, + &0) stat end -function iprobe(source::Integer, tag::Integer, c::Comm) - ierr = Array(Int32, 1) - flag = Array(Int32, 1) - stat = Array(Int32, MPI_STATUS_SIZE) - - ccall(MPI_IPROBE, Void, - (Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, - Ptr{Int32}), - &source, &tag, &c.fval, flag, stat, ierr) - - @_mpi_error_check ierr[1] "MPI_IPROBE" - (bool(flag[1]), stat) +function Recv!{T<:MPIDatatype}(buf::Array{T}, src::Integer, tag::Integer, + comm::Comm) + Recv!(buf, length(buf), src, tag, comm) end -function get_count{T<:MpiDatatype}(status::Array{Int32,1}, ::Type{T}) - @assert length(status) >= MPI_STATUS_SIZE - ierr = Array(Int32, 1) - count = Array(Int32, 1) - - ccall(MPI_GET_COUNT, Void, (Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}), - status, &_mpi_datatype_map[T], count, ierr) - - @_mpi_error_check ierr[1] "MPI_GET_COUNT" - count[1] +#= +function Recv{T<:MPIDatatype}(::Type{T}, src::Integer, tag::Integer, comm::Comm) + buf = Array(T, 1) + Recv!(buf, src, tag, comm) + buf[1] end +=# -function recv!(source::Integer, tag::Integer, c::Comm, status::Array{Int32,1}) - @assert length(status) >= MPI_STATUS_SIZE - ierr = Array(Int32, 1) - pstat = probe(source, tag, c) - count = get_count(pstat, Uint8) - +function recv(src::Integer, tag::Integer, comm::Comm) + stat = Probe(src, tag, comm) + count = Get_count(stat, Uint8) buf = Array(Uint8, count) + stat = Recv!(buf, Get_source(stat), Get_tag(stat), comm) + (deserialize(buf), stat) +end - ccall(MPI_RECV, Void, - (Ptr{Uint8}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, - Ptr{Int32}, Ptr{Int32}, Ptr{Int32}), - buf, &length(buf), &MPI_BYTE, &source, &tag, &c.fval, status, ierr) +function Irecv!{T<:MPIDatatype}(buf::Union(Ptr{T},Array{T}), count::Integer, + src::Integer, tag::Integer, comm::Comm) + rval = Array(Int32, 1) + ccall(MPI_IRECV, Void, + (Ptr{T}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, + Ptr{Int32}, Ptr{Int32}), + buf, &count, &datatypes[T], &src, &tag, &comm.val, rval, &0) + Request(rval[1], buf) +end - @_mpi_error_check ierr[1] "MPI_RECV" +function Irecv!{T<:MPIDatatype}(buf::Array{T}, src::Integer, tag::Integer, + comm::Comm) + Irecv!(buf, length(buf), src, tag, comm) +end - _mpi_deserialize(buf) +function irecv(src::Integer, tag::Integer, comm::Comm) + (flag, stat) = Iprobe(src, tag, comm) + if !flag + return (false, nothing, nothing) + end + count = Get_count(stat, Uint8) + buf = Array(Uint8, count) + stat = Recv!(buf, Get_source(stat), Get_tag(stat), comm) + (true, deserialize(buf), stat) end -function recv(source::Integer, tag::Integer, c::Comm) - stat = Array(Int32, MPI_STATUS_SIZE) - recv!(source, tag, c, stat) +function Wait!(req::Request) + stat = Status() + ccall(MPI_WAIT, Void, (Ptr{Int32},Ptr{Int32},Ptr{Int32}), + &req.val, stat.val, &0) + req.buffer = nothing + stat end -function iprobe_recv!(source::Integer, tag::Integer, c::Comm, - status::Array{Int32,1}) - @assert length(status) >= MPI_STATUS_SIZE - ierr = Array(Int32, 1) - (flag,pstat) = iprobe(source, tag, c) +function Test!(req::Request) + flag = Array(Int32, 1) + stat = Status() + ccall(MPI_TEST, Void, (Ptr{Int32},Ptr{Int32},Ptr{Int32},Ptr{Int32}), + &req.val, flag, stat.val, &0) + flag = bool(flag[1]) if !flag return (false, nothing) end - count = get_count(pstat, Uint8) - - buf = Array(Uint8, count[1]) - - ccall(MPI_RECV, Void, - (Ptr{Uint8}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, - Ptr{Int32}, Ptr{Int32}, Ptr{Int32}), - buf, &length(buf), &MPI_BYTE, &source, &tag, &c.fval, status, ierr) - - @_mpi_error_check ierr[1] "MPI_RECV" - - (true, _mpi_deserialize(buf)) + req.buffer = nothing + (true, stat) end -function iprobe_recv(source::Integer, tag::Integer, c::Comm) - stat = Array(Int32, MPI_STATUS_SIZE) - iprobe_recv!(source, tag, c, stat) +function Waitall!(reqs::Array{Request,1}) + count = length(reqs) + reqvals = [reqs[i].val for i in 1:count] + statvals = Array(Int32, MPI_STATUS_SIZE, count) + ccall(MPI_WAITALL, Void, (Ptr{Int32},Ptr{Int32},Ptr{Int32},Ptr{Int32}), + &count, reqvals, statvals, &0) + stats = Array(Status, count) + for i in 1:count + reqs[i].val = reqvals[i] + reqs[i].buffer = nothing + stats[i] = Status() + stats[i].val[:] = statvals[:,i] + end + stats end -function test!(req::Request) +function Testany!(reqs::Array{Request,1}) + count = length(reqs) + reqvals = [reqs[i].val for i in 1:count] + index = Array(Int32, 1) flag = Array(Int32, 1) - stat = Array(Int32, MPI_STATUS_SIZE) - ierr = Array(Int32, 1) - - ccall(MPI_TEST, Void, (Ptr{Int32},Ptr{Int32},Ptr{Int32},Ptr{Int32},), - &req.fval, flag, stat, ierr) - - @_mpi_error_check ierr[1] "MPI_WAIT" - + stat = Status() + ccall(MPI_TESTANY, Void, + (Ptr{Int32},Ptr{Int32},Ptr{Int32},Ptr{Int32},Ptr{Int32},Ptr{Int32}), + &count, reqvals, index, flag, stat.val, &0) flag = bool(flag[1]) - if flag - req.extra = nothing + if !flag + return (false, index, nothing) end - (flag[1], stat) + reqs[index].val = reqvals[index] + reqs[index].buffer = nothing + (true, index, stat) end -function wait!(req::Request) - stat = Array(Int32, MPI_STATUS_SIZE) - ierr = Array(Int32, 1) - - ccall(MPI_WAIT, Void, (Ptr{Int32},Ptr{Int32},Ptr{Int32},), - &fval.req, stat, ierr) - - @_mpi_error_check ierr[1] "MPI_WAIT" +# Collective communication - req.extra = nothing - stat +function Barrier(comm::Comm) + ccall(MPI_BARRIER, Void, (Ptr{Int32},Ptr{Int32}), &comm.val, &0) end -function waitall!(reqs::Array{Request}) - ierr = Array(Int32, 1) - freqs = int32([r.fval for r in reqs]) - count = length(freqs) +function Bcast!{T<:MPIDatatype}(buffer::Union(Ptr{T},Array{T}), count::Integer, + root::Integer, comm::Comm) + ccall(MPI_BCAST, Void, + (Ptr{T}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}), + buffer, &count, &datatypes[T], &root, &comm.val, &0) + buffer +end - stats = Array(Int32, MPI_STATUS_SIZE, count) +function Bcast!{T<:MPIDatatype}(buffer::Array{T}, root::Integer, comm::Comm) + Bcast!(buffer, length(buffer), root, comm) +end - ccall(MPI_WAITALL, Void, (Ptr{Int32},Ptr{Int32},Ptr{Int32},Ptr{Int32},), - &count, freqs, stats, ierr) +#= +function Bcast{T<:MPIDatatype}(obj::T, root::Integer, comm::Comm) + buf = [T] + Bcast!(buf, root, comm) + buf[1] +end +=# - @_mpi_error_check ierr[1] "MPI_WAITALL" +function bcast(obj, root::Integer, comm::Comm) + isroot = Comm_rank(comm) == root + count = Array(Int32, 1) + if isroot + buf = serialize(obj) + count[1] = length(buf) + end + Bcast!(count, root, comm) + if !isroot + buf = Array(Uint8, count[1]) + end + Bcast!(buf, root, comm) + if !isroot + obj = deserialize(buf) + end + obj +end - map((x,y)->(x.fval=y; x.extra=nothing), reqs, freqs) +function Reduce{T<:MPIDatatype}(sendbuf::Union(Ptr{T},Array{T}), count::Integer, + op::Op, root::Integer, comm::Comm) + isroot = Comm_rank(comm) == root + recvbuf = Array(T, isroot ? count : 0) + ccall(MPI_REDUCE, Void, + (Ptr{T}, Ptr{T}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, + Ptr{Int32}, Ptr{Int32}), + sendbuf, recvbuf, &count, &datatypes[T], &op.val, &root, &comm.val, + &0) + isroot ? recvbuf : nothing +end - stats +function Reduce{T<:MPIDatatype}(sendbuf::Array{T}, op::Op, root::Integer, + comm::Comm) + Reduce(sendbuf, length(sendbuf), op, root, comm) end -waitall!(req::Request) = wait!(req) +function Reduce{T<:MPIDatatype}(object::T, op::Op, root::Integer, comm::Comm) + isroot = Comm_rank(comm) == root + sendbuf = T[object] + recvbuf = Reduce(sendbuf, op, root, comm) + isroot ? recvbuf[1] : nothing +end