Skip to content

Commit c9da012

Browse files
committed
Add DaggerGraphs subpackage
1 parent 9cbc0b4 commit c9da012

File tree

8 files changed

+499
-0
lines changed

8 files changed

+499
-0
lines changed

lib/DaggerGraphs/Project.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
name = "DaggerGraphs"
2+
uuid = "304567ff-242f-4479-af00-6fcdcd11a1dd"
3+
authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>", "pszufe <pszufe@gmail.com>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
8+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
9+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
10+
11+
[compat]
12+
Dagger = "0.17, 0.18"
13+
Graphs = "1"
14+
Tables = "1"
15+
julia = "1"

lib/DaggerGraphs/src/DaggerGraphs.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module DaggerGraphs
2+
3+
using Dagger
4+
import Dagger: Chunk
5+
using Graphs
6+
import Tables
7+
8+
export DGraph
9+
10+
include("dgraph.jl")
11+
include("adjlist.jl")
12+
include("edgeiter.jl")
13+
include("tables.jl")
14+
15+
function DGraph{T}(x; kwargs...) where T
16+
if Tables.istable(x)
17+
return fromtable(T, x; kwargs...)
18+
end
19+
throw(ArgumentError("Cannot convert a $(typeof(x)) to a DGraph"))
20+
end
21+
DGraph(x; kwargs...) = DGraph{Int}(x; kwargs...)
22+
23+
end # module DaggerGraphs

lib/DaggerGraphs/src/adjlist.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
struct AdjList{T}
2+
adj::Vector{Tuple{T,T}}
3+
end
4+
AdjList() = AdjList{Int}(Tuple{Int,Int}[])
5+
Base.copy(adj::AdjList) = AdjList(copy(adj.adj))
6+
function has_ext_adj(adj::AdjList, src::Int, dst::Int, directed::Bool)
7+
idx = findfirst(edge->(edge == (src, dst)) ||
8+
(!directed && (edge == (dst, src))),
9+
adj.adj)
10+
if idx !== nothing
11+
return true
12+
end
13+
return false
14+
end
15+
function add_ext_adj!(adj::AdjList, src::Int, dst::Int, directed::Bool)
16+
if has_ext_adj(adj, src, dst, directed)
17+
return false
18+
end
19+
push!(adj.adj, (src, dst))
20+
return true
21+
end
22+
Graphs.edges(adj::AdjList) = adj.adj
23+
function Graphs.inneighbors(adj::AdjList, v::Integer)
24+
neighbors = Int[]
25+
for (src, dst) in adj.adj
26+
if dst == v
27+
push!(neighbors, src)
28+
end
29+
end
30+
return neighbors
31+
end
32+
function Graphs.outneighbors(adj::AdjList, v::Integer)
33+
neighbors = Int[]
34+
for (src, dst) in adj.adj
35+
if src == v
36+
push!(neighbors, dst)
37+
end
38+
end
39+
return neighbors
40+
end

