Skip to content

Commit 83ca87f

Browse files
authored
feat(transform-flow): support eval() for transient flow (#508)
1 parent 0638cfd commit 83ca87f

File tree

10 files changed

+178
-59
lines changed

10 files changed

+178
-59
lines changed

python/cocoindex/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Cocoindex is a framework for building and running indexing pipelines.
33
"""
44
from . import functions, query, sources, storages, cli
5-
from .flow import FlowBuilder, DataScope, DataSlice, Flow, flow_def
5+
from .flow import FlowBuilder, DataScope, DataSlice, Flow, flow_def, transform_flow
66
from .flow import EvaluateAndDumpOptions, GeneratedField
77
from .flow import update_all_flows_async, FlowLiveUpdater, FlowLiveUpdaterOptions
88
from .llm import LlmSpec, LlmApiType

python/cocoindex/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def make_engine_value_decoder(
4444

4545
src_type_kind = src_type['kind']
4646

47-
if dst_annotation is inspect.Parameter.empty:
47+
if dst_annotation is None or dst_annotation is inspect.Parameter.empty or dst_annotation is Any:
4848
if src_type_kind == 'Struct' or src_type_kind in TABLE_TYPES:
4949
raise ValueError(f"Missing type annotation for `{''.join(field_path)}`."
5050
f"It's required for {src_type_kind} type.")

python/cocoindex/flow.py

Lines changed: 110 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
import re
99
import inspect
1010
import datetime
11+
import functools
1112

12-
from typing import Any, Callable, Sequence, TypeVar
13+
from typing import Any, Callable, Sequence, TypeVar, Generic, get_args, get_origin, Type, NamedTuple
1314
from threading import Lock
1415
from enum import Enum
1516
from dataclasses import dataclass
@@ -20,7 +21,7 @@
2021
from . import index
2122
from . import op
2223
from . import setting
23-
from .convert import dump_engine_object
24+
from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder
2425
from .typing import encode_enriched_type
2526
from .runtime import execution_context
2627

@@ -123,7 +124,7 @@ def attach_to_scope(self, scope: _engine.DataScopeRef, field_name: str) -> None:
123124
# TODO: We'll support this by an identity transformer or "aliasing" in the future.
124125
raise ValueError("DataSlice is already attached to a field")
125126

126-
class DataSlice:
127+
class DataSlice(Generic[T]):
127128
"""A data slice represents a slice of data in a flow. It's readonly."""
128129

129130
_state: _DataSliceState
@@ -183,11 +184,11 @@ def transform(self, fn_spec: op.FunctionSpec, *args, **kwargs) -> DataSlice:
183184
name, prefix=_to_snake_case(_spec_kind(fn_spec))+'_'),
184185
))
185186

186-
def call(self, func: Callable[[DataSlice], T]) -> T:
187+
def call(self, func: Callable[[DataSlice], T], *args, **kwargs) -> T:
187188
"""
188189
Call a function with the data slice.
189190
"""
190-
return func(self)
191+
return func(self, *args, **kwargs)
191192

