Skip to content

feat(transform-flow): support eval() for transient flow #508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/docs_to_knowledge_graph/.env
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# Postgres database address for cocoindex
COCOINDEX_DATABASE_URL=postgres://cocoindex:cocoindex@localhost/cocoindex
COCOINDEX_APP_NAMESPACE=Dev0
2 changes: 1 addition & 1 deletion python/cocoindex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Cocoindex is a framework for building and running indexing pipelines.
"""
from . import functions, query, sources, storages, cli
from .flow import FlowBuilder, DataScope, DataSlice, Flow, flow_def
from .flow import FlowBuilder, DataScope, DataSlice, Flow, flow_def, transform_flow
from .flow import EvaluateAndDumpOptions, GeneratedField
from .flow import update_all_flows_async, FlowLiveUpdater, FlowLiveUpdaterOptions
from .llm import LlmSpec, LlmApiType
Expand Down
2 changes: 1 addition & 1 deletion python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def make_engine_value_decoder(

src_type_kind = src_type['kind']

if dst_annotation is inspect.Parameter.empty:
if dst_annotation is None or dst_annotation is inspect.Parameter.empty or dst_annotation is Any:
if src_type_kind == 'Struct' or src_type_kind in TABLE_TYPES:
raise ValueError(f"Missing type annotation for `{''.join(field_path)}`."
f"It's required for {src_type_kind} type.")
Expand Down
131 changes: 110 additions & 21 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import re
import inspect
import datetime
import functools

from typing import Any, Callable, Sequence, TypeVar
from typing import Any, Callable, Sequence, TypeVar, Generic, get_args, get_origin, Type, NamedTuple
from threading import Lock
from enum import Enum
from dataclasses import dataclass
Expand All @@ -20,7 +21,7 @@
from . import index
from . import op
from . import setting
from .convert import dump_engine_object
from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder
from .typing import encode_enriched_type
from .runtime import execution_context

Expand Down Expand Up @@ -123,7 +124,7 @@ def attach_to_scope(self, scope: _engine.DataScopeRef, field_name: str) -> None:
# TODO: We'll support this by an identity transformer or "aliasing" in the future.
raise ValueError("DataSlice is already attached to a field")

class DataSlice:
class DataSlice(Generic[T]):
"""A data slice represents a slice of data in a flow. It's readonly."""

_state: _DataSliceState
Expand Down Expand Up @@ -183,11 +184,11 @@ def transform(self, fn_spec: op.FunctionSpec, *args, **kwargs) -> DataSlice:
name, prefix=_to_snake_case(_spec_kind(fn_spec))+'_'),
))

def call(self, func: Callable[[DataSlice], T]) -> T:
def call(self, func: Callable[[DataSlice], T], *args, **kwargs) -> T:
"""
Call a function with the data slice.
"""
return func(self)
return func(self, *args, **kwargs)

def _data_slice_state(data_slice: DataSlice) -> _DataSliceState:
return data_slice._state # pylint: disable=protected-access
Expand Down Expand Up @@ -642,48 +643,136 @@ async def _update_flow(name: str, fl: Flow) -> tuple[str, _engine.IndexUpdateInf
all_stats = await asyncio.gather(*(_update_flow(name, fl) for (name, fl) in fls.items()))
return dict(all_stats)

_transient_flow_name_builder = _NameBuilder()
class TransientFlow:
def _get_data_slice_annotation_type(data_slice_type: Type[DataSlice[T]]) -> Type[T] | None:
type_args = get_args(data_slice_type)
if data_slice_type is DataSlice:
return None
if get_origin(data_slice_type) != DataSlice or len(type_args) != 1:
raise ValueError(f"Expect a DataSlice[T] type, but got {data_slice_type}")
return type_args[0]

_transform_flow_name_builder = _NameBuilder()

class TransformFlowInfo(NamedTuple):
engine_flow: _engine.TransientFlow
result_decoder: Callable[[Any], T]

class TransformFlow(Generic[T]):
"""
A transient transformation flow that transforms in-memory data.
"""
_engine_flow: _engine.TransientFlow
_flow_fn: Callable[..., DataSlice[T]]
_flow_name: str
_flow_arg_types: list[Any]
_param_names: list[str]

_lazy_lock: asyncio.Lock
_lazy_flow_info: TransformFlowInfo | None = None

def __init__(
self, flow_fn: Callable[..., DataSlice],
self, flow_fn: Callable[..., DataSlice[T]],
flow_arg_types: Sequence[Any], /, name: str | None = None):
self._flow_fn = flow_fn
self._flow_name = _transform_flow_name_builder.build_name(name, prefix="_transform_flow_")
self._flow_arg_types = list(flow_arg_types)
self._lazy_lock = asyncio.Lock()

def __call__(self, *args, **kwargs) -> DataSlice[T]:
return self._flow_fn(*args, **kwargs)

flow_builder_state = _FlowBuilderState(
name=_transient_flow_name_builder.build_name(name, prefix="_transient_flow_"))
sig = inspect.signature(flow_fn)
if len(sig.parameters) != len(flow_arg_types):
@property
def _flow_info(self) -> TransformFlowInfo:
if self._lazy_flow_info is not None:
return self._lazy_flow_info
return execution_context.run(self._flow_info_async())

async def _flow_info_async(self) -> TransformFlowInfo:
if self._lazy_flow_info is not None:
return self._lazy_flow_info
async with self._lazy_lock:
if self._lazy_flow_info is None:
self._lazy_flow_info = await self._build_flow_info_async()
return self._lazy_flow_info

async def _build_flow_info_async(self) -> TransformFlowInfo:
flow_builder_state = _FlowBuilderState(name=self._flow_name)
sig = inspect.signature(self._flow_fn)
if len(sig.parameters) != len(self._flow_arg_types):
raise ValueError(
f"Number of parameters in the flow function ({len(sig.parameters)}) "
"does not match the number of argument types ({len(flow_arg_types)})")
f"does not match the number of argument types ({len(self._flow_arg_types)})")

