Skip to content

Parallel streams WIP not ready #533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,7 @@ async def astream_result(
which will be empty. Thus ``halt_after`` takes precedence -- if it is met, the streaming result container will contain the result of the
halt_after condition.
The :py:class:`AsyncStreamingResultContainer <burr.core.action.AsyncStreamingResultContainer>` is meant as a convenience -- specifically this allows for
The :py:class:`StreamingResultContainer <burr.core.action.StreamingResultContainer>` is meant as a convenience -- specifically this allows for
hooks, callbacks, etc... so you can take the control flow and still have state updated afterwards. Hooks/state update will be called after an exception
is thrown during streaming, or the stream is completed. Note that it is undefined behavior to attempt to execute another action while a stream is in progress.
Expand Down Expand Up @@ -1842,6 +1842,85 @@ async def callback(
generator, self._state, process_result, callback
)

@telemetry.capture_function_usage
@_call_execute_method_pre_post(ExecuteMethod.stream_iterate)
def stream_iterate(
self,
halt_after: Optional[Union[str, List[str]]] = None,
halt_before: Optional[Union[str, List[str]]] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> Generator[
Tuple[Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]]], None, None
]:
"""Produces an iterator that iterates through intermediate streams. You may want
to use this in something like deep research mode in which:
1. The user queries for a result
2. The application runs through a workflow
3. For each step of the workflow, the application
a. Streams out an intermediate result
b. Selects the next action/transition when the intermediate results is complete
Note that there are control-flow complexities involved here -- we need to ensure that the iterator
is properly pushed along and the prior streaming results containers are all finished before going to the next one.
:param halt_after: Action names/tags to halt before
:param halt_before: _description_, defaults to None
:param inputs: _description_, defaults to None
:return: _description_
:yield: _description_
"""
self.validate_correct_async_use()
halt_before, halt_after, inputs = self._process_control_flow_params(
halt_before, halt_after, inputs
)
self._validate_halt_conditions(halt_before, halt_after)
while self.has_next_action():
next_action = self.get_next_action()
_, streaming_result = self.stream_result(
halt_after=[next_action.name], halt_before=None, inputs=inputs
)
yield next_action, streaming_result
# We need to ensure it's fully exhausted before going to the next action
streaming_result.get()
if self._should_halt_iterate(halt_before, halt_after, next_action):
break

@telemetry.capture_function_usage
@_call_execute_method_pre_post(ExecuteMethod.astream_iterate)
async def astream_iterate(
self,
halt_after: Optional[Union[str, List[str]]] = None,
halt_before: Optional[Union[str, List[str]]] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> AsyncGenerator[
Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]]], None
]:
"""Async version of stream_iterate. Produces an async generator that iterates
through intermediate streams. See stream_iterate for more details.
:param halt_after: Action names/tags to halt before
:param halt_before: Action names/tags to halt after
:param inputs: Inputs to the first action run
:return: Async generator yielding tuples of (action, streaming_result_container)
:yield: Tuples of (action, streaming_result_container)
"""
self.validate_correct_async_use()
halt_before, halt_after, inputs = self._process_control_flow_params(
halt_before, halt_after, inputs
)
self._validate_halt_conditions(halt_before, halt_after)
while self.has_next_action():
next_action = self.get_next_action()
_, streaming_result = await self.astream_result( # Use astream_result
halt_after=[next_action.name], halt_before=None, inputs=inputs
)
yield next_action, streaming_result
# We need to ensure it's fully exhausted before going to the next action
await streaming_result.get() # await the get call
if self._should_halt_iterate(halt_before, halt_after, next_action):
break

