Skip to content

Commit 7100a93

Browse files
authored
Merge branch 'master' into precommit/ruff-format
2 parents 9aca6f4 + 6cb5813 commit 7100a93

File tree

7 files changed

+67
-28
lines changed

7 files changed

+67
-28
lines changed

src/lightning/data/processing/data_processor.py

+29-17
Original file line numberDiff line numberDiff line change
@@ -250,23 +250,35 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
250250

251251
def _map_items_to_workers_sequentially(num_workers: int, user_items: List[Any]) -> List[List[Any]]:
252252
num_nodes = _get_num_nodes()
253-
current_node_rank = _get_node_rank()
254-
node_size = len(user_items) // num_nodes
255-
workers_user_items = []
256-
for node_rank in range(num_nodes):
257-
if node_rank != current_node_rank:
258-
continue
259-
is_last_node = node_rank == num_nodes - 1
260-
start_node = node_rank * node_size
261-
end_node = len(user_items) if is_last_node else (node_rank + 1) * node_size
262-
node_user_items = user_items[start_node:end_node]
263-
worker_size = len(node_user_items) // num_workers
264-
for worker_idx in range(num_workers):
265-
is_last = worker_idx == num_workers - 1
266-
begin = worker_idx * worker_size
267-
end = len(node_user_items) if is_last else (worker_idx + 1) * worker_size
268-
workers_user_items.append(node_user_items[begin:end])
269-
return workers_user_items
253+
world_size = (num_nodes * num_workers)
254+
num_items_per_worker = len(user_items) // world_size
255+
256+
num_items_per_worker: List[int] = [num_items_per_worker for _ in range(world_size)]
257+
reminder = len(user_items) % world_size
258+
259+
for worker_idx in range(len(num_items_per_worker) - 1, -1, -1):
260+
if reminder == 0:
261+
break
262+
num_items_per_worker[worker_idx] += 1
263+
reminder -= 1
264+
265+
num_items_cumsum_per_worker = np.cumsum([0] + num_items_per_worker)
266+
267+
out = []
268+
node_rank = _get_node_rank()
269+
worker_idx_start = node_rank * num_workers
270+
worker_idx_end = (node_rank + 1) * num_workers
271+
272+
for worker_idx in range(world_size):
273+
if worker_idx_start <= worker_idx and worker_idx < worker_idx_end:
274+
start = num_items_cumsum_per_worker[worker_idx]
275+
end = num_items_cumsum_per_worker[worker_idx + 1]
276+
out.append(user_items[start : end])
277+
278+
if len(out) != num_workers:
279+
raise RuntimeError("The items didn't haven't been assigned properly. Please, open an issue on Github.")
280+
281+
return out
270282

271283

272284
def _map_items_to_workers_weighted(

src/lightning/data/streaming/combined.py

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def set_epoch(self, current_epoch: int) -> None:
6666
for dataset in self._datasets:
6767
dataset.set_epoch(current_epoch)
6868

69+
def set_shuffle(self, shuffle: bool) -> None:
70+
"""Set the current shuffle to the datasets."""
71+
for dataset in self._datasets:
72+
dataset.set_shuffle(shuffle)
73+
6974
def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
7075
if any(not isinstance(d, StreamingDataset) for d in datasets):
7176
raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")

src/lightning/data/streaming/dataloader.py

+6
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ def __init__(
541541
profile_batches: Union[bool, int] = False,
542542
profile_dir: Optional[str] = None,
543543
prefetch_factor: Optional[int] = None,
544+
shuffle: Optional[bool] = None,
544545
**kwargs: Any,
545546
) -> None: # pyright: ignore
546547
if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)):
@@ -549,6 +550,11 @@ def __init__(
549550
f" Found {dataset}."
550551
)
551552

553+
if shuffle is not None:
554+
dataset.set_shuffle(shuffle)
555+
556+
shuffle = None
557+
552558
if profile_batches and not _VIZ_TRACKER_AVAILABLE:
553559
raise ModuleNotFoundError("To use profile_batches, viztracer is required. Run `pip install viztracer`")
554560

src/lightning/data/streaming/dataset.py

