diff --git a/burr/core/application.py b/burr/core/application.py index a4c15a5a..12e28a07 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -1604,7 +1604,7 @@ async def astream_result( which will be empty. Thus ``halt_after`` takes precedence -- if it is met, the streaming result container will contain the result of the halt_after condition. - The :py:class:`AsyncStreamingResultContainer ` is meant as a convenience -- specifically this allows for + The :py:class:`StreamingResultContainer ` is meant as a convenience -- specifically this allows for hooks, callbacks, etc... so you can take the control flow and still have state updated afterwards. Hooks/state update will be called after an exception is thrown during streaming, or the stream is completed. Note that it is undefined behavior to attempt to execute another action while a stream is in progress. @@ -1842,6 +1842,85 @@ async def callback( generator, self._state, process_result, callback ) + @telemetry.capture_function_usage + @_call_execute_method_pre_post(ExecuteMethod.stream_iterate) + def stream_iterate( + self, + halt_after: Optional[Union[str, List[str]]] = None, + halt_before: Optional[Union[str, List[str]]] = None, + inputs: Optional[Dict[str, Any]] = None, + ) -> Generator[ + Tuple[Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]]], None, None + ]: + """Produces an iterator that iterates through intermediate streams. You may want + to use this in something like deep research mode in which: + + 1. The user queries for a result + 2. The application runs through a workflow + 3. For each step of the workflow, the application + a. Streams out an intermediate result + b. Selects the next action/transition when the intermediate results is complete + + Note that there are control-flow complexities involved here -- we need to ensure that the iterator + is properly pushed along and the prior streaming results containers are all finished before going to the next one. + + :param halt_after: Action names/tags to halt before + :param halt_before: _description_, defaults to None + :param inputs: _description_, defaults to None + :return: _description_ + :yield: _description_ + """ + self.validate_correct_async_use() + halt_before, halt_after, inputs = self._process_control_flow_params( + halt_before, halt_after, inputs + ) + self._validate_halt_conditions(halt_before, halt_after) + while self.has_next_action(): + next_action = self.get_next_action() + _, streaming_result = self.stream_result( + halt_after=[next_action.name], halt_before=None, inputs=inputs + ) + yield next_action, streaming_result + # We need to ensure it's fully exhausted before going to the next action + streaming_result.get() + if self._should_halt_iterate(halt_before, halt_after, next_action): + break + + @telemetry.capture_function_usage + @_call_execute_method_pre_post(ExecuteMethod.astream_iterate) + async def astream_iterate( + self, + halt_after: Optional[Union[str, List[str]]] = None, + halt_before: Optional[Union[str, List[str]]] = None, + inputs: Optional[Dict[str, Any]] = None, + ) -> AsyncGenerator[ + Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]]], None + ]: + """Async version of stream_iterate. Produces an async generator that iterates + through intermediate streams. See stream_iterate for more details. + + :param halt_after: Action names/tags to halt before + :param halt_before: Action names/tags to halt after + :param inputs: Inputs to the first action run + :return: Async generator yielding tuples of (action, streaming_result_container) + :yield: Tuples of (action, streaming_result_container) + """ + self.validate_correct_async_use() + halt_before, halt_after, inputs = self._process_control_flow_params( + halt_before, halt_after, inputs + ) + self._validate_halt_conditions(halt_before, halt_after) + while self.has_next_action(): + next_action = self.get_next_action() + _, streaming_result = await self.astream_result( # Use astream_result + halt_after=[next_action.name], halt_before=None, inputs=inputs + ) + yield next_action, streaming_result + # We need to ensure it's fully exhausted before going to the next action + await streaming_result.get() # await the get call + if self._should_halt_iterate(halt_before, halt_after, next_action): + break + @telemetry.capture_function_usage def visualize( self, diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py index 0c5c2376..5bdeba34 100644 --- a/burr/core/parallelism.py +++ b/burr/core/parallelism.py @@ -10,6 +10,7 @@ Callable, Dict, Generator, + Hashable, List, Literal, Optional, @@ -21,7 +22,12 @@ from burr.common import async_utils from burr.common.async_utils import SyncOrAsyncGenerator, SyncOrAsyncGeneratorOrItemOrList from burr.core import Action, ApplicationBuilder, ApplicationContext, Graph, State -from burr.core.action import SingleStepAction +from burr.core.action import ( + AsyncStreamingResultContainer, + SingleStepAction, + StreamingAction, + StreamingResultContainer, +) from burr.core.application import ApplicationIdentifiers from burr.core.graph import GraphBuilder from burr.core.persistence import BaseStateLoader, BaseStateSaver @@ -154,6 +160,16 @@ async def arun(self, parent_context: ApplicationContext): ) return state + def stream_run( + self, state: State, **run_kwargs + ) -> Generator[StreamingResultContainer, None, None]: + raise NotImplementedError + + async def astream_run( + self, state: State, **run_kwargs + ) -> Generator[AsyncStreamingResultContainer, None, None]: + raise NotImplementedError + def _stable_app_id_hash(app_id: str, child_key: str) -> str: """Gives a stable hash for an application. Given the parent app_id and a child key, @@ -341,6 +357,78 @@ def reads(self) -> list[str]: pass +StreamedType = TypeVar("StreamedType") +KeyType = TypeVar("KeyType", bound=Hashable) + + +class MultiProducerSingleConsumerSharedQueue(abc.ABC): + """Abstract interface for a non-blocking shared queue. + This allows us to unify multiprocessing, multithreading, etc... + + :param abc: _description_ + """ + + @abc.abstractmethod + def put(self, item: StreamedType) -> None: + pass + + @abc.abstractmethod + def get(self, block: bool = True) -> Optional[StreamedType]: + """Gets the latest from the queue. None if the queue is empty + and block is False. + + :return: _description_ + """ + pass + + @abc.abstractmethod + def is_done(self) -> bool: + """Whether or not the queue is done. + + :return: _description_ + """ + pass + + +# def sync_stream_merge( +# generators: Dict[KeyType, StreamedType] +# ) -> Generator[Tuple[KeyType, StreamedType], None, None]: +# """Merges multiple streams in the order in which they appear (insomuch as the +# receiver, e.g. this function, can capably discern) + +# TODO -- add optional error handling capabilities + +# :param keyed_streams: Streams with keys +# :yield: A tuple of key, stream item for each stream +# """ +# pass + + +class TaskBasedStreamingAction(StreamingAction): + def __init__(self): + super().__init__() + + @property + def writes(self) -> List[str]: + raise NotImplementedError + + def update(self, result: Dict, state: State) -> State: + raise NotImplementedError + + @property + def reads(self) -> List[str]: + raise NotImplementedError + + def stream_run(self, state: State, **run_kwargs) -> Generator[Dict, None, None]: + raise NotImplementedError + + @abc.abstractmethod + def tasks( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[SubGraphTask, None, None]: + raise NotImplementedError + + def _cascade_adapter( behavior: Union[Literal["cascade"], AdapterType, None], adapter: Union[AdapterType, None], diff --git a/burr/lifecycle/base.py b/burr/lifecycle/base.py index fde253ae..aa20cafe 100644 --- a/burr/lifecycle/base.py +++ b/burr/lifecycle/base.py @@ -287,6 +287,8 @@ class ExecuteMethod(enum.Enum): arun = "arun" stream_result = "stream_result" astream_result = "astream_result" + stream_iterate = "stream_iterate" + astream_iterate = "astream_iterate" @lifecycle.base_hook("pre_run_execute_call") diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 70676646..887a2d85 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -2045,6 +2045,115 @@ def test_stream_result_halt_after_run_through_streaming(): assert len(stream_event_tracker.post_end_stream_calls) == 1 +@pytest.mark.parametrize("exhaust_intermediate_generators", [True, False]) +def test_stream_iterate(exhaust_intermediate_generators: bool): + """Tests that we can pass through streaming results in streaming iterate. Note that this tests two cases: + 1. We exhaust the intermediate generators, and then call get() to get the final result + 2. We don't exhaust the intermediate generators, and then call get() to get the final result + This ensures that the application effectively does it for us. + """ + action_tracker = CallCaptureTracker() + stream_event_tracker = StreamEventCaptureTracker() + counter_action = base_streaming_single_step_counter.with_name("counter") + counter_action_2 = base_streaming_single_step_counter.with_name("counter_2") + app = Application( + state=State({"count": 0}), + entrypoint="counter", + adapter_set=LifecycleAdapterSet(action_tracker, stream_event_tracker), + partition_key="test", + uid="test-123", + graph=Graph( + actions=[counter_action, counter_action_2], + transitions=[ + Transition(counter_action, counter_action_2, default), + ], + ), + ) + for _, streaming_container in app.stream_iterate(halt_after=["counter_2"]): + if exhaust_intermediate_generators: + results = list(streaming_container) + assert len(results) == 10 + result, state = streaming_container.get() + assert result["count"] == state["count"] == 2 + assert state["tracker"] == [1, 2] + assert len(action_tracker.pre_called) == 2 + assert len(action_tracker.post_called) == 2 + assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} + assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} + assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ + 0, + 1, + ] # ensure sequence ID is respected + assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ + 0, + 1, + ] # ensure sequence ID is respected + + assert len(stream_event_tracker.pre_start_stream_calls) == 2 + assert len(stream_event_tracker.post_end_stream_calls) == 2 + assert len(stream_event_tracker.post_stream_item_calls) == 20 + assert len(stream_event_tracker.post_stream_item_calls) == 20 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("exhaust_intermediate_generators", [True, False]) +async def test_astream_iterate(exhaust_intermediate_generators: bool): + """Tests that we can pass through streaming results in astream_iterate. Note that this tests two cases: + 1. We exhaust the intermediate generators, and then call get() to get the final result + 2. We don't exhaust the intermediate generators, and then call get() to get the final result + This ensures that the application effectively does it for us. + """ + action_tracker = CallCaptureTracker() + stream_event_tracker = StreamEventCaptureTracker() + counter_action = base_streaming_single_step_counter_async.with_name( + "counter" + ) # Use async action + counter_action_2 = base_streaming_single_step_counter_async.with_name( + "counter_2" + ) # Use async action + app = Application( + state=State({"count": 0}), + entrypoint="counter", + adapter_set=LifecycleAdapterSet(action_tracker, stream_event_tracker), + partition_key="test", + uid="test-123", + graph=Graph( + actions=[counter_action, counter_action_2], + transitions=[ + Transition(counter_action, counter_action_2, default), + ], + ), + ) + streaming_container = None # Define outside the loop to access later + async for _, streaming_container in app.astream_iterate(halt_after=["counter_2"]): + if exhaust_intermediate_generators: + results = [] + async for item in streaming_container: # Use async for + results.append(item) + assert len(results) == 10 + assert streaming_container is not None # Ensure the loop ran + result, state = await streaming_container.get() # Use await + assert result["count"] == state["count"] == 2 + assert state["tracker"] == [1, 2] + assert len(action_tracker.pre_called) == 2 + assert len(action_tracker.post_called) == 2 + assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} + assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} + assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ + 0, + 1, + ] # ensure sequence ID is respected + assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ + 0, + 1, + ] # ensure sequence ID is respected + + assert len(stream_event_tracker.pre_start_stream_calls) == 2 + assert len(stream_event_tracker.post_end_stream_calls) == 2 + assert len(stream_event_tracker.post_stream_item_calls) == 20 + assert len(stream_event_tracker.post_stream_item_calls) == 20 + + async def test_astream_result_halt_after_run_through_streaming(): action_tracker = CallCaptureTracker() stream_event_tracker = StreamEventCaptureTrackerAsync()