Skip to content

Commit b3b5432

Browse files
committed
docstring updates, attempt to make shape/len more precise
1 parent 579727a commit b3b5432

File tree

1 file changed

+40
-44
lines changed

1 file changed

+40
-44
lines changed

src/tiledbsoma_ml/pytorch.py

Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,14 @@ def open_experiment(self) -> Generator[soma.Experiment, None, None]:
8484

8585

8686
class ExperimentAxisQueryIterable(Iterable[XObsDatum]):
87-
"""An :class:`Iterator` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as
87+
"""An :class:`Iterable` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as
8888
selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator
89-
produces equal sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and
89+
produces a batch containing equal-sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and
9090
:class:`pandas.DataFrame`, respectively.
9191
9292
Private base class for subclasses of :class:`torch.utils.data.IterableDataset` and
9393
:class:`torchdata.datapipes.iter.IterDataPipe`. Refer to :class:`ExperimentAxisQueryIterableDataset`
94-
and `ExperimentAxisQueryIterDataPipe` for more details on usage.
94+
and :class:`ExperimentAxisQueryIterDataPipe` for more details on usage.
9595
9696
Lifecycle:
9797
experimental
@@ -136,14 +136,10 @@ def __init__(
136136
If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the
137137
default), will return ``X`` data as a :class:`numpy.ndarray`.
138138
use_eager_fetch:
139-
Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made
140-
available for processing via the iterator. This allows network (or filesystem) requests to be made in
141-
parallel with client-side processing of the SOMA data, potentially improving overall performance at the
142-
cost of doubling memory utilization. Defaults to ``True``.
143-
144-
Returns:
145-
An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to
146-
a :class:`torch.utils.data.DataLoader` instance.
139+
Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is
140+
made available for processing via the iterator. This allows network (or filesystem) requests to be made
141+
in parallel with client-side processing of the SOMA data, potentially improving overall performance at
142+
the cost of doubling memory utilization. Defaults to ``True``.
147143
148144
Raises:
149145
``ValueError`` on various unsupported or malformed parameter values.
@@ -300,16 +296,16 @@ def __iter__(self) -> Iterator[XObsDatum]:
300296
yield from _mini_batch_iter
301297

