WIP: TrixiMPIArray
ranocha committed Mar 30, 2022
1 parent 141508f commit a178dea
Showing 3 changed files with 268 additions and 1 deletion.
Expand Up @@ -4,6 +4,7 @@ authors = ["Michael Schlottke-Lakemper <[email protected]>", "Gregor
version = "0.4.28-pre"

ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CodeTracking = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
Expand Down Expand Up @@ -41,6 +42,7 @@ TriplotRecipes = "808ab39a-a642-4abf-81ff-4cb34ebbffa3"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

ArrayInterface = "3"
CodeTracking = "1.0.5"
ConstructionBase = "1.3"
CPUSummary = "=0.1.8" # see
Expand Up @@ -25,6 +25,7 @@ using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, sparse, dropt
# import @reexport now to make it available for further imports/exports
using Reexport: @reexport

using ArrayInterface: static_length
using SciMLBase: CallbackSet, DiscreteCallback,
ODEProblem, ODESolution, ODEFunction
import SciMLBase: get_du, get_tmp_cache, u_modified!,
Expand All @@ -39,7 +40,6 @@ using HDF5: h5open, attributes
using IfElse: ifelse
using LinearMaps: LinearMap
using LoopVectorization: LoopVectorization, @turbo, indices
using LoopVectorization.ArrayInterface: static_length
using MPI: MPI
using MuladdMacro: @muladd
using GeometryBasics: GeometryBasics
Expand Down Expand Up @@ -99,6 +99,7 @@ include("basic_types.jl")
# Include all top-level source files
# TODO: MPI. Keep this module inside Trixi or move it to another repo as
# external dependency with simple test suite and documentation?
module TrixiMPIArrays

using ArrayInterface: ArrayInterface
using MPI: MPI

import ..Trixi: mpi_comm, mpi_rank

export TrixiMPIArray

TrixiMPIArray{T, N} <: AbstractArray{T, N}
A thin wrapper of arrays distributed via MPI used in Trixi.jl. The idea is that
these arrays behave as much as possible as plain arrays would in an SPMD-style
distributed MPI setting with exception of reductions, which are performed
globally. This allows to use these arrays in ODE solvers such as the ones from
OrdinaryDiffEq.jl, since vector space operations, broadcasting, and reductions
are the only operations required for explicit time integration methods with
fixed step sizes or adaptive step sizes based on CFL or error estimates.
!!! warning "Experimental code"
This code is experimental and may be changed or removed in any future release.
struct TrixiMPIArray{T, N, Parent<:AbstractArray{T, N}} <: AbstractArray{T, N}
# TODO: MPI. Shall we also include something like the following fields
# and remove them from the global state? Do we ever need something
# from the global MPI state without having a state vector `u`? Does
# including these fields here have a performance impact since it
# increases the size of these arrays?
# mpi_size::Int
# mpi_isroot::Bool
# mpi_isparallel::Bool

function TrixiMPIArray{T, N, Parent}(u_local::Parent) where {T, N, Parent<:AbstractArray{T, N}}
# TODO: MPI. Hard-coded to MPI.COMM_WORLD for now
mpi_comm = MPI.COMM_WORLD
mpi_rank = MPI.Comm_rank(MPI.COMM_WORLD)
return new{T, N, Parent}(u_local, mpi_comm, mpi_rank)

function TrixiMPIArray(u_local::AbstractArray{T, N}) where {T, N}
TrixiMPIArray{T, N, typeof(u_local)}(u_local)

# TODO: MPI. Adapt
# - wrap_array
# - wrap_array_native
# - return type of initialization stuff when setting an IC
# - dispatch on this array type instead of parallel trees etc. and use
# `parent(u)` to get local versions instead of `invoke`

# Custom interface and general Base interface not covered by other parts below
Base.parent(u::TrixiMPIArray) = u.u_local

mpi_comm(u::TrixiMPIArray) = u.mpi_comm
mpi_rank(u::TrixiMPIArray) = u.mpi_rank
# TODO: MPI. What about the following interface functions?
# mpi_nranks(u::TrixiMPIArray) = MPI_SIZE[]
# mpi_isparallel(u::TrixiMPIArray) = MPI_IS_PARALLEL[]

# Implementation of the abstract array interface of Base
# See
Base.size(u::TrixiMPIArray) = size(parent(u))
Base.getindex(u::TrixiMPIArray, idx) = getindex(parent(u), idx)
Base.setindex!(u::TrixiMPIArray, v, idx) = setindex!(parent(u), v, idx)
Base.IndexStyle(::Type{TrixiMPIArray{T, N, Parent}}) where {T, N, Parent} = IndexStyle(Parent)
Base.similar(u::TrixiMPIArray, ::Type{S}, dims::NTuple{N, Int}) where {S, N} = TrixiMPIArray(similar(parent(u), S, dims))
Base.axes(u::TrixiMPIArray) = axes(parent(u))

# Implementation of the strided array interface of Base
# See
Base.strides(u::TrixiMPIArray) = strides(parent(u))
Base.unsafe_convert(::Type{Ptr{T}}, u::TrixiMPIArray{T}) where {T} = Base.unsafe_convert(Ptr{T}, parent(u))
Base.elsize(::Type{TrixiMPIArray{T, N, Parent}}) where {T, N, Parent} = elsize(Parent)

# TODO: MPI. Do we need customized broadcasting?
# See

# Implementation of methods from ArrayInterface.jl for use with
# LoopVectorization.jl etc.
# See
ArrayInterface.parent_type(::Type{TrixiMPIArray{T, N, Parent}}) where {T, N, Parent} = Parent

# TODO: MPI. Do we need LinearAlgebra methods such as `norm` or `dot`?

# `mapreduce` functionality from Base using global reductions via MPI communication
function Base.mapreduce(f::F, op::Op, u::TrixiMPIArray; kwargs...) where {F, Op}
local_value = mapreduce(f, op, parent(u); kwargs...)
return MPI.Allreduce(local_value, op, mpi_comm(u))

# TODO: MPI. Default settings of OrdinaryDiffEq etc.
# Interesting options could be
# See

# TODO: MPI. How shall we handle `length`? We want `TrixiMPIArray`s to behave
# like regular `Array`s in most code, e.g., for `eachindex` etc.
# However, we need to divide by the `length` of the global array
# for `ODE_DEFAULT_NORM`. We could specialize `ODE_DEFAULT_NORM`
# accordingly, but that requires depending on DiffEqBase (instead of
# SciMLBase). Alternatively, we could implement this via Requires.jl,
# but that will prevent precompilation and maybe trigger invalidations.
# Alternatively, we could specialize `length` to return a global
# length and make sure that all local behavior is still working as
# expected (if we use `eachindex` instead of `1:length` etc.).

end # module

using .TrixiMPIArrays