192193
def _data_slice_state(data_slice: DataSlice) -> _DataSliceState:
193194
return data_slice._state # pylint: disable=protected-access
@@ -642,48 +643,136 @@ async def _update_flow(name: str, fl: Flow) -> tuple[str, _engine.IndexUpdateInf
642643
all_stats = await asyncio.gather(*(_update_flow(name, fl) for (name, fl) in fls.items()))
643644
return dict(all_stats)
644645

645-
_transient_flow_name_builder = _NameBuilder()
646-
class TransientFlow:
646+
def _get_data_slice_annotation_type(data_slice_type: Type[DataSlice[T]]) -> Type[T] | None:
647+
type_args = get_args(data_slice_type)
648+
if data_slice_type is DataSlice:
649+
return None
650+
if get_origin(data_slice_type) != DataSlice or len(type_args) != 1:
651+
raise ValueError(f"Expect a DataSlice[T] type, but got {data_slice_type}")
652+
return type_args[0]
653+
654+
_transform_flow_name_builder = _NameBuilder()
655+
656+
class TransformFlowInfo(NamedTuple):
657+
engine_flow: _engine.TransientFlow
658+
result_decoder: Callable[[Any], T]
659+
660+
class TransformFlow(Generic[T]):
647661
"""
648662
A transient transformation flow that transforms in-memory data.
649663
"""
650-
_engine_flow: _engine.TransientFlow
664+
_flow_fn: Callable[..., DataSlice[T]]
665+
_flow_name: str
666+
_flow_arg_types: list[Any]
667+
_param_names: list[str]
668+
669+
_lazy_lock: asyncio.Lock
670+
_lazy_flow_info: TransformFlowInfo | None = None
651671

652672
def __init__(
653-
self, flow_fn: Callable[..., DataSlice],
673+
self, flow_fn: Callable[..., DataSlice[T]],
654674
flow_arg_types: Sequence[Any], /, name: str | None = None):
675+
self._flow_fn = flow_fn
676+
self._flow_name = _transform_flow_name_builder.build_name(name, prefix="_transform_flow_")
677+
self._flow_arg_types = list(flow_arg_types)
678+
self._lazy_lock = asyncio.Lock()
679+
680+
def __call__(self, *args, **kwargs) -> DataSlice[T]:
681+
return self._flow_fn(*args, **kwargs)
655682

656-
flow_builder_state = _FlowBuilderState(
657-
name=_transient_flow_name_builder.build_name(name, prefix="_transient_flow_"))
658-
sig = inspect.signature(flow_fn)
659-
if len(sig.parameters) != len(flow_arg_types):
683+
@property
684+
def _flow_info(self) -> TransformFlowInfo:
685+
if self._lazy_flow_info is not None:
686+
return self._lazy_flow_info
687+
return execution_context.run(self._flow_info_async())
688+
689+
async def _flow_info_async(self) -> TransformFlowInfo:
690+
if self._lazy_flow_info is not None:
691+
return self._lazy_flow_info
692+
async with self._lazy_lock:
693+
if self._lazy_flow_info is None:
694+
self._lazy_flow_info = await self._build_flow_info_async()
695+
return self._lazy_flow_info
696+
697+
async def _build_flow_info_async(self) -> TransformFlowInfo:
698+
flow_builder_state = _FlowBuilderState(name=self._flow_name)
699+
sig = inspect.signature(self._flow_fn)
700+
if len(sig.parameters) != len(self._flow_arg_types):
660701
raise ValueError(
661702
f"Number of parameters in the flow function ({len(sig.parameters)}) "
662-
"does not match the number of argument types ({len(flow_arg_types)})")
703+
f"does not match the number of argument types ({len(self._flow_arg_types)})")
663704

664705
kwargs: dict[str, DataSlice] = {}
665-
for (param_name, param), param_type in zip(sig.parameters.items(), flow_arg_types):
706+
for (param_name, param), param_type in zip(sig.parameters.items(), self._flow_arg_types):
666707
if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
667708
inspect.Parameter.KEYWORD_ONLY):
668709
raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name")
669710
engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(
670711
param_name, encode_enriched_type(param_type))
671712
kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds))
672713

673-
output = flow_fn(**kwargs)
714+
output = self._flow_fn(**kwargs)
674715
flow_builder_state.engine_flow_builder.set_direct_output(
675716
_data_slice_state(output).engine_data_slice)
676-
self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow(
677-
execution_context.event_loop)
717+
engine_flow = await flow_builder_state.engine_flow_builder.build_transient_flow_async(execution_context.event_loop)
718+
self._param_names = list(sig.parameters.keys())
719+
720+
engine_return_type = _data_slice_state(output).engine_data_slice.data_type().schema()
721+
python_return_type = _get_data_slice_annotation_type(sig.return_annotation)
722+
result_decoder = make_engine_value_decoder([], engine_return_type['type'], python_return_type)
723+
724+
return TransformFlowInfo(engine_flow, result_decoder)
678725

679726
def __str__(self):
680-
return str(self._engine_flow)
727+
return str(self._flow_info.engine_flow)
681728

