@@ -84,14 +84,14 @@ def open_experiment(self) -> Generator[soma.Experiment, None, None]:
84
84
85
85
86
86
class ExperimentAxisQueryIterable (Iterable [XObsDatum ]):
87
- """An :class:`Iterator ` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as
87
+ """An :class:`Iterable ` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as
88
88
selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator
89
- produces equal sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and
89
+ produces a batch containing equal- sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and
90
90
:class:`pandas.DataFrame`, respectively.
91
91
92
92
Private base class for subclasses of :class:`torch.utils.data.IterableDataset` and
93
93
:class:`torchdata.datapipes.iter.IterDataPipe`. Refer to :class:`ExperimentAxisQueryIterableDataset`
94
- and `ExperimentAxisQueryIterDataPipe` for more details on usage.
94
+ and :class: `ExperimentAxisQueryIterDataPipe` for more details on usage.
95
95
96
96
Lifecycle:
97
97
experimental
@@ -136,14 +136,10 @@ def __init__(
136
136
If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the
137
137
default), will return ``X`` data as a :class:`numpy.ndarray`.
138
138
use_eager_fetch:
139
- Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made
140
- available for processing via the iterator. This allows network (or filesystem) requests to be made in
141
- parallel with client-side processing of the SOMA data, potentially improving overall performance at the
142
- cost of doubling memory utilization. Defaults to ``True``.
143
-
144
- Returns:
145
- An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to
146
- a :class:`torch.utils.data.DataLoader` instance.
139
+ Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is
140
+ made available for processing via the iterator. This allows network (or filesystem) requests to be made
141
+ in parallel with client-side processing of the SOMA data, potentially improving overall performance at
142
+ the cost of doubling memory utilization. Defaults to ``True``.
147
143
148
144
Raises:
149
145
``ValueError`` on various unsupported or malformed parameter values.
@@ -300,16 +296,16 @@ def __iter__(self) -> Iterator[XObsDatum]:
300
296
yield from _mini_batch_iter
301
297
302
298
def __len__ (self ) -> int :
303
- """Return the approximate number of batches this iterable will produce. If run in the context of :class:`torch.distributed` or
304
- as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell)
305
- count will reflect the size of the data partition assigned to the active process.
299
+ """Return the number of batches this iterable will produce. If run in the context of :class:`torch.distributed`
300
+ or as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the
301
+ batch count will reflect the size of the data partition assigned to the active process.
306
302
307
303
See important caveats in the PyTorch
308
304
[:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
309
305
documentation regarding ``len(dataloader)``, which also apply to this class.
310
306
311
307
Returns:
312
- An ``int``.
308
+ ``int`` (Number of batches) .
313
309
314
310
Lifecycle:
315
311
experimental
@@ -318,25 +314,31 @@ def __len__(self) -> int:
318
314
319
315
@property
320
316
def shape (self ) -> Tuple [int , int ]:
321
- """Get the approximate shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`.
322
- This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode
323
- (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect
324
- the size of the data partition assigned to the active process.
317
+ """Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`.
318
+
319
+ If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0),
320
+ the number of batches will reflect the size of the data partition assigned to the active process.
325
321
326
322
Returns:
327
- A tuple of two ``int`` values: number of obs , number of vars.
323
+ A tuple of two ``int`` values: number of batches , number of vars.
328
324
329
325
Lifecycle:
330
326
experimental
331
327
"""
332
328
self ._init_once ()
333
329
assert self ._obs_joinids is not None
334
330
assert self ._var_joinids is not None
335
- world_size , _ = _get_distributed_world_rank ()
336
- n_workers , _ = _get_worker_world_rank ()
337
- partition_len = len (self ._obs_joinids ) // world_size // n_workers
338
- div , rem = divmod (partition_len , self .batch_size )
339
- return div + bool (rem ), len (self ._var_joinids )
331
+ world_size , rank = _get_distributed_world_rank ()
332
+ n_workers , worker_id = _get_worker_world_rank ()
333
+ obs_per_proc , obs_rem = divmod (len (self ._obs_joinids ), world_size )
334
+ # obs rows assigned to this "distributed" process
335
+ n_proc_obs = obs_per_proc + bool (rank < obs_rem )
336
+ obs_per_worker , obs_rem = divmod (n_proc_obs , n_workers )
337
+ # obs rows assigned to this worker process
338
+ n_worker_obs = obs_per_worker + bool (worker_id < obs_rem )
339
+ n_batches , rem = divmod (n_worker_obs , self .batch_size )
340
+ # (num batches this worker will produce, num features)
341
+ return n_batches + bool (rem ), len (self ._var_joinids )
340
342
341
343
def __getitem__ (self , index : int ) -> XObsDatum :
342
344
raise NotImplementedError (
@@ -349,11 +351,9 @@ def _io_batch_iter(
349
351
X : soma .SparseNDArray ,
350
352
obs_joinid_iter : Iterator [npt .NDArray [np .int64 ]],
351
353
) -> Iterator [Tuple [sparse .csr_matrix , pd .DataFrame ]]:
352
- """Iterate over IO batches, i.e., SOMA query/read, producing a tuple of
353
- (X: csr_array, obs: DataFrame).
354
+ """Iterate over IO batches, i.e., SOMA query reads, producing tuples of ``(X: csr_array, obs: DataFrame)``.
354
355
355
- obs joinids read are controlled by the obs_joinid_iter. Iterator results will
356
- be reindexed.
356
+ ``obs`` joinids read are controlled by the ``obs_joinid_iter``. Iterator results will be reindexed.
357
357
358
358
Private method.
359
359
"""
@@ -475,7 +475,7 @@ class ExperimentAxisQueryIterDataPipe(
475
475
torch .utils .data .dataset .Dataset [XObsDatum ]
476
476
],
477
477
):
478
- """A :class:`torch.utils.data.IterableDataset ` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`.
478
+ """A :class:`torchdata.datapipes.iter.IterDataPipe ` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`.
479
479
480
480
This class is based upon the now-deprecated :class:`torchdata.datapipes` API, and should only be used for
481
481
legacy code. See [GitHub issue #1196](https://github.com/pytorch/data/issues/1196) and the
@@ -534,7 +534,7 @@ def __len__(self) -> int:
534
534
Lifecycle:
535
535
deprecated
536
536
"""
537
- return self ._exp_iter . __len__ ( )
537
+ return len ( self ._exp_iter )
538
538
539
539
@property
540
540
def shape (self ) -> Tuple [int , int ]:
@@ -640,10 +640,6 @@ def __init__(
640
640
parallel with client-side processing of the SOMA data, potentially improving overall performance at the
641
641
cost of doubling memory utilization. Defaults to ``True``.
642
642
643
- Returns:
644
- An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to
645
- a :class:`torch.data.utils.DataLoader` instance.
646
-
647
643
Raises:
648
644
``ValueError`` on various unsupported or malformed parameter values.
649
645
@@ -663,7 +659,8 @@ def __init__(
663
659
)
664
660
665
661
def __iter__ (self ) -> Iterator [XObsDatum ]:
666
- """Create Iterator yielding tuples of :class:`numpy.ndarray` and :class:`pandas.DataFrame`.
662
+ """Create ``Iterator`` yielding "mini-batch" tuples of :class:`numpy.ndarray` (or :class:`scipy.csr_matrix`) and
663
+ :class:`pandas.DataFrame`.
667
664
668
665
Returns:
669
666
``iterator``
@@ -678,30 +675,29 @@ def __iter__(self) -> Iterator[XObsDatum]:
678
675
yield X , obs
679
676
680
677
def __len__ (self ) -> int :
681
- """Return approximate number of batches this iterable will produce.
678
+ """Return number of batches this iterable will produce.
682
679
683
680
See important caveats in the PyTorch
684
681
[:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
685
682
documentation regarding ``len(dataloader)``, which also apply to this class.
686
683
687
684
Returns:
688
- An ``int``.
685
+ ``int`` (number of batches) .
689
686
690
687
Lifecycle:
691
688
experimental
692
689
"""
693
- return self ._exp_iter . __len__ ( )
690
+ return len ( self ._exp_iter )
694
691
695
692
@property
696
693
def shape (self ) -> Tuple [int , int ]:
697
- """Get the shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset `.
694
+ """Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable `.
698
695
699
- This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode
700
- (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect
701
- the size of the partition of the data assigned to the active process.
696
+ If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0),
697
+ the number of batches will reflect the size of the data partition assigned to the active process.
702
698
703
699
Returns:
704
- A tuple of ``int``s, for obs and var counts, respectively .
700
+ A tuple of two ``int`` values: number of batches, number of vars .
705
701
706
702
Lifecycle:
707
703
experimental
0 commit comments