Skip to content

Commit 322cb2d

Browse files
committed
Unifies the with_graph and with_actions/with_transitions code paths in
application builder Implementation of #526. Changes: 1. Adds with_graph(graph: Graph) to GraphBuilder -- this enables you to absorb another graph 2. Loosens up the mutually exclusive paths of with_graph and with_actions/transitions -- simplify to have one internal variable, allowing you to mix and match 3. Add tests for the API 4. Adds validation to ensure you don't have duplicated action names Note that this introduces the caveat of multiple duplicated action names, although the validation + fact that this is closer to a power-user feature makes it less problematic.
1 parent 56e74ed commit 322cb2d

File tree

3 files changed

+81
-27
lines changed

3 files changed

+81
-27
lines changed

burr/core/application.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,7 +2065,6 @@ def __init__(self):
20652065
self.loaded_from_fork: bool = False
20662066
self.tracker = None
20672067
self.graph_builder = None
2068-
self.prebuilt_graph = None
20692068
self.typing_system = None
20702069
self.parallel_executor_factory = None
20712070
self.state_persister = None
@@ -2143,9 +2142,21 @@ def with_state(
21432142
self.state = State(kwargs)
21442143
return self
21452144

2145+
def with_graphs(self, *graphs) -> "ApplicationBuilder[StateType]":
2146+
"""Adds multiple prebuilt graphs -- this just calls :py:meth:`with_graph <burr.core.application.ApplicationBuilder.with_graph>`
2147+
in a loop! See caveats in :py:meth:`with_graph <burr.core.application.ApplicationBuilder.with_graph>`.
2148+
2149+
:param graphs: Graphs to add to the application
2150+
:return: The application builder for future chaining.
2151+
"""
2152+
for graph in graphs:
2153+
self.with_graph(graph)
2154+
return self
2155+
21462156
def with_graph(self, graph: Graph) -> "ApplicationBuilder[StateType]":
2147-
"""Adds a prebuilt graph -- this is an alternative to using the with_actions and with_transitions methods.
2148-
While you will likely use with_actions and with_transitions, you may want this in a few cases:
2157+
"""Adds a prebuilt graph -- this can work in addition to using with_actions and with_transitions methods.
2158+
This will add all nodes + edges from a prebuilt graph to the current graph. Note that if you add two
2159+
graphs (or a combination of graphs/nodes/edges), you will need to ensure that there are no node name conflicts.
21492160
21502161
1. You want to reuse the same graph object for different applications
21512162
2. You want the logic that constructs the graph to be separate from that which constructs the application
@@ -2154,13 +2165,8 @@ def with_graph(self, graph: Graph) -> "ApplicationBuilder[StateType]":
21542165
:param graph: Graph object built with the :py:class:`GraphBuilder <burr.core.graph.GraphBuilder>`
21552166
:return: The application builder for future chaining.
21562167
"""
2157-
if self.graph_builder is not None:
2158-
raise ValueError(
2159-
BASE_ERROR_MESSAGE
2160-
+ "You have already called `with_actions`, or `with_transitions` -- you currently "
2161-
"cannot use the with_graph method along with that. Use `with_graph` or the other methods, not both"
2162-
)
2163-
self.prebuilt_graph = graph
2168+
self._initialize_graph_builder()
2169+
self.graph_builder = self.graph_builder.with_graph(graph)
21642170
return self
21652171

21662172
def with_parallel_executor(self, executor_factory: lambda: Executor):
@@ -2190,15 +2196,6 @@ def with_parallel_executor(self, executor_factory: lambda: Executor):
21902196
self.parallel_executor_factory = executor_factory
21912197
return self
21922198

2193-
def _ensure_no_prebuilt_graph(self):
2194-
if self.prebuilt_graph is not None:
2195-
raise ValueError(
2196-
BASE_ERROR_MESSAGE + "You have already called `with_graph` -- you currently "
2197-
"cannot use the with_actions, or with_transitions method along with that. "
2198-
"Use `with_graph` or the other methods, not both."
2199-
)
2200-
return self
2201-
22022199
def _initialize_graph_builder(self):
22032200
if self.graph_builder is None:
22042201
self.graph_builder = GraphBuilder()
@@ -2233,7 +2230,6 @@ def with_actions(
22332230
:param action_dict: Actions to add, keyed by name
22342231
:return: The application builder for future chaining.
22352232
"""
2236-
self._ensure_no_prebuilt_graph()
22372233
self._initialize_graph_builder()
22382234
self.graph_builder = self.graph_builder.with_actions(*action_list, **action_dict)
22392235
return self
@@ -2256,7 +2252,6 @@ def with_transitions(
22562252
:param transitions: Transitions to add
22572253
:return: The application builder for future chaining.
22582254
"""
2259-
self._ensure_no_prebuilt_graph()
22602255
self._initialize_graph_builder()
22612256
self.graph_builder = self.graph_builder.with_transitions(*transitions)
22622257
return self
@@ -2583,15 +2578,13 @@ def reset_to_entrypoint(self):
25832578
self.state = self.state.wipe(delete=[PRIOR_STEP])
25842579

25852580
def _get_built_graph(self) -> Graph:
2586-
if self.graph_builder is None and self.prebuilt_graph is None:
2581+
if self.graph_builder is None:
25872582
raise ValueError(
25882583
BASE_ERROR_MESSAGE
2589-
+ "You must set the graph using with_graph, or use with_entrypoint, with_actions, and with_transitions"
2590-
" to build the graph."
2584+
+ "No graph constructs exist. You must call some combination of with_graph, with_entrypoint, "
2585+
"with_actions, and with_transitions"
25912586
)
2592-
if self.graph_builder is not None:
2593-
return self.graph_builder.build()
2594-
return self.prebuilt_graph
2587+
return self.graph_builder.build()
25952588

25962589
def _build_common(self) -> Application:
25972590
graph = self._get_built_graph()

burr/core/graph.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ def _validate_actions(actions: Optional[List[Action]]):
2626
assert_set(actions, "_actions", "with_actions")
2727
if len(actions) == 0:
2828
raise ValueError("Must have at least one action in the application!")
29+
seen_action_names = set()
30+
for action in actions:
31+
if action.name in seen_action_names:
32+
raise ValueError(
33+
f"Action: {action.name} is duplicated in the actions list. "
34+
"Please ensure all actions have unique names. This could happen"
35+
"if you add two actions with the same name or add a graph that"
36+
"has actions with the same name as any that already exist."
37+
)
38+
seen_action_names.add(action.name)
2939

3040

3141
def _validate_transitions(
@@ -321,6 +331,25 @@ def with_transitions(
321331
self.transitions.append((action, to_, condition))
322332
return self
323333

334+
def with_graph(self, graph: Graph) -> "GraphBuilder":
335+
"""Adds an existing graph to the builder. Note that if you have any name clashes
336+
this will error out. This would happen if you add actions with the same name as actions
337+
that already exist.
338+
339+
:param graph: The graph to add
340+
:return: The application builder for future chaining.
341+
"""
342+
if self.actions is None:
343+
self.actions = []
344+
if self.transitions is None:
345+
self.transitions = []
346+
self.actions.extend(graph.actions)
347+
self.transitions.extend(
348+
(transition.from_.name, transition.to.name, transition.condition)
349+
for transition in graph.transitions
350+
)
351+
return self
352+
324353
def build(self) -> Graph:
325354
"""Builds/finalizes the graph.
326355

tests/core/test_graph.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ def test__validate_actions_empty():
8787
_validate_actions([])
8888

8989

90+
def test__validate_actions_duplicated():
91+
with pytest.raises(ValueError, match="duplicated"):
92+
_validate_actions([Result("test"), Result("test")])
93+
94+
9095
base_counter_action = PassedInAction(
9196
reads=["count"],
9297
writes=["count"],
@@ -110,6 +115,33 @@ def test_graph_builder_builds():
110115
assert len(graph.transitions) == 2
111116

112117

118+
def test_graph_builder_with_graph():
119+
graph1 = (
120+
GraphBuilder()
121+
.with_actions(counter=base_counter_action)
122+
.with_transitions(("counter", "counter", Condition.expr("count < 10")))
123+
.build()
124+
)
125+
graph2 = (
126+
GraphBuilder()
127+
.with_actions(counter2=base_counter_action)
128+
.with_transitions(("counter2", "counter2", Condition.expr("count < 20")))
129+
.build()
130+
)
131+
graph = (
132+
GraphBuilder()
133+
.with_graph(graph1)
134+
.with_graph(graph2)
135+
.with_actions(result=Result("count"))
136+
.with_transitions(
137+
("counter", "counter2"),
138+
("counter2", "result"),
139+
)
140+
)
141+
assert len(graph.actions) == 3
142+
assert len(graph.transitions) == 4
143+
144+
113145
def test_graph_builder_get_next_node():
114146
graph = (
115147
GraphBuilder()

0 commit comments

Comments
 (0)