lib/DaggerGraphs/src/dgraph.jl

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
const ELTYPE = Union{Dagger.EagerThunk, Chunk}
2+
3+
struct DGraphState{T,D}
4+
# A set of locally-connected SimpleDiGraphs
5+
parts::Vector{ELTYPE}
6+
# The range of vertices within each of `parts`
7+
parts_nv::Vector{UnitRange{T}}
8+
# The number of edges in each of `parts`
9+
parts_ne::Vector{T}
10+
# The maximum number of nodes for each of `parts`
11+
parts_v_max::Int
12+
13+
# A set of `AdjList` for each of `parts`
14+
# An edge is present here if either src or dst (but not both) is in
15+
# the respective `parts` graph
16+
ext_adjs::Vector{ELTYPE}
17+
# The number of edges in each of `ext_adjs`
18+
ext_adjs_ne::Vector{T}
19+
# The number of edges in each of `ext_adjs` where the source is this partition
20+
ext_adjs_ne_src::Vector{T}
21+
end
22+
struct DGraph{T,D,F} <: Graphs.AbstractGraph{T}
23+
state::Dagger.Chunk{DGraphState{T,D}}
24+
function DGraph{T}(; chunksize::Integer=8, directed::Bool=false) where {T}
25+
D = directed
26+
state = DGraphState{T,D}(ELTYPE[],
27+
UnitRange{T}[],
28+
T[],
29+
chunksize,
30+
ELTYPE[],
31+
T[],
32+
T[])
33+
return new{T,D}(Dagger.tochunk(state))
34+
end
35+
end
36+
DGraph(; kwargs...) = DGraph{Int}(; kwargs...)
37+
function DGraph{T}(n::Integer; kwargs...) where T
38+
g = DGraph{T}(; kwargs...)
39+
add_vertices!(g, n)
40+
return g
41+
end
42+
DGraph(n::Integer; kwargs...) = DGraph{Int}(n; kwargs...)
43+
function DGraph(sg::AbstractGraph{T}; directed::Bool=is_directed(sg), kwargs...) where T
44+
g = DGraph{T}(nv(sg); directed, kwargs...)
45+
foreach(edges(sg)) do edge
46+
add_edge!(g, edge)
47+
if !is_directed(sg) && directed
48+
add_edge!(g, dst(edge), src(edge))
49+
end
50+
end
51+
return g
52+
end
53+
function DGraph(dg::DGraph{T,D,F}; directed::Bool=D, freeze::Bool=false, chunksize::Integer=0) where {T,D}
54+
state = fetch(dg.state)
55+
# FIXME: Create g.state on same node as dg.state
56+
if chunksize == 0
57+
chunksize = state.parts_v_max
58+
end
59+
g = DGraph{T}(; directed, chunksize)
60+
@assert g.state.handle.owner == dg.state.handle.owner
61+
new_state = fetch(g.state)
62+
# TODO: Use streaming
63+
# FIXME: Support directed != D
64+
@assert directed == D "Changing directedness not yet supported"
65+
for part in 1:length(state.parts)
66+
# FIXME: Create on same nodes
67+
push!(new_state.parts, Dagger.@spawn copy(state.parts[part]))
68+
push!(new_state.parts_nv, state.parts_nv[part])
69+
push!(new_state.parts_ne, state.parts_ne[part])
70+
71+
push!(new_state.ext_adjs, Dagger.@spawn copy(state.ext_adjs[part]))
72+
push!(new_state.ext_adjs_ne, state.ext_adjs_ne[part])
73+
push!(new_state.ext_adjs_ne_src, state.ext_adjs_ne_src[part])
74+
end
75+
#=
76+
foreach(edges(dg)) do edge
77+
add_edge!(g, edge)
78+
if !is_directed(dg) && directed
79+
add_edge!(g, dst(edge), src(edge))
80+
end
81+
end
82+
=#
83+
return g
84+
end
85+
86+
freeze(g::DGraph{T,D,false}) where {T,D} = DGraph(g; freeze=true)
87+
88+
function Base.show(io::IO, g::DGraph{T,D}) where {T,D}
89+
print(io, "{$(nv(g)), $(ne(g))} $(D ? "" : "un")directed Dagger $T graph")
90+
end
91+
92+
Base.eltype(::DGraph{T}) where T = T
93+
Graphs.edgetype(::DGraph{T}) where T = Tuple{T,T}
94+
Graphs.nv(g::DGraph) = fetch(Dagger.@spawn nv(g.state))::Int
95+
function Graphs.nv(g::DGraphState)
96+
if !isempty(g.parts_nv)
97+
return last(g.parts_nv).stop
98+
else
99+
return 0
100+
end
101+
end
102+
Graphs.ne(g::DGraph) = fetch(Dagger.@spawn ne(g.state))::Int
103+
Graphs.ne(g::DGraphState) = sum(g.parts_ne; init=0) + sum(g.ext_adjs_ne_src; init=0)
104+
Graphs.has_vertex(g::DGraph, v::Integer) = 1 <= v <= nv(g)
105+
Graphs.has_edge(g::DGraph, src::Integer, dst::Integer) =
106+
fetch(Dagger.@spawn has_edge(g.state, src, dst))::Bool
107+
function Graphs.has_edge(g::DGraphState{T,D}, src::Integer, dst::Integer) where {T,D}
108+
src_part_idx = findfirst(span->src in span, g.parts_nv)
109+
src_part_idx !== nothing || return false
110+
dst_part_idx = findfirst(span->dst in span, g.parts_nv)
111+
dst_part_idx !== nothing || return false
112+
113+
if src_part_idx == dst_part_idx
114+
# The edge will be within a graph partition
115+
part = g.parts[src_part_idx]
116+
return fetch(Dagger.@spawn has_edge(part, src, dst))
117+
else
118+
# The edge will be in an AdjList
119+
adj = g.ext_adjs[src_part_idx]
120+
return fetch(Dagger.@spawn has_ext_adj(adj, src, dst, D))
121+
end
122+
end
123+
Graphs.is_directed(::DGraph{T,D}) where {T,D} = D
124+
Graphs.vertices(g::DGraph) = Base.OneTo(nv(g))
125+
Graphs.edges(g::DGraph) = DGraphEdgeIter(g)
126+
Graphs.zero(::Type{<:DGraph}) = DGraph()
127+
function Graphs.add_vertex!(g::DGraph)
128+
fetch(Dagger.@spawn add_vertices!(g.state, 1))
129+
return
130+
end
131+
Graphs.add_vertices!(g::DGraph, n::Integer) =
132+
fetch(Dagger.@spawn add_vertices!(g.state, n))
133+
function Graphs.add_vertices!(g::DGraphState, n::Integer)
134+
for _ in 1:n
135+
if fld(nv(g), g.parts_v_max) == length(g.parts)
136+
# We need to create a new partition for this vertex
137+
add_partition!(g, 1)
138+
else
139+
# We will add this vertex to the last partition
140+
part = last(g.parts)
141+
fetch(Dagger.@spawn add_vertex!(part))
142+
span = g.parts_nv[end]
143+
g.parts_nv[end] = UnitRange{Int}(span.start, span.stop+1)
144+
end
145+
end
146+
return n
147+
end
148+
add_partition!(g::DGraph, n::Integer) =
149+
fetch(Dagger.@spawn add_partition!(g.state, n))
150+
function add_partition!(g::DGraphState{T,D}, n::Integer) where {T,D}
151+
if n < 1
152+
throw(ArgumentError("n must be >= 1"))
153+
end
154+
push!(g.parts, Dagger.spawn(n) do n
155+
g = D ? SimpleDiGraph() : SimpleGraph()
156+
add_vertices!(g, n)
157+
g
158+
end)
159+
num_v = nv(g)
160+
push!(g.parts_nv, (num_v+1):(num_v+n))
161+
push!(g.parts_ne, 0)
162+
push!(g.ext_adjs, Dagger.@spawn AdjList())
163+
push!(g.ext_adjs_ne, 0)
164+
push!(g.ext_adjs_ne_src, 0)
165+
return length(g.parts)
166+
end
167+
Graphs.add_edge!(g::DGraph, src::Integer, dst::Integer) =
168+
fetch(Dagger.@spawn add_edge!(g.state, src, dst))
169+
Graphs.add_edge!(g::DGraph, edge::Edge) =
170+
add_edge!(g, src(edge), dst(edge))
171+
function Graphs.add_edge!(g::DGraphState{T,D}, src::Integer, dst::Integer) where {T,D}
172+
src_part_idx = findfirst(span->src in span, g.parts_nv)
173+
@assert src_part_idx !== nothing "Source vertex $src does not exist"
174+
175+
dst_part_idx = findfirst(span->dst in span, g.parts_nv)
176+
@assert dst_part_idx !== nothing "Destination vertex $dst does not exist"
177+
178+
if src_part_idx == dst_part_idx
179+
# Edge exists within a single partition
180+
part = g.parts[src_part_idx]
181+
src_shift = src - (g.parts_nv[src_part_idx].start - 1)
182+
dst_shift = dst - (g.parts_nv[dst_part_idx].start - 1)
183+
if fetch(Dagger.@spawn add_edge!(part, src_shift, dst_shift))
184+
g.parts_ne[src_part_idx] += 1
185+
else
186+
return false
187+
end
188+
else
189+
# Edge spans two partitions
190+
src_ext_adj = g.ext_adjs[src_part_idx]
191+
dst_ext_adj = g.ext_adjs[dst_part_idx]
192+
src_t = Dagger.@spawn add_ext_adj!(src_ext_adj, src, dst, D)
193+
dst_t = Dagger.@spawn add_ext_adj!(dst_ext_adj, src, dst, D)
194+
if !fetch(src_t) || !fetch(dst_t)
195+
return false
196+
end
197+
if D
198+
# TODO: This will cause imbalance for many outgoing edges from a few vertices
199+
g.ext_adjs_ne_src[src_part_idx] += 1
200+
else
201+
owner_part_idx = edge_owner(src, dst, src_part_idx, dst_part_idx)
202+
g.ext_adjs_ne_src[owner_part_idx] += 1
203+
end
204+
g.ext_adjs_ne[src_part_idx] += 1
205+
g.ext_adjs_ne[dst_part_idx] += 1
206+
end
207+
208+
return true
209+
end
210+
edge_owner(src::Int, dst::Int, src_part_idx::Int, dst_part_idx::Int) =
211+
iseven(hash(Base.unsafe_trunc(UInt, src+dst))) ? src_part_idx : dst_part_idx
212+
Graphs.inneighbors(g::DGraph, v::Integer) =
213+
fetch(Dagger.@spawn inneighbors(g.state, v))
214+
function Graphs.inneighbors(g::DGraphState, v::Integer)
215+
part_idx = findfirst(span->v in span, g.parts_nv)
216+
if part_idx === nothing
217+
throw(BoundsError(g, v))
218+
end
219+
220+
neighbors = Int[]
221+
shift = g.parts_nv[part_idx].start - 1
222+
223+
# Check against local edges
224+
v_shift = v - shift
225+
for local_neigh in fetch(Dagger.@spawn inneighbors(g.parts[part_idx], v_shift))
226+
push!(neighbors, local_neigh + shift)
227+
end
228+
229+
# Check against external edges
230+
for ext_neigh in fetch(Dagger.@spawn inneighbors(g.ext_adjs[part_idx], v))
231+
push!(neighbors, ext_neigh)
232+
end
233+
234+
return neighbors
235+
end
236+
Graphs.outneighbors(g::DGraph, v::Integer) =
237+
fetch(Dagger.@spawn outneighbors(g.state, v))
238+
function Graphs.outneighbors(g::DGraphState, v::Integer)
239+
part_idx = findfirst(span->v in span, g.parts_nv)
240+
if part_idx === nothing
241+
throw(BoundsError(g, v))
242+
end
243+
244+
neighbors = Int[]
245+
shift = g.parts_nv[part_idx].start - 1
246+
247+
# Check against local edges
248+
v_shift = v - shift
249+
for local_neigh in fetch(Dagger.@spawn outneighbors(g.parts[part_idx], v_shift))
250+
push!(neighbors, local_neigh + shift)
251+
end
252+
253+
# Check against external edges
254+
for ext_neigh in fetch(Dagger.@spawn outneighbors(g.ext_adjs[part_idx], v))
255+
push!(neighbors, ext_neigh)
256+
end
257+
258+
return neighbors
259+
end

0 commit comments

Comments
 (0)