Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change the DAG to have separate nodes for operations and arrays #337

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 166 additions & 87 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import lru_cache

import networkx as nx
import zarr

from cubed.backend_array_api import backend_array_to_numpy_array
from cubed.primitive.blockwise import can_fuse_pipelines, fuse
Expand All @@ -16,6 +17,14 @@
# A unique ID with sensible ordering, used for making directory names
CONTEXT_ID = f"cubed-{datetime.now().strftime('%Y%m%dT%H%M%S')}-{uuid.uuid4()}"

sym_counter = 0


def gensym(name="op"):
global sym_counter
sym_counter += 1
return f"{name}-{sym_counter:03}"


class Plan:
"""Deferred computation plan for a graph of arrays.
Expand All @@ -30,6 +39,12 @@ class Plan:
a function with repeated inputs. For example, consider `equals` where the
two arguments are the same array. We need to keep track of these cases, so
we use a NetworkX `MultiDiGraph` rather than just as `DiGraph`.

Compared to a more traditional DAG representing a computation, in Cubed
nodes are not values that are passed to functions, they are instead
"parallel computations" which are run for their side effects. Data does
not flow through the graph - it is written to external storage (Zarr files)
as the output of one pipeline, then read back as the input to later pipelines.
"""

def __init__(self, dag):
Expand All @@ -56,28 +71,50 @@ def _new(
frame = inspect.currentframe().f_back # go back one in the stack
stack_summaries = extract_stack_summaries(frame, limit=10)

op_name_unique = gensym()

if pipeline is None:
# op
dag.add_node(
name,
name=name,
op_name_unique,
name=op_name_unique,
op_name=op_name,
target=target,
type="op",
stack_summaries=stack_summaries,
hidden=hidden,
)
else:
# array (when multiple outputs are supported there could be more than one)
dag.add_node(
name,
name=name,
op_name=op_name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
else:
# op
dag.add_node(
op_name_unique,
name=op_name_unique,
op_name=op_name,
type="op",
stack_summaries=stack_summaries,
hidden=hidden,
pipeline=pipeline,
)
# array (when multiple outputs are supported there could be more than one)
dag.add_node(
name,
name=name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
for x in source_arrays:
if hasattr(x, "name"):
dag.add_edge(x.name, name)
dag.add_edge(x.name, op_name_unique)

return Plan(dag)

Expand All @@ -94,42 +131,62 @@ def optimize(self):
nodes = {n: d for (n, d) in dag.nodes(data=True)}

def can_fuse(n):
# node must have a single predecessor
# - not multiple edges pointing to a single predecessor
# node must be the single successor to the predecessor
# and both must have pipelines that can be fused
if dag.in_degree(n) != 1:
# fuse a single chain looking like this:
# op1 -> op2_input -> op2

op2 = n

# if node (op2) does not have a pipeline then it can't be fused
if "pipeline" not in nodes[op2]:
return False
pre = next(dag.predecessors(n))
if dag.out_degree(pre) != 1:

# if node (op2) does not have exactly one input then don't fuse
# (it could have no inputs or multiple inputs)
if dag.in_degree(op2) != 1:
return False
if "pipeline" not in nodes[pre] or "pipeline" not in nodes[n]:

# if input is used by another node then don't fuse
op2_input = next(dag.predecessors(op2))
if dag.out_degree(op2_input) != 1:
return False

# if node producing input (op1) has more than one output then don't fuse
op1 = next(dag.predecessors(op2_input))
if dag.out_degree(op1) != 1:
return False
return can_fuse_pipelines(nodes[pre]["pipeline"], nodes[n]["pipeline"])

# op1 and op2 must have pipelines that can be fused
if "pipeline" not in nodes[op1]:
return False
return can_fuse_pipelines(nodes[op1]["pipeline"], nodes[op2]["pipeline"])

for n in list(dag.nodes()):
if can_fuse(n):
pre = next(dag.predecessors(n))
pipeline = fuse(nodes[pre]["pipeline"], nodes[n]["pipeline"])
nodes[n]["pipeline"] = pipeline
assert nodes[n]["target"] == pipeline.target_array
op2 = n
op2_input = next(dag.predecessors(op2))
op1 = next(dag.predecessors(op2_input))
op1_inputs = list(dag.predecessors(op1))

pipeline = fuse(nodes[op1]["pipeline"], nodes[op2]["pipeline"])
nodes[op2]["pipeline"] = pipeline

for p in dag.predecessors(pre):
dag.add_edge(p, n)
dag.remove_node(pre)
for n in op1_inputs:
dag.add_edge(n, op2)
dag.remove_node(op2_input)
dag.remove_node(op1)

return Plan(dag)

def _create_lazy_zarr_arrays(self, dag):
# find all lazy zarr arrays in dag
all_array_nodes = []
all_pipeline_nodes = []
lazy_zarr_arrays = []
reserved_mem_values = []
for n, d in dag.nodes(data=True):
if "pipeline" in d and d["pipeline"].reserved_mem is not None:
reserved_mem_values.append(d["pipeline"].reserved_mem)
if isinstance(d["target"], LazyZarrArray):
all_array_nodes.append(n)
all_pipeline_nodes.append(n)
if "target" in d and isinstance(d["target"], LazyZarrArray):
lazy_zarr_arrays.append(d["target"])

reserved_mem = max(reserved_mem_values, default=0)
Expand All @@ -143,14 +200,20 @@ def _create_lazy_zarr_arrays(self, dag):
name,
name=name,
op_name=op_name,
target=None,
type="op",
pipeline=pipeline,
projected_mem=pipeline.projected_mem,
num_tasks=pipeline.num_tasks,
)
# make create arrays node a dependency of all lazy array nodes
for n in all_array_nodes:
dag.add_edge(name, n)
dag.add_node(
"arrays",
name="arrays",
target=None,
)
dag.add_edge(name, "arrays")
# make create arrays node a predecessor of all pipeline nodes so it runs first
for n in all_pipeline_nodes:
dag.add_edge("arrays", n)

return dag

Expand Down Expand Up @@ -197,7 +260,7 @@ def num_tasks(self, optimize_graph=True, resume=None):
def num_arrays(self, optimize_graph: bool = True) -> int:
"""Return the number of arrays in this plan."""
dag = self._finalize_dag(optimize_graph=optimize_graph)
return sum(n != "create-arrays" for n in dag.nodes())
return sum(d.get("type") == "array" for _, d in dag.nodes(data=True))

def max_projected_mem(self, optimize_graph=True, resume=None):
"""Return the maximum projected memory across all tasks to execute this plan."""
Expand All @@ -214,8 +277,8 @@ def visualize(
dag = self._finalize_dag(optimize_graph=optimize_graph)
dag = dag.copy() # make a copy since we mutate the DAG below

# remove edges from create-arrays node to avoid cluttering the diagram
dag.remove_edges_from(list(dag.out_edges("create-arrays")))
# remove edges from create-arrays output node to avoid cluttering the diagram
dag.remove_edges_from(list(dag.out_edges("arrays")))

# remove hidden nodes
dag.remove_nodes_from(
Expand Down Expand Up @@ -253,69 +316,85 @@ def visualize(

# now set node attributes with visualization info
for n, d in dag.nodes(data=True):
if d["op_name"] == "blockwise":
d["style"] = "filled"
d["fillcolor"] = "#dcbeff"
op_name_summary = "(bw)"
elif d["op_name"] == "rechunk":
d["style"] = "filled"
d["fillcolor"] = "#aaffc3"
op_name_summary = "(rc)"
else: # creation function
op_name_summary = ""
target = d["target"]
if target is not None:
tooltip = f"name: {n}\n"
node_type = d.get("type", None)
if node_type == "op":
op_name = d["op_name"]
if op_name == "blockwise":
d["style"] = '"rounded,filled"'
d["fillcolor"] = "#dcbeff"
op_name_summary = "(bw)"
elif op_name == "rechunk":
d["style"] = '"rounded,filled"'
d["fillcolor"] = "#aaffc3"
op_name_summary = "(rc)"
else:
# creation function
d["style"] = "rounded"
op_name_summary = ""
tooltip += f"op: {op_name}"

if "pipeline" in d:
pipeline = d["pipeline"]
tooltip += (
f"\nprojected memory: {memory_repr(pipeline.projected_mem)}"
)
tooltip += f"\ntasks: {pipeline.num_tasks}"
if pipeline.write_chunks is not None:
tooltip += f"\nwrite chunks: {pipeline.write_chunks}"

# remove pipeline attribute since it is a long string that causes graphviz to fail
del d["pipeline"]

if "stack_summaries" in d and d["stack_summaries"] is not None:
# add call stack information
stack_summaries = d["stack_summaries"]

first_cubed_i = min(
i for i, s in enumerate(stack_summaries) if s.is_cubed()
)
first_cubed_summary = stack_summaries[first_cubed_i]
caller_summary = stack_summaries[first_cubed_i - 1]

d["label"] = f"{first_cubed_summary.name} {op_name_summary}"

calls = " -> ".join(
[
s.name
for s in stack_summaries
if not s.is_on_python_lib_path()
]
)

line = f"{caller_summary.lineno} in {caller_summary.name}"

tooltip += f"\ncalls: {calls}"
tooltip += f"\nline: {line}"
del d["stack_summaries"]

elif node_type == "array":
target = d["target"]
chunkmem = memory_repr(chunk_memory(target.dtype, target.chunks))
tooltip = (
f"name: {n}\n"
f"shape: {target.shape}\n"
f"chunks: {target.chunks}\n"
f"dtype: {target.dtype}\n"
f"chunk memory: {chunkmem}\n"
)
else:
tooltip = ""
if "pipeline" in d:
pipeline = d["pipeline"]
tooltip += f"\nprojected memory: {memory_repr(pipeline.projected_mem)}"
tooltip += f"\ntasks: {pipeline.num_tasks}"
if pipeline.write_chunks is not None:
tooltip += f"\nwrite chunks: {pipeline.write_chunks}"
if "stack_summaries" in d and d["stack_summaries"] is not None:
# add call stack information
stack_summaries = d["stack_summaries"]

first_cubed_i = min(
i for i, s in enumerate(stack_summaries) if s.is_cubed()
)
first_cubed_summary = stack_summaries[first_cubed_i]
caller_summary = stack_summaries[first_cubed_i - 1]

# materialized arrays are light orange, virtual arrays are white
if isinstance(target, (LazyZarrArray, zarr.Array)):
d["style"] = "filled"
d["fillcolor"] = "#ffd8b1"
if n in array_display_names:
var_name = f" ({array_display_names[n]})"
var_name = array_display_names[n]
d["label"] = f"{n} ({var_name})"
tooltip += f"variable: {var_name}\n"
else:
var_name = ""
d[
"label"
] = f"{n}{var_name}\n{first_cubed_summary.name} {op_name_summary}"
d["label"] = n
tooltip += f"shape: {target.shape}\n"
tooltip += f"chunks: {target.chunks}\n"
tooltip += f"dtype: {target.dtype}\n"
tooltip += f"chunk memory: {chunkmem}\n"

calls = " -> ".join(
[s.name for s in stack_summaries if not s.is_on_python_lib_path()]
)

line = f"{caller_summary.lineno} in {caller_summary.name}"

tooltip += f"\ncalls: {calls}"
tooltip += f"\nline: {line}"
del d["stack_summaries"]
del d["target"]

d["tooltip"] = tooltip.strip()

# remove pipeline attribute since it is a long string that causes graphviz to fail
if "pipeline" in d:
del d["pipeline"]
if "target" in d:
del d["target"]
if "name" in d: # pydot already has name
del d["name"]
gv = nx.drawing.nx_pydot.to_pydot(dag)
Expand Down
Loading
Loading