Skip to content

Custom dataloader registry support #2932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 148 commits into
base: main
Choose a base branch
from

Conversation

ori-kron-wis
Copy link
Collaborator

No description provided.

@ori-kron-wis ori-kron-wis added this to the scvi-tools 1.2 milestone Aug 7, 2024
@ori-kron-wis ori-kron-wis self-assigned this Aug 7, 2024
@ori-kron-wis ori-kron-wis linked an issue Aug 7, 2024 that may be closed by this pull request
Copy link

codecov bot commented Aug 11, 2024

Codecov Report

Attention: Patch coverage is 79.22948% with 124 lines in your changes missing coverage. Please review.

Project coverage is 80.10%. Comparing base (c97655d) to head (d80ff0f).

Files with missing lines Patch % Lines
src/scvi/model/base/_base_model.py 48.50% 69 Missing ⚠️
src/scvi/dataloaders/_custom_dataloders.py 89.45% 27 Missing ⚠️
src/scvi/model/base/_archesmixin.py 84.09% 7 Missing ⚠️
src/scvi/model/base/_training_mixin.py 78.12% 7 Missing ⚠️
src/scvi/model/base/_rnamixin.py 93.33% 4 Missing ⚠️
src/scvi/model/base/_vaemixin.py 77.77% 4 Missing ⚠️
src/scvi/data/_utils.py 57.14% 3 Missing ⚠️
src/scvi/model/_scanvi.py 88.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2932      +/-   ##
==========================================
- Coverage   80.14%   80.10%   -0.05%     
==========================================
  Files         193      194       +1     
  Lines       17121    17604     +483     
==========================================
+ Hits        13722    14101     +379     
- Misses       3399     3503     +104     
Files with missing lines Coverage Δ
src/scvi/dataloaders/__init__.py 100.00% <100.00%> (ø)
src/scvi/dataloaders/_data_splitting.py 95.47% <ø> (ø)
src/scvi/model/_scvi.py 96.42% <100.00%> (+0.51%) ⬆️
src/scvi/model/base/_save_load.py 83.16% <100.00%> (+1.06%) ⬆️
src/scvi/data/_utils.py 85.00% <57.14%> (-1.13%) ⬇️
src/scvi/model/_scanvi.py 91.83% <88.00%> (-1.19%) ⬇️
src/scvi/model/base/_rnamixin.py 94.17% <93.33%> (-0.36%) ⬇️
src/scvi/model/base/_vaemixin.py 89.13% <77.77%> (+1.17%) ⬆️
src/scvi/model/base/_archesmixin.py 78.54% <84.09%> (+1.65%) ⬆️
src/scvi/model/base/_training_mixin.py 89.32% <78.12%> (-0.48%) ⬇️
... and 2 more

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@marianogabitto
Copy link
Contributor

Hi Ori,
I gave it a try to this branch. First, I needed to install psutil. Second, when I ran the tutorial, I found an error:

Code:

model = scvi.model.SCVI(adata=None, registry=datamodule.registry, datamodule=datamodule)
Traceback (most recent call last):
File "", line 1, in
File "/allen/programs/celltypes/workgroups/rnaseqanalysis/Mariano/Anaconda3/envs/scvi-env-largedata/lib/python3.12/site-packages/scvi/model/_scvi.py", line 184, in init
self.module = self._module_cls(
^^^^^^^^^^^^^^^^^
TypeError: VAE.init() got an unexpected keyword argument 'datamodule'

@marianogabitto
Copy link
Contributor

I fixed it.

In the tutorial file: scvi-tools/docs/user_guide/use_case /custom_dataloaders.md

Replace the line:
model = scvi.model.SCVI(adata=None, registry=datamodule.registry, datamodule=datamodule)

with:
model = scvi.model.SCVI(adata=None, registry=datamodule.registry)

@ori-kron-wis
Copy link
Collaborator Author

ori-kron-wis commented Apr 21, 2025

@marianogabitto thanks. I fixed those things.
Also see the test_custom_dataloder.py file and the corresponding tutorial on this to know the full capabilities we have currently on c.dataloders (scverse/scvi-tutorials#425).

As you also saw, my concern right now is the slowness of the process, in the data loading between batches. Let's focus on this.
Did you find/ have any suggestions on how we could make it faster?

feel free to fork the branch and add your commits.

@marianogabitto
Copy link
Contributor

marianogabitto commented Apr 21, 2025

Hi Ori,
I am adding the point that I brought in discourse: I compared head to head the TileDB and the regular anndata loaders and Tile is 50% slower. I tested the TileDB dataloader ir regular and DDP mode and what is causing the delay is the slow access to data. GPU peaks and process super fast but in between batches there is a long waiting time

Let me tell you how I did this comparison. On the one hand, I save the TileDB experiment as an anndata and I run scvi regularly. On the other hand, I grab the data loader code from scvi (splitting, Anndataloader, AnnDataset) and create an scvi external AnnData loader (just to be sure there was no difference between running the anndata in regular scvi versus passing it as a new anndataloader).

2 questions:
1) If you can wait until Wednesday, I am going to talk to people in my High Performance Computer Cluster and representatives from TileDB and try to debug if there is any way to accelerate the data loader with multiple workers or options. It is expected that the data loader that access disk is slower, I want to see how slow.
2) what is the data loader that is used in regular scvi ? Should I also use the following? wouldn't the train dataloder only have an amount of data equivalent to train_size ?
inference_dataloader = (
inference_datamodule.on_before_batch_transfer(batch, None)
for batch in inference_datamodule.train_dataloader()
)