kwargs: dict[str, DataSlice] = {}
for (param_name, param), param_type in zip(sig.parameters.items(), flow_arg_types):
for (param_name, param), param_type in zip(sig.parameters.items(), self._flow_arg_types):
if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY):
raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name")
engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(
param_name, encode_enriched_type(param_type))
kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds))

output = flow_fn(**kwargs)
output = self._flow_fn(**kwargs)
flow_builder_state.engine_flow_builder.set_direct_output(
_data_slice_state(output).engine_data_slice)
self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow(
execution_context.event_loop)
engine_flow = await flow_builder_state.engine_flow_builder.build_transient_flow_async(execution_context.event_loop)
self._param_names = list(sig.parameters.keys())

engine_return_type = _data_slice_state(output).engine_data_slice.data_type().schema()
python_return_type = _get_data_slice_annotation_type(sig.return_annotation)
result_decoder = make_engine_value_decoder([], engine_return_type['type'], python_return_type)

return TransformFlowInfo(engine_flow, result_decoder)

def __str__(self):
return str(self._engine_flow)
return str(self._flow_info.engine_flow)

def __repr__(self):
return repr(self._engine_flow)
return repr(self._flow_info.engine_flow)

def internal_flow(self) -> _engine.TransientFlow:
"""
Get the internal flow.
"""
return self._engine_flow
return self._flow_info.engine_flow

def eval(self, *args, **kwargs) -> T:
"""
Evaluate the transform flow.
"""
return execution_context.run(self.eval_async(*args, **kwargs))

async def eval_async(self, *args, **kwargs) -> T:
"""
Evaluate the transform flow.
"""
flow_info = await self._flow_info_async()
params = []
for i, arg in enumerate(self._param_names):
if i < len(args):
params.append(encode_engine_value(args[i]))
elif arg in kwargs:
params.append(encode_engine_value(kwargs[arg]))
else:
raise ValueError(f"Parameter {arg} is not provided")
engine_result = await flow_info.engine_flow.evaluate_async(params)
return flow_info.result_decoder(engine_result)


def transform_flow() -> Callable[[Callable[..., DataSlice[T]]], TransformFlow[T]]:
"""
A decorator to wrap the transform function.
"""
def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]):
sig = inspect.signature(fn)
arg_types = []
for (param_name, param) in sig.parameters.items():
if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY):
raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name")
arg_types.append(_get_data_slice_annotation_type(param.annotation))

_transform_flow = TransformFlow(fn, arg_types)
functools.update_wrapper(_transform_flow, fn)
return _transform_flow

return _transform_flow_wrapper
24 changes: 12 additions & 12 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,26 @@ def behavior_version(self):
return op_args.behavior_version

class _WrappedClass(executor_cls, _Fallback):
_args_converters: list[Callable[[Any], Any]]
_kwargs_converters: dict[str, Callable[[str, Any], Any]]
_args_decoders: list[Callable[[Any], Any]]
_kwargs_decoders: dict[str, Callable[[str, Any], Any]]
_acall: Callable

def __init__(self, spec):
super().__init__()
self.spec = spec
self._acall = _to_async_call(super().__call__)

def analyze(self, *args, **kwargs):
def analyze(self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema):
"""
Analyze the spec and arguments. In this phase, argument types should be validated.
It should return the expected result type for the current op.
"""
self._args_converters = []
self._kwargs_converters = {}
self._args_decoders = []
self._kwargs_decoders = {}

