Skip to content

Commit d6e2bcf

Browse files
committed
Better documentation
1 parent 121bb48 commit d6e2bcf

File tree

2 files changed

+199
-20
lines changed

2 files changed

+199
-20
lines changed

mergekit/graph.py

+128-1
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,55 @@ def uses_accelerator(self) -> bool:
9191
"""
9292
Returns True if the task can take advantage of matrix operation
9393
acceleration (such as on a GPU).
94+
95+
Tasks that perform heavy matrix operations should return True here
96+
so they can be scheduled on appropriate devices.
97+
98+
Returns:
99+
bool: True if the task benefits from acceleration, False otherwise
94100
"""
95101
return False
96102

97103
def main_thread_only(self) -> bool:
98104
"""
99105
Returns True if the task should only be executed on the main thread.
106+
107+
Tasks with side effects like file I/O or that require specific thread
108+
context should return True here to avoid parallel execution issues.
109+
110+
Returns:
111+
bool: True if the task must run on the main thread, False otherwise
100112
"""
101113
return False
102114

103115
def duplicate_per_gpu(self) -> bool:
104116
"""
105117
Returns True if the task should be duplicated for each GPU.
118+
119+
Tasks that are faster to execute than to transfer between devices
120+
or are common dependencies of otherwise independent tasks should
121+
return True here to maximize parallelism.
122+
123+
Returns:
124+
bool: True if the task should be duplicated per GPU, False otherwise
106125
"""
107126
return False
108127

109128

110129
class TaskUniverse:
130+
"""
131+
Container for tasks and their relationships.
132+
133+
Maintains a registry of tasks and their dependencies, allowing efficient
134+
lookup and traversal of the task graph.
135+
136+
Attributes:
137+
tasks: List of all tasks in this universe
138+
task_to_index: Mapping from task instances to their indices
139+
task_arguments: Mapping from task indices to their argument dependencies
140+
_type_id_to_index: Quick lookup for seen task instances
141+
"""
142+
111143
tasks: List[Task]
112144
task_to_index: Dict[Task, int]
113145
task_arguments: Dict[int, Dict[str, int]]
@@ -123,6 +155,18 @@ def __init__(self, tasks: Optional[Iterable[Task]] = None):
123155
self.add_task(task)
124156

125157
def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle":
158+
"""
159+
Add a task to the universe and return a handle to it.
160+
161+
If the task already exists in the universe, returns a handle to the existing instance.
162+
163+
Args:
164+
task: The task to add
165+
recursive: If True, also add all dependent tasks recursively
166+
167+
Returns:
168+
TaskHandle: A handle to the added task
169+
"""
126170
_ti_key = (type(task), id(task))
127171
if _ti_key in self._type_id_to_index:
128172
index = self._type_id_to_index[_ti_key]
@@ -144,30 +188,80 @@ def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle":
144188
return TaskHandle(self, index)
145189

146190
def get_handle(self, task: Task) -> Optional["TaskHandle"]:
191+
"""
192+
Get a TaskHandle for an existing task, if it exists in this universe.
193+
194+
Args:
195+
task: The task to look up
196+
197+
Returns:
198+
Optional[TaskHandle]: A handle to the task, or None if not found
199+
"""
147200
if task not in self.task_to_index:
148201
return None
149202
return TaskHandle(self, self.task_to_index[task])
150203

151204

152205
class TaskHandle:
206+
"""
207+
A reference to a task within a specific TaskUniverse.
208+
209+
TaskHandle provides a lightweight way to refer to tasks without directly
210+
holding the task instances themselves. Particularly useful for putting
211+
tasks in sets or as keys in dictionaries. Much faster to compare and hash
212+
than full Task instances.
213+
214+
Attributes:
215+
_universe: The TaskUniverse containing the referenced task
216+
_index: The index of the task within the universe
217+
"""
218+
153219
__slots__ = ["_universe", "_index"]
154220
_universe: TaskUniverse
155221
_index: int
156222

157223
def __init__(self, universe: TaskUniverse, index: int):
224+
"""
225+
Initialize a TaskHandle.
226+
227+
Args:
228+
universe: The TaskUniverse containing the task
229+
index: The index of the task within the universe
230+
"""
158231
self._universe = universe
159232
self._index = index
160233

161234
def task(self) -> Task:
235+
"""
236+
Get the actual Task instance referenced by this handle.
237+
238+
Returns:
239+
Task: The referenced task
240+
"""
162241
return self._universe.tasks[self._index]
163242

164243
def arguments(self) -> Dict[str, "TaskHandle"]:
244+
"""
245+
Get handles to all argument tasks (dependencies) of this task.
246+
247+
Returns:
248+
Dict[str, TaskHandle]: Mapping from argument names to task handles
249+
"""
165250
return {
166251
k: TaskHandle(self._universe, v)
167252
for k, v in self._universe.task_arguments[self._index].items()
168253
}
169254

