Skip to content

Commit

Permalink
Merge pull request #65 from carstenbauer/cb/ijuliafix
Browse files Browse the repository at this point in the history
  • Loading branch information
carstenbauer authored Apr 20, 2023
2 parents 4f46865 + 7041f22 commit 1c9ed8a
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 32 deletions.
13 changes: 2 additions & 11 deletions src/pinning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function pinthreads end
function _nthreadsarg(threadpool)
@static if VERSION >= v"1.9-"
if threadpool == :all
return Threads.maxthreadid()
return Threads.nthreads(:default) + Threads.nthreads(:interactive)
else
return Threads.nthreads(threadpool)
end
Expand All @@ -112,17 +112,8 @@ function pinthreads(cpuids::AbstractVector{<:Integer};
if force || first_pin_attempt()
warn && _check_environment()
_check_cpuids(cpuids)
tids = threadids(threadpool)
limit = min(length(cpuids), nthreads)
@static if VERSION >= v"1.9-"
if threadpool == :all
tids = 1:Threads.maxthreadid()
else
tids = filter(i -> Threads.threadpool(i) == threadpool,
1:Threads.maxthreadid())
end
else
tids = 1:limit
end
@debug("pinthreads", limit, nthreads, tids)
for (i, tid) in pairs(@view(tids[1:limit]))
pinthread(tid, cpuids[i]; warn = false)
Expand Down
24 changes: 4 additions & 20 deletions src/querying.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,14 @@ See `getcpuid` for more information.
"""
function getcpuids(; threadpool = :default)::Vector{Int}
@static if VERSION >= v"1.9-"
if threadpool == :all
nt = Threads.maxthreadid()
tids_pool = 1:nt
@assert nt == Threads.nthreads(:default) + Threads.nthreads(:interactive)
elseif threadpool in (:default, :interactive)
nt = nthreads(threadpool)
tids_pool = filter(i -> Threads.threadpool(i) == threadpool,
1:Threads.maxthreadid())
else
if !(threadpool in (:all, :default, :interactive))
throw(ArgumentError("Unknown value for `threadpool` keyword argument. " *
"Supported values are `:all`, `:default`, and " *
"`:interactive`."))
end
tids_pool = threadids(threadpool)
nt = length(tids_pool)
cpuids = zeros(Int, nt)
@assert length(tids_pool) == nt
for (i, tid) in pairs(tids_pool)
cpuids[i] = fetch(@tspawnat tid getcpuid())
end
Expand All @@ -51,16 +44,7 @@ end
Print the affinity masks of all Julia threads.
"""
function print_affinity_masks(io = getstdout(); threadpool = :default, kwargs...)
@static if VERSION >= v"1.9-"
if threadpool == :all
tids = 1:Threads.maxthreadid()
else
tids = filter(i -> Threads.threadpool(i) == threadpool,
1:Threads.maxthreadid())
end
else
tids = 1:nthreads()
end
tids = threadids(threadpool)
for tid in tids
mask = uv_thread_getaffinity(tid)
str = _affinity_mask_to_string(mask; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/threadinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function threadinfo(io = getstdout(); blas = false, hints = false, color = true,
end
thread_cpuids = getcpuids(; threadpool)
elseif threadpool == :all
njlthreads = Threads.maxthreadid()
njlthreads = Threads.nthreads(:default) + Threads.nthreads(:interactive)
thread_cpuids = getcpuids(; threadpool = :all)
else
throw(ArgumentError("Unknown value for `threadpool` keyword argument. Supported " *
Expand Down
25 changes: 25 additions & 0 deletions src/utility.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
function threadids(threadpool = :default)::Vector{Int}
@static if VERSION < v"1.9-"
return collect(1:nthreads())
else
if threadpool == :all
nt = nthreads(:default) + nthreads(:interactive)
tids = collect(1:Threads.maxthreadid())
else
nt = nthreads(threadpool)
tids = filter(i -> Threads.threadpool(i) == threadpool, 1:Threads.maxthreadid())
end

if nt != length(tids)
# IJulia manually adds a heartbeat thread that mus be ignored...
# see https://github.com/JuliaLang/IJulia.jl/issues/1072
# Currently, we just assume that it is the last thread.
# Might not be safe, in particular not once users can dynamically add threads
# in the future.
pop!(tids)
end

return tids
end
end

"""
@tspawnat tid -> task
Mimics `Threads.@spawn`, but assigns the task to thread `tid` (with `sticky = true`).
Expand Down
13 changes: 13 additions & 0 deletions test/tests_utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ using Test
[1, 5, 2, 6, 3, 7, 4, 8, 9, 10, 11, 12]
end

@testset "threadids" begin
@static if VERSION < v"1.9-"
@test ThreadPinning.threadids() == 1:Threads.nthreads()
else
@test ThreadPinning.threadids(:all) == 1:Threads.maxthreadid() # no IJulia here :)
# :default threads first, then :interactive threads
@test ThreadPinning.threadids(:default) == 1:Threads.nthreads(:default)
if Threads.nthreads(:interactive) > 0
@test ThreadPinning.threadids(:interactive) == (1:Threads.nthreads(:interactive)) .+ Threads.nthreads(:default)
end
end
end

@testset "tspawnat" begin
@static if VERSION < v"1.9-"
for tid in 1:Threads.nthreads()
Expand Down

0 comments on commit 1c9ed8a

Please sign in to comment.