# Match arguments with parameters.
next_param_idx = 0
for arg in args:
for arg in args:
if next_param_idx >= len(expected_args):
raise ValueError(
f"Too many arguments passed in: {len(args)} > {len(expected_args)}")
Expand All @@ -128,7 +128,7 @@ def analyze(self, *args, **kwargs):
inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD):
raise ValueError(
f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
self._args_converters.append(
self._args_decoders.append(
make_engine_value_decoder(
[arg_name], arg.value_type['type'], arg_param.annotation))
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
Expand All @@ -146,7 +146,7 @@ def analyze(self, *args, **kwargs):
if expected_arg is None:
raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
arg_param = expected_arg[1]
self._kwargs_converters[kwarg_name] = make_engine_value_decoder(
self._kwargs_decoders[kwarg_name] = make_engine_value_decoder(
[kwarg_name], kwarg.value_type['type'], arg_param.annotation)

missing_args = [name for (name, arg) in expected_kwargs
Expand Down Expand Up @@ -174,8 +174,8 @@ async def prepare(self):
await _to_async_call(setup_method)()

async def __call__(self, *args, **kwargs):
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg)
decoded_args = (decoder(arg) for decoder, arg in zip(self._args_decoders, args))
decoded_kwargs = {arg_name: self._kwargs_decoders[arg_name](arg)
for arg_name, arg in kwargs.items()}

if op_args.gpu:
Expand All @@ -185,9 +185,9 @@ async def __call__(self, *args, **kwargs):
# For now, we use a lock to ensure only one task is executed at a time.
# TODO: Implement multi-processing dispatching.
async with _gpu_dispatch_lock:
output = await self._acall(*converted_args, **converted_kwargs)
output = await self._acall(*decoded_args, **decoded_kwargs)
else:
output = await self._acall(*converted_args, **converted_kwargs)
output = await self._acall(*decoded_args, **decoded_kwargs)
return encode_engine_value(output)

_WrappedClass.__name__ = executor_cls.__name__
Expand Down
2 changes: 1 addition & 1 deletion python/cocoindex/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _lazy_handler() -> _engine.SimpleSemanticsQueryHandler:
if engine_handler is None:
engine_handler = _engine.SimpleSemanticsQueryHandler(
flow.internal_flow(), target_name,
fl.TransientFlow(query_transform_flow, [str]).internal_flow(),
fl.TransformFlow(query_transform_flow, [str]).internal_flow(),
default_similarity_metric.value)
engine_handler.register_query_handler(name)
return engine_handler
Expand Down
2 changes: 1 addition & 1 deletion python/cocoindex/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _load_field(target: dict[str, Any], name: str, env_name: str, required: bool
class Settings:
"""Settings for the cocoindex library."""
database: DatabaseConnectionSpec
app_namespace: str
app_namespace: str = ""

@classmethod
def from_env(cls) -> Self:
Expand Down
3 changes: 1 addition & 2 deletions src/builder/analyzed_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,11 @@ pub struct AnalyzedTransientFlow {
impl AnalyzedTransientFlow {
pub async fn from_transient_flow(
transient_flow: spec::TransientFlowSpec,
registry: &ExecutorFactoryRegistry,
py_exec_ctx: Option<crate::py::PythonExecutionContext>,
) -> Result<Self> {
let ctx = analyzer::build_flow_instance_context(&transient_flow.name, py_exec_ctx);
let (output_type, data_schema, execution_plan_fut) =
analyzer::analyze_transient_flow(&transient_flow, &ctx, registry)?;
analyzer::analyze_transient_flow(&transient_flow, &ctx)?;
Ok(Self {
transient_flow_instance: transient_flow,
data_schema,
Expand Down
7 changes: 5 additions & 2 deletions src/builder/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1302,14 +1302,17 @@ pub fn analyze_flow(
pub fn analyze_transient_flow<'a>(
flow_inst: &TransientFlowSpec,
flow_ctx: &'_ Arc<FlowInstanceContext>,
registry: &'a ExecutorFactoryRegistry,
) -> Result<(
EnrichedValueType,
FlowSchema,
impl Future<Output = Result<TransientExecutionPlan>> + Send + 'a,
)> {
let mut root_data_scope = DataScopeBuilder::new();
let analyzer_ctx = AnalyzerContext { registry, flow_ctx };
let registry = crate::ops::executor_factory_registry();
let analyzer_ctx = AnalyzerContext {
registry: &registry,
flow_ctx,
};
let mut input_fields = vec![];
for field in flow_inst.input_fields.iter() {
let analyzed_field = root_data_scope.add_field(field.name.clone(), &field.value_type)?;
Expand Down
Loading