diff --git a/base/Base.jl b/base/Base.jl index afa5a3d93d27c..710be9d098f75 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -149,7 +149,7 @@ const liblapack_name = libblas_name # Note that `atomics.jl` here should be deprecated Core.eval(Threads, :(include("atomics.jl"))) include("channels.jl") -include("partr.jl") +include("scheduler/scheduler.jl") include("task.jl") include("threads_overloads.jl") include("weakkeydict.jl") diff --git a/base/scheduler/CDLL.jl b/base/scheduler/CDLL.jl new file mode 100644 index 0000000000000..f71277ad42181 --- /dev/null +++ b/base/scheduler/CDLL.jl @@ -0,0 +1,266 @@ + +module ConcurrentList #Concurrent Doubly Linked List + +mutable struct Node{T} + const value::Union{T, Nothing} + @atomic next::Union{Node{T}, Nothing} + @atomic prev::Union{Node{T}, Nothing} + + Node{T}(value, next, prev) where T = new{T}(value, next, prev) + function Node(next::Node{T}) where T # Marker + this = new{T}(nothing, next, nothing) + @atomic :release this.prev = this + return this + end +end + +Node(value::T, next, prev) where T = Node{T}(value, next, prev) + +get_next(node::Node) = @atomic :acquire node.next +set_next(node::Node, next) = @atomic :release node.next = next +get_prev(node::Node) = @atomic :acquire node.prev +set_prev(node::Node, prev) = @atomic :release node.prev = prev +function cas_next(node::Node, exp::Node, desired::Node) + _,success = @atomicreplace :acquire_release :monotonic node.next exp => desired + return success +end +is_special(node::Node) = node.value === nothing +is_trailer(node::Node) = get_next(node) === nothing +is_header(node::Node) = get_prev(node) === nothing +is_marker(node::Node) = get_prev(node) === node + +function is_deleted(node::Node) + f = get_next(node) + return f !== nothing && is_marker(f) +end + +function next_nonmarker(node::Node) + f = get_next(node) + return (f === nothing || !is_marker(f)) ? f : get_next(f) +end + +function Base.show(io::IO, node::Node) + if is_special(node) + if is_marker(node) + print(io, "MarkerNode") + return + elseif is_header(node) + next = get_next(node) + if next === nothing + print(io, "BrokenNode()") + return + elseif is_marker(node) + print(io, "HeaderNode(next: MarkerNode)") + return + elseif is_trailer(next) + print(io, "HeaderNode(next: TrailerNode)") + return + end + print(io, "HeaderNode(next: ", next,")") + return + elseif is_trailer(node) + prev = get_prev(node) + if prev === nothing + print(io, "BrokenNode()") + return + elseif is_marker(node) + print(io, "TrailerNode(prev: MarkerNode)") + return + elseif is_header(prev) + print(io, "TrailerNode(prev: HeaderNode)") + return + end + print(io, "TrailerNode(prev: ", prev,")") + return + end + end + print(io, "Node(", node.value,")") +end + +function successor(node::Node) + f = next_nonmarker(node) + while true + if f === nothing + return nothing + end + if !is_deleted(f) + if get_prev(f) !== node && !is_deleted(node) + set_prev(f, node) # relink f to node + end + return f + end + s = next_nonmarker(f) + if f === get_next(node) + cas_next(node, f, s) + end + f = s + end +end + +function find_predecessor_of(node::Node{T}, target::Node{T}) where {T} + n = node + while true + f = successor(n) + if (f === target) + return n + end + if (f === nothing) + return nothing + end + n = f + end +end + +function predecessor(node::Node) + n = node + while true + b = get_prev(n) + if (b === nothing) + return find_predecessor_of(n, node) + end + s = get_next(b) + if (s === node) + return b + end + if (s === nothing || !is_marker(s)) + p = find_predecessor_of(b, node) + if (p !== nothing) + return p + end + end + n = b + end +end + +function forward(node::Node) + f = successor(node) + return (f === nothing || is_special(f)) ? nothing : f +end + +function back(node::Node) + f = predecessor(node) + return (f === nothing || is_special(f)) ? nothing : f +end + +function append!(node::Node{T}, val::T) where {T} + while true + f = get_next(node) + if (f === nothing || is_marker(f)) + return nothing + end + x = Node(val, f, node) + if cas_next(node, f, x) + set_prev(f, x) + return x + end + end +end + +function prepend!(node::Node{T}, val::T) where {T} + while true + b = predecessor(node) + if b === nothing + return nothing + end + x = Node(val, node, b) + if cas_next(b, node, x) + set_prev(node, x) + return x + end + end +end + +function delete!(node::Node) + b = get_prev(node) + f = get_next(node) + if (b !== nothing && f !== nothing && !is_marker(f) && cas_next(node, f, Node(f))) + if (cas_next(b, node, f)) + set_prev(f, b) + end + return true + end + return false +end + +function replace!(node::Node{T}, val::T) where {T} + while true + b = get_prev(node) + f = get_next(node) + if (b === nothing || f === nothing || is_marker(f)) + return nothing + end + x = Node(val, f, b) + if cas_next(node, f, Node(x)) + successor(b) + successor(x) + return x + end + end +end + +function usable(node::Node) + return node !== nothing && !is_special(node) +end + +const _PADDING_TUPLE = ntuple(zero, 15) +mutable struct ConcurrentDoublyLinkedList{T} + @atomic header::Union{Node{T}, Nothing} # 8 bytes + padding::NTuple{15,UInt64} # 120 bytes + @atomic trailer::Union{Node{T}, Nothing} + padding2::NTuple{15,UInt64} + function ConcurrentDoublyLinkedList{T}(header::Union{Node{T}, Nothing}, trailer::Union{Node{T}, Nothing}) where {T} + new{T}(header, _PADDING_TUPLE, trailer, _PADDING_TUPLE) + end +end + +function ConcurrentDoublyLinkedList{T}() where {T} + h = Node{T}(nothing, nothing, nothing) + t = Node{T}(nothing, nothing, h) + set_next(h, t) + ConcurrentDoublyLinkedList{T}(h, t) +end + +const CDLL = ConcurrentDoublyLinkedList + +function Base.pushfirst!(cdll::CDLL{T}, val::T) where {T} + while (append!((@atomic :acquire cdll.header), val) === nothing) + end +end + +function pushlast!(cdll::CDLL{T}, val::T) where {T} + while (prepend!((@atomic :acquire cdll.trailer), val) === nothing) + end +end + +function Base.popfirst!(cdll::CDLL) + while true + n = successor((@atomic :acquire cdll.header)) + if !usable(n) + return nothing + end + if delete!(n) + return n.value + end + end +end + +function poplast!(cdll::CDLL) + while true + n = predecessor((@atomic :acquire cdll.trailer)) + if !usable(n) + return nothing + end + if delete!(n) + return n.value + end + end +end + +Base.push!(cdll::CDLL{T}, val::T) where {T} = pushfirst!(cdll, val) +Base.pop!(cdll::CDLL) = poplast!(cdll) +steal!(cdll::CDLL) = popfirst!(cdll) +Base.isempty(cdll::CDLL) = !usable(successor(@atomic :acquire cdll.header)) + +const Queue = CDLL + +end diff --git a/base/scheduler/CLL.jl b/base/scheduler/CLL.jl new file mode 100644 index 0000000000000..fdd8b4b6aeebe --- /dev/null +++ b/base/scheduler/CLL.jl @@ -0,0 +1,188 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +# Also see `work-stealing-queue.h` this is a pure Julia re-implementation + +# ======= +# Chase and Lev's work-stealing queue, optimized for +# weak memory models by Le et al. +# +# * Chase D., Lev Y. Dynamic Circular Work-Stealing queue +# * Le N. M. et al. Correct and Efficient Work-Stealing for +# Weak Memory Models +# ======= + +module CLL + + +if Sys.ARCH == :x86_64 + # https://github.com/llvm/llvm-project/pull/106555 + fence() = Base.llvmcall( + (raw""" + define void @fence() #0 { + entry: + tail call void asm sideeffect "lock orq $$0 , (%rsp)", ""(); should this have ~{memory} + ret void + } + attributes #0 = { alwaysinline } + """, "fence"), Nothing, Tuple{}) +else + fence() = Core.Intrinsics.atomic_fence(:sequentially_consistent) +end + +# mutable so that we don't get a mutex in WSQueue +mutable struct WSBuffer{T} + const buffer::AtomicMemory{T} + const capacity::Int64 + const mask::Int64 + @noinline function WSBuffer{T}(capacity::Int64) where T + if __unlikely(capacity == 0) + throw(ArgumentError("Capacity can't be zero")) + end + if __unlikely(count_ones(capacity) != 1) + throw(ArgumentError("Capacity must be a power of two")) + end + buffer = AtomicMemory{T}(undef, capacity) + mask = capacity - 1 + return new(buffer, capacity, mask) + end +end + +function Base.getindex_atomic(buf::WSBuffer{T}, order::Symbol, idx::Int64) where T + @inbounds Base.getindex_atomic(buf.buffer, order, ((idx - 1) & buf.mask) + 1) +end + +function Base.setindex_atomic!(buf::WSBuffer{T}, order::Symbol, val::T, idx::Int64) where T + @inbounds Base.setindex_atomic!(buf.buffer, order, val,((idx - 1) & buf.mask) + 1) +end + +function Base.modifyindex_atomic!(buf::WSBuffer{T}, order::Symbol, op, val::T, idx::Int64) where T + @inbounds Base.modifyindex_atomic!(buf.buffer, order, op, val, ((idx - 1) & buf.mask) + 1) +end + +function Base.swapindex_atomic!(buf::WSBuffer{T}, order::Symbol, val::T, idx::Int64) where T + @inbounds Base.swapindex_atomic!(buf.buffer, order, val, ((idx - 1) & buf.mask) + 1) +end + +function Base.replaceindex_atomic!(buf::WSBuffer{T}, success_order::Symbol, fail_order::Symbol, expected::T, desired::T, idx::Int64) where T + @inbounds Base.replaceindex_atomic!(buf.buffer, success_order, fail_order, expected, desired, ((idx - 1) & buf.mask) + 1) +end + +function Base.copyto!(dst::WSBuffer{T}, src::WSBuffer{T}, top, bottom) where T + # must use queue indexes. When the queue is in state top=3, bottom=18, capacity=16 + # the real index of element 18 in the queue is 2, after growing in the new buffer it must be 18 + @assert dst.capacity >= src.capacity + @assert top <= bottom + # TODO overflow of bottom? + for i in top:bottom + @atomic :monotonic dst[i] = @atomic :monotonic src[i] + end +end + +const CACHE_LINE=64 # hardware_destructive_interference + + +""" + WSQueue{T} + +Work-stealing queue after Chase & Le. + +!!! note + popfirst! and push! are only allowed to be called from owner. +""" +mutable struct WSQueue{T} + @atomic top::Int64 # 8 bytes + __align::NTuple{CACHE_LINE-sizeof(Int64), UInt8} + @atomic bottom::Int64 + __align2::NTuple{CACHE_LINE-sizeof(Int64), UInt8} + @atomic buffer::WSBuffer{T} + function WSQueue{T}(capacity = 64) where T + new(1, ntuple(Returns(UInt8(0)), Val(CACHE_LINE-sizeof(Int64))), + 1, ntuple(Returns(UInt8(0)), Val(CACHE_LINE-sizeof(Int64))), + WSBuffer{T}(capacity)) + end +end +@assert Base.fieldoffset(WSQueue{Int64}, 1) == 0 +@assert Base.fieldoffset(WSQueue{Int64}, 3) == CACHE_LINE +@assert Base.fieldoffset(WSQueue{Int64}, 5) == 2*CACHE_LINE + +@noinline function grow!(q::WSQueue{T}, buffer, top, bottom) where T + new_buffer = WSBuffer{T}(2*buffer.capacity) + copyto!(new_buffer, buffer, top, bottom) + @atomic :release q.buffer = new_buffer + return new_buffer +end + +# accessing q.buffer requires a GC frame :/ + +# pushBottom +function Base.push!(q::WSQueue{T}, v::T) where T + bottom = @atomic :monotonic q.bottom + top = @atomic :acquire q.top + buffer = @atomic :monotonic q.buffer + + size = bottom-top + if __unlikely(size > (buffer.capacity - 1)) # Chase-Lev has size >= (buf.capacity - 1) || Le has size > (buf.capacity - 1) + buffer = grow!(q, buffer, top, bottom) # Le does buffer = @atomic :monotonic q.buffer + end + @atomic :monotonic buffer[bottom] = v + fence() + @atomic :monotonic q.bottom = bottom + 1 + return nothing +end + +# popBottom / take +function Base.popfirst!(q::WSQueue{T}) where T + bottom = (@atomic :monotonic q.bottom) - 1 + buffer = @atomic :monotonic q.buffer + @atomic :monotonic q.bottom = bottom + fence() + top = @atomic :monotonic q.top + + size = bottom - top + 1 + if __likely(size > 0) + # Non-empty queue + v = @atomic :monotonic buffer[bottom] + if size == 1 + # Single last element in queue + _, success = @atomicreplace :sequentially_consistent :monotonic q.top top => top + 1 + @atomic :monotonic q.bottom = bottom + 1 + if !success + # Failed race + return nothing + end + end + return v + else + # Empty queue + @atomic :monotonic q.bottom = bottom + 1 + return nothing + end +end + +function steal!(q::WSQueue{T}) where T + top = @atomic :acquire q.top + fence() + bottom = @atomic :acquire q.bottom + size = bottom - top + if __likely(size > 0) + # Non-empty queue + buffer = @atomic :acquire q.buffer # consume in Le + v = @atomic :monotonic buffer[top] + _, success = @atomicreplace :sequentially_consistent :monotonic q.top top => top+1 + if !success + # Failed race + return nothing + end + return v + end + return nothing # failed +end + +Base.pop!(q::WSQueue{T}) where T = popfirst!(q) +@inline __likely(cond::Bool) = ccall("llvm.expect", llvmcall, Bool, (Bool, Bool), cond, true) +@inline __unlikely(cond::Bool) = ccall("llvm.expect", llvmcall, Bool, (Bool, Bool), cond, false) +Base.isempty(q::WSQueue) = (q.bottom - q.top) == 0 + +const Queue = WSQueue + +end #module \ No newline at end of file diff --git a/base/partr.jl b/base/scheduler/partr.jl similarity index 66% rename from base/partr.jl rename to base/scheduler/partr.jl index 6053a584af5ba..fd80065056e62 100644 --- a/base/partr.jl +++ b/base/scheduler/partr.jl @@ -19,63 +19,6 @@ const heap_d = UInt32(8) const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)] const heaps_lock = [SpinLock(), SpinLock()] - -""" - cong(max::UInt32) - -Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0. -""" -cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check - -get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ()) - -set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed) - -""" - rand_ptls(max::UInt32) - -Return a random UInt32 in the range `0:max-1` using the thread-local RNG -state. Max must be greater than 0. -""" -Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32) - rngseed = get_ptls_rng() - val, seed = rand_uniform_max_int32(max, rngseed) - set_ptls_rng(seed) - return val % UInt32 -end - -# This implementation is based on OpenSSLs implementation of rand_uniform -# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99 -# Comments are vendored from their implementation as well. -# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143. - -# Essentially it boils down to incrementally generating a fixed point -# number on the interval [0, 1) and multiplying this number by the upper -# range limit. Once it is certain what the fractional part contributes to -# the integral part of the product, the algorithm has produced a definitive -# result. -""" - rand_uniform_max_int32(max::UInt32, seed::UInt64) - -Return a random UInt32 in the range `0:max-1` using the given seed. -Max must be greater than 0. -""" -Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64) - if max == UInt32(1) - return UInt32(0), seed - end - # We are generating a fixed point number on the interval [0, 1). - # Multiplying this by the range gives us a number on [0, upper). - # The high word of the multiplication result represents the integral part - # This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes - seed = UInt64(69069) * seed + UInt64(362437) - prod = (UInt64(max)) * (seed % UInt32) # 64 bit product - i = prod >> 32 % UInt32 # integral part - return i % UInt32, seed -end - - - function multiq_sift_up(heap::taskheap, idx::Int32) while idx > Int32(1) parent = (idx - Int32(2)) รท heap_d + Int32(1) @@ -147,10 +90,10 @@ function multiq_insert(task::Task, priority::UInt16) task.priority = priority - rn = cong(heap_p) + rn = Base.Scheduler.cong(heap_p) tpheaps = heaps[tp] while !trylock(tpheaps[rn].lock) - rn = cong(heap_p) + rn = Base.Scheduler.cong(heap_p) end heap = tpheaps[rn] @@ -190,8 +133,8 @@ function multiq_deletemin() if i == heap_p return nothing end - rn1 = cong(heap_p) - rn2 = cong(heap_p) + rn1 = Base.Scheduler.cong(heap_p) + rn2 = Base.Scheduler.cong(heap_p) prio1 = tpheaps[rn1].priority prio2 = tpheaps[rn2].priority if prio1 > prio2 @@ -235,6 +178,9 @@ function multiq_check_empty() if tp == 0 # Foreign thread return true end + if !isempty(Base.workqueue_for(tid)) + return false + end for i = UInt32(1):length(heaps[tp]) if heaps[tp][i].ntasks != 0 return false @@ -243,4 +189,9 @@ function multiq_check_empty() return true end + +enqueue!(t::Task) = multiq_insert(t, t.priority) +dequeue!() = multiq_deletemin() +checktaskempty() = multiq_check_empty() + end diff --git a/base/scheduler/scheduler.jl b/base/scheduler/scheduler.jl new file mode 100644 index 0000000000000..b12fd931f65e7 --- /dev/null +++ b/base/scheduler/scheduler.jl @@ -0,0 +1,75 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +module Scheduler + +""" + cong(max::UInt32) + +Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0. +""" +cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check + +get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ()) + +set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed) + +""" + rand_ptls(max::UInt32) + +Return a random UInt32 in the range `0:max-1` using the thread-local RNG +state. Max must be greater than 0. +""" +Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32) + rngseed = get_ptls_rng() + val, seed = rand_uniform_max_int32(max, rngseed) + set_ptls_rng(seed) + return val % UInt32 +end + +# This implementation is based on OpenSSLs implementation of rand_uniform +# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99 +# Comments are vendored from their implementation as well. +# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143. + +# Essentially it boils down to incrementally generating a fixed point +# number on the interval [0, 1) and multiplying this number by the upper +# range limit. Once it is certain what the fractional part contributes to +# the integral part of the product, the algorithm has produced a definitive +# result. +""" + rand_uniform_max_int32(max::UInt32, seed::UInt64) + +Return a random UInt32 in the range `0:max-1` using the given seed. +Max must be greater than 0. +""" +Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64) + if max == UInt32(1) + return UInt32(0), seed + end + # We are generating a fixed point number on the interval [0, 1). + # Multiplying this by the range gives us a number on [0, upper). + # The high word of the multiplication result represents the integral part + # This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes + seed = UInt64(69069) * seed + UInt64(362437) + prod = (UInt64(max)) * (seed % UInt32) # 64 bit product + i = prod >> 32 % UInt32 # integral part + return i % UInt32, seed +end + +include("scheduler/partr.jl") +include("scheduler/workstealing.jl") + +const ChosenScheduler = Workstealing + + + +# Scheduler interface: + # enqueue! which pushes a runnable Task into it + # dequeue! which pops a runnable Task from it + # checktaskempty which returns true if the scheduler has no available Tasks + +enqueue!(t::Task) = ChosenScheduler.enqueue!(t) +dequeue!() = ChosenScheduler.dequeue!() +checktaskempty() = ChosenScheduler.checktaskempty() + +end diff --git a/base/scheduler/workstealing.jl b/base/scheduler/workstealing.jl new file mode 100644 index 0000000000000..4828aaa1a2e0f --- /dev/null +++ b/base/scheduler/workstealing.jl @@ -0,0 +1,160 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +module Workstealing + +# Expected interface for a work-stealing queue: +# push!(queue, task) +# pop!(queue) # Only legal if you are the queues owner. +# steal!(queue) +include("scheduler/CLL.jl") +include("scheduler/CDLL.jl") + +# Threadpool utilities +function cur_threadpoolid() + return ccall(:jl_cur_threadpoolid, Int8, ()) + 1 +end + +function cur_threadpool_tid() + return ccall(:jl_cur_threadpool_tid, Int16, ()) + 1 +end + +function get_task_tpid(task::Task) + return ccall(:jl_get_task_threadpoolid, Int8, (Any,), task) + 1 +end + +# Logic for threadpools: +# Each thread has a global thread id, always called tid and unique per thread. +# Accesed via Threads.threadid() for a thread and Threads.threadid(task) for a task. + +# Each threadpool has an id, called tpid and unique per threadpool. +# Accessed via cur_threadpoolid() for the current thread and task_tpid for a task. + +# Each thread also has a threadpool_tid, called tp_tid, its use is to index into the array of queues for the threadpool. +# Accessed via cur_threadpool_tid() for the current thread. Or by checking if it's in the Threads.threadpooltids array. +# It's calculated by doing Threads.threadid() - Threads.threadpooltids(tpid)[1], though we store in the thread ptls for performance. + +# The calls return 1 based indexed numbers so threadpool 1 is :interactive and 2 is :default +# When a thread has either a tp_tid of 0 or a tpid of 0 it means that they aren't associated with a threadpool and should be inserted in the index 1 of the tasks tpid + + +function release_copyto!(dest::AtomicMemory{T}, src::AbstractArray{T,1}) where T + Base._checkaxs(axes(dest), axes(src)) + for i in eachindex(src) + @atomic :monotonic dest[i] = src[i] + end + Core.Intrinsics.atomic_fence(:release) + return dest +end + +make_atomic(x::AbstractArray{T,1}) where {T} = release_copyto!(AtomicMemory{T}(undef, size(x)), x) + +const QueueModule = ConcurrentList +const Queue = QueueModule.Queue{Task} +const Queues_lock = Threads.SpinLock() +global Queues::AtomicMemory{Memory{Queue}} = make_atomic([Memory{Queue}([Queue()]) for _ in 1:Threads.nthreadpools()]) # One array of queues per threadpool + +function queue_for(tp_tid::Int, tpid::Int) + @assert tp_tid >= 0 + qs = @atomic :monotonic Queues[tpid] + if (tp_tid == 0) + queue_index = 1 # We always have a queue for someone that isn't us to push to + else + queue_index = tp_tid + 1 + end + if length(qs) >= queue_index && isassigned(qs, queue_index) + return qs[queue_index] + end + # slow path to allocate it + # TODO: outline this + l = Queues_lock + @lock l begin + qs = @atomic :monotonic Queues[tpid] + if length(qs) < queue_index + nt = Threads._nthreads_in_pool(Int8(tpid - 1)) + 1 + @assert queue_index <= nt + new_q = copyto!(typeof(qs)(undef, length(qs) + nt - 1), qs) + qs = new_q + @atomic :monotonic Queues[tpid] = new_q + end + if !isassigned(qs, queue_index) + qs[queue_index] = Queue() + end + return qs[queue_index] + end +end + +function enqueue!(t::Task) + task_tpid = get_task_tpid(t) + thread_tpid = cur_threadpoolid() + + if task_tpid == thread_tpid + + push!(queue_for(Int(cur_threadpool_tid()), Int(thread_tpid)), t) + else + push!(queue_for(0, Int(task_tpid)), t) + end + return nothing +end + +function dequeue!() + tpid = cur_threadpoolid() + tp_tid = cur_threadpool_tid() + tid = Threads.threadid() + q = queue_for(Int(tp_tid), Int(tpid)) + t = pop!(q) # Check own queue first + if t !== nothing + if ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1) == 0 + push!(q, t) # Is there a way to avoid popping the same unrunnable task over and over? + ccall(:jl_wakeup_thread, Cvoid, (Int16,), (Threads.threadid(t) - 1) % Int16) + else + return t + end + end + t = attempt_steal!(Int(tp_tid), Int(tpid)) # Otherwise try to steal from others + return t +end + +function attempt_steal!(tp_tid::Int, tpid::Int) + tid = Threads.threadid() + nt = Threads._nthreads_in_pool(Int8(tpid - 1)) + for _ in 1:(4*nt) # Try to steal 4x nthread times + tp_tid2 = Base.Scheduler.cong(UInt32(nt + 1)) - 1 # From 0 to nt since queue_for uses 0 for the foreign queue + tp_tid == tp_tid2 && continue + t = QueueModule.steal!(queue_for(Int(tp_tid2), Int(tpid))) #TODO: Change types of things to avoid the convert + if t !== nothing + if ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1) == 0 + push!(queue_for(0, Int(get_task_tpid(t))), t) + ccall(:jl_wakeup_thread, Cvoid, (Int16,), (Threads.threadid(t) - 1) % Int16) + else + return t + end + end + end + for i in 0:(nt) # Try to steal from other threads round robin + t = QueueModule.steal!(queue_for(Int(i), Int(tpid))) #TODO: Change types of things to avoid the convert + if t !== nothing + if ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1) == 0 + push!(queue_for(0, Int(get_task_tpid(t))), t) + ccall(:jl_wakeup_thread, Cvoid, (Int16,), (Threads.threadid(t) - 1) % Int16) + else + return t + end + end + end + return nothing +end + +function checktaskempty() + qs = @atomic :monotonic Queues[cur_threadpoolid()] + for i in eachindex(qs) + if isassigned(qs, i) + q = qs[i] + if !isempty(q) + return false + end + end + end + return true +end + +end \ No newline at end of file diff --git a/base/task.jl b/base/task.jl index cddf1fc854f4c..48463ef115135 100644 --- a/base/task.jl +++ b/base/task.jl @@ -937,7 +937,6 @@ end function enq_work(t::Task) (t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable") - # Sticky tasks go into their thread's work queue. if t.sticky tid = Threads.threadid(t) @@ -968,19 +967,40 @@ function enq_work(t::Task) ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1) push!(workqueue_for(tid), t) else - # Otherwise, put the task in the multiqueue. - Partr.multiq_insert(t, t.priority) + # Otherwise, push the task to the scheduler + Scheduler.enqueue!(t) tid = 0 end end - ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16) + + if (tid == 0 && Threads.threadpool(t) == :default) + ccall(:jl_wake_any_thread, Cvoid, (Any,), current_task()) + else + ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16) + end return t end +const ChildFirst = false + function schedule(t::Task) # [task] created -scheduled-> wait_time maybe_record_enqueued!(t) - enq_work(t) + if ChildFirst + ct = current_task() + if ct.sticky || t.sticky + maybe_record_enqueued!(t) + enq_work(t) + else + maybe_record_enqueued!(t) + enq_work(ct) + yieldto(t) + end + else + maybe_record_enqueued!(t) + enq_work(t) + end + return t end """ @@ -1186,10 +1206,11 @@ function trypoptask(W::StickyWorkqueue) end return t end - return Partr.multiq_deletemin() + t = Scheduler.dequeue!() + return t end -checktaskempty = Partr.multiq_check_empty +checktaskempty = Scheduler.checktaskempty function wait() ct = current_task() diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index 7b204066b8c28..385612ca8b6b2 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -441,6 +441,7 @@ XX(jl_tagged_gensym) \ XX(jl_take_buffer) \ XX(jl_task_get_next) \ + XX(jl_wake_any_thread) \ XX(jl_termios_size) \ XX(jl_test_cpu_feature) \ XX(jl_threadid) \ diff --git a/src/julia_threads.h b/src/julia_threads.h index dbe9166f288a9..4e5d74b42edde 100644 --- a/src/julia_threads.h +++ b/src/julia_threads.h @@ -137,6 +137,7 @@ struct _jl_bt_element_t; typedef struct _jl_tls_states_t { int16_t tid; int8_t threadpoolid; + int16_t threadpool_tid; uint64_t rngseed; _Atomic(volatile size_t *) safepoint; // may be changed to the suspend page by any thread _Atomic(int8_t) sleep_check_state; // read/write from foreign threads @@ -214,6 +215,7 @@ typedef struct _jl_tls_states_t { uint64_t uv_run_leave; uint64_t sleep_enter; uint64_t sleep_leave; + uint64_t woken_up; ) // some hidden state (usually just because we don't have the type's size declaration) diff --git a/src/scheduler.c b/src/scheduler.c index 731a0c5146605..d3abba8fc7ca0 100644 --- a/src/scheduler.c +++ b/src/scheduler.c @@ -32,7 +32,8 @@ static const int16_t sleeping_like_the_dead JL_UNUSED = 2; // plus a running count of the number of in-flight wake-ups // n.b. this may temporarily exceed jl_n_threads _Atomic(int) n_threads_running = 0; - +// Number of threads sleeping in the scheduler, this number may be lower than the actual number +_Atomic(int) n_threads_idle = 0; // invariant: No thread is ever asleep unless sleep_check_state is sleeping (or we have a wakeup signal pending). // invariant: Any particular thread is not asleep unless that thread's sleep_check_state is sleeping. // invariant: The transition of a thread state to sleeping must be followed by a check that there wasn't work pending for it. @@ -220,6 +221,7 @@ static int wake_thread(int16_t tid) JL_NOTSAFEPOINT if (jl_atomic_load_relaxed(&ptls2->sleep_check_state) != not_sleeping) { int8_t state = sleeping; if (jl_atomic_cmpswap_relaxed(&ptls2->sleep_check_state, &state, not_sleeping)) { + JULIA_DEBUG_SLEEPWAKE( ptls2->woken_up = cycleclock() ); int wasrunning = jl_atomic_fetch_add_relaxed(&n_threads_running, 1); // increment in-flight wakeup count assert(wasrunning); (void)wasrunning; JL_PROBE_RT_SLEEP_CHECK_WAKE(ptls2, state); @@ -359,6 +361,43 @@ static int may_sleep(jl_ptls_t ptls) JL_NOTSAFEPOINT } +STATIC_INLINE void wake_any(jl_task_t *ct) JL_NOTSAFEPOINT +{ + // Find sleeping thread to wake up in the default thread pool + int self = jl_atomic_load_relaxed(&ct->tid); + int nthreads = jl_atomic_load_acquire(&jl_n_threads); + int idle_threads = jl_atomic_load_relaxed(&n_threads_idle); + jl_task_t *uvlock = jl_atomic_load_relaxed(&jl_uv_mutex.owner); + if (uvlock == ct) + uv_stop(jl_global_event_loop()); + if (idle_threads > 0) { + int anysleep = 0; + for (int tid = self + 1; tid < nthreads; tid++) { + if ((tid != self) && wake_thread(tid)) { + anysleep = 1; + break; + } + } + for (int tid = jl_n_threads_per_pool[JL_THREADPOOL_ID_INTERACTIVE]; tid < self; tid++) { + if ((tid != self) && wake_thread(tid)) { + anysleep = 1; + break; + } + } + if (anysleep) { + jl_fence(); // This fence is expensive but needed for libuv to do RUN_ONCE + if (uvlock != ct && jl_atomic_load_relaxed(&jl_uv_mutex.owner) != NULL) + wake_libuv(); + } + } +} + +JL_DLLEXPORT void jl_wake_any_thread(jl_task_t *ct) +{ + wake_any(ct); +} + + JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q, jl_value_t *checkempty) { jl_task_t *ct = jl_current_task; @@ -366,8 +405,10 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q, while (1) { jl_task_t *task = get_next_task(trypoptask, q); - if (task) + if (task) { + wake_any(ct); return task; + } // quick, race-y check to see if there seems to be any stuff in there jl_cpu_pause(); @@ -382,12 +423,14 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q, // acquire sleep-check lock assert(jl_atomic_load_relaxed(&ptls->sleep_check_state) == not_sleeping); jl_atomic_store_relaxed(&ptls->sleep_check_state, sleeping); + jl_atomic_fetch_add_relaxed(&n_threads_idle, 1); jl_fence(); // [^store_buffering_1] JL_PROBE_RT_SLEEP_CHECK_SLEEP(ptls); if (!check_empty(checkempty)) { // uses relaxed loads if (set_not_sleeping(ptls)) { JL_PROBE_RT_SLEEP_CHECK_TASKQ_WAKE(ptls); } + jl_atomic_fetch_add_relaxed(&n_threads_idle, -1); continue; } volatile int isrunning = 1; @@ -399,12 +442,14 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q, if (set_not_sleeping(ptls)) { JL_PROBE_RT_SLEEP_CHECK_TASK_WAKE(ptls); } + jl_atomic_fetch_add_relaxed(&n_threads_idle, -1); continue; // jump to JL_CATCH } if (task) { if (set_not_sleeping(ptls)) { JL_PROBE_RT_SLEEP_CHECK_TASK_WAKE(ptls); } + jl_atomic_fetch_add_relaxed(&n_threads_idle, -1); continue; // jump to JL_CATCH } @@ -437,7 +482,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q, // responsibility, so need to make sure thread 0 will take care // of us. if (jl_atomic_load_relaxed(&jl_uv_mutex.owner) == NULL) // aka trylock - jl_wakeup_thread(jl_atomic_load_relaxed(&io_loop_tid)); + wakeup_thread(ct, jl_atomic_load_relaxed(&io_loop_tid)); } if (uvlock) { @@ -451,7 +496,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q, uv_loop_t *loop = jl_global_event_loop(); loop->stop_flag = 0; JULIA_DEBUG_SLEEPWAKE( ptls->uv_run_enter = cycleclock() ); - active = uv_run(loop, UV_RUN_ONCE); + active = uv_run(loop, UV_RUN_NOWAIT); JULIA_DEBUG_SLEEPWAKE( ptls->uv_run_leave = cycleclock() ); jl_gc_safepoint(); } @@ -465,6 +510,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q, if (set_not_sleeping(ptls)) { JL_PROBE_RT_SLEEP_CHECK_UV_WAKE(ptls); } + jl_atomic_fetch_add_relaxed(&n_threads_idle, -1); start_cycles = 0; continue; // jump to JL_CATCH } @@ -474,6 +520,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q, if (set_not_sleeping(ptls)) { JL_PROBE_RT_SLEEP_CHECK_UV_WAKE(ptls); } + jl_atomic_fetch_add_relaxed(&n_threads_idle, -1); start_cycles = 0; continue; // jump to JL_CATCH } @@ -519,6 +566,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q, // else should we warn the user of certain deadlock here if tid == 0 && n_threads_running == 0? uv_cond_wait(&ptls->wake_signal, &ptls->sleep_lock); } + jl_atomic_fetch_add_relaxed(&n_threads_idle, -1); assert(jl_atomic_load_relaxed(&ptls->sleep_check_state) == not_sleeping); assert(jl_atomic_load_relaxed(&n_threads_running)); start_cycles = 0; @@ -533,13 +581,17 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q, } JL_CATCH { // probably SIGINT, but possibly a user mistake in trypoptask - if (!isrunning) + if (!isrunning) { + jl_atomic_fetch_add_relaxed(&n_threads_idle, -1); jl_atomic_fetch_add_relaxed(&n_threads_running, 1); + } set_not_sleeping(ptls); jl_rethrow(); } - if (task) + if (task) { + wake_any(ct); return task; + } } else { // maybe check the kernel for new messages too diff --git a/src/threading.c b/src/threading.c index 4256115214fc2..d6da8f28917b9 100644 --- a/src/threading.c +++ b/src/threading.c @@ -320,6 +320,37 @@ JL_DLLEXPORT void jl_set_ptls_rng(uint64_t new_seed) JL_NOTSAFEPOINT jl_current_task->ptls->rngseed = new_seed; } + +JL_DLLEXPORT int8_t jl_cur_threadpoolid(void) JL_NOTSAFEPOINT +{ + return jl_current_task->ptls->threadpoolid; +} + +JL_DLLEXPORT int16_t jl_cur_threadpool_tid(void) JL_NOTSAFEPOINT +{ + return jl_current_task->ptls->threadpool_tid; +} + +STATIC_INLINE void set_ptls_tpid(jl_ptls_t ptls) JL_NOTSAFEPOINT +{ + int16_t tid = ptls->tid; + int nthreads = jl_atomic_load_acquire(&jl_n_threads); + if (tid < 0 || tid >= nthreads) + jl_error("invalid tid"); + int n = 0; + for (int i = 0; i < jl_n_threadpools; i++) { + int old_n = n; + n += jl_n_threads_per_pool[i]; + if (tid < n) { + ptls->threadpoolid = i; + ptls->threadpool_tid = tid - old_n; + return; + } + } + ptls->threadpoolid = -1; // everything else uses threadpool -1 (does not belong to any threadpool) + ptls->threadpool_tid = -1; +} + jl_ptls_t jl_init_threadtls(int16_t tid) { #ifndef _OS_WINDOWS_ @@ -378,6 +409,7 @@ jl_ptls_t jl_init_threadtls(int16_t tid) if (tid == -1) tid = jl_atomic_load_relaxed(&jl_n_threads); ptls->tid = tid; + set_ptls_tpid(ptls); jl_ptls_t *allstates = jl_atomic_load_relaxed(&jl_all_tls_states); if (jl_all_tls_states_size <= tid) { int i, newsize = jl_all_tls_states_size + tid + 2; diff --git a/test.jl b/test.jl new file mode 100644 index 0000000000000..7a31d00c6b827 --- /dev/null +++ b/test.jl @@ -0,0 +1,7 @@ +function fib(n::Int) + n <= 1 && return n + t = Threads.@spawn fib(n - 2) + return fib(n - 1) + fetch(t)::Int + end + +fib(34) diff --git a/test/threads.jl b/test/threads.jl index fa0b33a6352f3..747452906114c 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -338,7 +338,7 @@ end end @testset "rand_ptls underflow" begin - @test Base.Partr.cong(UInt32(0)) == 0 + @test Base.Scheduler.cong(UInt32(0)) == 0 end @testset "num_stack_mappings metric" begin