Skip to content

Commit 0d72173

Browse files
committed
Refactor scheduler and switch to a spinner thread concept for wakeups
This also adds a counter for idle/sleeping threads to avoid checking every thread when everyone is running.
1 parent 3d85309 commit 0d72173

File tree

9 files changed

+238
-74
lines changed

9 files changed

+238
-74
lines changed

base/Base.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ const liblapack_name = libblas_name
137137
# Note that `atomics.jl` here should be deprecated
138138
Core.eval(Threads, :(include("atomics.jl")))
139139
include("channels.jl")
140-
include("partr.jl")
140+
include("scheduler/scheduler.jl")
141141
include("task.jl")
142142
include("threads_overloads.jl")
143143
include("weakkeydict.jl")

base/partr.jl renamed to base/scheduler/partr.jl

Lines changed: 23 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -19,63 +19,6 @@ const heap_d = UInt32(8)
1919
const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)]
2020
const heaps_lock = [SpinLock(), SpinLock()]
2121

22-
23-
"""
24-
cong(max::UInt32)
25-
26-
Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
27-
"""
28-
cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check
29-
30-
get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ())
31-
32-
set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed)
33-
34-
"""
35-
rand_ptls(max::UInt32)
36-
37-
Return a random UInt32 in the range `0:max-1` using the thread-local RNG
38-
state. Max must be greater than 0.
39-
"""
40-
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32)
41-
rngseed = get_ptls_rng()
42-
val, seed = rand_uniform_max_int32(max, rngseed)
43-
set_ptls_rng(seed)
44-
return val % UInt32
45-
end
46-
47-
# This implementation is based on OpenSSLs implementation of rand_uniform
48-
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
49-
# Comments are vendored from their implementation as well.
50-
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.
51-
52-
# Essentially it boils down to incrementally generating a fixed point
53-
# number on the interval [0, 1) and multiplying this number by the upper
54-
# range limit. Once it is certain what the fractional part contributes to
55-
# the integral part of the product, the algorithm has produced a definitive
56-
# result.
57-
"""
58-
rand_uniform_max_int32(max::UInt32, seed::UInt64)
59-
60-
Return a random UInt32 in the range `0:max-1` using the given seed.
61-
Max must be greater than 0.
62-
"""
63-
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
64-
if max == UInt32(1)
65-
return UInt32(0), seed
66-
end
67-
# We are generating a fixed point number on the interval [0, 1).
68-
# Multiplying this by the range gives us a number on [0, upper).
69-
# The high word of the multiplication result represents the integral part
70-
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
71-
seed = UInt64(69069) * seed + UInt64(362437)
72-
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
73-
i = prod >> 32 % UInt32 # integral part
74-
return i % UInt32, seed
75-
end
76-
77-
78-
7922
function multiq_sift_up(heap::taskheap, idx::Int32)
8023
while idx > Int32(1)
8124
parent = (idx - Int32(2)) ÷ heap_d + Int32(1)
@@ -147,10 +90,10 @@ function multiq_insert(task::Task, priority::UInt16)
14790

14891
task.priority = priority
14992

150-
rn = cong(heap_p)
93+
rn = Base.Scheduler.cong(heap_p)
15194
tpheaps = heaps[tp]
15295
while !trylock(tpheaps[rn].lock)
153-
rn = cong(heap_p)
96+
rn = Base.Scheduler.cong(heap_p)
15497
end
15598

