Skip to content

Commit

Permalink
Add backend to use unified memory and multiple GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
efaulhaber committed Jan 10, 2025
1 parent 9560a82 commit 6282767
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
59 changes: 59 additions & 0 deletions ext/PointNeighborsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
module PointNeighborsCUDAExt

using PointNeighbors: PointNeighbors, generic_kernel, CUDAMultiGPUBackend, KernelAbstractions
using CUDA: CUDA, CuArray, CUDABackend

const UnifiedCuArray = CuArray{<:Any, <:Any, CUDA.UnifiedMemory}

# This is needed because TrixiParticles passes `get_backend(coords)` to distinguish between
# `nothing` (Polyester.jl) and `KernelAbstractions.CPU`.
PointNeighbors.get_backend(x::UnifiedCuArray) = CUDAMultiGPUBackend()

# Convert input array to `CuArray` with unified memory
function PointNeighbors.Adapt.adapt_structure(to::CUDAMultiGPUBackend, array::Array)
return CuArray{eltype(array), ndims(array), CUDA.UnifiedMemory}(array)
end

@inline function PointNeighbors.parallel_foreach(f, iterator, x::UnifiedCuArray)
PointNeighbors.parallel_foreach(f, iterator, CUDAMultiGPUBackend())
end

# On GPUs, execute `f` inside a GPU kernel with KernelAbstractions.jl
@inline function PointNeighbors.parallel_foreach(f, iterator, x::CUDAMultiGPUBackend)
# On the GPU, we can only loop over `1:N`. Therefore, we loop over `1:length(iterator)`
# and index with `iterator[eachindex(iterator)[i]]`.
# Note that this only works with vector-like iterators that support arbitrary indexing.
indices = eachindex(iterator)

# Skip empty loops
length(indices) == 0 && return

# Partition `ndrange` to the GPUs
n_gpus = length(CUDA.devices())
indices_split = Iterators.partition(indices, ceil(Int, length(indices) / n_gpus))
@assert length(indices_split) <= n_gpus

backend = CUDABackend()

# Spawn kernel on each device
for (i, indices_) in enumerate(indices_split)
# Select the correct device for this partition
CUDA.device!(i - 1)

# Call the generic kernel, which only calls a function with the global GPU index
generic_kernel(backend)(ndrange = length(indices_)) do j
@inbounds @inline f(iterator[indices_[j]])
end
end

# Synchronize each device
for i in 1:length(indices_split)
CUDA.device!(i - 1)
KernelAbstractions.synchronize(backend)
end

# Select first device again
CUDA.device!(0)
end

end # module
3 changes: 0 additions & 3 deletions src/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,3 @@ function Adapt.adapt_structure(to, nhs::GridNeighborhoodSearch)
return GridNeighborhoodSearch(cell_list, search_radius, periodic_box, n_cells,
cell_size, update_buffer, nhs.update_strategy)
end

# This is useful to pass the backend directly to `@threaded`
KernelAbstractions.get_backend(backend::KernelAbstractions.Backend) = backend
13 changes: 11 additions & 2 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ with `Threads.@threads :static`.
"""
struct ThreadsStaticBackend <: AbstractThreadingBackend end

const ParallelizationBackend = Union{AbstractThreadingBackend, KernelAbstractions.Backend}
struct CUDAMultiGPUBackend end

const ParallelizationBackend = Union{AbstractThreadingBackend, KernelAbstractions.Backend, CUDAMultiGPUBackend}

"""
@threaded x for ... end
Expand Down Expand Up @@ -140,7 +142,9 @@ end
# Skip empty loops
ndrange == 0 && return

backend = KernelAbstractions.get_backend(x)
# Use our `get_backend`, which usually forwards to `KernelAbstractions.get_backend`,
# but also works when `x` is already a backend.
backend = get_backend(x)

# Call the generic kernel that is defined below, which only calls a function with
# the global GPU index.
Expand All @@ -155,3 +159,8 @@ end
i = @index(Global)
@inline f(i)
end

get_backend(x) = KernelAbstractions.get_backend(x)

# This is useful to pass the backend directly to `@threaded`
get_backend(backend::KernelAbstractions.Backend) = backend

0 comments on commit 6282767

Please sign in to comment.