@@ -132,15 +132,140 @@ def _make_engine_value_converter(
132
132
133
133
_gpu_dispatch_lock = Lock ()
134
134
135
- def executor_class (gpu : bool = False , cache : bool = False , behavior_version : int | None = None ) -> Callable [[type ], type ]:
135
+ @dataclasses .dataclass
136
+ class OpArgs :
136
137
"""
137
- Decorate a class to provide an executor for an op.
138
+ - gpu: Whether the executor will be executed on GPU.
139
+ - cache: Whether the executor will be cached.
140
+ - behavior_version: The behavior version of the executor. Cache will be invalidated if it
141
+ changes. Must be provided if `cache` is True.
142
+ """
143
+ gpu : bool = False
144
+ cache : bool = False
145
+ behavior_version : int | None = None
146
+
147
+ def _register_op_factory (
148
+ category : OpCategory ,
149
+ expected_args : list [tuple [str , inspect .Parameter ]],
150
+ expected_return ,
151
+ executor_cls : type ,
152
+ spec_cls : type ,
153
+ op_args : OpArgs ,
154
+ ):
155
+ """
156
+ Register an op factory.
157
+ """
158
+ class _Fallback :
159
+ def enable_cache (self ):
160
+ return op_args .cache
161
+
162
+ def behavior_version (self ):
163
+ return op_args .behavior_version
164
+
165
+ class _WrappedClass (executor_cls , _Fallback ):
166
+ _args_converters : list [Callable [[Any ], Any ]]
167
+ _kwargs_converters : dict [str , Callable [[str , Any ], Any ]]
168
+
169
+ def __init__ (self , spec ):
170
+ super ().__init__ ()
171
+ self .spec = spec
172
+
173
+ def analyze (self , * args , ** kwargs ):
174
+ """
175
+ Analyze the spec and arguments. In this phase, argument types should be validated.
176
+ It should return the expected result type for the current op.
177
+ """
178
+ self ._args_converters = []
179
+ self ._kwargs_converters = {}
180
+
181
+ # Match arguments with parameters.
182
+ next_param_idx = 0
183
+ for arg in args :
184
+ if next_param_idx >= len (expected_args ):
185
+ raise ValueError (
186
+ f"Too many arguments passed in: { len (args )} > { len (expected_args )} " )
187
+ arg_name , arg_param = expected_args [next_param_idx ]
188
+ if arg_param .kind in (
189
+ inspect .Parameter .KEYWORD_ONLY , inspect .Parameter .VAR_KEYWORD ):
190
+ raise ValueError (
191
+ f"Too many positional arguments passed in: { len (args )} > { next_param_idx } " )
192
+ self ._args_converters .append (
193
+ _make_engine_value_converter (
194
+ [arg_name ], arg .value_type ['type' ], arg_param .annotation ))
195
+ if arg_param .kind != inspect .Parameter .VAR_POSITIONAL :
196
+ next_param_idx += 1
197
+
198
+ expected_kwargs = expected_args [next_param_idx :]
199
+
200
+ for kwarg_name , kwarg in kwargs .items ():
201
+ expected_arg = next (
202
+ (arg for arg in expected_kwargs
203
+ if (arg [0 ] == kwarg_name and arg [1 ].kind in (
204
+ inspect .Parameter .KEYWORD_ONLY , inspect .Parameter .POSITIONAL_OR_KEYWORD ))
205
+ or arg [1 ].kind == inspect .Parameter .VAR_KEYWORD ),
206
+ None )
207
+ if expected_arg is None :
208
+ raise ValueError (f"Unexpected keyword argument passed in: { kwarg_name } " )
209
+ arg_param = expected_arg [1 ]
210
+ self ._kwargs_converters [kwarg_name ] = _make_engine_value_converter (
211
+ [kwarg_name ], kwarg .value_type ['type' ], arg_param .annotation )
212
+
213
+ missing_args = [name for (name , arg ) in expected_kwargs
214
+ if arg .default is inspect .Parameter .empty
215
+ and (arg .kind == inspect .Parameter .POSITIONAL_ONLY or
216
+ (arg .kind in (inspect .Parameter .KEYWORD_ONLY ,
217
+ inspect .Parameter .POSITIONAL_OR_KEYWORD )
218
+ and name not in kwargs ))]
219
+ if len (missing_args ) > 0 :
220
+ raise ValueError (f"Missing arguments: { ', ' .join (missing_args )} " )
221
+
222
+ prepare_method = getattr (executor_cls , 'analyze' , None )
223
+ if prepare_method is not None :
224
+ return prepare_method (self , * args , ** kwargs )
225
+ else :
226
+ return expected_return
227
+
228
+ def prepare (self ):
229
+ """
230
+ Prepare for execution.
231
+ It's executed after `analyze` and before any `__call__` execution.
232
+ """
233
+ setup_method = getattr (executor_cls , 'prepare' , None )
234
+ if setup_method is not None :
235
+ setup_method (self )
236
+
237
+ def __call__ (self , * args , ** kwargs ):
238
+ converted_args = (converter (arg ) for converter , arg in zip (self ._args_converters , args ))
239
+ converted_kwargs = {arg_name : self ._kwargs_converters [arg_name ](arg )
240
+ for arg_name , arg in kwargs .items ()}
241
+ if op_args .gpu :
242
+ # For GPU executions, data-level parallelism is applied, so we don't want to
243
+ # execute different tasks in parallel.
244
+ # Besides, multiprocessing is more appropriate for pytorch.
245
+ # For now, we use a lock to ensure only one task is executed at a time.
246
+ # TODO: Implement multi-processing dispatching.
247
+ with _gpu_dispatch_lock :
248
+ output = super ().__call__ (* converted_args , ** converted_kwargs )
249
+ else :
250
+ output = super ().__call__ (* converted_args , ** converted_kwargs )
251
+ return to_engine_value (output )
252
+
253
+ _WrappedClass .__name__ = executor_cls .__name__
138
254
139
- Args:
140
- gpu: Whether the executor will be executed on GPU.
141
- cache: Whether the executor will be cached.
142
- behavior_version: The behavior version of the executor. Cache will be invalidated if it changes. Must be provided if `cache` is True.
255
+ if category == OpCategory .FUNCTION :
256
+ _engine .register_function_factory (
257
+ spec_cls .__name__ ,
258
+ _FunctionExecutorFactory (spec_cls , _WrappedClass ))
259
+ else :
260
+ raise ValueError (f"Unsupported executor type { category } " )
261
+
262
+ return _WrappedClass
263
+
264
+ def executor_class (** args ) -> Callable [[type ], type ]:
143
265
"""
266
+ Decorate a class to provide an executor for an op.
267
+ """
268
+ op_args = OpArgs (** args )
144
269
145
270
def _inner (cls : type [Executor ]) -> type :
146
271
"""
@@ -149,110 +274,46 @@ def _inner(cls: type[Executor]) -> type:
149
274
type_hints = get_type_hints (cls )
150
275
if 'spec' not in type_hints :
151
276
raise TypeError ("Expect a `spec` field with type hint" )
152
-
153
277
spec_cls = type_hints ['spec' ]
154
- op_name = spec_cls .__name__
155
- category = spec_cls ._op_category
156
-
157
278
sig = inspect .signature (cls .__call__ )
158
- expected_args = list (sig .parameters .items ())[1 :] # First argument is `self`
159
- expected_return = sig .return_annotation
160
-
161
- cls_type : type = cls
162
-
163
- class _Fallback :
164
- def enable_cache (self ):
165
- return cache
166
-
167
- def behavior_version (self ):
168
- return behavior_version
169
-
170
- class _WrappedClass (cls_type , _Fallback ):
171
- _args_converters : list [Callable [[Any ], Any ]]
172
- _kwargs_converters : dict [str , Callable [[str , Any ], Any ]]
173
-
174
- def __init__ (self , spec ):
175
- super ().__init__ ()
176
- self .spec = spec
177
-
178
- def analyze (self , * args , ** kwargs ):
179
- """
180
- Analyze the spec and arguments. In this phase, argument types should be validated.
181
- It should return the expected result type for the current op.
182
- """
183
- self ._args_converters = []
184
- self ._kwargs_converters = {}
185
-
186
- # Match arguments with parameters.
187
- next_param_idx = 0
188
- for arg in args :
189
- if next_param_idx >= len (expected_args ):
190
- raise ValueError (f"Too many arguments passed in: { len (args )} > { len (expected_args )} " )
191
- arg_name , arg_param = expected_args [next_param_idx ]
192
- if arg_param .kind == inspect .Parameter .KEYWORD_ONLY or arg_param .kind == inspect .Parameter .VAR_KEYWORD :
193
- raise ValueError (f"Too many positional arguments passed in: { len (args )} > { next_param_idx } " )
194
- self ._args_converters .append (
195
- _make_engine_value_converter ([arg_name ], arg .value_type ['type' ], arg_param .annotation ))
196
- if arg_param .kind != inspect .Parameter .VAR_POSITIONAL :
197
- next_param_idx += 1
198
-
199
- expected_kwargs = expected_args [next_param_idx :]
200
-
201
- for kwarg_name , kwarg in kwargs .items ():
202
- expected_arg = next (
203
- (arg for arg in expected_kwargs
204
- if (arg [0 ] == kwarg_name and arg [1 ].kind in (inspect .Parameter .KEYWORD_ONLY , inspect .Parameter .POSITIONAL_OR_KEYWORD ))
205
- or arg [1 ].kind == inspect .Parameter .VAR_KEYWORD ),
206
- None )
207
- if expected_arg is None :
208
- raise ValueError (f"Unexpected keyword argument passed in: { kwarg_name } " )
209
- arg_param = expected_arg [1 ]
210
- self ._kwargs_converters [kwarg_name ] = _make_engine_value_converter (
211
- [kwarg_name ], kwarg .value_type ['type' ], arg_param .annotation )
212
-
213
- missing_args = [name for (name , arg ) in expected_kwargs
214
- if arg .default is inspect .Parameter .empty
215
- and (arg .kind == inspect .Parameter .POSITIONAL_ONLY or
216
- (arg .kind in (inspect .Parameter .KEYWORD_ONLY , inspect .Parameter .POSITIONAL_OR_KEYWORD ) and name not in kwargs ))]
217
- if len (missing_args ) > 0 :
218
- raise ValueError (f"Missing arguments: { ', ' .join (missing_args )} " )
219
-
220
- prepare_method = getattr (cls_type , 'analyze' , None )
221
- if prepare_method is not None :
222
- return prepare_method (self , * args , ** kwargs )
223
- else :
224
- return expected_return
225
-
226
- def prepare (self ):
227
- """
228
- Prepare for execution.
229
- It's executed after `analyze` and before any `__call__` execution.
230
- """
231
- setup_method = getattr (cls_type , 'prepare' , None )
232
- if setup_method is not None :
233
- setup_method (self )
279
+ return _register_op_factory (
280
+ category = spec_cls ._op_category ,
281
+ expected_args = list (sig .parameters .items ())[1 :], # First argument is `self`
282
+ expected_return = sig .return_annotation ,
283
+ executor_cls = cls ,
284
+ spec_cls = spec_cls ,
285
+ op_args = op_args )
286
+
287
+ return _inner
288
+
289
+ def function (** args ) -> Callable [[Callable ], FunctionSpec ]:
290
+ """
291
+ Decorate a function to provide a function for an op.
292
+ """
293
+ op_args = OpArgs (** args )
294
+
295
+ def _inner (fn : Callable ) -> FunctionSpec :
296
+
297
+ # Convert snake case to camel case.
298
+ op_name = '' .join (word .capitalize () for word in fn .__name__ .split ('_' ))
299
+ sig = inspect .signature (fn )
234
300
301
+ class _Executor :
235
302
def __call__ (self , * args , ** kwargs ):
236
- converted_args = (converter (arg ) for converter , arg in zip (self ._args_converters , args ))
237
- converted_kwargs = {arg_name : self ._kwargs_converters [arg_name ](arg ) for arg_name , arg in kwargs .items ()}
238
- if gpu :
239
- # For GPU executions, data-level parallelism is applied, so we don't want to execute different tasks in parallel.
240
- # Besides, multiprocessing is more appropriate for pytorch.
241
- # For now, we use a lock to ensure only one task is executed at a time.
242
- # TODO: Implement multi-processing dispatching.
243
- with _gpu_dispatch_lock :
244
- output = super ().__call__ (* converted_args , ** converted_kwargs )
245
- else :
246
- output = super ().__call__ (* converted_args , ** converted_kwargs )
247
- return to_engine_value (output )
303
+ return fn (* args , ** kwargs )
248
304
249
- _WrappedClass .__name__ = cls .__name__
305
+ class _Spec (FunctionSpec ):
306
+ pass
307
+ _Spec .__name__ = op_name
250
308
251
- if category == OpCategory .FUNCTION :
252
- _engine .register_function_factory (op_name , _FunctionExecutorFactory (spec_cls , _WrappedClass ))
253
- else :
254
- raise ValueError (f"Unsupported executor type { category } " )
309
+ _register_op_factory (
310
+ category = OpCategory .FUNCTION ,
311
+ expected_args = list (sig .parameters .items ()),
312
+ expected_return = sig .return_annotation ,
313
+ executor_cls = _Executor ,
314
+ spec_cls = _Spec ,
315
+ op_args = op_args )
255
316
256
- return _WrappedClass
317
+ return _Spec ()
257
318
258
319
return _inner
0 commit comments