From e352b9edbcdb20ad43562398b304d473162c8c55 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 6 Feb 2024 19:42:35 +0200 Subject: [PATCH] Add support for logical indexing (#590) --- src/AMDGPU.jl | 1 + src/kernels/indexing.jl | 33 +++++++++++++++++++++++++++++++++ src/kernels/mapreduce.jl | 26 ++++++++++++-------------- test/rocarray/base.jl | 1 + test/rocarray/indexing.jl | 14 ++++++++++++++ 5 files changed, 61 insertions(+), 14 deletions(-) create mode 100644 src/kernels/indexing.jl create mode 100644 test/rocarray/indexing.jl diff --git a/src/AMDGPU.jl b/src/AMDGPU.jl index b05cf1dac..effd364f5 100644 --- a/src/AMDGPU.jl +++ b/src/AMDGPU.jl @@ -105,6 +105,7 @@ include("broadcast.jl") include("exception_handler.jl") include("kernels/mapreduce.jl") +include("kernels/indexing.jl") include("kernels/accumulate.jl") include("kernels/sorting.jl") include("kernels/reverse.jl") diff --git a/src/kernels/indexing.jl b/src/kernels/indexing.jl new file mode 100644 index 000000000..b2fd3dc1b --- /dev/null +++ b/src/kernels/indexing.jl @@ -0,0 +1,33 @@ +Base.to_index(::ROCArray, I::AbstractArray{Bool}) = findall(I) + +# TODO Julia 1.11 specifics + +function Base.findall(bools::AnyROCArray{Bool}) + I = keytype(bools) + indices = cumsum(reshape(bools, prod(size(bools)))) + + n = @allowscalar indices[end] + ys = ROCArray{I}(undef, n) + + if n > 0 + function _ker!(ys, bools, indices) + i = workitemIdx().x + (workgroupIdx().x - Int32(1)) * workgroupDim().x + + @inbounds if i ≤ length(bools) && bools[i] + ii = CartesianIndices(bools)[i] + b = indices[i] # new position + ys[b] = ii + end + return + end + + kernel = @roc launch=false _ker!(ys, bools, indices) + config = launch_configuration(kernel) + groupsize = min(length(indices), config.groupsize) + gridsize = cld(length(indices), groupsize) + kernel(ys, bools, indices; groupsize, gridsize) + end + unsafe_free!(indices) + + return ys +end diff --git a/src/kernels/mapreduce.jl b/src/kernels/mapreduce.jl index 82ed5fe8d..53bad7184 100644 --- a/src/kernels/mapreduce.jl +++ b/src/kernels/mapreduce.jl @@ -6,27 +6,27 @@ # Reduce a value across a group, using local memory for communication @inline function reduce_group(op, val::T, neutral) where T - items::UInt32 = workgroupDim().x - item::UInt32 = workitemIdx().x + items = workgroupDim().x + item = workitemIdx().x # Shared mem for a complete reduction. shared = @ROCDynamicLocalArray(T, items, false) @inbounds shared[item] = val # Perform a reduction. - d::UInt32 = UInt32(1) + d = 1 while d < items sync_workgroup() - index::UInt32 = UInt32(2) * d * (item - UInt32(1)) + UInt32(1) + index = 2 * d * (item - 1) + 1 @inbounds if index ≤ items other_val = (index + d) ≤ items ? shared[index + d] : neutral shared[index] = op(shared[index], other_val) end - d *= UInt32(2) + d *= 2 end # Load the final value on the first item. - if item == UInt32(1) + if item == 1 val = @inbounds shared[item] end @@ -46,9 +46,8 @@ function partial_mapreduce_device(f, op, neutral, Rreduce, Rother, R, As...) localIdx_reduce = workitemIdx().x localDim_reduce = workgroupDim().x - n_elements_other::UInt32 = length(Rother) - groupIdx_reduce, groupIdx_other = fldmod1(workgroupIdx().x, n_elements_other) - groupDim_reduce = gridGroupDim().x ÷ n_elements_other + groupIdx_reduce, groupIdx_other = fldmod1(workgroupIdx().x, length(Rother)) + groupDim_reduce = gridGroupDim().x ÷ length(Rother) # group-based indexing into the values outside of the reduction dimension # (that means we can safely synchronize items within this group) @@ -63,11 +62,10 @@ function partial_mapreduce_device(f, op, neutral, Rreduce, Rother, R, As...) val = op(neutral, neutral) # reduce serially across chunks of input vector that don't fit in a group - ireduce = localIdx_reduce + (groupIdx_reduce - UInt32(1)) * localDim_reduce - n_elements_reduce::UInt32 = length(Rreduce) - while ireduce ≤ n_elements_reduce + ireduce = localIdx_reduce + (groupIdx_reduce - 1) * localDim_reduce + while ireduce ≤ length(Rreduce) Ireduce = Rreduce[ireduce] - J = Base.max(Iother, Ireduce) + J = max(Iother, Ireduce) val = op(val, f(_map_getindex(As, J)...)) ireduce += localDim_reduce * groupDim_reduce end @@ -75,7 +73,7 @@ function partial_mapreduce_device(f, op, neutral, Rreduce, Rother, R, As...) val = reduce_group(op, val, neutral) # write back to memory - if localIdx_reduce == UInt32(1) + if localIdx_reduce == 1 R[Iout] = val end end diff --git a/test/rocarray/base.jl b/test/rocarray/base.jl index 46b68de87..6af9fd0f1 100644 --- a/test/rocarray/base.jl +++ b/test/rocarray/base.jl @@ -202,6 +202,7 @@ end include("sorting.jl") include("reverse.jl") +include("indexing.jl") if length(AMDGPU.devices()) > 1 include("multi_gpu.jl") diff --git a/test/rocarray/indexing.jl b/test/rocarray/indexing.jl new file mode 100644 index 000000000..ab8520117 --- /dev/null +++ b/test/rocarray/indexing.jl @@ -0,0 +1,14 @@ +@testset "Selection" begin + x = rand(Int32, 16) + m = rand(Bool, 16) + xd, md = ROCArray.((x, m)) + @test x[m] == Array(xd[md]) + + x = rand(Int32, 1, 16) + xd = ROCArray(x) + @test x[:, m] == Array(xd[:, md]) + + x = rand(Int32, 3, 1, 16) + xd = ROCArray(x) + @test x[:, :, m] == Array(xd[:, :, md]) +end