diff --git a/src/kernels.jl b/src/kernels.jl index 13e2283..522d80b 100644 --- a/src/kernels.jl +++ b/src/kernels.jl @@ -37,13 +37,21 @@ function transduce_impl(rf::F, init, arrays...) where {F} end # @info "ys, = _transduce!(nothing, rf, ...)" Text(summary(ys)) # @info "ys, = _transduce!(nothing, rf, ...)" collect(ys) - length(ys) == 1 && return @allowscalar ys[1] + if length(ys) == 1 + y = @allowscalar ys[1] + CUDA.unsafe_free!(parent(dest)) + return y + end rf2 = AlwaysCombine(rf) while true ys, = _transduce!(buf, rf2, CombineInit(), ys) # @info "ys, = _transduce!(buf, rf2, ...)" Text(summary(ys)) # @info "ys, = _transduce!(buf, rf2, ...)" collect(ys) - length(ys) == 1 && return @allowscalar ys[1] + if length(ys) == 1 + y = @allowscalar ys[1] + CUDA.unsafe_free!(parent(dest)) + return y + end dest, buf = buf, dest # reusing buffer; is it useful? end diff --git a/src/shfl.jl b/src/shfl.jl index 4a70051..c9975c1 100644 --- a/src/shfl.jl +++ b/src/shfl.jl @@ -42,13 +42,21 @@ function transduce_shfl_impl(rf::F, init, arrays...) where {F} end # @info "ys, = transduce_shfl!(nothing, rf, ...)" Text(summary(ys)) # @info "ys, = transduce_shfl!(nothing, rf, ...)" collect(ys) - length(ys) == 1 && return @allowscalar ys[1] + if length(ys) == 1 + y = @allowscalar ys[1] + CUDA.unsafe_free!(parent(dest)) + return y + end rf2 = AlwaysCombine(rf) while true ys, = transduce_shfl!(buf, rf2, init, ys) # @info "ys, = transduce_shfl!(buf, rf2, ...)" Text(summary(ys)) # @info "ys, = transduce_shfl!(buf, rf2, ...)" collect(ys) - length(ys) == 1 && return @allowscalar ys[1] + if length(ys) == 1 + y = @allowscalar ys[1] + CUDA.unsafe_free!(parent(dest)) + return y + end dest, buf = buf, dest # reusing buffer; is it useful? end