15699
heap = tpheaps[rn]
@@ -190,8 +133,8 @@ function multiq_deletemin()
190133
if i == heap_p
191134
return nothing
192135
end
193-
rn1 = cong(heap_p)
194-
rn2 = cong(heap_p)
136+
rn1 = Base.Scheduler.cong(heap_p)
137+
rn2 = Base.Scheduler.cong(heap_p)
195138
prio1 = tpheaps[rn1].priority
196139
prio2 = tpheaps[rn2].priority
197140
if prio1 > prio2
@@ -211,7 +154,21 @@ function multiq_deletemin()
211154
heap = tpheaps[rn1]
212155
task = heap.tasks[1]
213156
if ccall(:jl_set_task_tid, Cint, (Any, Cint), task, tid-1) == 0
157+
# This task is stuck to a thread that's likely sleeping, move the task to it's private queue and wake it up
158+
# We move this out of the queue to avoid spinning on it
159+
ntasks = heap.ntasks
160+
@atomic :monotonic heap.ntasks = ntasks - Int32(1)
161+
heap.tasks[1] = heap.tasks[ntasks]
162+
Base._unsetindex!(heap.tasks, Int(ntasks))
163+
prio1 = typemax(UInt16)
164+
if ntasks > 1
165+
multiq_sift_down(heap, Int32(1))
166+
prio1 = heap.tasks[1].priority
167+
end
168+
@atomic :monotonic heap.priority = prio1
169+
push!(workqueue_for(tid), t)
214170
unlock(heap.lock)
171+
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
215172
@goto retry
216173
end
217174
ntasks = heap.ntasks
@@ -243,4 +200,9 @@ function multiq_check_empty()
243200
return true
244201
end
245202

203+
204+
enqueue!(t::Task) = multiq_insert(t, t.priority)
205+
dequeue!() = multiq_deletemin()
206+
checktaskempty() = multiq_check_empty()
207+
246208
end

base/scheduler/scheduler.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
module Scheduler
4+
5+
"""
6+
cong(max::UInt32)
7+
8+
Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
9+
"""
10+
cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check
11+
12+
get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ())
13+
14+
set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed)
15+
16+
"""
17+
rand_ptls(max::UInt32)
18+
19+
Return a random UInt32 in the range `0:max-1` using the thread-local RNG
20+
state. Max must be greater than 0.
21+
"""
22+
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32)
23+
rngseed = get_ptls_rng()
24+
val, seed = rand_uniform_max_int32(max, rngseed)
25+
set_ptls_rng(seed)
26+
return val % UInt32
27+
end
28+
29+
# This implementation is based on OpenSSLs implementation of rand_uniform
30+
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
31+
# Comments are vendored from their implementation as well.
32+
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.
33+
34+
# Essentially it boils down to incrementally generating a fixed point
35+
# number on the interval [0, 1) and multiplying this number by the upper
36+
# range limit. Once it is certain what the fractional part contributes to
37+
# the integral part of the product, the algorithm has produced a definitive
38+
# result.
39+
"""
40+
rand_uniform_max_int32(max::UInt32, seed::UInt64)
41+
42+
Return a random UInt32 in the range `0:max-1` using the given seed.
43+
Max must be greater than 0.
44+
"""
45+
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
46+
if max == UInt32(1)
47+
return UInt32(0), seed
48+
end
49+
# We are generating a fixed point number on the interval [0, 1).
50+
# Multiplying this by the range gives us a number on [0, upper).
51+
# The high word of the multiplication result represents the integral part
52+
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
53+
seed = UInt64(69069) * seed + UInt64(362437)
54+
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
55+
i = prod >> 32 % UInt32 # integral part
56+
return i % UInt32, seed
57+
end
58+
59+
include("scheduler/partr.jl")
60+
61+
const ChosenScheduler = Partr
62+
63+
64+
65+
# Scheduler interface:
66+
# enqueue! which pushes a runnable Task into it
67+
# dequeue! which pops a runnable Task from it
68+
# checktaskempty which returns true if the scheduler has no available Tasks
69+
70+
enqueue!(t::Task) = ChosenScheduler.enqueue!(t)
71+
dequeue!() = ChosenScheduler.dequeue!()
72+
checktaskempty() = ChosenScheduler.checktaskempty()
73+
74+
end

base/task.jl

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,6 @@ end
937937