@telemetry.capture_function_usage
def visualize(
self,
Expand Down
90 changes: 89 additions & 1 deletion burr/core/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Callable,
Dict,
Generator,
Hashable,
List,
Literal,
Optional,
Expand All @@ -21,7 +22,12 @@
from burr.common import async_utils
from burr.common.async_utils import SyncOrAsyncGenerator, SyncOrAsyncGeneratorOrItemOrList
from burr.core import Action, ApplicationBuilder, ApplicationContext, Graph, State
from burr.core.action import SingleStepAction
from burr.core.action import (
AsyncStreamingResultContainer,
SingleStepAction,
StreamingAction,
StreamingResultContainer,
)
from burr.core.application import ApplicationIdentifiers
from burr.core.graph import GraphBuilder
from burr.core.persistence import BaseStateLoader, BaseStateSaver
Expand Down Expand Up @@ -154,6 +160,16 @@ async def arun(self, parent_context: ApplicationContext):
)
return state

def stream_run(
self, state: State, **run_kwargs
) -> Generator[StreamingResultContainer, None, None]:
raise NotImplementedError

async def astream_run(
self, state: State, **run_kwargs
) -> Generator[AsyncStreamingResultContainer, None, None]:
raise NotImplementedError


def _stable_app_id_hash(app_id: str, child_key: str) -> str:
"""Gives a stable hash for an application. Given the parent app_id and a child key,
Expand Down Expand Up @@ -341,6 +357,78 @@ def reads(self) -> list[str]:
pass


StreamedType = TypeVar("StreamedType")
KeyType = TypeVar("KeyType", bound=Hashable)


class MultiProducerSingleConsumerSharedQueue(abc.ABC):
"""Abstract interface for a non-blocking shared queue.
This allows us to unify multiprocessing, multithreading, etc...
:param abc: _description_
"""

@abc.abstractmethod
def put(self, item: StreamedType) -> None:
pass

@abc.abstractmethod
def get(self, block: bool = True) -> Optional[StreamedType]:
"""Gets the latest from the queue. None if the queue is empty
and block is False.
:return: _description_
"""
pass

@abc.abstractmethod
def is_done(self) -> bool:
"""Whether or not the queue is done.
:return: _description_
"""
pass


# def sync_stream_merge(
# generators: Dict[KeyType, StreamedType]
# ) -> Generator[Tuple[KeyType, StreamedType], None, None]:
# """Merges multiple streams in the order in which they appear (insomuch as the
# receiver, e.g. this function, can capably discern)

# TODO -- add optional error handling capabilities

# :param keyed_streams: Streams with keys
# :yield: A tuple of key, stream item for each stream
# """
# pass


class TaskBasedStreamingAction(StreamingAction):
def __init__(self):
super().__init__()

@property
def writes(self) -> List[str]:
raise NotImplementedError

def update(self, result: Dict, state: State) -> State:
raise NotImplementedError

@property
def reads(self) -> List[str]:
raise NotImplementedError

def stream_run(self, state: State, **run_kwargs) -> Generator[Dict, None, None]:
raise NotImplementedError

@abc.abstractmethod
def tasks(
self, state: State, context: ApplicationContext, inputs: Dict[str, Any]
) -> Generator[SubGraphTask, None, None]:
raise NotImplementedError


def _cascade_adapter(
behavior: Union[Literal["cascade"], AdapterType, None],
adapter: Union[AdapterType, None],
Expand Down
2 changes: 2 additions & 0 deletions burr/lifecycle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ class ExecuteMethod(enum.Enum):
arun = "arun"
stream_result = "stream_result"
astream_result = "astream_result"
stream_iterate = "stream_iterate"
astream_iterate = "astream_iterate"


@lifecycle.base_hook("pre_run_execute_call")
Expand Down
109 changes: 109 additions & 0 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2045,6 +2045,115 @@ def test_stream_result_halt_after_run_through_streaming():
assert len(stream_event_tracker.post_end_stream_calls) == 1


@pytest.mark.parametrize("exhaust_intermediate_generators", [True, False])
def test_stream_iterate(exhaust_intermediate_generators: bool):
"""Tests that we can pass through streaming results in streaming iterate. Note that this tests two cases:
1. We exhaust the intermediate generators, and then call get() to get the final result
2. We don't exhaust the intermediate generators, and then call get() to get the final result
This ensures that the application effectively does it for us.
"""
action_tracker = CallCaptureTracker()
stream_event_tracker = StreamEventCaptureTracker()
counter_action = base_streaming_single_step_counter.with_name("counter")
counter_action_2 = base_streaming_single_step_counter.with_name("counter_2")
app = Application(
state=State({"count": 0}),
entrypoint="counter",
adapter_set=LifecycleAdapterSet(action_tracker, stream_event_tracker),
partition_key="test",
uid="test-123",
graph=Graph(
actions=[counter_action, counter_action_2],
transitions=[
Transition(counter_action, counter_action_2, default),
],
),
)
for _, streaming_container in app.stream_iterate(halt_after=["counter_2"]):
if exhaust_intermediate_generators:
results = list(streaming_container)
assert len(results) == 10
result, state = streaming_container.get()
assert result["count"] == state["count"] == 2
assert state["tracker"] == [1, 2]
assert len(action_tracker.pre_called) == 2
assert len(action_tracker.post_called) == 2
assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"}
assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"}
assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [
0,
1,
] # ensure sequence ID is respected
assert [item["sequence_id"] for _, item in action_tracker.post_called] == [
0,
1,
] # ensure sequence ID is respected

assert len(stream_event_tracker.pre_start_stream_calls) == 2
assert len(stream_event_tracker.post_end_stream_calls) == 2
assert len(stream_event_tracker.post_stream_item_calls) == 20
assert len(stream_event_tracker.post_stream_item_calls) == 20


@pytest.mark.asyncio
@pytest.mark.parametrize("exhaust_intermediate_generators", [True, False])
async def test_astream_iterate(exhaust_intermediate_generators: bool):
"""Tests that we can pass through streaming results in astream_iterate. Note that this tests two cases:
1. We exhaust the intermediate generators, and then call get() to get the final result
2. We don't exhaust the intermediate generators, and then call get() to get the final result
This ensures that the application effectively does it for us.
"""
action_tracker = CallCaptureTracker()
stream_event_tracker = StreamEventCaptureTracker()
counter_action = base_streaming_single_step_counter_async.with_name(
"counter"
) # Use async action
counter_action_2 = base_streaming_single_step_counter_async.with_name(
"counter_2"
) # Use async action
app = Application(
state=State({"count": 0}),
entrypoint="counter",
adapter_set=LifecycleAdapterSet(action_tracker, stream_event_tracker),
partition_key="test",
uid="test-123",
graph=Graph(
actions=[counter_action, counter_action_2],
transitions=[
Transition(counter_action, counter_action_2, default),
],
),
)
streaming_container = None # Define outside the loop to access later
async for _, streaming_container in app.astream_iterate(halt_after=["counter_2"]):
if exhaust_intermediate_generators:
results = []
async for item in streaming_container: # Use async for
results.append(item)
assert len(results) == 10
assert streaming_container is not None # Ensure the loop ran
result, state = await streaming_container.get() # Use await
assert result["count"] == state["count"] == 2
assert state["tracker"] == [1, 2]
assert len(action_tracker.pre_called) == 2
assert len(action_tracker.post_called) == 2
assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"}
assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"}
assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [
0,
1,
] # ensure sequence ID is respected
assert [item["sequence_id"] for _, item in action_tracker.post_called] == [
0,
1,
] # ensure sequence ID is respected

assert len(stream_event_tracker.pre_start_stream_calls) == 2
assert len(stream_event_tracker.post_end_stream_calls) == 2
assert len(stream_event_tracker.post_stream_item_calls) == 20
assert len(stream_event_tracker.post_stream_item_calls) == 20


async def test_astream_result_halt_after_run_through_streaming():
action_tracker = CallCaptureTracker()
stream_event_tracker = StreamEventCaptureTrackerAsync()
Expand Down