From 7130ee148f14a9b08ea79b405bd7636df10b904a Mon Sep 17 00:00:00 2001 From: LJ Date: Sat, 17 May 2025 22:12:40 -0700 Subject: [PATCH] feat(transform-flow): support `eval()` for transient flow --- examples/docs_to_knowledge_graph/.env | 1 - python/cocoindex/__init__.py | 2 +- python/cocoindex/convert.py | 2 +- python/cocoindex/flow.py | 131 +++++++++++++++++++++----- python/cocoindex/op.py | 24 ++--- python/cocoindex/query.py | 2 +- python/cocoindex/setting.py | 2 +- src/builder/analyzed_flow.rs | 3 +- src/builder/analyzer.rs | 7 +- src/builder/flow_builder.rs | 41 ++++---- src/py/mod.rs | 23 +++++ 11 files changed, 178 insertions(+), 60 deletions(-) diff --git a/examples/docs_to_knowledge_graph/.env b/examples/docs_to_knowledge_graph/.env index 2ff10b65..335f3060 100644 --- a/examples/docs_to_knowledge_graph/.env +++ b/examples/docs_to_knowledge_graph/.env @@ -1,3 +1,2 @@ # Postgres database address for cocoindex COCOINDEX_DATABASE_URL=postgres://cocoindex:cocoindex@localhost/cocoindex -COCOINDEX_APP_NAMESPACE=Dev0 diff --git a/python/cocoindex/__init__.py b/python/cocoindex/__init__.py index 6c23d746..4ab0b4a1 100644 --- a/python/cocoindex/__init__.py +++ b/python/cocoindex/__init__.py @@ -2,7 +2,7 @@ Cocoindex is a framework for building and running indexing pipelines. """ from . import functions, query, sources, storages, cli -from .flow import FlowBuilder, DataScope, DataSlice, Flow, flow_def +from .flow import FlowBuilder, DataScope, DataSlice, Flow, flow_def, transform_flow from .flow import EvaluateAndDumpOptions, GeneratedField from .flow import update_all_flows_async, FlowLiveUpdater, FlowLiveUpdaterOptions from .llm import LlmSpec, LlmApiType diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index fad21b65..646e5bae 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -44,7 +44,7 @@ def make_engine_value_decoder( src_type_kind = src_type['kind'] - if dst_annotation is inspect.Parameter.empty: + if dst_annotation is None or dst_annotation is inspect.Parameter.empty or dst_annotation is Any: if src_type_kind == 'Struct' or src_type_kind in TABLE_TYPES: raise ValueError(f"Missing type annotation for `{''.join(field_path)}`." f"It's required for {src_type_kind} type.") diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index 79382d41..8e166afc 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -8,8 +8,9 @@ import re import inspect import datetime +import functools -from typing import Any, Callable, Sequence, TypeVar +from typing import Any, Callable, Sequence, TypeVar, Generic, get_args, get_origin, Type, NamedTuple from threading import Lock from enum import Enum from dataclasses import dataclass @@ -20,7 +21,7 @@ from . import index from . import op from . import setting -from .convert import dump_engine_object +from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder from .typing import encode_enriched_type from .runtime import execution_context @@ -123,7 +124,7 @@ def attach_to_scope(self, scope: _engine.DataScopeRef, field_name: str) -> None: # TODO: We'll support this by an identity transformer or "aliasing" in the future. raise ValueError("DataSlice is already attached to a field") -class DataSlice: +class DataSlice(Generic[T]): """A data slice represents a slice of data in a flow. It's readonly.""" _state: _DataSliceState @@ -183,11 +184,11 @@ def transform(self, fn_spec: op.FunctionSpec, *args, **kwargs) -> DataSlice: name, prefix=_to_snake_case(_spec_kind(fn_spec))+'_'), )) - def call(self, func: Callable[[DataSlice], T]) -> T: + def call(self, func: Callable[[DataSlice], T], *args, **kwargs) -> T: """ Call a function with the data slice. """ - return func(self) + return func(self, *args, **kwargs) def _data_slice_state(data_slice: DataSlice) -> _DataSliceState: return data_slice._state # pylint: disable=protected-access @@ -642,27 +643,67 @@ async def _update_flow(name: str, fl: Flow) -> tuple[str, _engine.IndexUpdateInf all_stats = await asyncio.gather(*(_update_flow(name, fl) for (name, fl) in fls.items())) return dict(all_stats) -_transient_flow_name_builder = _NameBuilder() -class TransientFlow: +def _get_data_slice_annotation_type(data_slice_type: Type[DataSlice[T]]) -> Type[T] | None: + type_args = get_args(data_slice_type) + if data_slice_type is DataSlice: + return None + if get_origin(data_slice_type) != DataSlice or len(type_args) != 1: + raise ValueError(f"Expect a DataSlice[T] type, but got {data_slice_type}") + return type_args[0] + +_transform_flow_name_builder = _NameBuilder() + +class TransformFlowInfo(NamedTuple): + engine_flow: _engine.TransientFlow + result_decoder: Callable[[Any], T] + +class TransformFlow(Generic[T]): """ A transient transformation flow that transforms in-memory data. """ - _engine_flow: _engine.TransientFlow + _flow_fn: Callable[..., DataSlice[T]] + _flow_name: str + _flow_arg_types: list[Any] + _param_names: list[str] + + _lazy_lock: asyncio.Lock + _lazy_flow_info: TransformFlowInfo | None = None def __init__( - self, flow_fn: Callable[..., DataSlice], + self, flow_fn: Callable[..., DataSlice[T]], flow_arg_types: Sequence[Any], /, name: str | None = None): + self._flow_fn = flow_fn + self._flow_name = _transform_flow_name_builder.build_name(name, prefix="_transform_flow_") + self._flow_arg_types = list(flow_arg_types) + self._lazy_lock = asyncio.Lock() + + def __call__(self, *args, **kwargs) -> DataSlice[T]: + return self._flow_fn(*args, **kwargs) - flow_builder_state = _FlowBuilderState( - name=_transient_flow_name_builder.build_name(name, prefix="_transient_flow_")) - sig = inspect.signature(flow_fn) - if len(sig.parameters) != len(flow_arg_types): + @property + def _flow_info(self) -> TransformFlowInfo: + if self._lazy_flow_info is not None: + return self._lazy_flow_info + return execution_context.run(self._flow_info_async()) + + async def _flow_info_async(self) -> TransformFlowInfo: + if self._lazy_flow_info is not None: + return self._lazy_flow_info + async with self._lazy_lock: + if self._lazy_flow_info is None: + self._lazy_flow_info = await self._build_flow_info_async() + return self._lazy_flow_info + + async def _build_flow_info_async(self) -> TransformFlowInfo: + flow_builder_state = _FlowBuilderState(name=self._flow_name) + sig = inspect.signature(self._flow_fn) + if len(sig.parameters) != len(self._flow_arg_types): raise ValueError( f"Number of parameters in the flow function ({len(sig.parameters)}) " - "does not match the number of argument types ({len(flow_arg_types)})") + f"does not match the number of argument types ({len(self._flow_arg_types)})") kwargs: dict[str, DataSlice] = {} - for (param_name, param), param_type in zip(sig.parameters.items(), flow_arg_types): + for (param_name, param), param_type in zip(sig.parameters.items(), self._flow_arg_types): if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name") @@ -670,20 +711,68 @@ def __init__( param_name, encode_enriched_type(param_type)) kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds)) - output = flow_fn(**kwargs) + output = self._flow_fn(**kwargs) flow_builder_state.engine_flow_builder.set_direct_output( _data_slice_state(output).engine_data_slice) - self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow( - execution_context.event_loop) + engine_flow = await flow_builder_state.engine_flow_builder.build_transient_flow_async(execution_context.event_loop) + self._param_names = list(sig.parameters.keys()) + + engine_return_type = _data_slice_state(output).engine_data_slice.data_type().schema() + python_return_type = _get_data_slice_annotation_type(sig.return_annotation) + result_decoder = make_engine_value_decoder([], engine_return_type['type'], python_return_type) + + return TransformFlowInfo(engine_flow, result_decoder) def __str__(self): - return str(self._engine_flow) + return str(self._flow_info.engine_flow) def __repr__(self): - return repr(self._engine_flow) + return repr(self._flow_info.engine_flow) def internal_flow(self) -> _engine.TransientFlow: """ Get the internal flow. """ - return self._engine_flow + return self._flow_info.engine_flow + + def eval(self, *args, **kwargs) -> T: + """ + Evaluate the transform flow. + """ + return execution_context.run(self.eval_async(*args, **kwargs)) + + async def eval_async(self, *args, **kwargs) -> T: + """ + Evaluate the transform flow. + """ + flow_info = await self._flow_info_async() + params = [] + for i, arg in enumerate(self._param_names): + if i < len(args): + params.append(encode_engine_value(args[i])) + elif arg in kwargs: + params.append(encode_engine_value(kwargs[arg])) + else: + raise ValueError(f"Parameter {arg} is not provided") + engine_result = await flow_info.engine_flow.evaluate_async(params) + return flow_info.result_decoder(engine_result) + + +def transform_flow() -> Callable[[Callable[..., DataSlice[T]]], TransformFlow[T]]: + """ + A decorator to wrap the transform function. + """ + def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]): + sig = inspect.signature(fn) + arg_types = [] + for (param_name, param) in sig.parameters.items(): + if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY): + raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name") + arg_types.append(_get_data_slice_annotation_type(param.annotation)) + + _transform_flow = TransformFlow(fn, arg_types) + functools.update_wrapper(_transform_flow, fn) + return _transform_flow + + return _transform_flow_wrapper diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 911b6d30..886c3ae4 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -100,8 +100,8 @@ def behavior_version(self): return op_args.behavior_version class _WrappedClass(executor_cls, _Fallback): - _args_converters: list[Callable[[Any], Any]] - _kwargs_converters: dict[str, Callable[[str, Any], Any]] + _args_decoders: list[Callable[[Any], Any]] + _kwargs_decoders: dict[str, Callable[[str, Any], Any]] _acall: Callable def __init__(self, spec): @@ -109,17 +109,17 @@ def __init__(self, spec): self.spec = spec self._acall = _to_async_call(super().__call__) - def analyze(self, *args, **kwargs): + def analyze(self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema): """ Analyze the spec and arguments. In this phase, argument types should be validated. It should return the expected result type for the current op. """ - self._args_converters = [] - self._kwargs_converters = {} + self._args_decoders = [] + self._kwargs_decoders = {} # Match arguments with parameters. next_param_idx = 0 - for arg in args: + for arg in args: if next_param_idx >= len(expected_args): raise ValueError( f"Too many arguments passed in: {len(args)} > {len(expected_args)}") @@ -128,7 +128,7 @@ def analyze(self, *args, **kwargs): inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD): raise ValueError( f"Too many positional arguments passed in: {len(args)} > {next_param_idx}") - self._args_converters.append( + self._args_decoders.append( make_engine_value_decoder( [arg_name], arg.value_type['type'], arg_param.annotation)) if arg_param.kind != inspect.Parameter.VAR_POSITIONAL: @@ -146,7 +146,7 @@ def analyze(self, *args, **kwargs): if expected_arg is None: raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}") arg_param = expected_arg[1] - self._kwargs_converters[kwarg_name] = make_engine_value_decoder( + self._kwargs_decoders[kwarg_name] = make_engine_value_decoder( [kwarg_name], kwarg.value_type['type'], arg_param.annotation) missing_args = [name for (name, arg) in expected_kwargs @@ -174,8 +174,8 @@ async def prepare(self): await _to_async_call(setup_method)() async def __call__(self, *args, **kwargs): - converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args)) - converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg) + decoded_args = (decoder(arg) for decoder, arg in zip(self._args_decoders, args)) + decoded_kwargs = {arg_name: self._kwargs_decoders[arg_name](arg) for arg_name, arg in kwargs.items()} if op_args.gpu: @@ -185,9 +185,9 @@ async def __call__(self, *args, **kwargs): # For now, we use a lock to ensure only one task is executed at a time. # TODO: Implement multi-processing dispatching. async with _gpu_dispatch_lock: - output = await self._acall(*converted_args, **converted_kwargs) + output = await self._acall(*decoded_args, **decoded_kwargs) else: - output = await self._acall(*converted_args, **converted_kwargs) + output = await self._acall(*decoded_args, **decoded_kwargs) return encode_engine_value(output) _WrappedClass.__name__ = executor_cls.__name__ diff --git a/python/cocoindex/query.py b/python/cocoindex/query.py index 9b5f1056..7e2c5197 100644 --- a/python/cocoindex/query.py +++ b/python/cocoindex/query.py @@ -50,7 +50,7 @@ def _lazy_handler() -> _engine.SimpleSemanticsQueryHandler: if engine_handler is None: engine_handler = _engine.SimpleSemanticsQueryHandler( flow.internal_flow(), target_name, - fl.TransientFlow(query_transform_flow, [str]).internal_flow(), + fl.TransformFlow(query_transform_flow, [str]).internal_flow(), default_similarity_metric.value) engine_handler.register_query_handler(name) return engine_handler diff --git a/python/cocoindex/setting.py b/python/cocoindex/setting.py index bbdb87fd..4f382b0b 100644 --- a/python/cocoindex/setting.py +++ b/python/cocoindex/setting.py @@ -49,7 +49,7 @@ def _load_field(target: dict[str, Any], name: str, env_name: str, required: bool class Settings: """Settings for the cocoindex library.""" database: DatabaseConnectionSpec - app_namespace: str + app_namespace: str = "" @classmethod def from_env(cls) -> Self: diff --git a/src/builder/analyzed_flow.rs b/src/builder/analyzed_flow.rs index 544cd95c..7109963c 100644 --- a/src/builder/analyzed_flow.rs +++ b/src/builder/analyzed_flow.rs @@ -75,12 +75,11 @@ pub struct AnalyzedTransientFlow { impl AnalyzedTransientFlow { pub async fn from_transient_flow( transient_flow: spec::TransientFlowSpec, - registry: &ExecutorFactoryRegistry, py_exec_ctx: Option, ) -> Result { let ctx = analyzer::build_flow_instance_context(&transient_flow.name, py_exec_ctx); let (output_type, data_schema, execution_plan_fut) = - analyzer::analyze_transient_flow(&transient_flow, &ctx, registry)?; + analyzer::analyze_transient_flow(&transient_flow, &ctx)?; Ok(Self { transient_flow_instance: transient_flow, data_schema, diff --git a/src/builder/analyzer.rs b/src/builder/analyzer.rs index 2cd7ce1a..88bddc2b 100644 --- a/src/builder/analyzer.rs +++ b/src/builder/analyzer.rs @@ -1302,14 +1302,17 @@ pub fn analyze_flow( pub fn analyze_transient_flow<'a>( flow_inst: &TransientFlowSpec, flow_ctx: &'_ Arc, - registry: &'a ExecutorFactoryRegistry, ) -> Result<( EnrichedValueType, FlowSchema, impl Future> + Send + 'a, )> { let mut root_data_scope = DataScopeBuilder::new(); - let analyzer_ctx = AnalyzerContext { registry, flow_ctx }; + let registry = crate::ops::executor_factory_registry(); + let analyzer_ctx = AnalyzerContext { + registry: ®istry, + flow_ctx, + }; let mut input_fields = vec![]; for field in flow_inst.input_fields.iter() { let analyzed_field = root_data_scope.add_field(field.name.clone(), &field.value_type)?; diff --git a/src/builder/flow_builder.rs b/src/builder/flow_builder.rs index e7e39f53..20676490 100644 --- a/src/builder/flow_builder.rs +++ b/src/builder/flow_builder.rs @@ -1,10 +1,11 @@ -use crate::prelude::*; +use crate::{prelude::*, py::Pythonized}; use pyo3::{exceptions::PyException, prelude::*}; +use pyo3_async_runtimes::tokio::future_into_py; use std::{collections::btree_map, ops::Deref}; use super::analyzer::{ - build_flow_instance_context, AnalyzerContext, CollectorBuilder, DataScopeBuilder, OpScope, + AnalyzerContext, CollectorBuilder, DataScopeBuilder, OpScope, build_flow_instance_context, }; use crate::{ base::{ @@ -83,6 +84,10 @@ impl DataType { pub fn __repr__(&self) -> String { self.__str__() } + + pub fn schema(&self) -> Pythonized { + Pythonized(self.schema.clone()) + } } #[pyclass] @@ -142,7 +147,7 @@ impl DataSlice { spec::ValueMapping::Constant { .. } => { return Err(PyException::new_err( "field access not supported for literal", - )) + )); } }; Ok(Some(DataSlice { @@ -585,11 +590,11 @@ impl FlowBuilder { Ok(py::Flow(flow_ctx)) } - pub fn build_transient_flow( + pub fn build_transient_flow_async<'py>( &self, - py: Python<'_>, + py: Python<'py>, py_event_loop: Py, - ) -> PyResult { + ) -> PyResult> { if self.direct_input_fields.is_empty() { return Err(PyException::new_err("expect at least one direct input")); } @@ -605,16 +610,14 @@ impl FlowBuilder { output_value: direct_output_value.clone(), }; let py_ctx = crate::py::PythonExecutionContext::new(py, py_event_loop); - let analyzed_flow = py - .allow_threads(|| { - get_runtime().block_on(super::AnalyzedTransientFlow::from_transient_flow( - spec, - &crate::ops::executor_factory_registry(), - Some(py_ctx), - )) - }) - .into_py_result()?; - Ok(py::TransientFlow(Arc::new(analyzed_flow))) + + future_into_py(py, async move { + let analyzed_flow = + super::AnalyzedTransientFlow::from_transient_flow(spec, Some(py_ctx)) + .await + .into_py_result()?; + Ok(py::TransientFlow(Arc::new(analyzed_flow))) + }) } pub fn __str__(&self) -> String { @@ -695,7 +698,8 @@ impl FlowBuilder { } else if !common_scope.is_op_scope_descendant(scope) { api_bail!( "expect all arguments share the common scope, got {} and {} exclusive to each other", - common_scope, scope + common_scope, + scope ); } } @@ -703,7 +707,8 @@ impl FlowBuilder { if !target_scope.is_op_scope_descendant(common_scope) { api_bail!( "the field can only be attached to a scope or sub-scope of the input value. Target scope: {}, input scope: {}", - target_scope, common_scope + target_scope, + common_scope ); } common_scope = target_scope; diff --git a/src/py/mod.rs b/src/py/mod.rs index e0048d0c..cfa6fee4 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -1,3 +1,4 @@ +use crate::execution::evaluator::evaluate_transient_flow; use crate::prelude::*; use crate::base::schema::{FieldSchema, ValueType}; @@ -11,6 +12,7 @@ use crate::ops::{interface::ExecutorFactory, py_factory::PyFunctionFactory, regi use crate::server::{self, ServerSettings}; use crate::settings::Settings; use crate::setup; +use pyo3::IntoPyObjectExt; use pyo3::{exceptions::PyException, prelude::*}; use pyo3_async_runtimes::tokio::future_into_py; use std::collections::btree_map; @@ -349,6 +351,27 @@ impl TransientFlow { pub fn __repr__(&self) -> String { self.__str__() } + + pub fn evaluate_async<'py>( + &self, + py: Python<'py>, + args: Vec>, + ) -> PyResult> { + let flow = self.0.clone(); + let input_values: Vec = std::iter::zip( + self.0.transient_flow_instance.input_fields.iter(), + args.into_iter(), + ) + .map(|(input_schema, arg)| value_from_py_object(&input_schema.value_type.typ, &arg)) + .collect::>()?; + + future_into_py(py, async move { + let result = evaluate_transient_flow(&flow, &input_values) + .await + .into_py_result()?; + Python::with_gil(|py| value_to_py_object(py, &result)?.into_py_any(py)) + }) + } } #[pyclass]