Skip to content

Commit 86e6aec

Browse files
authored
Make @main_fn decorator support async functions. (#259)
1 parent ce0a7b3 commit 86e6aec

File tree

2 files changed

+41
-18
lines changed

2 files changed

+41
-18
lines changed

examples/gdrive_text_embedding/main.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def gdrive_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope:
4949
default_similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)
5050

5151
@cocoindex.main_fn()
52-
def _run():
52+
async def _run():
5353
# Use a `FlowLiveUpdater` to keep the flow data updated.
54-
with cocoindex.FlowLiveUpdater(gdrive_text_embedding_flow):
54+
async with cocoindex.FlowLiveUpdater(gdrive_text_embedding_flow):
5555
# Run queries in a loop to demonstrate the query capabilities.
5656
while True:
5757
try:
@@ -70,4 +70,5 @@ def _run():
7070

7171
if __name__ == "__main__":
7272
load_dotenv(override=True)
73-
_run()
73+
import asyncio
74+
asyncio.run(_run())

python/cocoindex/lib.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""
22
Library level functions and states.
33
"""
4-
import json
54
import os
65
import sys
7-
from typing import Callable, Self
6+
import functools
7+
import inspect
8+
import asyncio
9+
from typing import Callable, Self, Any
810
from dataclasses import dataclass
911

1012
from . import _engine
@@ -78,20 +80,40 @@ def main_fn(
7880
7981
If the settings are not provided, they are loaded from the environment variables.
8082
"""
81-
def _main_wrapper(fn: Callable) -> Callable:
8283

83-
def _inner(*args, **kwargs):
84-
effective_settings = settings or Settings.from_env()
85-
init(effective_settings)
86-
try:
87-
if len(sys.argv) > 1 and sys.argv[1] == cocoindex_cmd:
88-
return cli.cli.main(sys.argv[2:], prog_name=f"{sys.argv[0]} {sys.argv[1]}")
89-
else:
90-
return fn(*args, **kwargs)
91-
finally:
92-
stop()
84+
def _pre_init() -> None:
85+
effective_settings = settings or Settings.from_env()
86+
init(effective_settings)
87+
88+
def _should_run_cli() -> bool:
89+
return len(sys.argv) > 1 and sys.argv[1] == cocoindex_cmd
9390

94-
_inner.__name__ = fn.__name__
95-
return _inner
91+
def _run_cli():
92+
return cli.cli.main(sys.argv[2:], prog_name=f"{sys.argv[0]} {sys.argv[1]}")
93+
94+
def _main_wrapper(fn: Callable) -> Callable:
95+
if inspect.iscoroutinefunction(fn):
96+
@functools.wraps(fn)
97+
async def _inner(*args, **kwargs):
98+
_pre_init()
99+
try:
100+
if _should_run_cli():
101+
# Schedule to a separate thread as it invokes nested event loop.
102+
return await asyncio.to_thread(_run_cli)
103+
return await fn(*args, **kwargs)
104+
finally:
105+
stop()
106+
return _inner
107+
else:
108+
@functools.wraps(fn)
109+
def _inner(*args, **kwargs):
110+
_pre_init()
111+
try:
112+
if _should_run_cli():
113+
return _run_cli()
114+
return fn(*args, **kwargs)
115+
finally:
116+
stop()
117+
return _inner
96118

97119
return _main_wrapper

0 commit comments

Comments
 (0)