Skip to content

Commit

Permalink
Add support for logical indexing (#590)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Feb 6, 2024
1 parent fbb562e commit e352b9e
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 14 deletions.
1 change: 1 addition & 0 deletions src/AMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ include("broadcast.jl")
include("exception_handler.jl")

include("kernels/mapreduce.jl")
include("kernels/indexing.jl")
include("kernels/accumulate.jl")
include("kernels/sorting.jl")
include("kernels/reverse.jl")
Expand Down
33 changes: 33 additions & 0 deletions src/kernels/indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
Base.to_index(::ROCArray, I::AbstractArray{Bool}) = findall(I)

# TODO Julia 1.11 specifics

function Base.findall(bools::AnyROCArray{Bool})
I = keytype(bools)
indices = cumsum(reshape(bools, prod(size(bools))))

n = @allowscalar indices[end]
ys = ROCArray{I}(undef, n)

if n > 0
function _ker!(ys, bools, indices)
i = workitemIdx().x + (workgroupIdx().x - Int32(1)) * workgroupDim().x

@inbounds if i length(bools) && bools[i]
ii = CartesianIndices(bools)[i]
b = indices[i] # new position
ys[b] = ii
end
return
end

kernel = @roc launch=false _ker!(ys, bools, indices)
config = launch_configuration(kernel)
groupsize = min(length(indices), config.groupsize)
gridsize = cld(length(indices), groupsize)
kernel(ys, bools, indices; groupsize, gridsize)
end
unsafe_free!(indices)

return ys
end
26 changes: 12 additions & 14 deletions src/kernels/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@

# Reduce a value across a group, using local memory for communication
@inline function reduce_group(op, val::T, neutral) where T
items::UInt32 = workgroupDim().x
item::UInt32 = workitemIdx().x
items = workgroupDim().x
item = workitemIdx().x

# Shared mem for a complete reduction.
shared = @ROCDynamicLocalArray(T, items, false)
@inbounds shared[item] = val

# Perform a reduction.
d::UInt32 = UInt32(1)
d = 1
while d < items
sync_workgroup()
index::UInt32 = UInt32(2) * d * (item - UInt32(1)) + UInt32(1)
index = 2 * d * (item - 1) + 1
@inbounds if index items
other_val = (index + d) items ? shared[index + d] : neutral
shared[index] = op(shared[index], other_val)
end
d *= UInt32(2)
d *= 2
end

# Load the final value on the first item.
if item == UInt32(1)
if item == 1
val = @inbounds shared[item]
end

Expand All @@ -46,9 +46,8 @@ function partial_mapreduce_device(f, op, neutral, Rreduce, Rother, R, As...)
localIdx_reduce = workitemIdx().x
localDim_reduce = workgroupDim().x

n_elements_other::UInt32 = length(Rother)
groupIdx_reduce, groupIdx_other = fldmod1(workgroupIdx().x, n_elements_other)
groupDim_reduce = gridGroupDim().x ÷ n_elements_other
groupIdx_reduce, groupIdx_other = fldmod1(workgroupIdx().x, length(Rother))
groupDim_reduce = gridGroupDim().x ÷ length(Rother)

# group-based indexing into the values outside of the reduction dimension
# (that means we can safely synchronize items within this group)
Expand All @@ -63,19 +62,18 @@ function partial_mapreduce_device(f, op, neutral, Rreduce, Rother, R, As...)
val = op(neutral, neutral)

# reduce serially across chunks of input vector that don't fit in a group
ireduce = localIdx_reduce + (groupIdx_reduce - UInt32(1)) * localDim_reduce
n_elements_reduce::UInt32 = length(Rreduce)
while ireduce n_elements_reduce
ireduce = localIdx_reduce + (groupIdx_reduce - 1) * localDim_reduce
while ireduce length(Rreduce)
Ireduce = Rreduce[ireduce]
J = Base.max(Iother, Ireduce)
J = max(Iother, Ireduce)
val = op(val, f(_map_getindex(As, J)...))
ireduce += localDim_reduce * groupDim_reduce
end

val = reduce_group(op, val, neutral)

# write back to memory
if localIdx_reduce == UInt32(1)
if localIdx_reduce == 1
R[Iout] = val
end
end
Expand Down
1 change: 1 addition & 0 deletions test/rocarray/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ end

include("sorting.jl")
include("reverse.jl")
include("indexing.jl")

if length(AMDGPU.devices()) > 1
include("multi_gpu.jl")
Expand Down
14 changes: 14 additions & 0 deletions test/rocarray/indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@testset "Selection" begin
x = rand(Int32, 16)
m = rand(Bool, 16)
xd, md = ROCArray.((x, m))
@test x[m] == Array(xd[md])

x = rand(Int32, 1, 16)
xd = ROCArray(x)
@test x[:, m] == Array(xd[:, md])

x = rand(Int32, 3, 1, 16)
xd = ROCArray(x)
@test x[:, :, m] == Array(xd[:, :, md])
end

0 comments on commit e352b9e

Please sign in to comment.