+3
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def __init__(
107107
self.serializers = serializers
108108
self._state_dict: Optional[Dict[str, Any]] = None
109109

110+
def set_shuffle(self, shuffle: bool) -> None:
111+
self.shuffle = shuffle
112+
110113
def set_epoch(self, current_epoch: int) -> None:
111114
"""Set the current epoch to the dataset on epoch starts.
112115

src/lightning/data/utilities/env.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,19 @@ def detect(cls) -> "_DistributedEnv":
3131
if torch.distributed.is_available() and torch.distributed.is_initialized():
3232
world_size = torch.distributed.get_world_size()
3333
global_rank = torch.distributed.get_rank()
34+
# Note: On multi node CPU, the number of nodes won't be correct.
35+
num_nodes = world_size // torch.cuda.device_count() if torch.cuda.is_available() else world_size
36+
if torch.cuda.is_available() and world_size % torch.cuda.device_count() != 0:
37+
raise RuntimeError("The world size should be divisible by the number of GPUs.")
3438
else:
3539
world_size = None
3640
global_rank = 0
41+
num_nodes = 1
3742

3843
if world_size is None or world_size == -1:
3944
world_size = 1
4045

41-
# TODO: Add support for other accelerators
42-
num_nodes = (world_size // torch.cuda.device_count()) if torch.cuda.is_available() else 1
43-
44-
if num_nodes > 1:
45-
# validate the world size is divisble by the number of GPUs
46-
assert world_size % torch.cuda.device_count() == 0
47-
48-
return cls(world_size=world_size, global_rank=global_rank, num_nodes=max(1, num_nodes))
46+
return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes)
4947

5048
def __repr__(self) -> str:
5149
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"

tests/tests_data/processing/test_data_processor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def test_map_items_to_workers_sequentially(monkeypatch):
310310
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)))
311311
assert workers_user_items == [[0, 1], [2, 3, 4]]
312312
workers_user_items = _map_items_to_workers_sequentially(3, list(range(5)))
313-
assert workers_user_items == [[0], [1], [2, 3, 4]]
313+
assert workers_user_items == [[0], [1, 2], [3, 4]]
314314
workers_user_items = _map_items_to_workers_sequentially(4, list(range(5)))
315315
assert workers_user_items == [[0], [1], [2], [3, 4]]
316316

@@ -335,7 +335,7 @@ def test_map_items_to_workers_sequentially(monkeypatch):
335335
workers_user_items = _map_items_to_workers_sequentially(2, list(range(32)))
336336
assert workers_user_items == [[0, 1, 2, 3], [4, 5, 6, 7]]
337337
workers_user_items = _map_items_to_workers_sequentially(3, list(range(32)))
338-
assert workers_user_items == [[0, 1], [2, 3], [4, 5, 6, 7]]
338+
assert workers_user_items == [[0, 1], [2, 3], [4, 5]]
339339
workers_user_items = _map_items_to_workers_sequentially(4, list(range(32)))
340340
assert workers_user_items == [[0, 1], [2, 3], [4, 5], [6, 7]]
341341

@@ -346,7 +346,7 @@ def test_map_items_to_workers_sequentially(monkeypatch):
346346
workers_user_items = _map_items_to_workers_sequentially(2, list(range(32)))
347347
assert workers_user_items == [[24, 25, 26, 27], [28, 29, 30, 31]]
348348
workers_user_items = _map_items_to_workers_sequentially(3, list(range(32)))
349-
assert workers_user_items == [[24, 25], [26, 27], [28, 29, 30, 31]]
349+
assert workers_user_items == [[23, 24, 25], [26, 27, 28], [29, 30, 31]]
350350
workers_user_items = _map_items_to_workers_sequentially(4, list(range(32)))
351351
assert workers_user_items == [[24, 25], [26, 27], [28, 29], [30, 31]]
352352

tests/tests_data/streaming/test_dataloader.py

+15
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ def __init__(self, size, step):
1212
self.size = size
1313
self.step = step
1414
self.counter = 0
15+
self.shuffle = None
16+
17+
def set_shuffle(self, shuffle):
18+
self.shuffle = shuffle
1519

1620
def __len__(self):
1721
return self.size
@@ -92,3 +96,14 @@ def test_dataloader_profiling(profile, tmpdir, monkeypatch):
9296
batches.append(batch)
9397

9498
assert os.path.exists(os.path.join(tmpdir, "result.json"))
99+
100+
101+
def test_dataloader_shuffle():
102+
dataset = TestCombinedStreamingDataset(
103+
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
104+
)
105+
assert dataset._datasets[0].shuffle is None
106+
assert dataset._datasets[1].shuffle is None
107+
StreamingDataLoader(dataset, batch_size=2, num_workers=1, shuffle=True)
108+
assert dataset._datasets[0].shuffle
109+
assert dataset._datasets[1].shuffle

0 commit comments

Comments
 (0)