Skip to content

refactor(python-sdk): make async APIs clear #465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/gdrive_text_embedding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def gdrive_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope:
default_similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)

@cocoindex.main_fn()
async def _run():
def _run():
# Use a `FlowLiveUpdater` to keep the flow data updated.
async with cocoindex.FlowLiveUpdater(gdrive_text_embedding_flow):
with cocoindex.FlowLiveUpdater(gdrive_text_embedding_flow):
# Run queries in a loop to demonstrate the query capabilities.
while True:
try:
Expand All @@ -74,4 +74,4 @@ async def _run():

if __name__ == "__main__":
load_dotenv(override=True)
asyncio.run(_run())
_run()
2 changes: 1 addition & 1 deletion python/cocoindex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from . import functions, query, sources, storages, cli
from .flow import FlowBuilder, DataScope, DataSlice, Flow, flow_def
from .flow import EvaluateAndDumpOptions, GeneratedField
from .flow import update_all_flows, FlowLiveUpdater, FlowLiveUpdaterOptions
from .flow import update_all_flows_async, FlowLiveUpdater, FlowLiveUpdaterOptions
from .llm import LlmSpec, LlmApiType
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry
Expand Down
17 changes: 7 additions & 10 deletions python/cocoindex/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import click
import datetime

Expand All @@ -7,7 +6,6 @@

from . import flow, lib, setting
from .setup import sync_setup, drop_setup, flow_names_with_setup, apply_setup_changes
from .runtime import execution_context

@click.group()
def cli():
Expand Down Expand Up @@ -136,13 +134,12 @@ def update(flow_name: str | None, live: bool, quiet: bool):
Update the index to reflect the latest data from data sources.
"""
options = flow.FlowLiveUpdaterOptions(live_mode=live, print_stats=not quiet)
async def _update():
if flow_name is None:
await flow.update_all_flows(options)
else:
updater = await flow.FlowLiveUpdater.create(_flow_by_name(flow_name), options)
await updater.wait()
execution_context.run(_update())
if flow_name is None:
return flow.update_all_flows(options)
else:
updater = flow.FlowLiveUpdater(_flow_by_name(flow_name), options)
updater.wait()
return updater.update_stats()

@cli.command()
@click.argument("flow_name", type=str, required=False)
Expand Down Expand Up @@ -217,7 +214,7 @@ def server(address: str | None, live_update: bool, quiet: bool, cors_origin: str

if live_update:
options = flow.FlowLiveUpdaterOptions(live_mode=True, print_stats=not quiet)
execution_context.run(flow.update_all_flows(options))
flow.update_all_flows(options)
if COCOINDEX_HOST in cors_origins:
click.echo(f"Open CocoInsight at: {COCOINDEX_HOST}/cocoinsight")
input("Press Enter to stop...")
Expand Down
54 changes: 37 additions & 17 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import re
import inspect
import datetime
import json

from typing import Any, Callable, Sequence, TypeVar
from threading import Lock
Expand Down Expand Up @@ -394,12 +393,13 @@ def __init__(self, arg: Flow | _engine.FlowLiveUpdater, options: FlowLiveUpdater
arg.internal_flow(), dump_engine_object(options or FlowLiveUpdaterOptions())))

@staticmethod
async def create(fl: Flow, options: FlowLiveUpdaterOptions | None = None) -> FlowLiveUpdater:
async def create_async(fl: Flow, options: FlowLiveUpdaterOptions | None = None) -> FlowLiveUpdater:
"""
Create a live updater for a flow.
Similar to the constructor, but for async usage.
"""
engine_live_updater = await _engine.FlowLiveUpdater.create(
await fl.ainternal_flow(),
await fl.internal_flow_async(),
dump_engine_object(options or FlowLiveUpdaterOptions()))
return FlowLiveUpdater(engine_live_updater)

Expand All @@ -408,21 +408,28 @@ def __enter__(self) -> FlowLiveUpdater:

def __exit__(self, exc_type, exc_value, traceback):
self.abort()
execution_context.run(self.wait())
self.wait()

async def __aenter__(self) -> FlowLiveUpdater:
return self

async def __aexit__(self, exc_type, exc_value, traceback):
self.abort()
await self.wait()
await self.wait_async()

async def wait(self) -> None:
def wait(self) -> None:
"""
Wait for the live updater to finish.
"""
execution_context.run(self.wait_async())

async def wait_async(self) -> None:
"""
Wait for the live updater to finish. Async version.
"""
await self._engine_live_updater.wait()


def abort(self) -> None:
"""
Abort the live updater.
Expand Down Expand Up @@ -500,13 +507,20 @@ def name(self) -> str:
"""
return self._lazy_engine_flow().name()

async def update(self) -> _engine.IndexUpdateInfo:
def update(self) -> _engine.IndexUpdateInfo:
"""
Update the index defined by the flow.
Once the function returns, the index is fresh up to the moment when the function is called.
"""
return execution_context.run(self.update_async())

async def update_async(self) -> _engine.IndexUpdateInfo:
"""
Update the index defined by the flow.
Once the function returns, the indice is fresh up to the moment when the function is called.
Once the function returns, the index is fresh up to the moment when the function is called.
"""
updater = await FlowLiveUpdater.create(self, FlowLiveUpdaterOptions(live_mode=False))
await updater.wait()
updater = await FlowLiveUpdater.create_async(self, FlowLiveUpdaterOptions(live_mode=False))
await updater.wait_async()
return updater.update_stats()

def evaluate_and_dump(self, options: EvaluateAndDumpOptions):
Expand All @@ -521,7 +535,7 @@ def internal_flow(self) -> _engine.Flow:
"""
return self._lazy_engine_flow()

async def ainternal_flow(self) -> _engine.Flow:
async def internal_flow_async(self) -> _engine.Flow:
"""
Get the engine flow. The async version.
"""
Expand Down Expand Up @@ -587,21 +601,27 @@ def ensure_all_flows_built() -> None:
for fl in flows():
fl.internal_flow()

async def aensure_all_flows_built() -> None:
async def ensure_all_flows_built_async() -> None:
"""
Ensure all flows are built.
"""
for fl in flows():
await fl.ainternal_flow()
await fl.internal_flow_async()

def update_all_flows(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]:
"""
Update all flows.
"""
return execution_context.run(update_all_flows_async(options))

async def update_all_flows(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]:
async def update_all_flows_async(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]:
"""
Update all flows.
"""
await aensure_all_flows_built()
await ensure_all_flows_built_async()
async def _update_flow(fl: Flow) -> _engine.IndexUpdateInfo:
updater = await FlowLiveUpdater.create(fl, options)
await updater.wait()
updater = await FlowLiveUpdater.create_async(fl, options)
await updater.wait_async()
return updater.update_stats()
fls = flows()
all_stats = await asyncio.gather(*(_update_flow(fl) for fl in fls))
Expand Down