Skip to content

Commit 7b51b6c

Browse files
thomasthomas
thomas
authored and
thomas
committed
update
1 parent 265025b commit 7b51b6c

File tree

2 files changed

+31
-20
lines changed

2 files changed

+31
-20
lines changed

src/lightning/data/processing/data_processor.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -250,23 +250,34 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
250250

251251
def _map_items_to_workers_sequentially(num_workers: int, user_items: List[Any]) -> List[List[Any]]:
252252
num_nodes = _get_num_nodes()
253-
current_node_rank = _get_node_rank()
254-
node_size = len(user_items) // num_nodes
255-
workers_user_items = []
256-
for node_rank in range(num_nodes):
257-
if node_rank != current_node_rank:
258-
continue
259-
is_last_node = node_rank == num_nodes - 1
260-
start_node = node_rank * node_size
261-
end_node = len(user_items) if is_last_node else (node_rank + 1) * node_size
262-
node_user_items = user_items[start_node:end_node]
263-
worker_size = len(node_user_items) // num_workers
264-
for worker_idx in range(num_workers):
265-
is_last = worker_idx == num_workers - 1
266-
begin = worker_idx * worker_size
267-
end = len(node_user_items) if is_last else (worker_idx + 1) * worker_size
268-
workers_user_items.append(node_user_items[begin:end])
269-
return workers_user_items
253+
world_size = (num_nodes * num_workers)
254+
num_items_per_worker = len(user_items) // world_size
255+
256+
num_items_per_worker: List[int] = [num_items_per_worker for _ in range(world_size)]
257+
reminder = len(user_items) % world_size
258+
259+
for worker_idx in range(len(num_items_per_worker) - 1, -1, -1):
260+
if reminder == 0:
261+
break
262+
num_items_per_worker[worker_idx] += 1
263+
reminder -= 1
264+
265+
num_items_cumsum_per_worker = np.cumsum([0] + num_items_per_worker)
266+
267+
out = []
268+
node_rank = _get_node_rank()
269+
worker_idx_start = node_rank * num_workers
270+
worker_idx_end = (node_rank + 1) * num_workers
271+
272+
for worker_idx in range(world_size):
273+
if worker_idx_start <= worker_idx and worker_idx < worker_idx_end:
274+
start = num_items_cumsum_per_worker[worker_idx]
275+
end = num_items_cumsum_per_worker[worker_idx + 1]
276+
out.append(user_items[start : end])
277+
278+
assert len(out) == num_workers
279+
280+
return out
270281

271282

272283
def _map_items_to_workers_weighted(

tests/tests_data/processing/test_data_processor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def test_map_items_to_workers_sequentially(monkeypatch):
310310
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)))
311311
assert workers_user_items == [[0, 1], [2, 3, 4]]
312312
workers_user_items = _map_items_to_workers_sequentially(3, list(range(5)))
313-
assert workers_user_items == [[0], [1], [2, 3, 4]]
313+
assert workers_user_items == [[0], [1, 2], [3, 4]]
314314
workers_user_items = _map_items_to_workers_sequentially(4, list(range(5)))
315315
assert workers_user_items == [[0], [1], [2], [3, 4]]
316316

@@ -335,7 +335,7 @@ def test_map_items_to_workers_sequentially(monkeypatch):
335335
workers_user_items = _map_items_to_workers_sequentially(2, list(range(32)))
336336
assert workers_user_items == [[0, 1, 2, 3], [4, 5, 6, 7]]
337337
workers_user_items = _map_items_to_workers_sequentially(3, list(range(32)))
338-
assert workers_user_items == [[0, 1], [2, 3], [4, 5, 6, 7]]
338+
assert workers_user_items == [[0, 1], [2, 3], [4, 5]]
339339
workers_user_items = _map_items_to_workers_sequentially(4, list(range(32)))
340340
assert workers_user_items == [[0, 1], [2, 3], [4, 5], [6, 7]]
341341

@@ -346,7 +346,7 @@ def test_map_items_to_workers_sequentially(monkeypatch):
346346
workers_user_items = _map_items_to_workers_sequentially(2, list(range(32)))
347347
assert workers_user_items == [[24, 25, 26, 27], [28, 29, 30, 31]]
348348
workers_user_items = _map_items_to_workers_sequentially(3, list(range(32)))
349-
assert workers_user_items == [[24, 25], [26, 27], [28, 29, 30, 31]]
349+
assert workers_user_items == [[23, 24, 25], [26, 27, 28], [29, 30, 31]]
350350
workers_user_items = _map_items_to_workers_sequentially(4, list(range(32)))
351351
assert workers_user_items == [[24, 25], [26, 27], [28, 29], [30, 31]]
352352

0 commit comments

Comments
 (0)