From 1346e44546a0285cf758f73b726945973afc911f Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Thu, 6 Feb 2025 22:57:34 -0800 Subject: [PATCH] add comment about why were updating state dict --- torchdata/stateful_dataloader/sampler.py | 13 +- .../stateful_dataloader.py | 215 +++++++++++++----- 2 files changed, 170 insertions(+), 58 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index f6396c7a1..ba3f6a61b 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -109,15 +109,18 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) - if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( - self.sampler, _InfiniteConstantSampler - ): + if not ( + isinstance(self.sampler, Stateful) + or isinstance(self.sampler_iter, Stateful) + ) and not isinstance(self.sampler, _InfiniteConstantSampler): # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) - def update_state_dict(self): - if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): + def update_state_dict(self) -> None: + if isinstance(self.sampler_iter, Stateful) and hasattr( + self.sampler_iter, "update_state_dict" + ): self.sampler_iter.update_state_dict() diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 2b0db11e9..ff06263dc 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -46,7 +46,10 @@ ) from torch.utils.data.dataloader import _BaseDataLoaderIter, _InfiniteConstantSampler -from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper +from torch.utils.data.datapipes.datapipe import ( + _IterDataPipeSerializationWrapper, + _MapDataPipeSerializationWrapper, +) from .incremental_state import ( _DATASET_ITER_STATE, @@ -215,7 +218,8 @@ def __init__( if num_workers < 0: raise ValueError( - "num_workers option should be non-negative; " "use num_workers=0 to disable multiprocessing." + "num_workers option should be non-negative; " + "use num_workers=0 to disable multiprocessing." ) if timeout < 0: @@ -291,7 +295,9 @@ def __init__( # specific workers. if isinstance(dataset, IterDataPipe): if shuffle is not None: - dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) + dataset = torch.utils.data.graph_settings.apply_shuffle_settings( + dataset, shuffle=shuffle + ) # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. elif shuffle not in {False, None}: raise ValueError( @@ -320,7 +326,9 @@ def __init__( # auto_collation with custom batch_sampler if batch_size != 1 or shuffle or sampler is not None or drop_last: raise ValueError( - "batch_sampler option is mutually exclusive " "with batch_size, shuffle, sampler, and " "drop_last" + "batch_sampler option is mutually exclusive " + "with batch_size, shuffle, sampler, and " + "drop_last" ) batch_size = None drop_last = False @@ -328,7 +336,8 @@ def __init__( # no auto_collation if drop_last: raise ValueError( - "batch_size=None option disables auto-batching " "and is mutually exclusive with drop_last" + "batch_size=None option disables auto-batching " + "and is mutually exclusive with drop_last" ) if sampler is None: # give default samplers @@ -362,7 +371,9 @@ def __init__( # set DataLoader's __initialized attribute. self._DataLoader__initialized = True - self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ] + self._IterableDataset_len_called = ( + None # See NOTE [ IterableDataset and __len__ ] + ) self._iterator = None @@ -449,6 +460,9 @@ def __next__(self): try: return super().__next__() except StopIteration: + # If we are at the end of the iteration, we want to update the state dict of _sampler_iter. + # because in __iter__ after self._iterator is set using self._get_iterator() [which makes self.next_iter_state = None], + # it is checked if self._iterator._finished is True, and if it is, self._iterator is reset with next_iter_state = None. if hasattr(self._sampler_iter, "update_state_dict"): self._sampler_iter.update_state_dict() self._finished = True @@ -475,7 +489,9 @@ def __init__(self, loader, next_iter_state=None): # Taking care of distributed sharding if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): # For BC, use default SHARDING_PRIORITIES - torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank) + torch.utils.data.graph_settings.apply_sharding( + self._dataset, self._world_size, self._rank + ) if next_iter_state is not None: self.load_state_dict(next_iter_state) @@ -498,7 +514,9 @@ def _next_data(self): def state_dict(self): if self._dataset_kind == _DatasetKind.Iterable: fetcher_state = { - _DATASET_ITER_STATE: try_to_serialize(self._dataset_fetcher.dataset_iter), + _DATASET_ITER_STATE: try_to_serialize( + self._dataset_fetcher.dataset_iter + ), _FETCHER_ENDED: self._dataset_fetcher.ended, } dataset_state = None @@ -528,11 +546,17 @@ def load_state_dict(self, state_dict): self._sampler_iter_yielded = state_dict[_SAMPLER_ITER_YIELDED] # Try to restore from either _index_sampler state_dict or _sampler_iter state_dict - if isinstance(self._index_sampler, Stateful) or isinstance(self._sampler_iter, Stateful): - self._index_sampler = try_to_deserialize(self._index_sampler, state_dict[_INDEX_SAMPLER_STATE]) + if isinstance(self._index_sampler, Stateful) or isinstance( + self._sampler_iter, Stateful + ): + self._index_sampler = try_to_deserialize( + self._index_sampler, state_dict[_INDEX_SAMPLER_STATE] + ) self._sampler_iter = iter(self._index_sampler) if state_dict[_SAMPLER_ITER_STATE] is not None: - self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) + self._sampler_iter = try_to_deserialize( + self._sampler_iter, state_dict[_SAMPLER_ITER_STATE] + ) else: if not isinstance( @@ -540,7 +564,9 @@ def load_state_dict(self, state_dict): torch.utils.data.dataloader._InfiniteConstantSampler, ): # Fallback to fastforward - self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) + self._sampler_iter = itertools.islice( + self._index_sampler, self._sampler_iter_yielded, None + ) self._num_yielded = state_dict[self._NUM_YIELDED] self._IterableDataset_len_called = state_dict[_ITERABLEDATASET_LEN_CALLED] self._shared_seed = state_dict[_SHARED_SEED] @@ -549,8 +575,12 @@ def load_state_dict(self, state_dict): # 1. try to restore dataset state # 2. generate dataset iterator # 3. try to restore iterator state - if state_dict[_DATASET_STATE] is not None and isinstance(self._dataset, Stateful): - self._dataset = try_to_deserialize(self._dataset, state_dict[_DATASET_STATE]) + if state_dict[_DATASET_STATE] is not None and isinstance( + self._dataset, Stateful + ): + self._dataset = try_to_deserialize( + self._dataset, state_dict[_DATASET_STATE] + ) self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, @@ -560,14 +590,18 @@ def load_state_dict(self, state_dict): ) if self._dataset_kind == _DatasetKind.Iterable: # If either dataset or it's iter is stateful, we don't fast-forward - if isinstance(self._dataset, Stateful) or isinstance(self._dataset_fetcher.dataset_iter, Stateful): + if isinstance(self._dataset, Stateful) or isinstance( + self._dataset_fetcher.dataset_iter, Stateful + ): if state_dict[_FETCHER_STATE] is not None: if state_dict[_FETCHER_STATE][_DATASET_ITER_STATE] is not None: self._dataset_fetcher.dataset_iter = try_to_deserialize( self._dataset_fetcher.dataset_iter, state_dict[_FETCHER_STATE][_DATASET_ITER_STATE], ) - self._dataset_fetcher.ended = state_dict[_FETCHER_STATE][_FETCHER_ENDED] + self._dataset_fetcher.ended = state_dict[_FETCHER_STATE][ + _FETCHER_ENDED + ] else: # No state, just try to fastforward if self._num_yielded > 0: @@ -942,16 +976,20 @@ def __init__(self, loader, next_iter_state): self._SNAPSHOT in next_iter_state ), f"State doesn't contain key '{self._SNAPSHOT}' expected for multiprocess dataloader" wstates = next_iter_state[self._SNAPSHOT].get(self._WORKER_SNAPSHOTS, {}) - assert set(map(self._worker_key, range(len(wstates)))) == set(wstates.keys()), ( + assert set(map(self._worker_key, range(len(wstates)))) == set( + wstates.keys() + ), ( len(wstates), wstates.keys(), ) for worker_key, sd in wstates.items(): worker_states[worker_key] = sd - self._base_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get(self._BASE_SEED, self._base_seed) - self._shared_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get( - _SHARED_SEED, self._shared_seed + self._base_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get( + self._BASE_SEED, self._base_seed ) + self._shared_seed = next_iter_state[self._SNAPSHOT][ + self._MAIN_SNAPSHOT + ].get(_SHARED_SEED, self._shared_seed) for i in range(self._num_workers): # No certainty which module multiprocessing_context is @@ -999,7 +1037,9 @@ def __init__(self, loader, next_iter_state): if self._pin_memory_device == "xpu": current_device = torch.xpu.current_device() # type: ignore[attr-defined] elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): - custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) + custom_device_mod = getattr( + torch, torch._C._get_privateuse1_backend_name() + ) current_device = custom_device_mod.current_device() else: current_device = torch.cuda.current_device() # choose cuda for default @@ -1032,7 +1072,9 @@ def __init__(self, loader, next_iter_state): import atexit for w in self._workers: - atexit.register(_StatefulMultiProcessingDataLoaderIter._clean_up_worker, w) + atexit.register( + _StatefulMultiProcessingDataLoaderIter._clean_up_worker, w + ) # .pid can be None only before process is spawned (not the case, so ignore) _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] @@ -1048,17 +1090,23 @@ def __init__(self, loader, next_iter_state): # We need to send initial worker state back to the main process to handle state_dict() requests # before n >= num_workers steps are taken. # self._worker_snapshots: Dict[str, _IncrementalWorkerState] = {} - self._worker_snapshots = {key: _IncrementalWorkerState(state) for key, state in worker_states.items()} + self._worker_snapshots = { + key: _IncrementalWorkerState(state) for key, state in worker_states.items() + } self._reset(loader, first_iter=True, prime_prefetch=next_iter_state is None) # Try to restore main state if next_iter_state is not None: - self._restore_main_state(next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT]) + self._restore_main_state( + next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT] + ) self._num_yielded = next_iter_state[self._SNAPSHOT][self._SNAPSHOT_STEP] self._update_snapshot( snapshot_step=next_iter_state[self._SNAPSHOT][self._SNAPSHOT_STEP], - last_yielded_worker_id=next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID], + last_yielded_worker_id=next_iter_state[self._SNAPSHOT][ + self._LAST_YIELDED_WORKER_ID + ], num_workers=self._num_workers, main_snapshot=next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT], worker_snapshots=self._worker_snapshots, @@ -1069,7 +1117,10 @@ def __init__(self, loader, next_iter_state): for state in worker_states.values(): if state is None: continue - if state[_DATASET_STATE] is None and state[_FETCHER_STATE][_DATASET_ITER_STATE] is None: + if ( + state[_DATASET_STATE] is None + and state[_FETCHER_STATE][_DATASET_ITER_STATE] is None + ): fast_forward = True break @@ -1086,10 +1137,17 @@ def __init__(self, loader, next_iter_state): for _ in range(self._num_yielded): next(self) # Check if last_yielded_worker_id matches - if self._last_yielded_worker_id != next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID]: - raise ValueError("last_yielded_worker_id does not match, the dataset may have changed") + if ( + self._last_yielded_worker_id + != next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID] + ): + raise ValueError( + "last_yielded_worker_id does not match, the dataset may have changed" + ) else: - self._last_yielded_worker_id = next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID] + self._last_yielded_worker_id = next_iter_state[self._SNAPSHOT][ + self._LAST_YIELDED_WORKER_ID + ] for _ in range(self._last_yielded_worker_id + 1): next(self._worker_queue_idx_cycle) for _ in range(self._prefetch_factor * self._num_workers): @@ -1107,7 +1165,9 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): # map: task idx => - (worker_id,) if data isn't fetched (outstanding) # \ (worker_id, data) if data is already fetched (out-of-order) self._task_info = {} - self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + self._tasks_outstanding = ( + 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + ) # A list of booleans representing whether each worker still has work to # do, i.e., not having exhausted its iterable dataset object. It always # contains all `True`s if not using an iterable-style dataset @@ -1132,7 +1192,9 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): while remaining > 0: _, data = self._get_data() if not all(self._workers_status): - raise ValueError(f"A worker has failed during startup! {self._workers_status}") + raise ValueError( + f"A worker has failed during startup! {self._workers_status}" + ) elif isinstance(data, _AckStartup): if isinstance(data.initial_state, ExceptionWrapper): data.initial_state.reraise() @@ -1140,27 +1202,37 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): if data.is_delta: self._worker_snapshots[self._worker_key(data.worker_id)].apply_delta(data.initial_state) # type: ignore[arg-type] else: - self._worker_snapshots[self._worker_key(data.worker_id)] = _IncrementalWorkerState( + self._worker_snapshots[ + self._worker_key(data.worker_id) + ] = _IncrementalWorkerState( data.initial_state # type: ignore[arg-type] ) remaining -= 1 else: - raise ValueError(f"Invalid response from worker after startup: {data}") + raise ValueError( + f"Invalid response from worker after startup: {data}" + ) else: # We resume the prefetching in case it was enabled for idx in range(self._num_workers): - self._index_queues[idx].put(_utils.worker._ResumeIteration(self._shared_seed)) + self._index_queues[idx].put( + _utils.worker._ResumeIteration(self._shared_seed) + ) resume_iteration_cnt = self._num_workers while resume_iteration_cnt > 0: return_idx, data = self._get_data() if not all(self._workers_status): - raise ValueError(f"A worker has failed during Resume! {self._workers_status}") + raise ValueError( + f"A worker has failed during Resume! {self._workers_status}" + ) if isinstance(return_idx, _utils.worker._ResumeIteration): assert isinstance(data, _AckStartup), (return_idx, data) if isinstance(data.initial_state, ExceptionWrapper): data.initial_state.reraise() assert data.initial_state is not None, data - self._worker_snapshots[self._worker_key(data.worker_id)] = _IncrementalWorkerState( + self._worker_snapshots[ + self._worker_key(data.worker_id) + ] = _IncrementalWorkerState( data.initial_state # type: ignore[arg-type] ) resume_iteration_cnt -= 1 @@ -1228,7 +1300,9 @@ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): self._mark_worker_as_unavailable(worker_id) if len(failed_workers) > 0: pids_str = ", ".join(str(w.pid) for w in failed_workers) - raise RuntimeError(f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly") from e + raise RuntimeError( + f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly" + ) from e if isinstance(e, queue.Empty): return (False, None) import errno @@ -1366,7 +1440,9 @@ def _get_data(self): if success: return data else: - raise RuntimeError(f"DataLoader timed out after {self._timeout} seconds") + raise RuntimeError( + f"DataLoader timed out after {self._timeout} seconds" + ) elif self._pin_memory: while self._pin_memory_thread.is_alive(): success, data = self._try_get_data() @@ -1399,7 +1475,9 @@ def _next_data(self): info = self._task_info.get(self._rcvd_idx, None) if info: worker_id = info[0] - if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active + if ( + len(info) == 2 or self._workers_status[worker_id] + ): # has data or is still active break del self._task_info[self._rcvd_idx] self._rcvd_idx += 1 @@ -1415,7 +1493,9 @@ def _next_data(self): if len(self._task_info[self._rcvd_idx]) == 2: data, worker_id, state_dict = self._task_info.pop(self._rcvd_idx)[1] if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) + self._update_worker_snapshot( + self._worker_key(data.worker_id), state_dict + ) self._rcvd_idx += 1 continue else: @@ -1432,7 +1512,9 @@ def _next_data(self): self._workers_status[data.worker_id] = False else: self._mark_worker_as_unavailable(data.worker_id) - assert state_dict is not None, "StopIteration should always be accompanied by a state_dict" + assert ( + state_dict is not None + ), "StopIteration should always be accompanied by a state_dict" self._try_put_index() # We want to process states until we get to that position # in the worker cycle, therefore if out-of-order we want @@ -1443,7 +1525,9 @@ def _next_data(self): if not self._in_order: # don't store it for later, process now if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) + self._update_worker_snapshot( + self._worker_key(data.worker_id), state_dict + ) continue del self._task_info[idx] return self._process_data(data, worker_id, state_dict) @@ -1451,7 +1535,9 @@ def _next_data(self): else: del self._task_info[idx] if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) + self._update_worker_snapshot( + self._worker_key(data.worker_id), state_dict + ) self._rcvd_idx += 1 continue else: @@ -1473,18 +1559,26 @@ def _restore_main_state(self, state_dict): assert self._num_workers == state_dict[self._NUM_WORKERS] # Try to restore from either _index_sampler state_dict or _sampler_iter state_dict self._sampler_iter_yielded = state_dict[_SAMPLER_ITER_YIELDED] - if isinstance(self._index_sampler, Stateful) or isinstance(self._sampler_iter, Stateful): - self._index_sampler = try_to_deserialize(self._index_sampler, state_dict[_INDEX_SAMPLER_STATE]) + if isinstance(self._index_sampler, Stateful) or isinstance( + self._sampler_iter, Stateful + ): + self._index_sampler = try_to_deserialize( + self._index_sampler, state_dict[_INDEX_SAMPLER_STATE] + ) self._sampler_iter = iter(self._index_sampler) if state_dict[_SAMPLER_ITER_STATE] is not None: - self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) + self._sampler_iter = try_to_deserialize( + self._sampler_iter, state_dict[_SAMPLER_ITER_STATE] + ) else: if not isinstance( self._index_sampler, torch.utils.data.dataloader._InfiniteConstantSampler, ): # Fallback to fastforward - self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) + self._sampler_iter = itertools.islice( + self._index_sampler, self._sampler_iter_yielded, None + ) self._IterableDataset_len_called = state_dict[_ITERABLEDATASET_LEN_CALLED] self._shared_seed = state_dict[_SHARED_SEED] self._base_seed = state_dict[self._BASE_SEED] @@ -1521,7 +1615,9 @@ def _try_put_index(self): if self._workers_status[worker_queue_idx]: if self._in_order: break - elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum(self._workers_status): + elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum( + self._workers_status + ): # when self._in_order is False, distribute work to a worker if it has capacity # _workers_status is updated only in this thread, so the sum is guaranteed > 0 break @@ -1547,14 +1643,20 @@ def _process_data(self, data, worker_id, state_dict): self._last_yielded_worker_id = worker_id # Update latest worker state if state_dict is not None: - self._update_worker_snapshot(self._worker_key(state_dict[_WORKER_ID]), state_dict) - if self._snapshot_interval and ((self._num_yielded + 1) % self._snapshot_interval == 0): + self._update_worker_snapshot( + self._worker_key(state_dict[_WORKER_ID]), state_dict + ) + if self._snapshot_interval and ( + (self._num_yielded + 1) % self._snapshot_interval == 0 + ): self._take_snapshot() return data def _take_snapshot(self): main_snapshot_idx = None - while len(self._main_snapshots) and (self._main_snapshots[0][0] <= self._rcvd_idx - 1): + while len(self._main_snapshots) and ( + self._main_snapshots[0][0] <= self._rcvd_idx - 1 + ): main_snapshot_idx, main_snapshot = self._main_snapshots.popleft() if not self._in_order and main_snapshot_idx is None: # in_order is False and no main snapshot is available as we're ahead of rcvd_idx @@ -1584,7 +1686,10 @@ def _update_snapshot( self._SNAPSHOT_STEP: snapshot_step, self._LAST_YIELDED_WORKER_ID: last_yielded_worker_id, self._MAIN_SNAPSHOT: main_snapshot, - self._WORKER_SNAPSHOTS: {key: worker_state.get_state() for key, worker_state in worker_snapshots.items()}, + self._WORKER_SNAPSHOTS: { + key: worker_state.get_state() + for key, worker_state in worker_snapshots.items() + }, } def _mark_worker_as_unavailable(self, worker_id, shutdown=False): @@ -1617,7 +1722,11 @@ def _shutdown_workers(self): # Called when shutting down this `_MultiProcessingDataLoaderIter`. # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on # the logic of this function. - if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None: + if ( + _utils is None + or _utils.python_exit_status is True + or _utils.python_exit_status is None + ): # See (2) of the note. If Python is shutting down, do no-op. return # Normal exit when last reference is gone / iterator is depleted.