Skip to content

Commit 916788c

Browse files
authored
Add @function decorator for simple functions. (#146)
1 parent 03e3732 commit 916788c

File tree

3 files changed

+177
-131
lines changed

3 files changed

+177
-131
lines changed

examples/code_embedding/main.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,10 @@
33
import cocoindex
44
import os
55

6-
class ExtractExtension(cocoindex.op.FunctionSpec):
7-
"""Summarize a Python module."""
8-
9-
@cocoindex.op.executor_class()
10-
class ExtractExtensionExecutor:
11-
"""Executor for ExtractExtension."""
12-
13-
spec: ExtractExtension
14-
15-
def __call__(self, filename: str) -> str:
16-
return os.path.splitext(filename)[1]
6+
@cocoindex.op.function()
7+
def extract_extension(filename: str) -> str:
8+
"""Extract the extension of a filename."""
9+
return os.path.splitext(filename)[1]
1710

1811
def code_to_embedding(text: cocoindex.DataSlice) -> cocoindex.DataSlice:
1912
"""
@@ -35,7 +28,7 @@ def code_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind
3528
code_embeddings = data_scope.add_collector()
3629

3730
with data_scope["files"].row() as file:
38-
file["extension"] = file["filename"].transform(ExtractExtension())
31+
file["extension"] = file["filename"].transform(extract_extension)
3932
file["chunks"] = file["content"].transform(
4033
cocoindex.functions.SplitRecursively(),
4134
language=file["extension"], chunk_size=1000, chunk_overlap=300)

examples/manuals_llm_extraction/main.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,13 @@ class ModuleSummary:
6464
num_classes: int
6565
num_methods: int
6666

67-
@dataclasses.dataclass
68-
class SummarizeModule(cocoindex.op.FunctionSpec):
67+
@cocoindex.op.function()
68+
def summarize_module(module_info: ModuleInfo) -> ModuleSummary:
6969
"""Summarize a Python module."""
70-
71-
@cocoindex.op.executor_class()
72-
class SummarizeModuleExecutor:
73-
"""Executor for SummarizeModule."""
74-
75-
spec: SummarizeModule
76-
77-
def __call__(self, module_info: ModuleInfo) -> ModuleSummary:
78-
return ModuleSummary(
79-
num_classes=len(module_info.classes),
80-
num_methods=len(module_info.methods),
81-
)
70+
return ModuleSummary(
71+
num_classes=len(module_info.classes),
72+
num_methods=len(module_info.methods),
73+
)
8274

8375
@cocoindex.flow_def(name="ManualExtraction")
8476
def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope):
@@ -103,7 +95,7 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco
10395
# api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"),
10496
output_type=ModuleInfo,
10597
instruction="Please extract Python module information from the manual."))
106-
doc["module_summary"] = doc["module_info"].transform(SummarizeModule())
98+
doc["module_summary"] = doc["module_info"].transform(summarize_module)
10799
modules_index.collect(
108100
filename=doc["filename"],
109101
module_info=doc["module_info"],

python/cocoindex/op.py

Lines changed: 165 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,140 @@ def _make_engine_value_converter(
132132

133133
_gpu_dispatch_lock = Lock()
134134

135-
def executor_class(gpu: bool = False, cache: bool = False, behavior_version: int | None = None) -> Callable[[type], type]:
135+
@dataclasses.dataclass
136+
class OpArgs:
136137
"""
137-
Decorate a class to provide an executor for an op.
138+
- gpu: Whether the executor will be executed on GPU.
139+
- cache: Whether the executor will be cached.
140+
- behavior_version: The behavior version of the executor. Cache will be invalidated if it
141+
changes. Must be provided if `cache` is True.
142+
"""
143+
gpu: bool = False
144+
cache: bool = False
145+
behavior_version: int | None = None
146+
147+
def _register_op_factory(
148+
category: OpCategory,
149+
expected_args: list[tuple[str, inspect.Parameter]],
150+
expected_return,
151+
executor_cls: type,
152+
spec_cls: type,
153+
op_args: OpArgs,
154+
):
155+
"""
156+
Register an op factory.
157+
"""
158+
class _Fallback:
159+
def enable_cache(self):
160+
return op_args.cache
161+
162+
def behavior_version(self):
163+
return op_args.behavior_version
164+
165+
class _WrappedClass(executor_cls, _Fallback):
166+
_args_converters: list[Callable[[Any], Any]]
167+
_kwargs_converters: dict[str, Callable[[str, Any], Any]]
168+
169+
def __init__(self, spec):
170+
super().__init__()
171+
self.spec = spec
172+
173+
def analyze(self, *args, **kwargs):
174+
"""
175+
Analyze the spec and arguments. In this phase, argument types should be validated.
176+
It should return the expected result type for the current op.
177+
"""
178+
self._args_converters = []
179+
self._kwargs_converters = {}
180+
181+
# Match arguments with parameters.
182+
next_param_idx = 0
183+
for arg in args:
184+
if next_param_idx >= len(expected_args):
185+
raise ValueError(
186+
f"Too many arguments passed in: {len(args)} > {len(expected_args)}")
187+
arg_name, arg_param = expected_args[next_param_idx]
188+
if arg_param.kind in (
189+
inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD):
190+
raise ValueError(
191+
f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
192+
self._args_converters.append(
193+
_make_engine_value_converter(
194+
[arg_name], arg.value_type['type'], arg_param.annotation))
195+
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
196+
next_param_idx += 1
197+
198+
expected_kwargs = expected_args[next_param_idx:]
199+
200+
for kwarg_name, kwarg in kwargs.items():
201+
expected_arg = next(
202+
(arg for arg in expected_kwargs
203+
if (arg[0] == kwarg_name and arg[1].kind in (
204+
inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD))
205+
or arg[1].kind == inspect.Parameter.VAR_KEYWORD),
206+
None)
207+
if expected_arg is None:
208+
raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
209+
arg_param = expected_arg[1]
210+
self._kwargs_converters[kwarg_name] = _make_engine_value_converter(
211+
[kwarg_name], kwarg.value_type['type'], arg_param.annotation)
212+
213+
missing_args = [name for (name, arg) in expected_kwargs
214+
if arg.default is inspect.Parameter.empty
215+
and (arg.kind == inspect.Parameter.POSITIONAL_ONLY or
216+
(arg.kind in (inspect.Parameter.KEYWORD_ONLY,
217+
inspect.Parameter.POSITIONAL_OR_KEYWORD)
218+
and name not in kwargs))]
219+
if len(missing_args) > 0:
220+
raise ValueError(f"Missing arguments: {', '.join(missing_args)}")
221+
222+
prepare_method = getattr(executor_cls, 'analyze', None)
223+
if prepare_method is not None:
224+
return prepare_method(self, *args, **kwargs)
225+
else:
226+
return expected_return
227+
228+
def prepare(self):
229+
"""
230+
Prepare for execution.
231+
It's executed after `analyze` and before any `__call__` execution.
232+
"""
233+
setup_method = getattr(executor_cls, 'prepare', None)
234+
if setup_method is not None:
235+
setup_method(self)
236+
237+
def __call__(self, *args, **kwargs):
238+
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
239+
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg)
240+
for arg_name, arg in kwargs.items()}
241+
if op_args.gpu:
242+
# For GPU executions, data-level parallelism is applied, so we don't want to
243+
# execute different tasks in parallel.
244+
# Besides, multiprocessing is more appropriate for pytorch.
245+
# For now, we use a lock to ensure only one task is executed at a time.
246+
# TODO: Implement multi-processing dispatching.
247+
with _gpu_dispatch_lock:
248+
output = super().__call__(*converted_args, **converted_kwargs)
249+
else:
250+
output = super().__call__(*converted_args, **converted_kwargs)
251+
return to_engine_value(output)
252+
253+
_WrappedClass.__name__ = executor_cls.__name__
138254

139-
Args:
140-
gpu: Whether the executor will be executed on GPU.
141-
cache: Whether the executor will be cached.
142-
behavior_version: The behavior version of the executor. Cache will be invalidated if it changes. Must be provided if `cache` is True.
255+
if category == OpCategory.FUNCTION:
256+
_engine.register_function_factory(
257+
spec_cls.__name__,
258+
_FunctionExecutorFactory(spec_cls, _WrappedClass))
259+
else:
260+
raise ValueError(f"Unsupported executor type {category}")
261+
262+
return _WrappedClass
263+
264+
def executor_class(**args) -> Callable[[type], type]:
143265
"""
266+
Decorate a class to provide an executor for an op.
267+
"""
268+
op_args = OpArgs(**args)
144269

145270
def _inner(cls: type[Executor]) -> type:
146271
"""
@@ -149,110 +274,46 @@ def _inner(cls: type[Executor]) -> type:
149274
type_hints = get_type_hints(cls)
150275
if 'spec' not in type_hints:
151276
raise TypeError("Expect a `spec` field with type hint")
152-
153277
spec_cls = type_hints['spec']
154-
op_name = spec_cls.__name__
155-
category = spec_cls._op_category
156-
157278
sig = inspect.signature(cls.__call__)
158-
expected_args = list(sig.parameters.items())[1:] # First argument is `self`
159-
expected_return = sig.return_annotation
160-
161-
cls_type: type = cls
162-
163-
class _Fallback:
164-
def enable_cache(self):
165-
return cache
166-
167-
def behavior_version(self):
168-
return behavior_version
169-
170-
class _WrappedClass(cls_type, _Fallback):
171-
_args_converters: list[Callable[[Any], Any]]
172-
_kwargs_converters: dict[str, Callable[[str, Any], Any]]
173-
174-
def __init__(self, spec):
175-
super().__init__()
176-
self.spec = spec
177-
178-
def analyze(self, *args, **kwargs):
179-
"""
180-
Analyze the spec and arguments. In this phase, argument types should be validated.
181-
It should return the expected result type for the current op.
182-
"""
183-
self._args_converters = []
184-
self._kwargs_converters = {}
185-
186-
# Match arguments with parameters.
187-
next_param_idx = 0
188-
for arg in args:
189-
if next_param_idx >= len(expected_args):
190-
raise ValueError(f"Too many arguments passed in: {len(args)} > {len(expected_args)}")
191-
arg_name, arg_param = expected_args[next_param_idx]
192-
if arg_param.kind == inspect.Parameter.KEYWORD_ONLY or arg_param.kind == inspect.Parameter.VAR_KEYWORD:
193-
raise ValueError(f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
194-
self._args_converters.append(
195-
_make_engine_value_converter([arg_name], arg.value_type['type'], arg_param.annotation))
196-
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
197-
next_param_idx += 1
198-
199-
expected_kwargs = expected_args[next_param_idx:]
200-
201-
for kwarg_name, kwarg in kwargs.items():
202-
expected_arg = next(
203-
(arg for arg in expected_kwargs
204-
if (arg[0] == kwarg_name and arg[1].kind in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD))
205-
or arg[1].kind == inspect.Parameter.VAR_KEYWORD),
206-
None)
207-
if expected_arg is None:
208-
raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
209-
arg_param = expected_arg[1]
210-
self._kwargs_converters[kwarg_name] = _make_engine_value_converter(
211-
[kwarg_name], kwarg.value_type['type'], arg_param.annotation)
212-
213-
missing_args = [name for (name, arg) in expected_kwargs
214-
if arg.default is inspect.Parameter.empty
215-
and (arg.kind == inspect.Parameter.POSITIONAL_ONLY or
216-
(arg.kind in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) and name not in kwargs))]
217-
if len(missing_args) > 0:
218-
raise ValueError(f"Missing arguments: {', '.join(missing_args)}")
219-
220-
prepare_method = getattr(cls_type, 'analyze', None)
221-
if prepare_method is not None:
222-
return prepare_method(self, *args, **kwargs)
223-
else:
224-
return expected_return
225-
226-
def prepare(self):
227-
"""
228-
Prepare for execution.
229-
It's executed after `analyze` and before any `__call__` execution.
230-
"""
231-
setup_method = getattr(cls_type, 'prepare', None)
232-
if setup_method is not None:
233-
setup_method(self)
279+
return _register_op_factory(
280+
category=spec_cls._op_category,
281+
expected_args=list(sig.parameters.items())[1:], # First argument is `self`
282+
expected_return=sig.return_annotation,
283+
executor_cls=cls,
284+
spec_cls=spec_cls,
285+
op_args=op_args)
286+
287+
return _inner
288+
289+
def function(**args) -> Callable[[Callable], FunctionSpec]:
290+
"""
291+
Decorate a function to provide a function for an op.
292+
"""
293+
op_args = OpArgs(**args)
294+
295+
def _inner(fn: Callable) -> FunctionSpec:
296+
297+
# Convert snake case to camel case.
298+
op_name = ''.join(word.capitalize() for word in fn.__name__.split('_'))
299+
sig = inspect.signature(fn)
234300

