|
8 | 8 | import re
|
9 | 9 | import inspect
|
10 | 10 | import datetime
|
| 11 | +import functools |
11 | 12 |
|
12 |
| -from typing import Any, Callable, Sequence, TypeVar |
| 13 | +from typing import Any, Callable, Sequence, TypeVar, Generic, get_args, get_origin, Type, NamedTuple |
13 | 14 | from threading import Lock
|
14 | 15 | from enum import Enum
|
15 | 16 | from dataclasses import dataclass
|
|
20 | 21 | from . import index
|
21 | 22 | from . import op
|
22 | 23 | from . import setting
|
23 |
| -from .convert import dump_engine_object |
| 24 | +from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder |
24 | 25 | from .typing import encode_enriched_type
|
25 | 26 | from .runtime import execution_context
|
26 | 27 |
|
@@ -123,7 +124,7 @@ def attach_to_scope(self, scope: _engine.DataScopeRef, field_name: str) -> None:
|
123 | 124 | # TODO: We'll support this by an identity transformer or "aliasing" in the future.
|
124 | 125 | raise ValueError("DataSlice is already attached to a field")
|
125 | 126 |
|
126 |
| -class DataSlice: |
| 127 | +class DataSlice(Generic[T]): |
127 | 128 | """A data slice represents a slice of data in a flow. It's readonly."""
|
128 | 129 |
|
129 | 130 | _state: _DataSliceState
|
@@ -183,11 +184,11 @@ def transform(self, fn_spec: op.FunctionSpec, *args, **kwargs) -> DataSlice:
|
183 | 184 | name, prefix=_to_snake_case(_spec_kind(fn_spec))+'_'),
|
184 | 185 | ))
|
185 | 186 |
|
186 |
| - def call(self, func: Callable[[DataSlice], T]) -> T: |
| 187 | + def call(self, func: Callable[[DataSlice], T], *args, **kwargs) -> T: |
187 | 188 | """
|
188 | 189 | Call a function with the data slice.
|
189 | 190 | """
|
190 |
| - return func(self) |
| 191 | + return func(self, *args, **kwargs) |
191 | 192 |
|
192 | 193 | def _data_slice_state(data_slice: DataSlice) -> _DataSliceState:
|
193 | 194 | return data_slice._state # pylint: disable=protected-access
|
@@ -642,48 +643,136 @@ async def _update_flow(name: str, fl: Flow) -> tuple[str, _engine.IndexUpdateInf
|
642 | 643 | all_stats = await asyncio.gather(*(_update_flow(name, fl) for (name, fl) in fls.items()))
|
643 | 644 | return dict(all_stats)
|
644 | 645 |
|
645 |
| -_transient_flow_name_builder = _NameBuilder() |
646 |
| -class TransientFlow: |
| 646 | +def _get_data_slice_annotation_type(data_slice_type: Type[DataSlice[T]]) -> Type[T] | None: |
| 647 | + type_args = get_args(data_slice_type) |
| 648 | + if data_slice_type is DataSlice: |
| 649 | + return None |
| 650 | + if get_origin(data_slice_type) != DataSlice or len(type_args) != 1: |
| 651 | + raise ValueError(f"Expect a DataSlice[T] type, but got {data_slice_type}") |
| 652 | + return type_args[0] |
| 653 | + |
| 654 | +_transform_flow_name_builder = _NameBuilder() |
| 655 | + |
| 656 | +class TransformFlowInfo(NamedTuple): |
| 657 | + engine_flow: _engine.TransientFlow |
| 658 | + result_decoder: Callable[[Any], T] |
| 659 | + |
| 660 | +class TransformFlow(Generic[T]): |
647 | 661 | """
|
648 | 662 | A transient transformation flow that transforms in-memory data.
|
649 | 663 | """
|
650 |
| - _engine_flow: _engine.TransientFlow |
| 664 | + _flow_fn: Callable[..., DataSlice[T]] |
| 665 | + _flow_name: str |
| 666 | + _flow_arg_types: list[Any] |
| 667 | + _param_names: list[str] |
| 668 | + |
| 669 | + _lazy_lock: asyncio.Lock |
| 670 | + _lazy_flow_info: TransformFlowInfo | None = None |
651 | 671 |
|
652 | 672 | def __init__(
|
653 |
| - self, flow_fn: Callable[..., DataSlice], |
| 673 | + self, flow_fn: Callable[..., DataSlice[T]], |
654 | 674 | flow_arg_types: Sequence[Any], /, name: str | None = None):
|
| 675 | + self._flow_fn = flow_fn |
| 676 | + self._flow_name = _transform_flow_name_builder.build_name(name, prefix="_transform_flow_") |
| 677 | + self._flow_arg_types = list(flow_arg_types) |
| 678 | + self._lazy_lock = asyncio.Lock() |
| 679 | + |
| 680 | + def __call__(self, *args, **kwargs) -> DataSlice[T]: |
| 681 | + return self._flow_fn(*args, **kwargs) |
655 | 682 |
|
656 |
| - flow_builder_state = _FlowBuilderState( |
657 |
| - name=_transient_flow_name_builder.build_name(name, prefix="_transient_flow_")) |
658 |
| - sig = inspect.signature(flow_fn) |
659 |
| - if len(sig.parameters) != len(flow_arg_types): |
| 683 | + @property |
| 684 | + def _flow_info(self) -> TransformFlowInfo: |
| 685 | + if self._lazy_flow_info is not None: |
| 686 | + return self._lazy_flow_info |
| 687 | + return execution_context.run(self._flow_info_async()) |
| 688 | + |
| 689 | + async def _flow_info_async(self) -> TransformFlowInfo: |
| 690 | + if self._lazy_flow_info is not None: |
| 691 | + return self._lazy_flow_info |
| 692 | + async with self._lazy_lock: |
| 693 | + if self._lazy_flow_info is None: |
| 694 | + self._lazy_flow_info = await self._build_flow_info_async() |
| 695 | + return self._lazy_flow_info |
| 696 | + |
| 697 | + async def _build_flow_info_async(self) -> TransformFlowInfo: |
| 698 | + flow_builder_state = _FlowBuilderState(name=self._flow_name) |
| 699 | + sig = inspect.signature(self._flow_fn) |
| 700 | + if len(sig.parameters) != len(self._flow_arg_types): |
660 | 701 | raise ValueError(
|
661 | 702 | f"Number of parameters in the flow function ({len(sig.parameters)}) "
|
662 |
| - "does not match the number of argument types ({len(flow_arg_types)})") |
| 703 | + f"does not match the number of argument types ({len(self._flow_arg_types)})") |
663 | 704 |
|
664 | 705 | kwargs: dict[str, DataSlice] = {}
|
665 |
| - for (param_name, param), param_type in zip(sig.parameters.items(), flow_arg_types): |
| 706 | + for (param_name, param), param_type in zip(sig.parameters.items(), self._flow_arg_types): |
666 | 707 | if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
667 | 708 | inspect.Parameter.KEYWORD_ONLY):
|
668 | 709 | raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name")
|
669 | 710 | engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(
|
670 | 711 | param_name, encode_enriched_type(param_type))
|
671 | 712 | kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds))
|
672 | 713 |
|
673 |
| - output = flow_fn(**kwargs) |
| 714 | + output = self._flow_fn(**kwargs) |
674 | 715 | flow_builder_state.engine_flow_builder.set_direct_output(
|
675 | 716 | _data_slice_state(output).engine_data_slice)
|
676 |
| - self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow( |
677 |
| - execution_context.event_loop) |
| 717 | + engine_flow = await flow_builder_state.engine_flow_builder.build_transient_flow_async(execution_context.event_loop) |
| 718 | + self._param_names = list(sig.parameters.keys()) |
| 719 | + |
| 720 | + engine_return_type = _data_slice_state(output).engine_data_slice.data_type().schema() |
| 721 | + python_return_type = _get_data_slice_annotation_type(sig.return_annotation) |
| 722 | + result_decoder = make_engine_value_decoder([], engine_return_type['type'], python_return_type) |
| 723 | + |
| 724 | + return TransformFlowInfo(engine_flow, result_decoder) |
678 | 725 |
|
679 | 726 | def __str__(self):
|
680 |
| - return str(self._engine_flow) |
| 727 | + return str(self._flow_info.engine_flow) |
681 | 728 |
|
682 | 729 | def __repr__(self):
|
683 |
| - return repr(self._engine_flow) |
| 730 | + return repr(self._flow_info.engine_flow) |
684 | 731 |
|
685 | 732 | def internal_flow(self) -> _engine.TransientFlow:
|
686 | 733 | """
|
687 | 734 | Get the internal flow.
|
688 | 735 | """
|
689 |
| - return self._engine_flow |
| 736 | + return self._flow_info.engine_flow |
| 737 | + |
| 738 | + def eval(self, *args, **kwargs) -> T: |
| 739 | + """ |
| 740 | + Evaluate the transform flow. |
| 741 | + """ |
| 742 | + return execution_context.run(self.eval_async(*args, **kwargs)) |
| 743 | + |
| 744 | + async def eval_async(self, *args, **kwargs) -> T: |
| 745 | + """ |
| 746 | + Evaluate the transform flow. |
| 747 | + """ |
| 748 | + flow_info = await self._flow_info_async() |
| 749 | + params = [] |
| 750 | + for i, arg in enumerate(self._param_names): |
| 751 | + if i < len(args): |
| 752 | + params.append(encode_engine_value(args[i])) |
| 753 | + elif arg in kwargs: |
| 754 | + params.append(encode_engine_value(kwargs[arg])) |
| 755 | + else: |
| 756 | + raise ValueError(f"Parameter {arg} is not provided") |
| 757 | + engine_result = await flow_info.engine_flow.evaluate_async(params) |
| 758 | + return flow_info.result_decoder(engine_result) |
| 759 | + |
| 760 | + |
| 761 | +def transform_flow() -> Callable[[Callable[..., DataSlice[T]]], TransformFlow[T]]: |
| 762 | + """ |
| 763 | + A decorator to wrap the transform function. |
| 764 | + """ |
| 765 | + def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]): |
| 766 | + sig = inspect.signature(fn) |
| 767 | + arg_types = [] |
| 768 | + for (param_name, param) in sig.parameters.items(): |
| 769 | + if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD, |
| 770 | + inspect.Parameter.KEYWORD_ONLY): |
| 771 | + raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name") |
| 772 | + arg_types.append(_get_data_slice_annotation_type(param.annotation)) |
| 773 | + |
| 774 | + _transform_flow = TransformFlow(fn, arg_types) |
| 775 | + functools.update_wrapper(_transform_flow, fn) |
| 776 | + return _transform_flow |
| 777 | + |
| 778 | + return _transform_flow_wrapper |
0 commit comments