682729
def __repr__(self):
683-
return repr(self._engine_flow)
730+
return repr(self._flow_info.engine_flow)
684731

685732
def internal_flow(self) -> _engine.TransientFlow:
686733
"""
687734
Get the internal flow.
688735
"""
689-
return self._engine_flow
736+
return self._flow_info.engine_flow
737+
738+
def eval(self, *args, **kwargs) -> T:
739+
"""
740+
Evaluate the transform flow.
741+
"""
742+
return execution_context.run(self.eval_async(*args, **kwargs))
743+
744+
async def eval_async(self, *args, **kwargs) -> T:
745+
"""
746+
Evaluate the transform flow.
747+
"""
748+
flow_info = await self._flow_info_async()
749+
params = []
750+
for i, arg in enumerate(self._param_names):
751+
if i < len(args):
752+
params.append(encode_engine_value(args[i]))
753+
elif arg in kwargs:
754+
params.append(encode_engine_value(kwargs[arg]))
755+
else:
756+
raise ValueError(f"Parameter {arg} is not provided")
757+
engine_result = await flow_info.engine_flow.evaluate_async(params)
758+
return flow_info.result_decoder(engine_result)
759+
760+
761+
def transform_flow() -> Callable[[Callable[..., DataSlice[T]]], TransformFlow[T]]:
762+
"""
763+
A decorator to wrap the transform function.
764+
"""
765+
def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]):
766+
sig = inspect.signature(fn)
767+
arg_types = []
768+
for (param_name, param) in sig.parameters.items():
769+
if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
770+
inspect.Parameter.KEYWORD_ONLY):
771+
raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name")
772+
arg_types.append(_get_data_slice_annotation_type(param.annotation))
773+
774+
_transform_flow = TransformFlow(fn, arg_types)
775+
functools.update_wrapper(_transform_flow, fn)
776+
return _transform_flow
777+
778+
return _transform_flow_wrapper

python/cocoindex/op.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,26 +100,26 @@ def behavior_version(self):
100100
return op_args.behavior_version
101101

102102
class _WrappedClass(executor_cls, _Fallback):
103-
_args_converters: list[Callable[[Any], Any]]
104-
_kwargs_converters: dict[str, Callable[[str, Any], Any]]
103+
_args_decoders: list[Callable[[Any], Any]]
104+
_kwargs_decoders: dict[str, Callable[[str, Any], Any]]
105105
_acall: Callable
106106

107107
def __init__(self, spec):
108108
super().__init__()
109109
self.spec = spec
110110
self._acall = _to_async_call(super().__call__)
111111

112-
def analyze(self, *args, **kwargs):
112+
def analyze(self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema):
113113
"""
114114
Analyze the spec and arguments. In this phase, argument types should be validated.
115115
It should return the expected result type for the current op.
116116
"""
117-
self._args_converters = []
118-
self._kwargs_converters = {}
117+
self._args_decoders = []
118+
self._kwargs_decoders = {}
119119