938938
function enq_work(t::Task)
939939
(t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable")
940-
941940
# Sticky tasks go into their thread's work queue.
942941
if t.sticky
943942
tid = Threads.threadid(t)
@@ -968,19 +967,44 @@ function enq_work(t::Task)
968967
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1)
969968
push!(workqueue_for(tid), t)
970969
else
971-
# Otherwise, put the task in the multiqueue.
972-
Partr.multiq_insert(t, t.priority)
970+
# Otherwise, push the task to the scheduler
971+
Scheduler.enqueue!(t)
973972
tid = 0
974973
end
975974
end
976-
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
975+
976+
if (tid == 0)
977+
Core.Intrinsics.atomic_fence(:sequentially_consistent)
978+
n_spinning = Core.Intrinsics.atomic_pointerref(cglobal(:jl_n_threads_spinning, Cint), :monotonic)
979+
n_spinning == 0 && ccall(:jl_add_spinner, Cvoid, ())
980+
else
981+
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
982+
end
983+
# n_spinning = Core.Intrinsics.atomic_pointerref(cglobal(:jl_n_threads, Cint), :acquire)
984+
# n_spinning == 0 && ccall(:jl_add_spinner, Cvoid, ())
977985
return t
978986
end
979987

988+
const ChildFirst = false
989+
980990
function schedule(t::Task)
981991
# [task] created -scheduled-> wait_time
982992
maybe_record_enqueued!(t)
983-
enq_work(t)
993+
if ChildFirst
994+
ct = current_task()
995+
if ct.sticky || t.sticky
996+
maybe_record_enqueued!(t)
997+
enq_work(t)
998+
else
999+
maybe_record_enqueued!(t)
1000+
enq_work(ct)
1001+
yieldto(t)
1002+
end
1003+
else
1004+
maybe_record_enqueued!(t)
1005+
enq_work(t)
1006+
end
1007+
return t
9841008
end
9851009

9861010
"""
@@ -1176,10 +1200,10 @@ function trypoptask(W::StickyWorkqueue)
11761200
end
11771201
return t
11781202
end
1179-
return Partr.multiq_deletemin()
1203+
return Scheduler.dequeue!()
11801204
end
11811205

1182-
checktaskempty = Partr.multiq_check_empty
1206+
checktaskempty = Scheduler.checktaskempty
11831207

11841208
@noinline function poptask(W::StickyWorkqueue)
11851209
task = trypoptask(W)

src/jl_exported_data.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
#define JL_EXPORTED_DATA_SYMBOLS(XX) \
158158
XX(jl_n_threadpools, int) \
159159
XX(jl_n_threads, _Atomic(int)) \
160+
XX(jl_n_threads_spinning, _Atomic(int)) \
160161
XX(jl_n_gcthreads, int) \
161162
XX(jl_options, jl_options_t) \
162163
XX(jl_task_gcstack_offset, int) \

src/jl_exported_funcs.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@
449449
XX(jl_tagged_gensym) \
450450
XX(jl_take_buffer) \
451451
XX(jl_task_get_next) \
452+
XX(jl_add_spinner) \
452453
XX(jl_task_stack_buffer) \
453454
XX(jl_termios_size) \
454455
XX(jl_test_cpu_feature) \

src/julia.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,6 +2055,7 @@ JL_DLLEXPORT jl_sym_t *jl_get_ARCH(void) JL_NOTSAFEPOINT;
20552055
JL_DLLIMPORT jl_value_t *jl_get_libllvm(void) JL_NOTSAFEPOINT;
20562056
extern JL_DLLIMPORT int jl_n_threadpools;
20572057
extern JL_DLLIMPORT _Atomic(int) jl_n_threads;
2058+
extern JL_DLLIMPORT _Atomic(int) jl_n_threads_spinning; // Scheduler internal counter
20582059
extern JL_DLLIMPORT int jl_n_gcthreads;
20592060
extern int jl_n_markthreads;
20602061
extern int jl_n_sweepthreads;

src/julia_threads.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ typedef struct _jl_tls_states_t {
209209
uint64_t uv_run_leave;
210210
uint64_t sleep_enter;
211211
uint64_t sleep_leave;
212+
uint64_t woken_up;
212213
)
213214

214215
// some hidden state (usually just because we don't have the type's size declaration)

0 commit comments

Comments
 (0)