|
| 1 | +const ELTYPE = Union{Dagger.EagerThunk, Chunk} |
| 2 | + |
| 3 | +struct DGraphState{T,D} |
| 4 | + # A set of locally-connected SimpleDiGraphs |
| 5 | + parts::Vector{ELTYPE} |
| 6 | + # The range of vertices within each of `parts` |
| 7 | + parts_nv::Vector{UnitRange{T}} |
| 8 | + # The number of edges in each of `parts` |
| 9 | + parts_ne::Vector{T} |
| 10 | + # The maximum number of nodes for each of `parts` |
| 11 | + parts_v_max::Int |
| 12 | + |
| 13 | + # A set of `AdjList` for each of `parts` |
| 14 | + # An edge is present here if either src or dst (but not both) is in |
| 15 | + # the respective `parts` graph |
| 16 | + ext_adjs::Vector{ELTYPE} |
| 17 | + # The number of edges in each of `ext_adjs` |
| 18 | + ext_adjs_ne::Vector{T} |
| 19 | + # The number of edges in each of `ext_adjs` where the source is this partition |
| 20 | + ext_adjs_ne_src::Vector{T} |
| 21 | +end |
| 22 | +struct DGraph{T,D,F} <: Graphs.AbstractGraph{T} |
| 23 | + state::Dagger.Chunk{DGraphState{T,D}} |
| 24 | + function DGraph{T}(; chunksize::Integer=8, directed::Bool=false) where {T} |
| 25 | + D = directed |
| 26 | + state = DGraphState{T,D}(ELTYPE[], |
| 27 | + UnitRange{T}[], |
| 28 | + T[], |
| 29 | + chunksize, |
| 30 | + ELTYPE[], |
| 31 | + T[], |
| 32 | + T[]) |
| 33 | + return new{T,D}(Dagger.tochunk(state)) |
| 34 | + end |
| 35 | +end |
| 36 | +DGraph(; kwargs...) = DGraph{Int}(; kwargs...) |
| 37 | +function DGraph{T}(n::Integer; kwargs...) where T |
| 38 | + g = DGraph{T}(; kwargs...) |
| 39 | + add_vertices!(g, n) |
| 40 | + return g |
| 41 | +end |
| 42 | +DGraph(n::Integer; kwargs...) = DGraph{Int}(n; kwargs...) |
| 43 | +function DGraph(sg::AbstractGraph{T}; directed::Bool=is_directed(sg), kwargs...) where T |
| 44 | + g = DGraph{T}(nv(sg); directed, kwargs...) |
| 45 | + foreach(edges(sg)) do edge |
| 46 | + add_edge!(g, edge) |
| 47 | + if !is_directed(sg) && directed |
| 48 | + add_edge!(g, dst(edge), src(edge)) |
| 49 | + end |
| 50 | + end |
| 51 | + return g |
| 52 | +end |
| 53 | +function DGraph(dg::DGraph{T,D,F}; directed::Bool=D, freeze::Bool=false, chunksize::Integer=0) where {T,D} |
| 54 | + state = fetch(dg.state) |
| 55 | + # FIXME: Create g.state on same node as dg.state |
| 56 | + if chunksize == 0 |
| 57 | + chunksize = state.parts_v_max |
| 58 | + end |
| 59 | + g = DGraph{T}(; directed, chunksize) |
| 60 | + @assert g.state.handle.owner == dg.state.handle.owner |
| 61 | + new_state = fetch(g.state) |
| 62 | + # TODO: Use streaming |
| 63 | + # FIXME: Support directed != D |
| 64 | + @assert directed == D "Changing directedness not yet supported" |
| 65 | + for part in 1:length(state.parts) |
| 66 | + # FIXME: Create on same nodes |
| 67 | + push!(new_state.parts, Dagger.@spawn copy(state.parts[part])) |
| 68 | + push!(new_state.parts_nv, state.parts_nv[part]) |
| 69 | + push!(new_state.parts_ne, state.parts_ne[part]) |
| 70 | + |
| 71 | + push!(new_state.ext_adjs, Dagger.@spawn copy(state.ext_adjs[part])) |
| 72 | + push!(new_state.ext_adjs_ne, state.ext_adjs_ne[part]) |
| 73 | + push!(new_state.ext_adjs_ne_src, state.ext_adjs_ne_src[part]) |
| 74 | + end |
| 75 | + #= |
| 76 | + foreach(edges(dg)) do edge |
| 77 | + add_edge!(g, edge) |
| 78 | + if !is_directed(dg) && directed |
| 79 | + add_edge!(g, dst(edge), src(edge)) |
| 80 | + end |
| 81 | + end |
| 82 | + =# |
| 83 | + return g |
| 84 | +end |
| 85 | + |
| 86 | +freeze(g::DGraph{T,D,false}) where {T,D} = DGraph(g; freeze=true) |
| 87 | + |
| 88 | +function Base.show(io::IO, g::DGraph{T,D}) where {T,D} |
| 89 | + print(io, "{$(nv(g)), $(ne(g))} $(D ? "" : "un")directed Dagger $T graph") |
| 90 | +end |
| 91 | + |
| 92 | +Base.eltype(::DGraph{T}) where T = T |
| 93 | +Graphs.edgetype(::DGraph{T}) where T = Tuple{T,T} |
| 94 | +Graphs.nv(g::DGraph) = fetch(Dagger.@spawn nv(g.state))::Int |
| 95 | +function Graphs.nv(g::DGraphState) |
| 96 | + if !isempty(g.parts_nv) |
| 97 | + return last(g.parts_nv).stop |
| 98 | + else |
| 99 | + return 0 |
| 100 | + end |
| 101 | +end |
| 102 | +Graphs.ne(g::DGraph) = fetch(Dagger.@spawn ne(g.state))::Int |
| 103 | +Graphs.ne(g::DGraphState) = sum(g.parts_ne; init=0) + sum(g.ext_adjs_ne_src; init=0) |
| 104 | +Graphs.has_vertex(g::DGraph, v::Integer) = 1 <= v <= nv(g) |
| 105 | +Graphs.has_edge(g::DGraph, src::Integer, dst::Integer) = |
| 106 | + fetch(Dagger.@spawn has_edge(g.state, src, dst))::Bool |
| 107 | +function Graphs.has_edge(g::DGraphState{T,D}, src::Integer, dst::Integer) where {T,D} |
| 108 | + src_part_idx = findfirst(span->src in span, g.parts_nv) |
| 109 | + src_part_idx !== nothing || return false |
| 110 | + dst_part_idx = findfirst(span->dst in span, g.parts_nv) |
| 111 | + dst_part_idx !== nothing || return false |
| 112 | + |
| 113 | + if src_part_idx == dst_part_idx |
| 114 | + # The edge will be within a graph partition |
| 115 | + part = g.parts[src_part_idx] |
| 116 | + return fetch(Dagger.@spawn has_edge(part, src, dst)) |
| 117 | + else |
| 118 | + # The edge will be in an AdjList |
| 119 | + adj = g.ext_adjs[src_part_idx] |
| 120 | + return fetch(Dagger.@spawn has_ext_adj(adj, src, dst, D)) |
| 121 | + end |
| 122 | +end |
| 123 | +Graphs.is_directed(::DGraph{T,D}) where {T,D} = D |
| 124 | +Graphs.vertices(g::DGraph) = Base.OneTo(nv(g)) |
| 125 | +Graphs.edges(g::DGraph) = DGraphEdgeIter(g) |
| 126 | +Graphs.zero(::Type{<:DGraph}) = DGraph() |
| 127 | +function Graphs.add_vertex!(g::DGraph) |
| 128 | + fetch(Dagger.@spawn add_vertices!(g.state, 1)) |
| 129 | + return |
| 130 | +end |
| 131 | +Graphs.add_vertices!(g::DGraph, n::Integer) = |
| 132 | + fetch(Dagger.@spawn add_vertices!(g.state, n)) |
| 133 | +function Graphs.add_vertices!(g::DGraphState, n::Integer) |
| 134 | + for _ in 1:n |
| 135 | + if fld(nv(g), g.parts_v_max) == length(g.parts) |
| 136 | + # We need to create a new partition for this vertex |
| 137 | + add_partition!(g, 1) |
| 138 | + else |
| 139 | + # We will add this vertex to the last partition |
| 140 | + part = last(g.parts) |
| 141 | + fetch(Dagger.@spawn add_vertex!(part)) |
| 142 | + span = g.parts_nv[end] |
| 143 | + g.parts_nv[end] = UnitRange{Int}(span.start, span.stop+1) |
| 144 | + end |
| 145 | + end |
| 146 | + return n |
| 147 | +end |
| 148 | +add_partition!(g::DGraph, n::Integer) = |
| 149 | + fetch(Dagger.@spawn add_partition!(g.state, n)) |
| 150 | +function add_partition!(g::DGraphState{T,D}, n::Integer) where {T,D} |
| 151 | + if n < 1 |
| 152 | + throw(ArgumentError("n must be >= 1")) |
| 153 | + end |
| 154 | + push!(g.parts, Dagger.spawn(n) do n |
| 155 | + g = D ? SimpleDiGraph() : SimpleGraph() |
| 156 | + add_vertices!(g, n) |
| 157 | + g |
| 158 | + end) |
| 159 | + num_v = nv(g) |
| 160 | + push!(g.parts_nv, (num_v+1):(num_v+n)) |
| 161 | + push!(g.parts_ne, 0) |
| 162 | + push!(g.ext_adjs, Dagger.@spawn AdjList()) |
| 163 | + push!(g.ext_adjs_ne, 0) |
| 164 | + push!(g.ext_adjs_ne_src, 0) |
| 165 | + return length(g.parts) |
| 166 | +end |
| 167 | +Graphs.add_edge!(g::DGraph, src::Integer, dst::Integer) = |
| 168 | + fetch(Dagger.@spawn add_edge!(g.state, src, dst)) |
| 169 | +Graphs.add_edge!(g::DGraph, edge::Edge) = |
| 170 | + add_edge!(g, src(edge), dst(edge)) |
| 171 | +function Graphs.add_edge!(g::DGraphState{T,D}, src::Integer, dst::Integer) where {T,D} |
| 172 | + src_part_idx = findfirst(span->src in span, g.parts_nv) |
| 173 | + @assert src_part_idx !== nothing "Source vertex $src does not exist" |
| 174 | + |
| 175 | + dst_part_idx = findfirst(span->dst in span, g.parts_nv) |
| 176 | + @assert dst_part_idx !== nothing "Destination vertex $dst does not exist" |
| 177 | + |
| 178 | + if src_part_idx == dst_part_idx |
| 179 | + # Edge exists within a single partition |
| 180 | + part = g.parts[src_part_idx] |
| 181 | + src_shift = src - (g.parts_nv[src_part_idx].start - 1) |
| 182 | + dst_shift = dst - (g.parts_nv[dst_part_idx].start - 1) |
| 183 | + if fetch(Dagger.@spawn add_edge!(part, src_shift, dst_shift)) |
| 184 | + g.parts_ne[src_part_idx] += 1 |
| 185 | + else |
| 186 | + return false |
| 187 | + end |
| 188 | + else |
| 189 | + # Edge spans two partitions |
| 190 | + src_ext_adj = g.ext_adjs[src_part_idx] |
| 191 | + dst_ext_adj = g.ext_adjs[dst_part_idx] |
| 192 | + src_t = Dagger.@spawn add_ext_adj!(src_ext_adj, src, dst, D) |
| 193 | + dst_t = Dagger.@spawn add_ext_adj!(dst_ext_adj, src, dst, D) |
| 194 | + if !fetch(src_t) || !fetch(dst_t) |
| 195 | + return false |
| 196 | + end |
| 197 | + if D |
| 198 | + # TODO: This will cause imbalance for many outgoing edges from a few vertices |
| 199 | + g.ext_adjs_ne_src[src_part_idx] += 1 |
| 200 | + else |
| 201 | + owner_part_idx = edge_owner(src, dst, src_part_idx, dst_part_idx) |
| 202 | + g.ext_adjs_ne_src[owner_part_idx] += 1 |
| 203 | + end |
| 204 | + g.ext_adjs_ne[src_part_idx] += 1 |
| 205 | + g.ext_adjs_ne[dst_part_idx] += 1 |
| 206 | + end |
| 207 | + |
| 208 | + return true |
| 209 | +end |
| 210 | +edge_owner(src::Int, dst::Int, src_part_idx::Int, dst_part_idx::Int) = |
| 211 | + iseven(hash(Base.unsafe_trunc(UInt, src+dst))) ? src_part_idx : dst_part_idx |
| 212 | +Graphs.inneighbors(g::DGraph, v::Integer) = |
| 213 | + fetch(Dagger.@spawn inneighbors(g.state, v)) |
| 214 | +function Graphs.inneighbors(g::DGraphState, v::Integer) |
| 215 | + part_idx = findfirst(span->v in span, g.parts_nv) |
| 216 | + if part_idx === nothing |
| 217 | + throw(BoundsError(g, v)) |
| 218 | + end |
| 219 | + |
| 220 | + neighbors = Int[] |
| 221 | + shift = g.parts_nv[part_idx].start - 1 |
| 222 | + |
| 223 | + # Check against local edges |
| 224 | + v_shift = v - shift |
| 225 | + for local_neigh in fetch(Dagger.@spawn inneighbors(g.parts[part_idx], v_shift)) |
| 226 | + push!(neighbors, local_neigh + shift) |
| 227 | + end |
| 228 | + |
| 229 | + # Check against external edges |
| 230 | + for ext_neigh in fetch(Dagger.@spawn inneighbors(g.ext_adjs[part_idx], v)) |
| 231 | + push!(neighbors, ext_neigh) |
| 232 | + end |
| 233 | + |
| 234 | + return neighbors |
| 235 | +end |
| 236 | +Graphs.outneighbors(g::DGraph, v::Integer) = |
| 237 | + fetch(Dagger.@spawn outneighbors(g.state, v)) |
| 238 | +function Graphs.outneighbors(g::DGraphState, v::Integer) |
| 239 | + part_idx = findfirst(span->v in span, g.parts_nv) |
| 240 | + if part_idx === nothing |
| 241 | + throw(BoundsError(g, v)) |
| 242 | + end |
| 243 | + |
| 244 | + neighbors = Int[] |
| 245 | + shift = g.parts_nv[part_idx].start - 1 |
| 246 | + |
| 247 | + # Check against local edges |
| 248 | + v_shift = v - shift |
| 249 | + for local_neigh in fetch(Dagger.@spawn outneighbors(g.parts[part_idx], v_shift)) |
| 250 | + push!(neighbors, local_neigh + shift) |
| 251 | + end |
| 252 | + |
| 253 | + # Check against external edges |
| 254 | + for ext_neigh in fetch(Dagger.@spawn outneighbors(g.ext_adjs[part_idx], v)) |
| 255 | + push!(neighbors, ext_neigh) |
| 256 | + end |
| 257 | + |
| 258 | + return neighbors |
| 259 | +end |
0 commit comments