170255
def __eq__(self, other):
256+
"""
257+
Check if two TaskHandles refer to the same task in the same universe.
258+
259+
Args:
260+
other: Another object to compare with
261+
262+
Returns:
263+
bool: True if equal, False otherwise
264+
"""
171265
if not isinstance(other, TaskHandle):
172266
return False
173267
if self._index != other._index:
@@ -180,21 +274,53 @@ def __hash__(self):
180274
return self._index
181275

182276
def __str__(self):
183-
return f"TaskHandle({self._index})"
277+
return f"TaskHandle({type(self.task()).__name__}, {self._index})"
278+
279+
__repr__ = __str__
184280

185281

186282
class ExecutionSchedule:
283+
"""
284+
Represents an ordered schedule of tasks for execution and their lifecycle information.
285+
286+
Tracks when each task's result can be discarded to optimize memory usage.
287+
288+
Attributes:
289+
tasks: Ordered list of tasks to execute
290+
last_use_index: Maps each task to the index in the schedule where its result is last used
291+
"""
292+
187293
tasks: List[TaskHandle]
188294
last_use_index: Dict[TaskHandle, int]
189295

190296
def __init__(self, tasks: List[TaskHandle], last_use_index: Dict[TaskHandle, int]):
297+
"""
298+
Initialize an execution schedule.
299+
300+
Args:
301+
tasks: Ordered list of tasks to execute
302+
last_use_index: Dictionary mapping tasks to their last use index in the schedule
303+
"""
191304
self.tasks = tasks
192305
self.last_use_index = last_use_index
193306

194307

195308
def build_schedule(
196309
targets: List[TaskHandle], cached_values: Dict[TaskHandle, Any]
197310
) -> ExecutionSchedule:
311+
"""
312+
Build an execution schedule for the given target tasks.
313+
314+
Creates a topologically sorted schedule that respects task dependencies and
315+
tracks when each task's result can be discarded to optimize memory usage.
316+
317+
Args:
318+
targets: List of target tasks that need to be executed
319+
cached_values: Dictionary of task results that are already available
320+
321+
Returns:
322+
ExecutionSchedule: A schedule containing tasks to execute and their lifecycle info
323+
"""
198324
if not targets:
199325
return ExecutionSchedule(tasks=[], last_use_index={})
200326

@@ -241,6 +367,7 @@ def _compare_key(node: TaskHandle) -> Tuple[str, int]:
241367
if (node != dummy_handle) and node not in (cached_values or {})
242368
]
243369

370+
# Calculate last use indices for memory optimization
244371
last_use_index = {}
245372
for idx, task in reversed(list(enumerate(schedule))):
246373
for dep in task.arguments().values():

mergekit/multigpu_executor.py

+71-19
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,17 @@
3434