120120
# Match arguments with parameters.
121121
next_param_idx = 0
122-
for arg in args:
122+
for arg in args:
123123
if next_param_idx >= len(expected_args):
124124
raise ValueError(
125125
f"Too many arguments passed in: {len(args)} > {len(expected_args)}")
@@ -128,7 +128,7 @@ def analyze(self, *args, **kwargs):
128128
inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD):
129129
raise ValueError(
130130
f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
131-
self._args_converters.append(
131+
self._args_decoders.append(
132132
make_engine_value_decoder(
133133
[arg_name], arg.value_type['type'], arg_param.annotation))
134134
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
@@ -146,7 +146,7 @@ def analyze(self, *args, **kwargs):
146146
if expected_arg is None:
147147
raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
148148
arg_param = expected_arg[1]
149-
self._kwargs_converters[kwarg_name] = make_engine_value_decoder(
149+
self._kwargs_decoders[kwarg_name] = make_engine_value_decoder(
150150
[kwarg_name], kwarg.value_type['type'], arg_param.annotation)
151151

152152
missing_args = [name for (name, arg) in expected_kwargs
@@ -174,8 +174,8 @@ async def prepare(self):
174174
await _to_async_call(setup_method)()
175175

176176
async def __call__(self, *args, **kwargs):
177-
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
178-
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg)
177+
decoded_args = (decoder(arg) for decoder, arg in zip(self._args_decoders, args))
178+
decoded_kwargs = {arg_name: self._kwargs_decoders[arg_name](arg)
179179
for arg_name, arg in kwargs.items()}
180180

181181
if op_args.gpu:
@@ -185,9 +185,9 @@ async def __call__(self, *args, **kwargs):
185185
# For now, we use a lock to ensure only one task is executed at a time.
186186
# TODO: Implement multi-processing dispatching.
187187
async with _gpu_dispatch_lock:
188-
output = await self._acall(*converted_args, **converted_kwargs)
188+
output = await self._acall(*decoded_args, **decoded_kwargs)
189189
else:
190-
output = await self._acall(*converted_args, **converted_kwargs)
190+
output = await self._acall(*decoded_args, **decoded_kwargs)
191191
return encode_engine_value(output)
192192

193193
_WrappedClass.__name__ = executor_cls.__name__

python/cocoindex/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _lazy_handler() -> _engine.SimpleSemanticsQueryHandler:
5050
if engine_handler is None:
5151
engine_handler = _engine.SimpleSemanticsQueryHandler(
5252
flow.internal_flow(), target_name,
53-
fl.TransientFlow(query_transform_flow, [str]).internal_flow(),
53+
fl.TransformFlow(query_transform_flow, [str]).internal_flow(),
5454
default_similarity_metric.value)
5555
engine_handler.register_query_handler(name)
5656
return engine_handler

python/cocoindex/setting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _load_field(target: dict[str, Any], name: str, env_name: str, required: bool
4949
class Settings:
5050
"""Settings for the cocoindex library."""
5151
database: DatabaseConnectionSpec
52-
app_namespace: str
52+
app_namespace: str = ""
5353

5454
@classmethod
5555
def from_env(cls) -> Self:

src/builder/analyzed_flow.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,11 @@ pub struct AnalyzedTransientFlow {
7575
impl AnalyzedTransientFlow {
7676
pub async fn from_transient_flow(
7777
transient_flow: spec::TransientFlowSpec,
78-
registry: &ExecutorFactoryRegistry,
7978
py_exec_ctx: Option<crate::py::PythonExecutionContext>,
8079
) -> Result<Self> {
8180
let ctx = analyzer::build_flow_instance_context(&transient_flow.name, py_exec_ctx);
8281
let (output_type, data_schema, execution_plan_fut) =
83-
analyzer::analyze_transient_flow(&transient_flow, &ctx, registry)?;
82+
analyzer::analyze_transient_flow(&transient_flow, &ctx)?;
8483
Ok(Self {
8584
transient_flow_instance: transient_flow,
8685
data_schema,

src/builder/analyzer.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,14 +1302,17 @@ pub fn analyze_flow(
13021302
pub fn analyze_transient_flow<'a>(
13031303
flow_inst: &TransientFlowSpec,
13041304
flow_ctx: &'_ Arc<FlowInstanceContext>,
1305-
registry: &'a ExecutorFactoryRegistry,
13061305
) -> Result<(
13071306
EnrichedValueType,
13081307
FlowSchema,
13091308
impl Future<Output = Result<TransientExecutionPlan>> + Send + 'a,
13101309
)> {
13111310
let mut root_data_scope = DataScopeBuilder::new();
1312-
let analyzer_ctx = AnalyzerContext { registry, flow_ctx };
1311+
let registry = crate::ops::executor_factory_registry();
1312+
let analyzer_ctx = AnalyzerContext {
1313+
registry: &registry,
1314+
flow_ctx,
1315+
};
13131316
let mut input_fields = vec![];
13141317
for field in flow_inst.input_fields.iter() {
13151318
let analyzed_field = root_data_scope.add_field(field.name.clone(), &field.value_type)?;

0 commit comments

Comments
 (0)