@@ -91,23 +91,55 @@ def uses_accelerator(self) -> bool:
91
91
"""
92
92
Returns True if the task can take advantage of matrix operation
93
93
acceleration (such as on a GPU).
94
+
95
+ Tasks that perform heavy matrix operations should return True here
96
+ so they can be scheduled on appropriate devices.
97
+
98
+ Returns:
99
+ bool: True if the task benefits from acceleration, False otherwise
94
100
"""
95
101
return False
96
102
97
103
def main_thread_only (self ) -> bool :
98
104
"""
99
105
Returns True if the task should only be executed on the main thread.
106
+
107
+ Tasks with side effects like file I/O or that require specific thread
108
+ context should return True here to avoid parallel execution issues.
109
+
110
+ Returns:
111
+ bool: True if the task must run on the main thread, False otherwise
100
112
"""
101
113
return False
102
114
103
115
def duplicate_per_gpu (self ) -> bool :
104
116
"""
105
117
Returns True if the task should be duplicated for each GPU.
118
+
119
+ Tasks that are faster to execute than to transfer between devices
120
+ or are common dependencies of otherwise independent tasks should
121
+ return True here to maximize parallelism.
122
+
123
+ Returns:
124
+ bool: True if the task should be duplicated per GPU, False otherwise
106
125
"""
107
126
return False
108
127
109
128
110
129
class TaskUniverse :
130
+ """
131
+ Container for tasks and their relationships.
132
+
133
+ Maintains a registry of tasks and their dependencies, allowing efficient
134
+ lookup and traversal of the task graph.
135
+
136
+ Attributes:
137
+ tasks: List of all tasks in this universe
138
+ task_to_index: Mapping from task instances to their indices
139
+ task_arguments: Mapping from task indices to their argument dependencies
140
+ _type_id_to_index: Quick lookup for seen task instances
141
+ """
142
+
111
143
tasks : List [Task ]
112
144
task_to_index : Dict [Task , int ]
113
145
task_arguments : Dict [int , Dict [str , int ]]
@@ -123,6 +155,18 @@ def __init__(self, tasks: Optional[Iterable[Task]] = None):
123
155
self .add_task (task )
124
156
125
157
def add_task (self , task : Task , recursive : bool = True ) -> "TaskHandle" :
158
+ """
159
+ Add a task to the universe and return a handle to it.
160
+
161
+ If the task already exists in the universe, returns a handle to the existing instance.
162
+
163
+ Args:
164
+ task: The task to add
165
+ recursive: If True, also add all dependent tasks recursively
166
+
167
+ Returns:
168
+ TaskHandle: A handle to the added task
169
+ """
126
170
_ti_key = (type (task ), id (task ))
127
171
if _ti_key in self ._type_id_to_index :
128
172
index = self ._type_id_to_index [_ti_key ]
@@ -144,30 +188,80 @@ def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle":
144
188
return TaskHandle (self , index )
145
189
146
190
def get_handle (self , task : Task ) -> Optional ["TaskHandle" ]:
191
+ """
192
+ Get a TaskHandle for an existing task, if it exists in this universe.
193
+
194
+ Args:
195
+ task: The task to look up
196
+
197
+ Returns:
198
+ Optional[TaskHandle]: A handle to the task, or None if not found
199
+ """
147
200
if task not in self .task_to_index :
148
201
return None
149
202
return TaskHandle (self , self .task_to_index [task ])
150
203
151
204
152
205
class TaskHandle :
206
+ """
207
+ A reference to a task within a specific TaskUniverse.
208
+
209
+ TaskHandle provides a lightweight way to refer to tasks without directly
210
+ holding the task instances themselves. Particularly useful for putting
211
+ tasks in sets or as keys in dictionaries. Much faster to compare and hash
212
+ than full Task instances.
213
+
214
+ Attributes:
215
+ _universe: The TaskUniverse containing the referenced task
216
+ _index: The index of the task within the universe
217
+ """
218
+
153
219
__slots__ = ["_universe" , "_index" ]
154
220
_universe : TaskUniverse
155
221
_index : int
156
222
157
223
def __init__ (self , universe : TaskUniverse , index : int ):
224
+ """
225
+ Initialize a TaskHandle.
226
+
227
+ Args:
228
+ universe: The TaskUniverse containing the task
229
+ index: The index of the task within the universe
230
+ """
158
231
self ._universe = universe
159
232
self ._index = index
160
233
161
234
def task (self ) -> Task :
235
+ """
236
+ Get the actual Task instance referenced by this handle.
237
+
238
+ Returns:
239
+ Task: The referenced task
240
+ """
162
241
return self ._universe .tasks [self ._index ]
163
242
164
243
def arguments (self ) -> Dict [str , "TaskHandle" ]:
244
+ """
245
+ Get handles to all argument tasks (dependencies) of this task.
246
+
247
+ Returns:
248
+ Dict[str, TaskHandle]: Mapping from argument names to task handles
249
+ """
165
250
return {
166
251
k : TaskHandle (self ._universe , v )
167
252
for k , v in self ._universe .task_arguments [self ._index ].items ()
168
253
}
169
254
170
255
def __eq__ (self , other ):
256
+ """
257
+ Check if two TaskHandles refer to the same task in the same universe.
258
+
259
+ Args:
260
+ other: Another object to compare with
261
+
262
+ Returns:
263
+ bool: True if equal, False otherwise
264
+ """
171
265
if not isinstance (other , TaskHandle ):
172
266
return False
173
267
if self ._index != other ._index :
@@ -180,21 +274,53 @@ def __hash__(self):
180
274
return self ._index
181
275
182
276
def __str__ (self ):
183
- return f"TaskHandle({ self ._index } )"
277
+ return f"TaskHandle({ type (self .task ()).__name__ } , { self ._index } )"
278
+
279
+ __repr__ = __str__
184
280
185
281
186
282
class ExecutionSchedule :
283
+ """
284
+ Represents an ordered schedule of tasks for execution and their lifecycle information.
285
+
286
+ Tracks when each task's result can be discarded to optimize memory usage.
287
+
288
+ Attributes:
289
+ tasks: Ordered list of tasks to execute
290
+ last_use_index: Maps each task to the index in the schedule where its result is last used
291
+ """
292
+
187
293
tasks : List [TaskHandle ]
188
294
last_use_index : Dict [TaskHandle , int ]
189
295
190
296
def __init__ (self , tasks : List [TaskHandle ], last_use_index : Dict [TaskHandle , int ]):
297
+ """
298
+ Initialize an execution schedule.
299
+
300
+ Args:
301
+ tasks: Ordered list of tasks to execute
302
+ last_use_index: Dictionary mapping tasks to their last use index in the schedule
303
+ """
191
304
self .tasks = tasks
192
305
self .last_use_index = last_use_index
193
306
194
307
195
308
def build_schedule (
196
309
targets : List [TaskHandle ], cached_values : Dict [TaskHandle , Any ]
197
310
) -> ExecutionSchedule :
311
+ """
312
+ Build an execution schedule for the given target tasks.
313
+
314
+ Creates a topologically sorted schedule that respects task dependencies and
315
+ tracks when each task's result can be discarded to optimize memory usage.
316
+
317
+ Args:
318
+ targets: List of target tasks that need to be executed
319
+ cached_values: Dictionary of task results that are already available
320
+
321
+ Returns:
322
+ ExecutionSchedule: A schedule containing tasks to execute and their lifecycle info
323
+ """
198
324
if not targets :
199
325
return ExecutionSchedule (tasks = [], last_use_index = {})
200
326
@@ -241,6 +367,7 @@ def _compare_key(node: TaskHandle) -> Tuple[str, int]:
241
367
if (node != dummy_handle ) and node not in (cached_values or {})
242
368
]
243
369
370
+ # Calculate last use indices for memory optimization
244
371
last_use_index = {}
245
372
for idx , task in reversed (list (enumerate (schedule ))):
246
373
for dep in task .arguments ().values ():
0 commit comments