M

@canergen
Copy link
Member

Just checking is it still faster if you load the AmnData in disk backed mode? This would indeed be surprising while the other overhead could come from loading from disk? Any chance to use a fast SSD storage of the data as usually recommended for from-disk loading. Data could also be scattered across SSDs and there tileDB might have suggestions how to optimize this - not perfect randomness is not really an issue if they first randomize order on disks (not a single experiment on one SSD).

@marianogabitto
Copy link
Contributor

Speeds are:
anndata in memory > anndata backed >~ tileDB .

yes, I am testing two things: I am talking with my support team about a partition with fast disk access and touching base with tiled about how to optimize this.

@ori-kron-wis
Copy link
Collaborator Author

ori-kron-wis commented Apr 21, 2025

@marianogabitto I will try to do some testing on my end.
Actually, the whole part of "on_before_batch_transfer" was something that I inherited from the census implementation long ago, but I never had the time to check it fully (see https://github.com/chanzuckerberg/cellxgene-census/blob/756708e9aa18791b7bae3712e9dd66d2b6ce9d75/api/python/notebooks/experimental/pytorch_scvi.ipynb)
and it makes sense to me, it might be the bottleneck.
We are not in a rush to release it. we can wait for the inputs from support team.

@marianogabitto
Copy link
Contributor

But you need the "on_before_batch" to move the batch to PyTorch, no ? isn't the reason of that callback to get the X, batch, labels data and to move it to tensors? what happens if they are already tensors?

@marianogabitto
Copy link
Contributor

One more... how about making the anndata_manager and the registry tutorial ?

@ori-kron-wis
Copy link
Collaborator Author

We will deal with everything, but let's take it step by step. on_before_batch is needed of course, but lets see how we can improve it.

@ori-kron-wis
Copy link
Collaborator Author

Your suggestion worked and I do see much better performance in train time when using

datamodule.setup()
model = scvi.model.SCVI(
    adata=None,
    registry=datamodule.registry,
    n_layers=n_layers,
    n_latent=n_latent,
    gene_likelihood="nb",
    encode_covariates=False,
)
# creating the dataloader for trainset
training_dataloader = (
    datamodule.on_before_batch_transfer(batch, None)
    for batch in datamodule.train_dataloader()
)
import time
start = time.time()
model.train(
    datamodule=training_dataloader,
    #datamodule=datamodule,
    max_epochs=100,
    batch_size=1024,
    # accelerator="gpu",
    # devices=-1,
    # strategy="ddp_find_unused_parameters_true",
)
end = time.time()
print(f"Elapsed time: {end - start:.2f} seconds")

See my updated tutorial on the other branch

@marianogabitto
Copy link
Contributor

Ori, I got lost ... what branch should I look for the tutorial? What branch should I check for testing your commits?

@ori-kron-wis
Copy link
Collaborator Author

This are the tutorials:
https://app.reviewnb.com/scverse/scvi-tutorials/pull/425/

tests are in the current branch

@marianogabitto
Copy link
Contributor

Ori, this is not working for me. When I invoke in the notebook:
training_dataloader = (
datamodule.on_before_batch_transfer(batch, None) for batch in datamodule.train_dataloader()
)
I get:
switching torch multiprocessing start method from "fork" to "spawn"
and then errors out....

@marianogabitto
Copy link
Contributor

marianogabitto commented Apr 25, 2025

Ori, all the examples that I am listing below are run by removing the code ".on_before_batch_transfer()". The way I posted before.

  1. When num_workers=0, I can train with low speeds, defined as below. When I fix number of workers, like num_workers=4,12 or 24. The trainer takes forever to initialize and then is even slower than below.

  2. Can you monitor your GPU usage with nvitop or nvtop? Let me tell you my head-to-head comparisons.

  • TIleDB from cell census. I believe that this is reading from S3, so it is never actually copy data to disk.
    It takes 120 sec/it to train. I see GPU activity almost zero all the time except at moments when it picks to 100%.

  • TIleDB from anndata created from the query . This is reading from a local disk directory.
    It takes 11 sec/it to train. I see GPU activity almost zero all the time except at moments when it picks to 100%.

  • Regular way of loading anndata into memory. It takes 1.2 sec/it to train. I see GPU activity at 40% all the time.

These led me to believe that we are not loading data into GPU memory fast enough.

  1. I forgot to tell you but the TileDB representative send me this as reference. It is different from the way we run because they launch the processes.
    https://github.com/single-cell-data/TileDB-SOMA-ML/tree/rw/cli/src/tiledbsoma_ml/cli#example-invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
custom_dataloader PR 2932 on-merge: backport to 1.3.x on-merge: backport to 1.3.x
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix custom dataloader registry
3 participants