15
15
import inspect
16
16
import logging
17
17
import os
18
+ from copy import deepcopy
18
19
from importlib import reload
20
+ from itertools import cycle
19
21
from typing import Any , Callable , Dict , List , Optional , Union
20
22
21
23
import torch
32
34
from torch .utils .data .sampler import BatchSampler , Sampler
33
35
34
36
from lightning .data .streaming import Cache
35
- from lightning .data .streaming .combined import CombinedStreamingDataset
37
+ from lightning .data .streaming .combined import (
38
+ __NUM_SAMPLES_YIELDED_KEY__ ,
39
+ __SAMPLES_KEY__ ,
40
+ CombinedStreamingDataset ,
41
+ )
36
42
from lightning .data .streaming .constants import _DEFAULT_CHUNK_BYTES , _TORCH_GREATER_EQUAL_2_1_0 , _VIZ_TRACKER_AVAILABLE
37
43
from lightning .data .streaming .dataset import StreamingDataset
38
44
from lightning .data .streaming .sampler import CacheBatchSampler
@@ -341,6 +347,35 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":
341
347
return _MultiProcessingDataLoaderIterPatch (self )
342
348
343
349
350
+ class _StreamingMultiProcessingDataLoaderIter (_MultiProcessingDataLoaderIter ):
351
+ def __init__ (self , loader : DataLoader ) -> None :
352
+ self ._loader = loader
353
+ self ._indexes = (
354
+ list (range (self ._loader ._latest_worker_idx , self ._loader .num_workers ))
355
+ if self ._loader ._latest_worker_idx > 0
356
+ else []
357
+ )
358
+ super ().__init__ (loader )
359
+
360
+ def _try_put_index (self ) -> None :
361
+ # Used to restart on the right DataLoader worker
362
+ if self ._loader .restore and self ._indexes :
363
+ assert self ._tasks_outstanding < self ._prefetch_factor * self ._num_workers
364
+
365
+ try :
366
+ index = self ._next_index ()
367
+ except StopIteration :
368
+ return
369
+ worker_queue_idx = self ._indexes .pop (0 )
370
+
371
+ self ._index_queues [worker_queue_idx ].put ((self ._send_idx , index ))
372
+ self ._task_info [self ._send_idx ] = (worker_queue_idx ,)
373
+ self ._tasks_outstanding += 1
374
+ self ._send_idx += 1
375
+ else :
376
+ super ()._try_put_index ()
377
+
378
+
344
379
class StreamingDataLoader (DataLoader ):
345
380
"""The `StreamingDataLoader` keeps track of the number of samples fetched in order to enable resumability of the
346
381
dataset."""
@@ -355,27 +390,82 @@ def __init__(
355
390
num_workers : int = 0 ,
356
391
** kwargs : Any ,
357
392
) -> None : # pyright: ignore
393
+ if not isinstance (dataset , (StreamingDataset , CombinedStreamingDataset )):
394
+ raise RuntimeError (
395
+ "The provided dataset should be either an instance of StreamingDataset or CombinedStreamingDataset."
396
+ f" Found { dataset } ."
397
+ )
398
+
399
+ self .current_epoch = 0
358
400
self .batch_size = batch_size
359
401
self .num_workers = num_workers
360
- self .num_samples_yielded = 0
402
+ self ._num_samples_yielded_streaming = 0
403
+ self ._num_samples_yielded_combined : Dict [int , List [Any ]] = {}
404
+ self .rng_state : Optional [Any ] = None
405
+ self ._worker_idx = cycle (list (range (self .num_workers if self .num_workers > 0 else 1 )))
406
+ self ._worker_idx_iter : Optional [Any ] = None
407
+ self ._latest_worker_idx = 0
408
+ self .restore = False
361
409
super ().__init__ (dataset , * args , batch_size = batch_size , num_workers = num_workers , ** kwargs ) # type: ignore
362
410
363
411
def __iter__ (self ) -> Any :
412
+ if not self .restore :
413
+ self ._latest_worker_idx = 0
414
+ self ._worker_idx = cycle (list (range (self .num_workers if self .num_workers > 0 else 1 )))
415
+ self ._worker_idx_iter = iter (self ._worker_idx )
416
+ self .current_epoch += 1
417
+ self ._num_samples_yielded_combined = {}
418
+ self ._num_samples_yielded_streaming = 0
419
+
420
+ self .dataset .set_epoch (self .current_epoch )
421
+
364
422
if isinstance (self .dataset , StreamingDataset ):
365
423
assert self .batch_size
366
- self .num_samples_yielded = 0
367
424
for batch in super ().__iter__ ():
368
- self .num_samples_yielded += self .batch_size
425
+ self ._latest_worker_idx = next (self ._worker_idx_iter ) # type: ignore
426
+ self ._num_samples_yielded_streaming += self .batch_size
369
427
yield batch
370
428
else :
371
- yield from super ().__iter__ ()
429
+ self .dataset ._set_use_streaming_dataloader (True )
430
+ assert self .batch_size
431
+ # TODO: Inject a custom collate function to avoid collating the __NUM_SAMPLES_YIELDED__ key
432
+ for batch in super ().__iter__ ():
433
+ self ._latest_worker_idx = next (self ._worker_idx_iter ) # type: ignore
434
+ if isinstance (batch , dict ) and __NUM_SAMPLES_YIELDED_KEY__ in batch :
435
+ self ._num_samples_yielded_combined [self ._latest_worker_idx ] = [
436
+ sample [- 1 ].item () if self .batch_size > 1 else sample .item ()
437
+ for sample in batch [__NUM_SAMPLES_YIELDED_KEY__ ]
438
+ ]
439
+
440
+ yield batch [__SAMPLES_KEY__ ]
441
+ else :
442
+ yield batch
443
+
444
+ self .restore = False
372
445
373
446
def state_dict (self ) -> Dict [str , Any ]:
374
447
if isinstance (self .dataset , StreamingDataset ):
375
448
assert self .batch_size
376
- num_samples = self .num_samples_yielded
377
- return self .dataset .state_dict (num_samples , self .num_workers , self .batch_size )
378
- return self .dataset .state_dict (self .num_workers , self .batch_size )
449
+ return {
450
+ "dataset" : self .dataset .state_dict (
451
+ self ._num_samples_yielded_streaming , self .num_workers , self .batch_size
452
+ ),
453
+ "current_epoch" : self .current_epoch ,
454
+ "num_samples_yielded" : self ._num_samples_yielded_streaming ,
455
+ "latest_worker_idx" : self ._latest_worker_idx ,
456
+ }
457
+
458
+ num_samples_yieled = [0 for _ in range (len (list (self ._num_samples_yielded_combined .values ())[0 ]))]
459
+ for worker_idx in self ._num_samples_yielded_combined :
460
+ for dataset_idx , samples_yieled in enumerate (self ._num_samples_yielded_combined [worker_idx ]):
461
+ num_samples_yieled [dataset_idx ] += samples_yieled
462
+
463
+ return {
464
+ "dataset" : self .dataset .state_dict (self .num_workers , self .batch_size , num_samples_yieled ),
465
+ "current_epoch" : self .current_epoch if self .restore else self .current_epoch - 1 ,
466
+ "latest_worker_idx" : self ._latest_worker_idx ,
467
+ "num_samples_yielded" : deepcopy (self ._num_samples_yielded_combined ),
468
+ }
379
469
380
470
def load_state_dict (self , obj : Dict [str , Any ]) -> None :
381
471
"""Load a dict containing training state (called from non-worker process).
@@ -386,7 +476,34 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:
386
476
obj (Any): The state.
387
477
388
478
"""
389
- if isinstance (self .dataset , (StreamingDataset , CombinedStreamingDataset )):
479
+ self .current_epoch = obj ["current_epoch" ]
480
+
481
+ if isinstance (self .dataset , StreamingDataset ):
482
+ self ._num_samples_yielded_streaming = obj ["num_samples_yielded" ]
483
+ else :
484
+ self ._num_samples_yielded_combined = obj ["num_samples_yielded" ]
485
+
486
+ # Used to restart on the next DataLoader worker from the previous run.
487
+ self ._latest_worker_idx = obj ["latest_worker_idx" ] + 1
488
+ self ._worker_idx_iter = iter (self ._worker_idx )
489
+ for _ in range (self ._latest_worker_idx ):
490
+ next (self ._worker_idx_iter )
491
+
492
+ # Inform we are resuming and disable resetting the StreamingDataLoader state.
493
+ # This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes.
494
+ self .restore = True
495
+
496
+ if isinstance (self .dataset , CombinedStreamingDataset ):
497
+ self .dataset ._set_use_streaming_dataloader (True )
390
498
self .dataset .load_state_dict (obj )
499
+ elif isinstance (self .dataset , StreamingDataset ):
500
+ self .dataset .load_state_dict (obj ["dataset" ])
391
501
else :
392
502
raise RuntimeError ("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`." )
503
+
504
+ def _get_iterator (self ) -> "_BaseDataLoaderIter" :
505
+ """Overriden to ensure the `Cache.done()` method is triggered on iteration done."""
506
+ if self .num_workers == 0 :
507
+ return _SingleProcessDataLoaderIter (self )
508
+ self .check_worker_number_rationality ()
509
+ return _StreamingMultiProcessingDataLoaderIter (self )
0 commit comments