301+
class _Executor:
235302
def __call__(self, *args, **kwargs):
236-
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
237-
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg) for arg_name, arg in kwargs.items()}
238-
if gpu:
239-
# For GPU executions, data-level parallelism is applied, so we don't want to execute different tasks in parallel.
240-
# Besides, multiprocessing is more appropriate for pytorch.
241-
# For now, we use a lock to ensure only one task is executed at a time.
242-
# TODO: Implement multi-processing dispatching.
243-
with _gpu_dispatch_lock:
244-
output = super().__call__(*converted_args, **converted_kwargs)
245-
else:
246-
output = super().__call__(*converted_args, **converted_kwargs)
247-
return to_engine_value(output)
303+
return fn(*args, **kwargs)
248304

249-
_WrappedClass.__name__ = cls.__name__
305+
class _Spec(FunctionSpec):
306+
pass
307+
_Spec.__name__ = op_name
250308

251-
if category == OpCategory.FUNCTION:
252-
_engine.register_function_factory(op_name, _FunctionExecutorFactory(spec_cls, _WrappedClass))
253-
else:
254-
raise ValueError(f"Unsupported executor type {category}")
309+
_register_op_factory(
310+
category=OpCategory.FUNCTION,
311+
expected_args=list(sig.parameters.items()),
312+
expected_return=sig.return_annotation,
313+
executor_cls=_Executor,
314+
spec_cls=_Spec,
315+
op_args=op_args)
255316

256-
return _WrappedClass
317+
return _Spec()
257318

258319
return _inner

0 commit comments

Comments
 (0)