diff --git a/python/cocoindex/cli.py b/python/cocoindex/cli.py index ae92387d..11dfe00c 100644 --- a/python/cocoindex/cli.py +++ b/python/cocoindex/cli.py @@ -137,9 +137,9 @@ def update(flow_name: str | None, live: bool, quiet: bool): if flow_name is None: return flow.update_all_flows(options) else: - updater = flow.FlowLiveUpdater(_flow_by_name(flow_name), options) - updater.wait() - return updater.update_stats() + with flow.FlowLiveUpdater(_flow_by_name(flow_name), options) as updater: + updater.wait() + return updater.update_stats() @cli.command() @click.argument("flow_name", type=str, required=False) diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index b989381a..f85f77b7 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -383,27 +383,16 @@ class FlowLiveUpdater: """ A live updater for a flow. """ - _engine_live_updater: _engine.FlowLiveUpdater + _flow: Flow + _options: FlowLiveUpdaterOptions + _engine_live_updater: _engine.FlowLiveUpdater | None = None - def __init__(self, arg: Flow | _engine.FlowLiveUpdater, options: FlowLiveUpdaterOptions | None = None): - if isinstance(arg, _engine.FlowLiveUpdater): - self._engine_live_updater = arg - else: - self._engine_live_updater = execution_context.run(_engine.FlowLiveUpdater( - arg.internal_flow(), dump_engine_object(options or FlowLiveUpdaterOptions()))) - - @staticmethod - async def create_async(fl: Flow, options: FlowLiveUpdaterOptions | None = None) -> FlowLiveUpdater: - """ - Create a live updater for a flow. - Similar to the constructor, but for async usage. - """ - engine_live_updater = await _engine.FlowLiveUpdater.create( - await fl.internal_flow_async(), - dump_engine_object(options or FlowLiveUpdaterOptions())) - return FlowLiveUpdater(engine_live_updater) + def __init__(self, fl: Flow, options: FlowLiveUpdaterOptions | None = None): + self._flow = fl + self._options = options or FlowLiveUpdaterOptions() def __enter__(self) -> FlowLiveUpdater: + self.start() return self def __exit__(self, exc_type, exc_value, traceback): @@ -411,12 +400,26 @@ def __exit__(self, exc_type, exc_value, traceback): self.wait() async def __aenter__(self) -> FlowLiveUpdater: + await self.start_async() return self async def __aexit__(self, exc_type, exc_value, traceback): self.abort() await self.wait_async() + def start(self) -> None: + """ + Start the live updater. + """ + execution_context.run(self.start_async()) + + async def start_async(self) -> None: + """ + Start the live updater. + """ + self._engine_live_updater = await _engine.FlowLiveUpdater.create( + await self._flow.internal_flow_async(), dump_engine_object(self._options)) + def wait(self) -> None: """ Wait for the live updater to finish. @@ -427,20 +430,24 @@ async def wait_async(self) -> None: """ Wait for the live updater to finish. Async version. """ - await self._engine_live_updater.wait() - + await self._get_engine_live_updater().wait() def abort(self) -> None: """ Abort the live updater. """ - self._engine_live_updater.abort() + self._get_engine_live_updater().abort() def update_stats(self) -> _engine.IndexUpdateInfo: """ Get the index update info. """ - return self._engine_live_updater.index_update_info() + return self._get_engine_live_updater().index_update_info() + + def _get_engine_live_updater(self) -> _engine.FlowLiveUpdater: + if self._engine_live_updater is None: + raise RuntimeError("Live updater is not started") + return self._engine_live_updater @dataclass @@ -620,9 +627,9 @@ async def update_all_flows_async(options: FlowLiveUpdaterOptions) -> dict[str, _ """ await ensure_all_flows_built_async() async def _update_flow(fl: Flow) -> _engine.IndexUpdateInfo: - updater = await FlowLiveUpdater.create_async(fl, options) - await updater.wait_async() - return updater.update_stats() + async with FlowLiveUpdater(fl, options) as updater: + await updater.wait_async() + return updater.update_stats() fls = flows() all_stats = await asyncio.gather(*(_update_flow(fl) for fl in fls)) return {fl.name: stats for fl, stats in zip(fls, all_stats)}