Skip to content

Commit

Permalink
Fix docs formatting (#1412)
Browse files Browse the repository at this point in the history
* fix formatting

* fix formatting again
  • Loading branch information
divyanshk authored Dec 21, 2024
1 parent 2707d18 commit 62092dd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
18 changes: 12 additions & 6 deletions torchdata/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __iter__(self):
def reset(self, initial_state: Optional[dict] = None):
"""Resets the iterator to the beginning, or to the state passed in by initial_state.
Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called.
Reset is a good place to put expensive initialization, as it will be lazily called when ``next()`` or ``state_dict()`` is called.
Subclasses must call ``super().reset(initial_state)``.
Args:
Expand All @@ -57,14 +57,18 @@ def reset(self, initial_state: Optional[dict] = None):
self.__initialized = True

def get_state(self) -> Dict[str, Any]:
"""Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode.
:return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future
"""Subclasses must implement this method, instead of ``state_dict()``. Should only be called by BaseNode.
Returns:
Dict[str, Any] - a state dict that may be passed to ``reset()`` at some point in the future
"""
raise NotImplementedError(type(self))

def next(self) -> T:
"""Subclasses must implement this method, instead of ``__next``. Should only be called by BaseNode.
:return: T - the next value in the sequence, or throw StopIteration
"""Subclasses must implement this method, instead of ``__next__``. Should only be called by BaseNode.
Returns:
T - the next value in the sequence, or throw StopIteration
"""
raise NotImplementedError(type(self))

Expand All @@ -83,7 +87,9 @@ def __next__(self):

def state_dict(self) -> Dict[str, Any]:
"""Get a state_dict for this BaseNode.
:return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future.
Returns:
Dict[str, Any] - a state dict that may be passed to ``reset()`` at some point in the future.
"""
try:
self.__initialized
Expand Down
3 changes: 3 additions & 0 deletions torchdata/nodes/samplers/multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ class MultiNodeWeightedSampler(BaseNode[T]):
weights for sampling. `seed` is used to initialize the random number generator.
The node implements the state using the following keys:
- DATASET_NODE_STATES_KEY: A dictionary of states for each source node.
- DATASETS_EXHAUSTED_KEY: A dictionary of booleans indicating whether each source node is exhausted.
- EPOCH_KEY: An epoch counter used to initialize the random number generator.
- NUM_YIELDED_KEY: The number of items yielded.
- WEIGHTED_SAMPLER_STATE_KEY: The state of the weighted sampler.
We support multiple stopping criteria:
- CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Cycle through the source nodes until all datasets are exhausted. This is the default behavior.
- FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted.
- ALL_DATASETS_EXHAUSTED: Stop when all datasets are exhausted.
Expand Down Expand Up @@ -203,6 +205,7 @@ class _WeightedSampler:
"""A weighted sampler that samples from a list of weights.
The class implements the state using the following keys:
- g_state: The state of the random number generator.
- g_rank_state: The state of the random number generator for the rank.
- offset: The offset of the batch of indices.
Expand Down

0 comments on commit 62092dd

Please sign in to comment.