diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index 92ee7de0..8853b3ba 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -89,7 +89,7 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} check_eltype(T) maxsize = prod(dims) * sizeof(T) - function _alloc_f() + GPUArrays.cached_alloc((JLArray, T, dims)) do data = Vector{UInt8}(undef, maxsize) ref = DataRef(data) do data resize!(data, 0) @@ -97,13 +97,6 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} obj = new{T, N}(ref, 0, dims) return finalizer(unsafe_free!, obj) end - - cache = GPUArrays.ALLOC_CACHE[] - return if cache ≡ nothing - _alloc_f() - else - GPUArrays.alloc!(_alloc_f, cache, (JLArray, T, dims))::JLArray{T, N} - end end # low-level constructor for wrapping existing data diff --git a/src/host/alloc_cache.jl b/src/host/alloc_cache.jl index 899f2e5d..6c9a1200 100644 --- a/src/host/alloc_cache.jl +++ b/src/host/alloc_cache.jl @@ -30,13 +30,18 @@ function get_pool!(cache::AllocCache{T}, pool::Symbol, uid::UInt64) where {T <: return uid_pool end -function alloc!(alloc_f, cache::AllocCache, key) +function cached_alloc(f, key) + cache = ALLOC_CACHE[] + if cache === nothing + return f() + end + x = nothing uid = hash(key) busy_pool = get_pool!(cache, :busy, uid) free_pool = get_pool!(cache, :free, uid) - isempty(free_pool) && (x = alloc_f()) + isempty(free_pool) && (x = f()) while !isempty(free_pool) && x ≡ nothing tmp = Base.@lock cache.lock pop!(free_pool) @@ -45,7 +50,7 @@ function alloc!(alloc_f, cache::AllocCache, key) x = tmp end - x ≡ nothing && (x = alloc_f()) + x ≡ nothing && (x = f()) Base.@lock cache.lock push!(busy_pool, x) return x end