302298
def __len__(self) -> int:
303-
"""Return the approximate number of batches this iterable will produce. If run in the context of :class:`torch.distributed` or
304-
as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell)
305-
count will reflect the size of the data partition assigned to the active process.
299+
"""Return the number of batches this iterable will produce. If run in the context of :class:`torch.distributed`
300+
or as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the
301+
batch count will reflect the size of the data partition assigned to the active process.
306302
307303
See important caveats in the PyTorch
308304
[:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
309305
documentation regarding ``len(dataloader)``, which also apply to this class.
310306
311307
Returns:
312-
An ``int``.
308+
``int`` (Number of batches).
313309
314310
Lifecycle:
315311
experimental
@@ -318,25 +314,31 @@ def __len__(self) -> int:
318314

319315
@property
320316
def shape(self) -> Tuple[int, int]:
321-
"""Get the approximate shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`.
322-
This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode
323-
(i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect
324-
the size of the data partition assigned to the active process.
317+
"""Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`.
318+
319+
If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0),
320+
the number of batches will reflect the size of the data partition assigned to the active process.
325321
326322
Returns:
327-
A tuple of two ``int`` values: number of obs, number of vars.
323+
A tuple of two ``int`` values: number of batches, number of vars.
328324
329325
Lifecycle:
330326
experimental
331327
"""
332328
self._init_once()
333329
assert self._obs_joinids is not None
334330
assert self._var_joinids is not None
335-
world_size, _ = _get_distributed_world_rank()
336-
n_workers, _ = _get_worker_world_rank()
337-
partition_len = len(self._obs_joinids) // world_size // n_workers
338-
div, rem = divmod(partition_len, self.batch_size)
339-
return div + bool(rem), len(self._var_joinids)
331+
world_size, rank = _get_distributed_world_rank()
332+
n_workers, worker_id = _get_worker_world_rank()
333+
obs_per_proc, obs_rem = divmod(len(self._obs_joinids), world_size)
334+
# obs rows assigned to this "distributed" process
335+
n_proc_obs = obs_per_proc + bool(rank < obs_rem)
336+
obs_per_worker, obs_rem = divmod(n_proc_obs, n_workers)
337+
# obs rows assigned to this worker process
338+
n_worker_obs = obs_per_worker + bool(worker_id < obs_rem)
339+
n_batches, rem = divmod(n_worker_obs, self.batch_size)
340+
# (num batches this worker will produce, num features)
341+
return n_batches + bool(rem), len(self._var_joinids)
340342

341343
def __getitem__(self, index: int) -> XObsDatum:
342344
raise NotImplementedError(
@@ -349,11 +351,9 @@ def _io_batch_iter(
349351
X: soma.SparseNDArray,
350352
obs_joinid_iter: Iterator[npt.NDArray[np.int64]],
351353
) -> Iterator[Tuple[sparse.csr_matrix, pd.DataFrame]]:
352-
"""Iterate over IO batches, i.e., SOMA query/read, producing a tuple of
353-
(X: csr_array, obs: DataFrame).
354+
"""Iterate over IO batches, i.e., SOMA query reads, producing tuples of ``(X: csr_array, obs: DataFrame)``.
354355
355-
obs joinids read are controlled by the obs_joinid_iter. Iterator results will
356-
be reindexed.
356+
``obs`` joinids read are controlled by the ``obs_joinid_iter``. Iterator results will be reindexed.
357357
358358
Private method.
359359
"""
@@ -475,7 +475,7 @@ class ExperimentAxisQueryIterDataPipe(
475475
torch.utils.data.dataset.Dataset[XObsDatum]
476476
],
477477
):
478-
"""A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`.
478+
"""A :class:`torchdata.datapipes.iter.IterDataPipe` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`.
479479
480480
This class is based upon the now-deprecated :class:`torchdata.datapipes` API, and should only be used for
481481
legacy code. See [GitHub issue #1196](https://github.com/pytorch/data/issues/1196) and the
@@ -534,7 +534,7 @@ def __len__(self) -> int:
534534
Lifecycle:
535535
deprecated
536536
"""
537-
return self._exp_iter.__len__()
537+
return len(self._exp_iter)
538538

539539
@property
540540
def shape(self) -> Tuple[int, int]:
@@ -640,10 +640,6 @@ def __init__(
640640
parallel with client-side processing of the SOMA data, potentially improving overall performance at the
641641
cost of doubling memory utilization. Defaults to ``True``.
642642
643-
Returns:
644-
An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to
645-
a :class:`torch.data.utils.DataLoader` instance.
646-
647643
Raises:
648644
``ValueError`` on various unsupported or malformed parameter values.
649645
@@ -663,7 +659,8 @@ def __init__(
663659
)
664660

665661
def __iter__(self) -> Iterator[XObsDatum]:
666-
"""Create Iterator yielding tuples of :class:`numpy.ndarray` and :class:`pandas.DataFrame`.
662+
"""Create ``Iterator`` yielding "mini-batch" tuples of :class:`numpy.ndarray` (or :class:`scipy.csr_matrix`) and
663+
:class:`pandas.DataFrame`.
667664
668665
Returns:
669666
``iterator``
@@ -678,30 +675,29 @@ def __iter__(self) -> Iterator[XObsDatum]:
678675
yield X, obs
679676

680677
def __len__(self) -> int:
681-
"""Return approximate number of batches this iterable will produce.
678+
"""Return number of batches this iterable will produce.
682679
683680
See important caveats in the PyTorch
684681
[:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
685682
documentation regarding ``len(dataloader)``, which also apply to this class.
686683
687684
Returns:
688-
An ``int``.
685+
``int`` (number of batches).
689686
690687
Lifecycle:
691688
experimental
692689
"""
693-
return self._exp_iter.__len__()
690+
return len(self._exp_iter)
694691

695692
@property
696693
def shape(self) -> Tuple[int, int]:
697-
"""Get the shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset`.
694+
"""Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`.
698695
699-
This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode
700-
(i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect
701-
the size of the partition of the data assigned to the active process.
696+
If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0),
697+
the number of batches will reflect the size of the data partition assigned to the active process.
702698
703699
Returns:
704-
A tuple of ``int``s, for obs and var counts, respectively.
700+
A tuple of two ``int`` values: number of batches, number of vars.
705701
706702
Lifecycle:
707703
experimental

0 commit comments

Comments
 (0)