3535
class MultiGPUExecutor:
3636
"""
37-
Execute tasks across multiple GPUs.
37+
Execute computational tasks in parallel across multiple GPUs.
38+
39+
This class analyzes the dependency structure of a task graph and distributes
40+
the workload across available GPUs while respecting:
41+
1. Tasks requiring main thread execution
42+
2. Tasks that need to be duplicated on each GPU
43+
3. Task dependencies and data locality
44+
4. Memory management for intermediate results
45+
46+
It automatically partitions the task graph into leading tasks (main thread, pre-GPU),
47+
parallel tasks (distributed across GPUs), and trailing tasks (main thread, post-GPU).
3848
3949
Attributes:
4050
num_gpus: Number of GPUs to utilize (None = all available)
@@ -49,12 +59,18 @@ def __init__(
4959
storage_device: Optional[torch.device] = None,
5060
):
5161
"""
52-
Initialize the executor with a list of tasks.
62+
Initialize the executor with a list of target tasks.
63+
64+
This performs initial task graph analysis, including:
65+
- Finding tasks that must run on the main thread before parallel execution
66+
- Finding tasks that must run on the main thread after parallel execution
67+
- Partitioning parallel tasks into islands that can run independently
68+
- Assigning islands to GPUs using a load-balancing approach
5369
5470
Args:
55-
tasks: List of tasks to execute
71+
targets: List of final target tasks to execute
5672
num_gpus: Number of GPUs to utilize (None = all available)
57-
storage_device: Device for storing tensors between stages
73+
storage_device: Device for storing intermediate results between execution stages
5874
"""
5975
self.results: Dict[TaskHandle, Any] = {}
6076
self.storage_device = storage_device
@@ -191,9 +207,19 @@ def _find_trailing_tasks(self, tasks: List[TaskHandle]) -> Set[TaskHandle]:
191207
"""
192208
Identify tasks that must execute AFTER parallel GPU tasks complete.
193209
194-
Trailing tasks must:
195-
- Require main thread execution
196-
- Not have non-trailing dependants
210+
This method finds tasks that need to run after parallel execution because they
211+
require the main thread and have dependencies on other tasks.
212+
213+
A task is considered "trailing" if:
214+
- It requires main thread execution (task.main_thread_only() is True)
215+
- All tasks dependent on it are also trailing tasks (recursive condition)
216+
- OR it has no dependents (terminal task)
217+
218+
Args:
219+
tasks: List of task handles to analyze
220+
221+
Returns:
222+
Set[TaskHandle]: Set of tasks that should be executed after parallel processing
197223
"""
198224
dependants = defaultdict(set)
199225
for task_idx, arg_indices in self.universe.task_arguments.items():
@@ -215,11 +241,21 @@ def _find_trailing_tasks(self, tasks: List[TaskHandle]) -> Set[TaskHandle]:
215241
return trailing_tasks
216242

217243
def _find_leading_tasks(self, tasks: List[TaskHandle]) -> Set[TaskHandle]:
218-
"""Identify tasks that must execute BEFORE parallel GPU tasks.
244+
"""
245+
Identify tasks that must execute BEFORE parallel GPU tasks.
246+
247+
This method finds tasks that need to run before parallel execution because they
248+
require the main thread and are dependencies for other tasks.
219249
220-
Leading tasks must:
221-
- Require main thread execution
222-
- Not have non-leading dependencies
250+
A task is considered "leading" if:
251+
- It requires main thread execution (task.main_thread_only() is True)
252+
- It has no dependencies, or all its dependencies are also leading tasks
253+
254+
Args:
255+
tasks: List of task handles to analyze
256+
257+
Returns:
258+
Set[TaskHandle]: Set of tasks that should be executed before parallel processing
223259
"""
224260
leading_tasks = set()
225261
for task_handle in tasks:
@@ -236,11 +272,22 @@ def _assign_islands_to_gpus(
236272
self, tasks: List[TaskHandle], num_gpus: int
237273
) -> Dict[torch.device, List[TaskHandle]]:
238274
"""
239-
Assign task islands to GPUs.
275+
Assign task islands to GPUs for parallel execution.
276+
277+
This method partitions the parallel task graph into independent subgraphs
278+
(islands) that can be executed independently on different GPUs. It uses
279+
a load-balancing approach to distribute islands across available GPUs.
240280
241-
Task islands (weakly connected components) are groups of tasks that
242-
can execute independently. This method identifies islands in the
243-
non-trailing, non-leading task graph and assigns them to devices.
281+
Task islands are identified as weakly connected components in the task
282+
dependency graph, meaning groups of tasks that are connected through
283+
dependencies but don't have dependencies outside their group.
284+
285+
Args:
286+
tasks: List of parallel tasks to assign to GPUs
287+
num_gpus: Number of available GPUs
288+
289+
Returns:
290+
Dict[torch.device, List[TaskHandle]]: Mapping from GPU devices to assigned tasks
244291
"""
245292
task_set = set(tasks)
246293

@@ -262,7 +309,7 @@ def _assign_islands_to_gpus(
262309
continue
263310
# don't need to sort, inner executor will handle
264311
island_tasks = [TaskHandle(self.universe, idx) for idx in island]
265-
# assign to GPU with fewest tasks
312+
# assign to GPU with fewest tasks (load balancing)
266313
device_idx = min(
267314
range(num_gpus),
268315
key=lambda i: len(assignments.get(torch.device(f"cuda:{i}"), [])),
@@ -281,11 +328,15 @@ def _device_worker(
281328
"""
282329
Execute a set of tasks on a single GPU.
283330
331+
This method runs as a thread worker for a specific GPU. It creates an execution
332+
stream on the assigned GPU, runs the tasks, and queues results back to the main
333+
thread. Only results needed for target tasks or trailing tasks are retained.
334+
284335
Args:
285-
island_tasks: List of tasks to execute
336+
task_list: List of tasks to execute on this GPU
286337
cached_values: Values of previously-executed dependent tasks
287-
device: Device to execute tasks on
288-
quiet: Suppress progress bar output
338+
device: GPU device to execute tasks on
339+
quiet: Whether to suppress progress bar output
289340
"""
290341
LOG.debug(f"Device {device} starting")
291342
with torch.device(device):
@@ -300,6 +351,7 @@ def _device_worker(
300351
count = 0
301352
for task_handle, result in exec._run(quiet=quiet):
302353
count += 1
354+
# Only keep results needed for target tasks or trailing tasks
303355
if not (
304356
task_handle in self.targets
305357
or task_handle in self.trailing_dependencies

0 commit comments

Comments
 (0)