Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore IJulia heartbeat thread #65

Merged
merged 1 commit into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions src/pinning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function pinthreads end
function _nthreadsarg(threadpool)
@static if VERSION >= v"1.9-"
if threadpool == :all
return Threads.maxthreadid()
return Threads.nthreads(:default) + Threads.nthreads(:interactive)
else
return Threads.nthreads(threadpool)
end
Expand All @@ -112,17 +112,8 @@ function pinthreads(cpuids::AbstractVector{<:Integer};
if force || first_pin_attempt()
warn && _check_environment()
_check_cpuids(cpuids)
tids = threadids(threadpool)
limit = min(length(cpuids), nthreads)
@static if VERSION >= v"1.9-"
if threadpool == :all
tids = 1:Threads.maxthreadid()
else
tids = filter(i -> Threads.threadpool(i) == threadpool,
1:Threads.maxthreadid())
end
else
tids = 1:limit
end
@debug("pinthreads", limit, nthreads, tids)
for (i, tid) in pairs(@view(tids[1:limit]))
pinthread(tid, cpuids[i]; warn = false)
Expand Down
24 changes: 4 additions & 20 deletions src/querying.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,14 @@ See `getcpuid` for more information.
"""
function getcpuids(; threadpool = :default)::Vector{Int}
@static if VERSION >= v"1.9-"
if threadpool == :all
nt = Threads.maxthreadid()
tids_pool = 1:nt
@assert nt == Threads.nthreads(:default) + Threads.nthreads(:interactive)
elseif threadpool in (:default, :interactive)
nt = nthreads(threadpool)
tids_pool = filter(i -> Threads.threadpool(i) == threadpool,
1:Threads.maxthreadid())
else
if !(threadpool in (:all, :default, :interactive))
throw(ArgumentError("Unknown value for `threadpool` keyword argument. " *
"Supported values are `:all`, `:default`, and " *
"`:interactive`."))
end
tids_pool = threadids(threadpool)
nt = length(tids_pool)
cpuids = zeros(Int, nt)
@assert length(tids_pool) == nt
for (i, tid) in pairs(tids_pool)
cpuids[i] = fetch(@tspawnat tid getcpuid())
end
Expand All @@ -51,16 +44,7 @@ end
Print the affinity masks of all Julia threads.
"""
function print_affinity_masks(io = getstdout(); threadpool = :default, kwargs...)
@static if VERSION >= v"1.9-"
if threadpool == :all
tids = 1:Threads.maxthreadid()
else
tids = filter(i -> Threads.threadpool(i) == threadpool,
1:Threads.maxthreadid())
end
else
tids = 1:nthreads()
end
tids = threadids(threadpool)
for tid in tids
mask = uv_thread_getaffinity(tid)
str = _affinity_mask_to_string(mask; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/threadinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function threadinfo(io = getstdout(); blas = false, hints = false, color = true,
end
thread_cpuids = getcpuids(; threadpool)
elseif threadpool == :all
njlthreads = Threads.maxthreadid()
njlthreads = Threads.nthreads(:default) + Threads.nthreads(:interactive)
thread_cpuids = getcpuids(; threadpool = :all)
else
throw(ArgumentError("Unknown value for `threadpool` keyword argument. Supported " *
Expand Down
25 changes: 25 additions & 0 deletions src/utility.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
function threadids(threadpool = :default)::Vector{Int}
@static if VERSION < v"1.9-"
return collect(1:nthreads())
else
if threadpool == :all
nt = nthreads(:default) + nthreads(:interactive)
tids = collect(1:Threads.maxthreadid())
else
nt = nthreads(threadpool)
tids = filter(i -> Threads.threadpool(i) == threadpool, 1:Threads.maxthreadid())
end

if nt != length(tids)
# IJulia manually adds a heartbeat thread that mus be ignored...
# see https://github.com/JuliaLang/IJulia.jl/issues/1072
# Currently, we just assume that it is the last thread.
# Might not be safe, in particular not once users can dynamically add threads
# in the future.
pop!(tids)
end

return tids
end
end

"""
@tspawnat tid -> task
Mimics `Threads.@spawn`, but assigns the task to thread `tid` (with `sticky = true`).
Expand Down
13 changes: 13 additions & 0 deletions test/tests_utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ using Test
[1, 5, 2, 6, 3, 7, 4, 8, 9, 10, 11, 12]
end

@testset "threadids" begin
@static if VERSION < v"1.9-"
@test ThreadPinning.threadids() == 1:Threads.nthreads()
else
@test ThreadPinning.threadids(:all) == 1:Threads.maxthreadid() # no IJulia here :)
# :default threads first, then :interactive threads
@test ThreadPinning.threadids(:default) == 1:Threads.nthreads(:default)
if Threads.nthreads(:interactive) > 0
@test ThreadPinning.threadids(:interactive) == (1:Threads.nthreads(:interactive)) .+ Threads.nthreads(:default)
end
end
end

@testset "tspawnat" begin
@static if VERSION < v"1.9-"
for tid in 1:Threads.nthreads()
Expand Down