diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml new file mode 100644 index 0000000000..d2e1be2aaa --- /dev/null +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -0,0 +1,72 @@ +name: test (custom dataloaders) + +on: + push: + branches: [main, "[0-9]+.[0-9]+.x"] + pull_request: + branches: [main, "[0-9]+.[0-9]+.x"] + types: [labeled, synchronize, opened] + schedule: + - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + # if PR has label "custom_dataloader" or "all tests" or if scheduled or manually triggered + if: >- + ( + contains(github.event.pull_request.labels.*.name, 'custom_dataloader') || + contains(github.event.pull_request.labels.*.name, 'all tests') || + contains(github.event_name, 'schedule') || + contains(github.event_name, 'workflow_dispatch') + ) + + runs-on: ${{ matrix.os }} + + defaults: + run: + shell: bash -e {0} # -e to fail on error + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.12"] + + name: integration + + env: + OS: ${{ matrix.os }} + PYTHON: ${{ matrix.python }} + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + cache: "pip" + cache-dependency-path: "**/pyproject.toml" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip wheel uv + python -m uv pip install --system "scvi-tools[tests] @ ." + + - name: Run specific custom dataloader pytest + env: + MPLBACKEND: agg + PLATFORM: ${{ matrix.os }} + DISPLAY: :42 + COLUMNS: 120 + run: | + coverage run -m pytest -v --color=yes --custom-dataloader-tests + coverage report + + - uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/CHANGELOG.md b/CHANGELOG.md index b801fba04b..d5cd973806 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ to [Semantic Versioning]. Full commit history is available in the - Add supervised module class {class}`scvi.module.base.SupervisedModuleClass`. {pr}`3237`. - Add get normalized function model property for any generative model {pr}`3238` and changed get_accessibility_estimates to get_normalized_accessibility, where needed. +- Add support for using Lamin custom dataloaders with {class}`scvi.model.SCVI`, {pr}`2932`. - Add Early stopping KL warmup steps. {pr}`3262`. - Add Minification option to {class}`~scvi.model.LinearSCVI` {pr}`3294`. diff --git a/docs/tutorials/index_use_cases.md b/docs/tutorials/index_use_cases.md index 1c9b728987..2f10e9a52e 100644 --- a/docs/tutorials/index_use_cases.md +++ b/docs/tutorials/index_use_cases.md @@ -6,4 +6,6 @@ notebooks/use_cases/autotune_scvi notebooks/use_cases/minification notebooks/use_cases/interpretability +notebooks/use_cases/custom_dl/tiledb +notebooks/use_cases/custom_dl/lamin ``` diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 7c43ad2ee6..d7867ab62b 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 7c43ad2ee650de39a2220470288a72bb0b78a50d +Subproject commit d7867ab62b4917ed60229927a1d4bdb6003f079b diff --git a/docs/user_guide/use_case/custom_dataloaders.md b/docs/user_guide/use_case/custom_dataloaders.md index 086b8d42ec..22bb6cc87c 100644 --- a/docs/user_guide/use_case/custom_dataloaders.md +++ b/docs/user_guide/use_case/custom_dataloaders.md @@ -21,13 +21,17 @@ Pros: - Optimized for ML Workflows: If your dataset is structured as tables (rows and columns), LamindDB’s format aligns well with SCVI's expectations, potentially reducing the need for complex transformations. ```python -os.system("lamin init --storage ./test-registries") import lamindb as ln +from scvi.dataloaders import MappedCollectionDataModule +import scvi +import os + +os.system("lamin init --storage ./test-registries") ln.setup.init(name="lamindb_instance_name", storage=save_path) # a test for mapped collection -collection = ln.Collection.get(name="covid_normal_lung") +collection = ln.Collection.using("laminlabs/cellxgene").get(name="covid_normal_lung") artifacts = collection.artifacts.all() artifacts.df() @@ -35,11 +39,12 @@ datamodule = MappedCollectionDataModule( collection, batch_key="assay", batch_size=1024, join="inner" ) model = scvi.model.SCVI(adata=None, registry=datamodule.registry) +model.train(max_epochs=1, batch_size=1024, datamodule=datamodule) ... ``` LamindDB may not be as efficient or flexible as TileDB for handling complex multi-dimensional data -2. [CZI](https://chanzuckerberg.com/) based [tiledb](https://tiledb.com/) custom dataloader is based on CensusSCVIDataModule and can run a large multi-dimensional datasets that are stored in TileDB’s format. +2. [CZI](https://chanzuckerberg.com/) based [tiledb](https://tiledb.com/) custom dataloader is based on TileDBDataModule and can run a large multi-dimensional datasets that are stored in TileDB’s format. TileDB is a general-purpose, multi-dimensional array storage engine designed for high-performance, scalable data access. It supports various data types, including dense and sparse arrays, and is optimized for handling large datasets efficiently. TileDB’s strength lies in its ability to store and query data across multiple dimensions and scale efficiently with large volumes of data. @@ -52,9 +57,10 @@ Scalability: Handles large datasets that exceed your system's memory capacity, m ```python import cellxgene_census import tiledbsoma as soma -from cellxgene_census.experimental.ml import experiment_dataloader -from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule +import tiledbsoma_ml +from scvi.dataloaders import TileDBDataModule import numpy as np +import scvi # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") @@ -66,25 +72,52 @@ obs_value_filter = ( hv_idx = np.arange(100) # just ot make it smaller and faster for debug -# this is CZI part to be taken once all is ready -batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"] -datamodule = CensusSCVIDataModule( - census["census_data"][experiment_name], +# For HVG, we can use the highly_variable_genes function provided in cellxgene_census, +# which can compute HVGs in constant memory: +hvg_query = census["census_data"][experiment_name].axis_query( measurement_name="RNA", - X_name="raw", obs_query=soma.AxisQuery(value_filter=obs_value_filter), var_query=soma.AxisQuery(coords=(list(hv_idx),)), +) + +# this is CZI part to be taken once all is ready +batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"] +label_keys = ["tissue_general"] +datamodule = TileDBDataModule( + hvg_query, + layer_name="raw", batch_size=1024, shuffle=True, - batch_keys=batch_keys, + seed=42, + batch_column_names=batch_keys, + label_keys=label_keys, + train_size=0.9, + unlabeled_category="label_0", dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, ) +# We can now create the scVI model object and train it: +model = scvi.model.SCVI( + adata=None, + registry=datamodule.registry, + gene_likelihood="nb", + encode_covariates=False, +) + +# creating the dataloader for trainset +datamodule.setup() +training_dataloader = ( + datamodule.on_before_batch_transfer(batch, None) + for batch in datamodule.train_dataloader() +) -# basicaly we should mimiC everything below to any model census in scvi -adata_orig = synthetic_iid() -scvi.model.SCVI.setup_anndata(adata_orig, batch_key="batch") -model = scvi.model.SCVI(adata_orig) +model.train( + datamodule=training_dataloader, + max_epochs=1, + batch_size=1024, + train_size=0.9, + early_stopping=False, +) ... ``` Key Differences between them in terms of Custom Dataloaders: @@ -110,6 +143,8 @@ When to Use Each: Writing custom dataloaders requires a good understanding of PyTorch’s DataLoader class and how to integrate it with SCVI, which may be difficult for beginners. It will also requite maintenance: If the data format or preprocessing needs change, you’ll have to modify and maintain the custom dataloader code, But it can be a greate addition to the model pipeline, in terms of runtime and how much data we can digest. +See relevant tutorials in this subject for further examples. + :::{note} -As for SCVI-Tools v1.3.0 Custom Dataloaders are experimental. +As for SCVI-Tools v1.3.0 Custom Dataloaders are experimental and only supported for adata and SCVI and SCANVI models ::: diff --git a/pyproject.toml b/pyproject.toml index ea9be12629..ecb31d639b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,11 +82,9 @@ docs = [ docsbuild = ["scvi-tools[docs,optional]"] # scvi.autotune -autotune = ["hyperopt>=0.2", "ray[tune]","scib-metrics"] +autotune = ["hyperopt>=0.2", "ray[tune]", "scib-metrics"] # scvi.hub.HubModel.pull_from_s3 aws = ["boto3"] -# scvi.data.cellxgene -census = ["cellxgene-census", "numpy<2.0"] # scvi.hub dependencies hub = ["huggingface_hub", "igraph", "leidenalg", "dvc[s3]"] # scvi.data.add_dna_sequence @@ -96,13 +94,15 @@ scanpy = ["scanpy>=1.10", "scikit-misc"] # for convinient files sharing file_sharing = ["pooch", "cellxgene-census"] # for parallelization engine -parallel = ["dask[array]>=2023.5.1,<2024.8.0"] +parallel = ["dask[array]>=2023.5.1,<2024.8.0", "zarr<3.0.0"] # for supervised models interpretability -interpretability = ["captum","shap"] +interpretability = ["captum", "shap"] +# for custom dataloders +dataloaders = ["lamindb>=1.3.0", "biomart", "bionty", "cellxgene_lamin", "cellxgene-census", "numpy<2.0", "tiledbsoma", "tiledb", "tiledbsoma_ml", "torchdata==0.9.0"] optional = [ - "scvi-tools[autotune,aws,hub,file_sharing,regseq,scanpy,parallel,interpretability]" + "scvi-tools[autotune,aws,hub,file_sharing,regseq,scanpy,parallel,interpretability,dataloaders]" ] tutorials = [ "cell2location", @@ -137,6 +137,7 @@ markers = [ "private: mark tests that uses private keys, like HF", "multigpu: mark tests that are used to check multi GPU performance", "autotune: mark tests that are used to check ray autotune capabilities", + "custom dataloaders: mark tests that are used to check different custom data loaders", ] [tool.ruff] diff --git a/src/scvi/data/_utils.py b/src/scvi/data/_utils.py index 871cc7b15c..1b5c5a5253 100644 --- a/src/scvi/data/_utils.py +++ b/src/scvi/data/_utils.py @@ -21,6 +21,8 @@ from . import _constants if TYPE_CHECKING: + from collections.abc import Iterator + import numpy.typing as npt from pandas.api.types import CategoricalDtype from torch import Tensor @@ -361,3 +363,21 @@ def _check_fragment_counts( ) # True if there are more 2s than 1s ret = not (non_fragments or binary) return ret + + +def _validate_adata_dataloader_input( + model, + adata: AnnOrMuData | None = None, + dataloader: Iterator[dict[str, Tensor | None]] | None = None, +): + """Validate that model uses adata or custom dataloader""" + if adata is not None and dataloader is not None: + raise ValueError("Only one of `adata` or `dataloader` can be provided.") + elif ( + hasattr(model, "registry") + and "setup_method_name" in model.registry.keys() + and model.registry["setup_method_name"] == "setup_datamodule" + and dataloader is None + ): + raise ValueError("`dataloader` must be provided.") + return diff --git a/src/scvi/dataloaders/__init__.py b/src/scvi/dataloaders/__init__.py index 302055c3d5..79d967199b 100644 --- a/src/scvi/dataloaders/__init__.py +++ b/src/scvi/dataloaders/__init__.py @@ -3,6 +3,7 @@ from ._ann_dataloader import AnnDataLoader from ._concat_dataloader import ConcatDataLoader +from ._custom_dataloders import MappedCollectionDataModule, TileDBDataModule from ._data_splitting import ( DataSplitter, DeviceBackedDataSplitter, @@ -20,4 +21,6 @@ "DataSplitter", "SemiSupervisedDataSplitter", "BatchDistributedSampler", + "MappedCollectionDataModule", + "TileDBDataModule", ] diff --git a/src/scvi/dataloaders/_custom_dataloders.py b/src/scvi/dataloaders/_custom_dataloders.py new file mode 100644 index 0000000000..ae3f60c7c0 --- /dev/null +++ b/src/scvi/dataloaders/_custom_dataloders.py @@ -0,0 +1,650 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import numpy as np +import torch +from lightning.pytorch import LightningDataModule +from sklearn.preprocessing import LabelEncoder +from torch.utils.data import DataLoader + +import scvi +from scvi.model._utils import parse_device_args +from scvi.utils import dependencies + +if TYPE_CHECKING: + from typing import Any + + import pandas as pd + + +@dependencies("lamindb") +class MappedCollectionDataModule(LightningDataModule): + import lamindb as ln + + def __init__( + self, + collection: ln.Collection, + batch_key: str | None = None, + label_key: str | None = None, + unlabeled_category: str | None = "Unknown", + batch_size: int = 128, + collection_val: ln.Collection | None = None, + accelerator: str = "auto", + device: int | str = "auto", + shuffle: bool = True, + model_name: str = "SCVI", + **kwargs, + ): + super().__init__() + self._batch_size = batch_size + self._batch_key = batch_key + self._label_key = label_key + self.model_name = model_name + self.shuffle = shuffle + self.unlabeled_category = unlabeled_category + self._parallel = kwargs.pop("parallel", True) + self.labels_ = None + + # here we initialize MappedCollection to use in a pytorch DataLoader + if self._label_key is not None: + self._dataset = collection.mapped( + obs_keys=[self._batch_key, self._label_key], parallel=self._parallel, **kwargs + ) + adata = collection.load(join="inner") + self.labels_ = adata.obs[self._label_key].values.astype(str) + if collection_val is not None: + self._validset = collection_val.mapped( + obs_keys=[self._batch_key, self._label_key], parallel=self._parallel, **kwargs + ) + else: + self._validset = None + else: + self._dataset = collection.mapped( + obs_keys=self._batch_key, parallel=self._parallel, **kwargs + ) + if collection_val is not None: + self._validset = collection_val.mapped( + obs_keys=self._batch_key, parallel=self._parallel, **kwargs + ) + else: + self._validset = None + # need by scvi and lightning.pytorch + self._log_hyperparams = False + self.allow_zero_length_dataloader_with_multiple_devices = False + _, _, self.device = parse_device_args( + accelerator=accelerator, devices=device, return_device="torch" + ) + + def close(self): + self._dataset.close() + self._validset.close() + + def train_dataloader(self) -> DataLoader: + return self._create_dataloader(shuffle=self.shuffle) + + def val_dataloader(self) -> DataLoader: + return self._create_dataloader_val(shuffle=self.shuffle) + + def inference_dataloader(self, shuffle=False, batch_size=4096, indices=None): + """Dataloader for inference with `on_before_batch_transfer` applied.""" + dataloader = self._create_dataloader(shuffle, batch_size, indices) + return self._InferenceDataloader(dataloader, self.on_before_batch_transfer) + + def _create_dataloader(self, shuffle, batch_size=None, indices=None): + if self._parallel: + num_workers = os.cpu_count() - 1 + worker_init_fn = self._dataset.torch_worker_init_fn + else: + num_workers = 0 + worker_init_fn = None + if batch_size is None: + batch_size = self._batch_size + if indices is not None: + dataset = self._dataset[indices] # TODO find a better way + else: + dataset = self._dataset + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + worker_init_fn=worker_init_fn, + ) + + def _create_dataloader_val(self, shuffle, batch_size=None, indices=None): + if self._validset is not None: + if self._parallel: + num_workers = os.cpu_count() - 1 + worker_init_fn = self._validset.torch_worker_init_fn + else: + num_workers = 0 + worker_init_fn = None + if batch_size is None: + batch_size = self._batch_size + if indices is not None: + validset = self._validset[indices] # TODO find a better way + else: + validset = self._validset + return DataLoader( + validset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + worker_init_fn=worker_init_fn, + ) + else: + pass + + @property + def n_obs(self) -> int: + return self._dataset.n_obs + + @property + def var_names(self) -> int: + return self._dataset.var_joint + + @property + def n_vars(self) -> int: + return self._dataset.n_vars + + @property + def n_batch(self) -> int: + if self._batch_key is None: + return 1 + return len(self._dataset.encoders[self._batch_key]) + + @property + def n_labels(self) -> int: + if self._label_key is None: + return 1 + combined = np.concatenate( + ([self.unlabeled_category], list(self._dataset.encoders[self._label_key].keys())) + ) + unique_values = np.unique(combined) + return len(unique_values) + + @property + def labels(self) -> np.ndarray: + if self._label_key is None: + return None + return np.array(list(self._dataset.encoders[self._label_key].keys())).astype(object) + + @property + def unlabeled_category(self) -> str: + """String assigned to unlabeled cells.""" + if not hasattr(self, "_unlabeled_category"): + raise AttributeError("`unlabeled_category` not set.") + return self._unlabeled_category + + @unlabeled_category.setter + def unlabeled_category(self, value: str | None): + if not (value is None or isinstance(value, str)): + raise ValueError("`unlabeled_category` must be a string or None.") + self._unlabeled_category = value + + @property + def registry(self) -> dict: + return { + "scvi_version": scvi.__version__, + "model_name": self.model_name, + "setup_args": { + "layer": None, + "batch_key": self._batch_key, + "labels_key": self._label_key, + "size_factor_key": None, + "categorical_covariate_keys": None, + "continuous_covariate_keys": None, + }, + "field_registries": { + "X": { + "data_registry": {"attr_name": "X", "attr_key": None}, + "state_registry": { + "n_obs": self.n_obs, + "n_vars": self.n_vars, + "column_names": self.var_names, + }, + "summary_stats": {"n_vars": self.n_vars, "n_cells": self.n_obs}, + }, + "batch": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, + "state_registry": { + "categorical_mapping": self.batch_labels, + "original_key": self._batch_key, + }, + "summary_stats": {"n_batch": self.n_batch}, + }, + "labels": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, + "state_registry": { + "categorical_mapping": self.labels, + "original_key": self._label_key, + "unlabeled_category": self.unlabeled_category, + }, + "summary_stats": {"n_labels": self.n_labels}, + }, + "size_factor": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {}, + }, + "extra_categorical_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_categorical_covs": 0}, + }, + "extra_continuous_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_continuous_covs": 0}, + }, + }, + "setup_method_name": "setup_datamodule", + } + + @property + def batch_labels(self) -> int | None: + if self._batch_key is None: + return None + return self._dataset.encoders[self._batch_key] + + @property + def label_keys(self) -> int | None: + if self._label_key is None: + return None + return self._dataset.encoders[self._label_key] + + def on_before_batch_transfer( + self, + batch, + dataloader_idx, + ): + X_KEY: str = "X" + BATCH_KEY: str = "batch" + LABEL_KEY: str = "labels" + + return { + X_KEY: batch["X"].float(), + BATCH_KEY: batch[self._batch_key][:, None] if self._batch_key is not None else None, + LABEL_KEY: batch[self._label_key][:, None] if self._label_key is not None else 0, + } + + class _InferenceDataloader: + """Wrapper to apply `on_before_batch_transfer` during iteration.""" + + def __init__(self, dataloader, transform_fn): + self.dataloader = dataloader + self.transform_fn = transform_fn + + def __iter__(self): + for batch in self.dataloader: + yield self.transform_fn(batch, dataloader_idx=None) + + def __len__(self): + return len(self.dataloader) + + +@dependencies("tiledbsoma") +@dependencies("tiledbsoma_ml") +class TileDBDataModule(LightningDataModule): + import tiledbsoma as soma + + """PyTorch Lightning DataModule for training scVI models from SOMA data + + Wraps a `tiledbsoma_ml.ExperimentDataset` to stream the results of a SOMA + `ExperimentAxisQuery`, exposing a `DataLoader` to generate tensors ready for scVI model + training. Also handles deriving the scVI batch label as a tuple of obs columns. + """ + + def __init__( + self, + query: soma.ExperimentAxisQuery, + *args, + batch_column_names: list[str] | None = None, + batch_labels: list[str] | None = None, + label_keys: list[str] | None = None, + unlabeled_category: str | None = "Unknown", + train_size: float | None = 1.0, + split_seed: int | None = None, + dataloader_kwargs: dict[str, Any] | None = None, + accelerator: str = "auto", + device: int | str = "auto", + model_name: str = "SCVI", + **kwargs, + ): + """ + Args: + + query: tiledbsoma.ExperimentAxisQuery + Defines the desired result set from a SOMA Experiment. + *args, **kwargs: + Additional arguments passed through to `tiledbsoma_ml.ExperimentDataset`. + + batch_column_names: List[str], optional + List of obs column names, the tuple of which defines the scVI batch label + (not to be confused with a batch of training data). + + batch_labels: List[str], optional + List of possible values of the batch label, for mapping to label tensors. By default, + this will be derived from the unique labels in the given query results (given + `batch_column_names`), making the label mapping depend on the query. The `batch_labels` + attribute in the `TileDBDataModule` used for training may be saved and here restored in + another instance for a different query. That ensures the label mapping will be correct + for the trained model, even if the second query doesn't return examples of every + training batch label. + + label_keys + List of obs column names concatenated to form the label column. + unlabeled_category + Value used for unlabeled cells in `labels_key` used to set up CZI datamodule with scvi. + + train_size + Fraction of data to use for training. + split_seed + Seed for data split. + + dataloader_kwargs: dict, optional + %(param_accelerator)s + %(param_device)s + + model_name + The SCVI-Tools Model we are running + + Keyword arguments passed to `tiledbsoma_ml.experiment_dataloader()`, e.g. `num_workers`. + """ + super().__init__() + self.query = query + self.dataset_args = args + self.dataset_kwargs = kwargs + self.dataloader_kwargs = dataloader_kwargs if dataloader_kwargs is not None else {} + self.train_size = train_size + self.split_seed = split_seed + self.model_name = model_name + + # deal with labels if needed + self.unlabeled_category = unlabeled_category + self.label_keys = label_keys + self.labels_colsep = "//" + self.label_colname = "_scvi_labels" + self.labels = None + self.label_encoder = None + self.labels_ = None + + # deal with batches + self.batch_column_names = batch_column_names + self.batch_colsep = "//" + self.batch_colname = "_scvi_batch" + # prepare LabelEncoder for the scVI batch label: + # 1. read obs DataFrame for the whole query result set + # 2. add scvi_batch column + # 3. fit LabelEncoder to the scvi_batch column's unique values + if batch_labels is None: + cols_sel = ( + self.batch_column_names + if self.label_keys is None + else self.batch_column_names + self.label_keys + ) + obs_df = self.query.obs(column_names=cols_sel).concat().to_pandas() + obs_df = obs_df[cols_sel] + self._add_batch_col(obs_df, inplace=True) + batch_labels = obs_df[self.batch_colname].unique() + self.batch_labels = batch_labels + self.batch_encoder = LabelEncoder().fit(self.batch_labels) + + if label_keys is not None: + obs_label_df = self.query.obs(column_names=self.label_keys).concat().to_pandas() + obs_label_df = obs_label_df[self.label_keys] + self._add_label_col(obs_label_df, inplace=True) + labels = obs_label_df[self.label_colname].unique() + self.labels = labels + self.label_encoder = LabelEncoder().fit(self.labels) + self.labels_ = obs_label_df["_scvi_labels"].values + + _, _, self.device = parse_device_args( + accelerator=accelerator, devices=device, return_device="torch" + ) + + def setup(self, stage: str | None = None) -> None: + # Instantiate the ExperimentDataset with the provided args and kwargs. + from tiledbsoma_ml import ExperimentDataset + + cols_sel = ( + self.batch_column_names + if self.label_keys is None + else self.batch_column_names + self.label_keys + ) + + self.train_dataset = ExperimentDataset( + self.query, + *self.dataset_args, + obs_column_names=cols_sel, + **self.dataset_kwargs, + ) + + if self.validation_size > 0.0: + datapipes = self.train_dataset.random_split( + self.train_size, self.validation_size, seed=self.split_seed + ) + self.train_dataset = datapipes[0] + self.val_dataset = datapipes[1] + else: + self.val_dataset = None + + def train_dataloader(self) -> DataLoader: + from tiledbsoma_ml import experiment_dataloader + + return experiment_dataloader( + self.train_dataset, + **self.dataloader_kwargs, + ) + + def val_dataloader(self) -> DataLoader: + from tiledbsoma_ml import experiment_dataloader + + if self.val_dataset is not None: + return experiment_dataloader( + self.val_dataset, + **self.dataloader_kwargs, + ) + else: + pass + + def _add_batch_col(self, obs_df: pd.DataFrame, inplace: bool = False): + # synthesize a new column for obs_df by concatenating the self.batch_column_names columns + if not inplace: + obs_df = obs_df.copy() + obs_df[self.batch_colname] = ( + obs_df[self.batch_column_names].astype(str).agg(self.batch_colsep.join, axis=1) + ) + if self.labels is not None: + obs_df[self.label_colname] = ( + obs_df[self.label_keys].astype(str).agg(self.labels_colsep.join, axis=1) + ) + return obs_df + + def _add_label_col(self, obs_label_df: pd.DataFrame, inplace: bool = False): + # synthesize a new column for obs_label_df by concatenating + # the self.batch_column_names columns + if not inplace: + obs_label_df = obs_label_df.copy() + obs_label_df[self.label_colname] = ( + obs_label_df[self.label_keys].astype(str).agg(self.labels_colsep.join, axis=1) + ) + return obs_label_df + + def on_before_batch_transfer( + self, + batch, + dataloader_idx: int, + ) -> dict[str, torch.Tensor | None]: + # DataModule hook: transform the ExperimentDataset data batch + # (X: ndarray, obs_df: DataFrame) + # into X & batch variable tensors for scVI (using batch_encoder on scvi_batch) + batch_X, batch_obs = batch + self._add_batch_col(batch_obs, inplace=True) + return { + "X": torch.from_numpy(batch_X).float(), + "batch": torch.from_numpy( + self.batch_encoder.transform(batch_obs[self.batch_colname]) + ).unsqueeze(1) + if self.batch_column_names is not None + else None, + "labels": torch.from_numpy( + self.label_encoder.transform(batch_obs[self.label_colname]) + ).unsqueeze(1) + if self.label_keys is not None + else torch.empty(0), + } + + # scVI code expects these properties on the DataModule: + + @property + def unlabeled_category(self) -> str: + """String assigned to unlabeled cells.""" + if not hasattr(self, "_unlabeled_category"): + raise AttributeError("`unlabeled_category` not set.") + return self._unlabeled_category + + @unlabeled_category.setter + def unlabeled_category(self, value: str | None): + if not (value is None or isinstance(value, str)): + raise ValueError("`unlabeled_category` must be a string or None.") + self._unlabeled_category = value + + @property + def split_seed(self) -> int: + """Seed for data split.""" + if not hasattr(self, "_split_seed"): + raise AttributeError("`split_seed` not set.") + return self._split_seed + + @split_seed.setter + def split_seed(self, value: int | None): + if value is not None and not isinstance(value, int): + raise ValueError("`split_seed` must be an integer.") + self._split_seed = value or 0 + + @property + def train_size(self) -> float: + """Fraction of data to use for training.""" + if not hasattr(self, "_train_size"): + raise AttributeError("`train_size` not set.") + return self._train_size + + @train_size.setter + def train_size(self, value: float | None): + if value is not None and not isinstance(value, float): + raise ValueError("`train_size` must be a float.") + elif value is not None and (value < 0.0 or value > 1.0): + raise ValueError("`train_size` must be between 0.0 and 1.0.") + self._train_size = value or 1.0 + + @property + def validation_size(self) -> float: + """Fraction of data to use for validation.""" + if not hasattr(self, "_train_size"): + raise AttributeError("`validation_size` not available.") + return 1.0 - self.train_size + + @property + def n_obs(self) -> int: + return len(self.query.obs_joinids()) + + @property + def n_vars(self) -> int: + return len(self.query.var_joinids()) + + @property + def n_batch(self) -> int: + return len(self.batch_encoder.classes_) + + @property + def n_labels(self) -> int: + if self.label_keys is not None: + combined = np.concatenate(([self.unlabeled_category], self.label_encoder.classes_)) + unique_values = np.unique(combined) + return len(unique_values) + else: + return 1 + + @property + def registry(self) -> dict: + batch_mapping = self.batch_labels + labels_mapping = self.labels + features_names = list( + self.query.var_joinids().tolist() if self.query is not None else range(self.n_vars) + ) + return { + "scvi_version": scvi.__version__, + "model_name": self.model_name, + "setup_args": { + "layer": None, + "batch_key": self.batch_colname, + "labels_key": self.label_keys[0] if self.label_keys is not None else "label", + "size_factor_key": None, + "categorical_covariate_keys": None, + "continuous_covariate_keys": None, + }, + "field_registries": { + "X": { + "data_registry": {"attr_name": "X", "attr_key": None}, + "state_registry": { + "n_obs": self.n_obs, + "n_vars": self.n_vars, + "column_names": [str(i) for i in features_names], + }, + "summary_stats": {"n_vars": self.n_vars, "n_cells": self.n_obs}, + }, + "batch": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, + "state_registry": { + "categorical_mapping": batch_mapping, + "original_key": "batch", + }, + "summary_stats": {"n_batch": self.n_batch}, + }, + "labels": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, + "state_registry": { + "categorical_mapping": labels_mapping, + "original_key": self.label_keys[0] + if self.label_keys is not None + else "label", + "unlabeled_category": self.unlabeled_category, + }, + "summary_stats": {"n_labels": self.n_labels}, + }, + "size_factor": {"data_registry": {}, "state_registry": {}, "summary_stats": {}}, + "extra_categorical_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_categorical_covs": 0}, + }, + "extra_continuous_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_continuous_covs": 0}, + }, + }, + "setup_method_name": "setup_datamodule", + } + + def inference_dataloader(self): + """Dataloader for inference with `on_before_batch_transfer` applied.""" + dataloader = self.train_dataloader() + return self._InferenceDataloader(dataloader, self.on_before_batch_transfer) + + class _InferenceDataloader: + """Wrapper to apply `on_before_batch_transfer` during iteration.""" + + def __init__(self, dataloader, transform_fn): + self.dataloader = dataloader + self.transform_fn = transform_fn + + def __iter__(self): + for batch in self.dataloader: + yield self.transform_fn(batch, dataloader_idx=None) + + def __len__(self): + return len(self.dataloader) diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 5504fad503..72c6bf8ca4 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -386,7 +386,8 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`, def __init__( self, - adata_manager: AnnDataManager, + adata_manager: AnnDataManager | None = None, + datamodule: pl.LightningDataModule | None = None, train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 7f0907c78a..7036ea9538 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -37,9 +37,14 @@ from typing import Literal from anndata import AnnData + from lightning import LightningDataModule from ._scvi import SCVI +_SCANVI_LATENT_QZM = "_scanvi_latent_qzm" +_SCANVI_LATENT_QZV = "_scanvi_latent_qzv" +_SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" + logger = logging.getLogger(__name__) @@ -109,7 +114,8 @@ class SCANVI( def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, @@ -118,23 +124,30 @@ def __init__( gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", use_observed_lib_size: bool = True, linear_classifier: bool = False, + datamodule: LightningDataModule | None = None, **model_kwargs, ): - super().__init__(adata) + super().__init__(adata, registry) scanvae_model_kwargs = dict(model_kwargs) - self._set_indices_and_labels() + self._set_indices_and_labels(datamodule) - # ignores unlabeled catgegory + # ignores unlabeled category n_labels = self.summary_stats.n_labels - 1 - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) + if adata is not None: + n_cats_per_cov = ( + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) + else: + # custom datamodule + n_cats_per_cov = self.summary_stats[f"n_{REGISTRY_KEYS.CAT_COVS_KEY}"] + if n_cats_per_cov == 0: + n_cats_per_cov = None n_batch = self.summary_stats.n_batch - use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry + use_size_factor_key = self.registry_["setup_args"][f"{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key"] library_log_means, library_log_vars = None, None if ( not use_size_factor_key @@ -184,6 +197,7 @@ def from_scvi_model( unlabeled_category: str, labels_key: str | None = None, adata: AnnData | None = None, + registry: dict | None = None, **scanvi_kwargs, ): """Initialize scanVI model with weights from pretrained :class:`~scvi.model.SCVI` model. @@ -200,6 +214,8 @@ def from_scvi_model( Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi. adata AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. + registry + Registry of the datamodule used to train scANVI model. scanvi_kwargs kwargs for scANVI model """ @@ -229,13 +245,15 @@ def from_scvi_model( if adata is None: adata = scvi_model.adata - else: + elif adata: if _is_minified(adata): raise ValueError("Please provide a non-minified `adata` to initialize scANVI.") # validate new anndata against old model scvi_model._validate_anndata(adata) + else: + adata = None - scvi_setup_args = deepcopy(scvi_model.adata_manager.registry[_SETUP_ARGS_KEY]) + scvi_setup_args = deepcopy(scvi_model.registry[_SETUP_ARGS_KEY]) scvi_labels_key = scvi_setup_args["labels_key"] if labels_key is None and scvi_labels_key is None: raise ValueError( @@ -243,13 +261,15 @@ def from_scvi_model( ) if scvi_labels_key is None: scvi_setup_args.update({"labels_key": labels_key}) - cls.setup_anndata( - adata, - unlabeled_category=unlabeled_category, - use_minified=False, - **scvi_setup_args, - ) - scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs) + if adata is not None: + cls.setup_anndata( + adata, + unlabeled_category=unlabeled_category, + use_minified=False, + **scvi_setup_args, + ) + + scanvi_model = cls(adata, scvi_model.registry, **non_kwargs, **kwargs, **scanvi_kwargs) scvi_state_dict = scvi_model.module.state_dict() scanvi_model.module.load_state_dict(scvi_state_dict, strict=False) scanvi_model.was_pretrained = True @@ -296,9 +316,12 @@ def setup_anndata( NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] # register new fields if the adata is minified - adata_minify_type = _get_adata_minify_type(adata) - if adata_minify_type is not None and use_minified: - anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) - adata_manager.register_fields(adata, **kwargs) - cls.register_manager(adata_manager) + if adata: + adata_minify_type = _get_adata_minify_type(adata) + if adata_minify_type is not None and use_minified: + anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 3deebf64fd..77533c7e31 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -27,6 +27,11 @@ from anndata import AnnData + +_SCVI_LATENT_QZM = "_scvi_latent_qzm" +_SCVI_LATENT_QZV = "_scvi_latent_qzv" +_SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" + logger = logging.getLogger(__name__) @@ -110,6 +115,7 @@ class SCVI( def __init__( self, adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, @@ -120,7 +126,7 @@ def __init__( latent_distribution: Literal["normal", "ln"] = "normal", **kwargs, ): - super().__init__(adata) + super().__init__(adata, registry) self._module_kwargs = { "n_hidden": n_hidden, @@ -148,13 +154,24 @@ def __init__( stacklevel=settings.warnings_stacklevel, ) else: - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) + if adata is not None: + n_cats_per_cov = ( + self.adata_manager.get_state_registry( + REGISTRY_KEYS.CAT_COVS_KEY + ).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) + else: + # custom datamodule + n_cats_per_cov = self.summary_stats[f"n_{REGISTRY_KEYS.CAT_COVS_KEY}"] + if n_cats_per_cov == 0: + n_cats_per_cov = None + n_batch = self.summary_stats.n_batch - use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry + use_size_factor_key = self.registry_["setup_args"][ + f"{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key" + ] library_log_means, library_log_vars = None, None if ( not use_size_factor_key diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 84293969ff..9a0e0ceb44 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -1,6 +1,10 @@ +from __future__ import annotations + +import inspect import logging import warnings from copy import deepcopy +from typing import TYPE_CHECKING import anndata import numpy as np @@ -13,12 +17,10 @@ from torch.distributions import transform_to from scvi import settings -from scvi._types import AnnOrMuData from scvi.data import _constants from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME from scvi.model._utils import parse_device_args from scvi.model.base._save_load import ( - _get_var_names, _initialize_model, _load_saved_files, _validate_var_names, @@ -26,7 +28,10 @@ from scvi.nn import FCLayers from scvi.utils._docstrings import devices_dsp -from ._base_model import BaseModelClass +if TYPE_CHECKING: + from scvi._types import AnnOrMuData + + from ._base_model import BaseModelClass logger = logging.getLogger(__name__) @@ -40,8 +45,9 @@ class ArchesMixin: @devices_dsp.dedent def load_query_data( cls, - adata: AnnOrMuData, - reference_model: str | BaseModelClass, + adata: AnnOrMuData = None, + reference_model: str | BaseModelClass = None, + registry: dict = None, inplace_subset_query_vars: bool = False, accelerator: str = "auto", device: int | str = "auto", @@ -86,6 +92,8 @@ def load_query_data( """ if reference_model is None: raise ValueError("Please provide a reference model as string or loaded model.") + if adata is None and registry is None: + raise ValueError("Please provide either an AnnData or a registry dictionary.") _, _, device = parse_device_args( accelerator=accelerator, @@ -95,50 +103,52 @@ def load_query_data( ) attr_dict, var_names, load_state_dict, pyro_param_store = _get_loaded_data( - reference_model, device=device + reference_model, device=device, adata=adata ) - if isinstance(adata, MuData): - for modality in adata.mod: + if adata: + if isinstance(adata, MuData): + for modality in adata.mod: + if inplace_subset_query_vars: + logger.debug(f"Subsetting {modality} query vars to reference vars.") + adata[modality]._inplace_subset_var(var_names[modality]) + _validate_var_names(adata[modality], var_names[modality]) + + else: if inplace_subset_query_vars: - logger.debug(f"Subsetting {modality} query vars to reference vars.") - adata[modality]._inplace_subset_var(var_names[modality]) - _validate_var_names(adata[modality], var_names[modality]) + logger.debug("Subsetting query vars to reference vars.") + adata._inplace_subset_var(var_names) + _validate_var_names(adata, var_names) - else: if inplace_subset_query_vars: logger.debug("Subsetting query vars to reference vars.") adata._inplace_subset_var(var_names) _validate_var_names(adata, var_names) - if inplace_subset_query_vars: - logger.debug("Subsetting query vars to reference vars.") - adata._inplace_subset_var(var_names) - _validate_var_names(adata, var_names) + registry = attr_dict.pop("registry_") + if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: + raise ValueError("It appears you are loading a model from a different class.") - registry = attr_dict.pop("registry_") - if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: - raise ValueError("It appears you are loading a model from a different class.") + if _SETUP_ARGS_KEY not in registry: + raise ValueError( + "Saved model does not contain original setup inputs. " + "Cannot load the original setup." + ) - if _SETUP_ARGS_KEY not in registry: - raise ValueError( - "Saved model does not contain original setup inputs. " - "Cannot load the original setup." - ) + if registry[_SETUP_METHOD_NAME] != "setup_datamodule": + setup_method = getattr(cls, registry[_SETUP_METHOD_NAME]) + setup_method( + adata, + source_registry=registry, + extend_categories=True, + allow_missing_labels=True, + **registry[_SETUP_ARGS_KEY], + ) - setup_method = getattr(cls, registry[_SETUP_METHOD_NAME]) - setup_method( - adata, - source_registry=registry, - extend_categories=True, - allow_missing_labels=True, - **registry[_SETUP_ARGS_KEY], - ) + model = _initialize_model(cls, adata, registry, attr_dict) - model = _initialize_model(cls, adata, attr_dict) - adata_manager = model.get_anndata_manager(adata, required=True) + version_split = model.registry[_constants._SCVI_VERSION_KEY].split(".") - version_split = adata_manager.registry[_constants._SCVI_VERSION_KEY].split(".") if int(version_split[1]) < 8 and int(version_split[0]) == 0: warnings.warn( "Query integration should be performed using models trained with version >= 0.8", @@ -146,6 +156,19 @@ def load_query_data( stacklevel=settings.warnings_stacklevel, ) + method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") + if method_name == "setup_datamodule": + attr_dict["n_input"] = attr_dict["n_vars"] + module_exp_params = inspect.signature(model._module_cls).parameters.keys() + common_keys1 = list(attr_dict.keys() & module_exp_params) + common_keys2 = model.init_params_["non_kwargs"].keys() & module_exp_params + common_items1 = {key: attr_dict[key] for key in common_keys1} + common_items2 = {key: model.init_params_["non_kwargs"][key] for key in common_keys2} + module = model._module_cls(**common_items1, **common_items2) + model.module = module + + model.module.load_state_dict(load_state_dict) + model.to_device(device) # model tweaking @@ -155,6 +178,12 @@ def load_query_data( load_ten = load_ten.to(new_ten.device) if new_ten.size() == load_ten.size(): continue + # new categoricals changed size + else: + dim_diff = new_ten.size()[-1] - load_ten.size()[-1] + fixed_ten = torch.cat([load_ten, new_ten[..., -dim_diff:]], dim=-1) + load_state_dict[key] = fixed_ten + # TODO VERIFY THIS! fixed_ten = load_ten.clone() for dim in range(len(new_ten.shape)): if new_ten.size(dim) != load_ten.size(dim): @@ -408,7 +437,7 @@ def requires_grad(key): par.requires_grad = False -def _get_loaded_data(reference_model, device=None): +def _get_loaded_data(reference_model, device=None, adata=None): if isinstance(reference_model, str): attr_dict, var_names, load_state_dict, _ = _load_saved_files( reference_model, load_adata=False, map_location=device @@ -417,7 +446,7 @@ def _get_loaded_data(reference_model, device=None): else: attr_dict = reference_model._get_user_attributes() attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"} - var_names = _get_var_names(reference_model.adata) + var_names = reference_model.get_var_names() load_state_dict = deepcopy(reference_model.module.state_dict()) pyro_param_store = pyro.get_param_store().get_state() diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 13bd986493..af7bc4cbb5 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -3,8 +3,10 @@ import inspect import logging import os +import sys import warnings from abc import ABCMeta, abstractmethod +from io import StringIO from typing import TYPE_CHECKING from uuid import uuid4 @@ -14,19 +16,28 @@ import torch from anndata import AnnData from mudata import MuData +from rich import box +from rich.console import Console from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager, fields from scvi.data._compat import registry_from_setup_dict from scvi.data._constants import ( _ADATA_MINIFY_TYPE_UNS_KEY, + _FIELD_REGISTRIES_KEY, _MODEL_NAME_KEY, _SCVI_UUID_KEY, + _SCVI_VERSION_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME, + _STATE_REGISTRY_KEY, ADATA_MINIFY_TYPE, ) -from scvi.data._utils import _assign_adata_uuid, _check_if_view, _get_adata_minify_type +from scvi.data._utils import ( + _assign_adata_uuid, + _check_if_view, + _get_adata_minify_type, +) from scvi.dataloaders import AnnDataLoader from scvi.model._utils import parse_device_args from scvi.model.base._constants import SAVE_KEYS @@ -40,9 +51,14 @@ from scvi.utils import attrdict, setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +from . import _constants + if TYPE_CHECKING: from collections.abc import Sequence + import pandas as pd + from lightning import LightningDataModule + from scvi._types import AnnOrMuData, MinifiedDataType logger = logging.getLogger(__name__) @@ -94,7 +110,7 @@ class BaseModelClass(metaclass=BaseModelMetaClass): _OBSERVED_LIB_SIZE_KEY = "observed_lib_size" _data_loader_cls = AnnDataLoader - def __init__(self, adata: AnnOrMuData | None = None): + def __init__(self, adata: AnnOrMuData | None = None, registry: object | None = None): # check if the given adata is minified and check if the model being created # supports minified-data mode (i.e. inherits from the abstract BaseMinifiedModeModelClass). # If not, raise an error to inform the user of the lack of minified-data functionality @@ -110,10 +126,21 @@ def __init__(self, adata: AnnOrMuData | None = None): self._adata_manager = self._get_most_recent_anndata_manager(adata, required=True) self._register_manager_for_instance(self.adata_manager) # Suffix registry instance variable with _ to include it when saving the model. - self.registry_ = self._adata_manager.registry - self.summary_stats = self._adata_manager.summary_stats + self.registry_ = self._adata_manager._registry + self.summary_stats = AnnDataManager._get_summary_stats_from_registry(self.registry_) + elif registry is not None: + self._adata = None + self._adata_manager = None + # Suffix registry instance variable with _ to include it when saving the model. + self.registry_ = registry + self.summary_stats = AnnDataManager._get_summary_stats_from_registry(registry) + elif (self.__class__.__name__ == "GIMVI") or (self.__class__.__name__ == "SCVI"): + # note some models do accept empty registry/adata (e.g: gimvi) + pass + else: + raise ValueError("adata or registry must be provided.") - self._module_init_on_train = adata is None + self._module_init_on_train = adata is None and registry is None self.is_trained_ = False self._model_summary_string = "" self.train_indices_ = None @@ -123,10 +150,24 @@ def __init__(self, adata: AnnOrMuData | None = None): self.get_normalized_function_name_ = "get_normalized_expression" @property - def adata(self) -> AnnOrMuData: + def adata(self) -> None | AnnOrMuData: """Data attached to model instance.""" return self._adata + @property + def registry(self) -> dict: + """Data attached to model instance.""" + return self.registry_ + + def get_var_names(self, legacy_mudata_format=False) -> dict: + """Variable names of input data.""" + from scvi.model.base._save_load import _get_var_names + + if self.adata: + return _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) + else: + return self.registry[_FIELD_REGISTRIES_KEY]["X"][_STATE_REGISTRY_KEY]["column_names"] + @adata.setter def adata(self, adata: AnnOrMuData): if adata is None: @@ -248,6 +289,23 @@ def _register_manager_for_instance(self, adata_manager: AnnDataManager): instance_manager_store = self._per_instance_manager_store[self.id] instance_manager_store[adata_id] = adata_manager + def data_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: + """Returns the object in AnnData associated with the key in the data registry. + + Parameters + ---------- + registry_key + key of object to get from ``self.data_registry`` + + Returns + ------- + The requested data. + """ + if not self.adata: + raise ValueError("self.adata is None. Please register AnnData object to access data.") + else: + return self._adata_manager.get_from_registry(registry_key) + def deregister_manager(self, adata: AnnData | None = None): """Deregisters the :class:`~scvi.data.AnnDataManager` instance associated with `adata`. @@ -340,6 +398,9 @@ def get_anndata_manager( If True, errors on missing manager. Otherwise, returns None when manager is missing. """ cls = self.__class__ + if not adata: + return None + if _SCVI_UUID_KEY not in adata.uns: if required: raise ValueError( @@ -479,6 +540,13 @@ def _validate_anndata( return adata + def transfer_fields(self, adata: AnnOrMuData, **kwargs) -> AnnData: + """Transfer fields from a model to an AnnData object.""" + if self.adata: + return self.adata_manager.transfer_fields(adata, **kwargs) + else: + raise ValueError("Model need to be initialized with AnnData to transfer fields.") + def _check_if_trained(self, warn: bool = True, message: str = _UNTRAINED_WARNING_MESSAGE): """Check if the model is trained. @@ -541,7 +609,7 @@ def _get_user_attributes(self): def _get_init_params(self, locals): """Returns the model init signature with associated passed in values. - Ignores the initial AnnData. + Ignores the initial AnnData or Registry. """ init = self.__init__ sig = inspect.signature(init) @@ -552,7 +620,9 @@ def _get_init_params(self, locals): all_params = { k: v for (k, v) in all_params.items() - if not isinstance(v, AnnData) and not isinstance(v, MuData) + if not isinstance(v, AnnData) + and not isinstance(v, MuData) + and k not in ("adata", "registry") } # not very efficient but is explicit # separates variable params (**kwargs) from non variable params into two dicts @@ -577,6 +647,7 @@ def save( save_anndata: bool = False, save_kwargs: dict | None = None, legacy_mudata_format: bool = False, + datamodule: LightningDataModule | None = None, **anndata_write_kwargs, ): """Save the state of the model. @@ -604,11 +675,13 @@ def save( variable names across all modalities concatenated, while the new format is a dictionary with keys corresponding to the modality names and values corresponding to the variable names for each modality. + datamodule + ``EXPERIMENTAL`` A :class:`~lightning.pytorch.core.LightningDataModule` instance to use + for training in place of the default :class:`~scvi.dataloaders.DataSplitter`. Can only + be passed in if the model was not initialized with :class:`~anndata.AnnData`. anndata_write_kwargs Kwargs for :meth:`~anndata.AnnData.write` """ - from scvi.model.base._save_load import _get_var_names - if not os.path.exists(dir_path) or overwrite: os.makedirs(dir_path, exist_ok=overwrite) else: @@ -636,13 +709,30 @@ def save( model_state_dict = self.module.state_dict() model_state_dict["pyro_param_store"] = pyro.get_param_store().get_state() - var_names = _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) + var_names = self.get_var_names(legacy_mudata_format=legacy_mudata_format) # get all the user attributes user_attributes = self._get_user_attributes() # only save the public attributes with _ at the very end user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} + method_name = self.registry.get(_SETUP_METHOD_NAME, "setup_anndata") + if method_name == "setup_datamodule": + user_attributes.update( + { + "n_batch": datamodule.n_batch, + "n_extra_categorical_covs": datamodule.registry["field_registries"][ + "extra_categorical_covs" + ]["summary_stats"]["n_extra_categorical_covs"], + "n_extra_continuous_covs": datamodule.registry["field_registries"][ + "extra_continuous_covs" + ]["summary_stats"]["n_extra_continuous_covs"], + "n_labels": datamodule.n_labels, + "n_vars": datamodule.n_vars, + "batch_labels": datamodule.batch_labels, + } + ) + torch.save( { SAVE_KEYS.MODEL_STATE_DICT_KEY: model_state_dict, @@ -675,6 +765,7 @@ def load( It is not necessary to run setup_anndata, as AnnData is validated against the saved `scvi` setup dictionary. If None, will check for and load anndata saved with the model. + If False, will load the model without AnnData. %(param_accelerator)s %(param_device)s prefix @@ -713,32 +804,48 @@ def load( ) adata = new_adata if new_adata is not None else adata - _validate_var_names(adata, var_names) - registry = attr_dict.pop("registry_") if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: raise ValueError("It appears you are loading a model from a different class.") - if _SETUP_ARGS_KEY not in registry: - raise ValueError( - "Saved model does not contain original setup inputs. " - "Cannot load the original setup." - ) - # Calling ``setup_anndata`` method with the original arguments passed into # the saved model. This enables simple backwards compatibility in the case of # newly introduced fields or parameters. - method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") - getattr(cls, method_name)(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]) + if adata: + if _SETUP_ARGS_KEY not in registry: + raise ValueError( + "Saved model does not contain original setup inputs. " + "Cannot load the original setup." + ) + _validate_var_names(adata, var_names) + method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") + if method_name != "setup_datamodule": + getattr(cls, method_name)( + adata, source_registry=registry, **registry[_SETUP_ARGS_KEY] + ) - model = _initialize_model(cls, adata, attr_dict) + model = _initialize_model(cls, adata, registry, attr_dict) pyro_param_store = model_state_dict.pop("pyro_param_store", None) + + method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") + if method_name == "setup_datamodule": + attr_dict["n_input"] = attr_dict["n_vars"] + module_exp_params = inspect.signature(model._module_cls).parameters.keys() + common_keys1 = list(attr_dict.keys() & module_exp_params) + common_keys2 = model.init_params_["non_kwargs"].keys() & module_exp_params + common_items1 = {key: attr_dict[key] for key in common_keys1} + common_items2 = {key: model.init_params_["non_kwargs"][key] for key in common_keys2} + module = model._module_cls(**common_items1, **common_items2) + model.module = module + model.module.on_load(model, pyro_param_store=pyro_param_store) model.module.load_state_dict(model_state_dict) model.to_device(device) + model.module.eval() - model._validate_anndata(adata) + if adata: + model._validate_anndata(adata) return model @classmethod @@ -903,6 +1010,149 @@ def view_anndata_setup( ) from err adata_manager.view_registry(hide_state_registries=hide_state_registries) + def view_setup_method_args(self) -> None: + """Prints setup kwargs used to produce a given registry. + + Parameters + ---------- + registry + Registry produced by an AnnDataManager. + """ + model_name = self.registry_[_MODEL_NAME_KEY] + setup_args = self.registry_[_SETUP_ARGS_KEY] + if model_name is not None and setup_args is not None: + rich.print(f"Setup via `{model_name}.setup_anndata` with arguments:") + rich.pretty.pprint(setup_args) + rich.print() + + def view_registry(self, hide_state_registries: bool = False) -> None: + """Prints summary of the registry. + + Parameters + ---------- + hide_state_registries + If True, prints a shortened summary without details of each state registry. + """ + version = self.registry_[_SCVI_VERSION_KEY] + rich.print(f"Anndata setup with scvi-tools version {version}.") + rich.print() + self.view_setup_method_args(self._registry) + + in_colab = "google.colab" in sys.modules + force_jupyter = None if not in_colab else True + console = rich.console.Console(force_jupyter=force_jupyter) + + ss = AnnDataManager._get_summary_stats_from_registry(self._registry) + dr = self._get_data_registry_from_registry(self._registry) + console.print(self._view_summary_stats(ss)) + console.print(self._view_data_registry(dr)) + + if not hide_state_registries: + for field in self.fields: + state_registry = self.get_state_registry(field.registry_key) + t = field.view_state_registry(state_registry) + if t is not None: + console.print(t) + + def get_state_registry(self, registry_key: str) -> attrdict: + """Returns the state registry for the AnnDataField registered with this instance.""" + return attrdict(self.registry_[_FIELD_REGISTRIES_KEY][registry_key][_STATE_REGISTRY_KEY]) + + def get_setup_arg(self, setup_arg: str) -> attrdict: + """Returns the string provided to setup of a specific setup_arg.""" + return self.registry_[_SETUP_ARGS_KEY][setup_arg] + + @staticmethod + def _view_summary_stats( + summary_stats: attrdict, as_markdown: bool = False + ) -> rich.table.Table | str: + """Prints summary stats.""" + if not as_markdown: + t = rich.table.Table(title="Summary Statistics") + else: + t = rich.table.Table(box=box.MARKDOWN) + + t.add_column( + "Summary Stat Key", + justify="center", + style="dodger_blue1", + no_wrap=True, + overflow="fold", + ) + t.add_column( + "Value", + justify="center", + style="dark_violet", + no_wrap=True, + overflow="fold", + ) + for stat_key, count in summary_stats.items(): + t.add_row(stat_key, str(count)) + + if as_markdown: + console = Console(file=StringIO(), force_jupyter=False) + console.print(t) + return console.file.getvalue().strip() + + return t + + @staticmethod + def _view_data_registry( + data_registry: attrdict, as_markdown: bool = False + ) -> rich.table.Table | str: + """Prints data registry.""" + if not as_markdown: + t = rich.table.Table(title="Data Registry") + else: + t = rich.table.Table(box=box.MARKDOWN) + + t.add_column( + "Registry Key", + justify="center", + style="dodger_blue1", + no_wrap=True, + overflow="fold", + ) + t.add_column( + "scvi-tools Location", + justify="center", + style="dark_violet", + no_wrap=True, + overflow="fold", + ) + + for registry_key, data_loc in data_registry.items(): + mod_key = getattr(data_loc, _constants._DR_MOD_KEY, None) + attr_name = data_loc.attr_name + attr_key = data_loc.attr_key + scvi_data_str = "adata" + if mod_key is not None: + scvi_data_str += f".mod['{mod_key}']" + if attr_key is None: + scvi_data_str += f".{attr_name}" + else: + scvi_data_str += f".{attr_name}['{attr_key}']" + t.add_row(registry_key, scvi_data_str) + + if as_markdown: + console = Console(file=StringIO(), force_jupyter=False) + console.print(t) + return console.file.getvalue().strip() + + return t + + def update_setup_method_args(self, setup_method_args: dict): + """Update setup method args. + + Parameters + ---------- + setup_method_args + This is a bit of a misnomer, this is a dict representing kwargs + of the setup method that will be used to update the existing values + in the registry of this instance. + """ + self._registry[_SETUP_ARGS_KEY].update(setup_method_args) + def get_normalized_expression(self, *args, **kwargs): msg = f"get_normalized_expression is not implemented for {self.__class__.__name__}." raise NotImplementedError(msg) @@ -914,11 +1164,14 @@ class BaseMinifiedModeModelClass(BaseModelClass): @property def minified_data_type(self) -> MinifiedDataType | None: """The type of minified data associated with this model, if applicable.""" - return ( - self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) - if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry - else None - ) + if self.adata_manager: + return ( + self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) + if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry + else None + ) + else: + return None def minify_adata( self, diff --git a/src/scvi/model/base/_rnamixin.py b/src/scvi/model/base/_rnamixin.py index 22c35704e8..daada26c05 100644 --- a/src/scvi/model/base/_rnamixin.py +++ b/src/scvi/model/base/_rnamixin.py @@ -13,6 +13,7 @@ from pyro.distributions.util import deep_to from scvi import REGISTRY_KEYS, settings +from scvi.data._utils import _validate_adata_dataloader_input from scvi.distributions._utils import DistributionConcatenator, subset_distribution from scvi.model._utils import _get_batch_code_from_category, scrna_raw_counts_properties from scvi.model.base._de_core import _de_core @@ -20,9 +21,11 @@ from scvi.utils import de_dsp, dependencies, track, unsupported_if_adata_minified if TYPE_CHECKING: + from collections.abc import Iterator from typing import Literal from anndata import AnnData + from torch import Tensor from scvi._types import Number @@ -162,6 +165,7 @@ def get_normalized_expression( return_mean: bool = True, return_numpy: bool | None = None, silent: bool = True, + dataloader: Iterator[dict[str, Tensor | None]] | None = None, **importance_weighting_kwargs, ) -> np.ndarray | pd.DataFrame: r"""Returns the normalized (decoded) gene expression. @@ -204,6 +208,10 @@ def get_normalized_expression( includes gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`. Otherwise, it defaults to `True`. %(de_silent)s + dataloader + An iterator over minibatches of data on which to compute the metric. The minibatches + should be formatted as a dictionary of :class:`~torch.Tensor` with keys as expected by + the model. If ``None``, a dataloader is created from ``adata``. importance_weighting_kwargs Keyword arguments passed into :meth:`~scvi.model.base.RNASeqMixin._get_importance_weights`. @@ -218,20 +226,34 @@ def get_normalized_expression( Otherwise, the method expects `n_samples_overall` to be provided and returns a 2d tensor of shape (n_samples_overall, n_genes). """ - adata = self._validate_anndata(adata) + _validate_adata_dataloader_input(self, adata, dataloader) - if indices is None: - indices = np.arange(adata.n_obs) - if n_samples_overall is not None: - assert n_samples == 1 # default value - n_samples = n_samples_overall // len(indices) + 1 - scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + if dataloader is None: + adata = self._validate_anndata(adata) - transform_batch = _get_batch_code_from_category( - self.get_anndata_manager(adata, required=True), transform_batch - ) + if indices is None: + indices = np.arange(adata.n_obs) + if n_samples_overall is not None: + assert n_samples == 1 # default value + n_samples = n_samples_overall // len(indices) + 1 + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + + transform_batch = _get_batch_code_from_category( + self.get_anndata_manager(adata, required=True), transform_batch + ) + + gene_mask = slice(None) if gene_list is None else adata.var_names.isin(gene_list) - gene_mask = slice(None) if gene_list is None else adata.var_names.isin(gene_list) + else: + scdl = dataloader + for param in [indices, batch_size, n_samples]: + if param is not None: + Warning( + f"Using {param} after custom Dataloader was initialize is redundant, " + f"please re-initialize with selected {param}", + ) + gene_mask = slice(None) + transform_batch = [None] if n_samples > 1 and return_mean is False: if return_numpy is False: @@ -314,7 +336,7 @@ def get_normalized_expression( elif n_samples > 1 and return_mean: exprs = exprs.mean(0) - if return_numpy is None or return_numpy is False: + if (return_numpy is None or return_numpy is False) and dataloader is None: return pd.DataFrame( exprs, columns=adata.var_names[gene_mask], @@ -427,6 +449,7 @@ def posterior_predictive_sample( n_samples: int = 1, gene_list: list[str] | None = None, batch_size: int | None = None, + dataloader: Iterator[dict[str, Tensor | None]] | None = None, ) -> GCXS: r"""Generate predictive samples from the posterior predictive distribution. @@ -455,6 +478,10 @@ def posterior_predictive_sample( Minibatch size to use for data loading and model inference. Defaults to ``scvi.settings.batch_size``. Passed into :meth:`~scvi.model.base.BaseModelClass._make_data_loader`. + dataloader + An iterator over minibatches of data on which to compute the metric. The minibatches + should be formatted as a dictionary of :class:`~torch.Tensor` with keys as expected by + the model. If ``None``, a dataloader is created from ``adata``. Returns ------- @@ -463,17 +490,30 @@ def posterior_predictive_sample( """ import sparse - adata = self._validate_anndata(adata) - dataloader = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + _validate_adata_dataloader_input(self, adata, dataloader) - if gene_list is None: - gene_mask = slice(None) + if dataloader is None: + adata = self._validate_anndata(adata) + dataloader = self._make_data_loader( + adata=adata, indices=indices, batch_size=batch_size + ) + + if gene_list is None: + gene_mask = slice(None) + else: + gene_mask = [gene in gene_list for gene in adata.var_names] + if not np.any(gene_mask): + raise ValueError( + "None of the provided genes in ``gene_list`` were detected in the data." + ) else: - gene_mask = [gene in gene_list for gene in adata.var_names] - if not np.any(gene_mask): - raise ValueError( - "None of the provided genes in ``gene_list`` were detected in the data." - ) + for param in [indices, batch_size, gene_list]: + if param is not None: + Warning( + f"Using {param} after custom Dataloader was initialize is redundant, " + f"please re-initialize with selected {param}", + ) + gene_mask = slice(None) x_hat = [] for tensors in dataloader: @@ -494,6 +534,7 @@ def _get_denoised_samples( batch_size: int = 64, rna_size_factor: int = 1000, transform_batch: list[int] | None = None, + dataloader: Iterator[dict[str, Tensor | None]] | None = None, ) -> np.ndarray: """Return samples from an adjusted posterior predictive. @@ -512,13 +553,29 @@ def _get_denoised_samples( size factor for RNA prior to sampling gamma distribution. transform_batch int of which batch to condition on for all cells. + dataloader + An iterator over minibatches of data on which to compute the metric. The minibatches + should be formatted as a dictionary of :class:`~torch.Tensor` with keys as expected by + the model. If ``None``, a dataloader is created from ``adata``. Returns ------- denoised_samples """ - adata = self._validate_anndata(adata) - scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + _validate_adata_dataloader_input(self, adata, dataloader) + + if dataloader is None: + adata = self._validate_anndata(adata) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + else: + scdl = dataloader + for param in [indices, batch_size, n_samples]: + if param is not None: + Warning( + f"Using {param} after custom Dataloader was initialize is redundant, " + f"please re-initialize with selected {param}", + ) + transform_batch = None data_loader_list = [] for tensors in scdl: @@ -651,6 +708,7 @@ def get_likelihood_parameters( n_samples: int | None = 1, give_mean: bool | None = False, batch_size: int | None = None, + dataloader: Iterator[dict[str, Tensor | None]] | None = None, ) -> dict[str, np.ndarray]: r"""Estimates for the parameters of the likelihood :math:`p(x \mid z)`. @@ -667,10 +725,24 @@ def get_likelihood_parameters( Return expected value of parameters or a samples batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + dataloader + An iterator over minibatches of data on which to compute the metric. The minibatches + should be formatted as a dictionary of :class:`~torch.Tensor` with keys as expected by + the model. If ``None``, a dataloader is created from ``adata``. """ - adata = self._validate_anndata(adata) + _validate_adata_dataloader_input(self, adata, dataloader) - scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + if dataloader is None: + adata = self._validate_anndata(adata) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + else: + scdl = dataloader + for param in [indices, batch_size, n_samples]: + if param is not None: + Warning( + f"Using {param} after custom Dataloader was initialize is redundant, " + f"please re-initialize with selected {param}", + ) dropout_list = [] mean_list = [] @@ -727,6 +799,7 @@ def get_latent_library_size( indices: list[int] | None = None, give_mean: bool = True, batch_size: int | None = None, + dataloader: Iterator[dict[str, Tensor | None]] | None = None, ) -> np.ndarray: r"""Returns the latent library size for each cell. @@ -743,11 +816,26 @@ def get_latent_library_size( Return the mean or a sample from the posterior distribution. batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + dataloader + An iterator over minibatches of data on which to compute the metric. The minibatches + should be formatted as a dictionary of :class:`~torch.Tensor` with keys as expected by + the model. If ``None``, a dataloader is created from ``adata``. """ - self._check_if_trained(warn=False) + _validate_adata_dataloader_input(self, adata, dataloader) + + if dataloader is None: + self._check_if_trained(warn=False) + adata = self._validate_anndata(adata) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + else: + scdl = dataloader + for param in [indices, batch_size]: + if param is not None: + Warning( + f"Using {param} after custom Dataloader was initialize is redundant, " + f"please re-initialize with selected {param}", + ) - adata = self._validate_anndata(adata) - scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) libraries = [] for tensors in scdl: inference_inputs = self.module._get_inference_input(tensors) diff --git a/src/scvi/model/base/_save_load.py b/src/scvi/model/base/_save_load.py index 4b5a60e411..1c742da21f 100644 --- a/src/scvi/model/base/_save_load.py +++ b/src/scvi/model/base/_save_load.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import logging import os import warnings @@ -102,7 +103,7 @@ def _load_saved_files( return attr_dict, var_names, model_state_dict, adata -def _initialize_model(cls, adata, attr_dict): +def _initialize_model(cls, adata, registry, attr_dict): """Helper to initialize a model.""" if "init_params_" not in attr_dict.keys(): raise ValueError( @@ -133,7 +134,13 @@ def _initialize_model(cls, adata, attr_dict): if "pretrained_model" in non_kwargs.keys(): non_kwargs.pop("pretrained_model") - model = cls(adata, **non_kwargs, **kwargs) + if not adata: + adata = None + + if "registry" in inspect.signature(cls).parameters: + model = cls(adata, registry=registry, **non_kwargs, **kwargs) + else: + model = cls(adata, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) @@ -177,7 +184,9 @@ def _get_var_names( def _validate_var_names( - adata: AnnOrMuData, source_var_names: npt.NDArray | dict[str, npt.NDArray] + adata: AnnOrMuData | None, + source_var_names: npt.NDArray | dict[str, npt.NDArray], + load_var_names: npt.NDArray | dict[str, npt.NDArray] | None = None, ) -> None: """Validate that source and loaded variable names match. @@ -188,15 +197,19 @@ def _validate_var_names( source_var_names Variable names from a saved model file corresponding to the variable names used during training. + load_var_names + Variable names from the loaded registry. """ from numpy import array_equal - is_anndata = isinstance(adata, AnnData) source_per_mod_var_names = isinstance(source_var_names, dict) - load_var_names = _get_var_names( - adata, - legacy_mudata_format=(not is_anndata and not source_per_mod_var_names), - ) + + if load_var_names is None: + is_anndata = isinstance(adata, AnnData) + load_var_names = _get_var_names( + adata, + legacy_mudata_format=(not is_anndata and not source_per_mod_var_names), + ) if source_per_mod_var_names: valid_load_var_names = all( diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index 9251dfa084..7a0d1effd6 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -10,7 +10,7 @@ import torch from scvi import REGISTRY_KEYS -from scvi.data._utils import get_anndata_attribute +from scvi.data._utils import _validate_adata_dataloader_input, get_anndata_attribute from scvi.dataloaders import DataSplitter, SemiSupervisedDataSplitter from scvi.model._utils import get_max_epochs_heuristic, use_distributed_sampler from scvi.train import SemiSupervisedTrainingPlan, TrainingPlan, TrainRunner @@ -18,13 +18,13 @@ from scvi.utils._docstrings import devices_dsp if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterator, Sequence from lightning import LightningDataModule + from torch import Tensor from scvi._types import AnnOrMuData - logger = logging.getLogger(__name__) @@ -103,15 +103,6 @@ def train( **kwargs Additional keyword arguments passed into :class:`~scvi.train.Trainer`. """ - if datamodule is not None and not self._module_init_on_train: - raise ValueError( - "Cannot pass in `datamodule` if the model was initialized with `adata`." - ) - elif datamodule is None and self._module_init_on_train: - raise ValueError( - "If the model was not initialized with `adata`, a `datamodule` must be passed in." - ) - if max_epochs is None: if datamodule is None: max_epochs = get_max_epochs_heuristic(self.adata.n_obs) @@ -169,23 +160,29 @@ class SemisupervisedTrainingMixin: _training_plan_cls = SemiSupervisedTrainingPlan - def _set_indices_and_labels(self): + def _set_indices_and_labels(self, datamodule=None): """Set indices for labeled and unlabeled cells.""" - labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + labels_state_registry = self.get_state_registry(REGISTRY_KEYS.LABELS_KEY) self.original_label_key = labels_state_registry.original_key self.unlabeled_category_ = labels_state_registry.unlabeled_category - labels = get_anndata_attribute( - self.adata, - self.adata_manager.data_registry.labels.attr_name, - self.original_label_key, - mod_key=getattr(self.adata_manager.data_registry.labels, "mod_key", None), - ).ravel() + if datamodule is None: + self.labels_ = get_anndata_attribute( + self.adata, + self.adata_manager.data_registry.labels.attr_name, + self.original_label_key, + mod_key=getattr(self.adata_manager.data_registry.labels, "mod_key", None), + ).ravel() + else: + if datamodule.registry["setup_method_name"] == "setup_datamodule": + self.labels_ = datamodule.labels_.ravel() + else: + self.labels_ = datamodule.labels.ravel() self._label_mapping = labels_state_registry.categorical_mapping # set unlabeled and labeled indices - self._unlabeled_indices = np.argwhere(labels == self.unlabeled_category_).ravel() - self._labeled_indices = np.argwhere(labels != self.unlabeled_category_).ravel() + self._unlabeled_indices = np.argwhere(self.labels_ == self.unlabeled_category_).ravel() + self._labeled_indices = np.argwhere(self.labels_ != self.unlabeled_category_).ravel() self._code_to_label = dict(enumerate(self._label_mapping)) def predict( @@ -197,6 +194,7 @@ def predict( use_posterior_mean: bool = True, ig_interpretability: bool = False, ig_args: dict | None = None, + dataloader: Iterator[dict[str, Tensor | None]] | None = None, ) -> (np.ndarray | pd.DataFrame, None | np.ndarray): """Return cell label predictions. @@ -221,11 +219,32 @@ def predict( sample prediction ig_args Keyword args for IntegratedGradients + dataloader + An iterator over minibatches of data on which to compute the metric. The minibatches + should be formatted as a dictionary of :class:`~torch.Tensor` with keys as expected by + the model. If ``None``, a dataloader is created from ``adata``. """ - adata = self._validate_anndata(adata) + _validate_adata_dataloader_input(self, adata, dataloader) - if indices is None: - indices = np.arange(adata.n_obs) + if dataloader is None: + adata = self._validate_anndata(adata) + + if indices is None: + indices = np.arange(adata.n_obs) + + scdl = self._make_data_loader( + adata=adata, + indices=indices, + batch_size=batch_size, + ) + else: + scdl = dataloader + for param in [indices, batch_size]: + if param is not None: + Warning( + f"Using {param} after custom Dataloader was initialize is redundant, " + f"please re-initialize with selected {param}", + ) attributions = None if ig_interpretability: @@ -242,18 +261,13 @@ def predict( attributions = [] # in case of no indices to predict return empty values - if len(indices) == 0: - pred = [] - if ig_interpretability: - return pred, attributions - else: - return pred - - scdl = self._make_data_loader( - adata=adata, - indices=indices, - batch_size=batch_size, - ) + if dataloader is None: + if len(indices) == 0: + pred = [] + if ig_interpretability: + return pred, attributions + else: + return pred y_pred = [] for _, tensors in enumerate(scdl): @@ -336,6 +350,7 @@ def train( devices: int | list[int] | str = "auto", datasplitter_kwargs: dict | None = None, plan_kwargs: dict | None = None, + datamodule: LightningDataModule | None = None, **trainer_kwargs, ): """Train the model. @@ -371,6 +386,10 @@ def train( Keyword args for :class:`~scvi.train.SemiSupervisedTrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. + datamodule + ``EXPERIMENTAL`` A :class:`~lightning.pytorch.core.LightningDataModule` instance to use + for training in place of the default :class:`~scvi.dataloaders.DataSplitter`. Can only + be passed in if the model was not initialized with :class:`~anndata.AnnData`. **trainer_kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ @@ -385,19 +404,26 @@ def train( plan_kwargs = {} if plan_kwargs is None else plan_kwargs datasplitter_kwargs = datasplitter_kwargs or {} - # if we have labeled cells, we want to subsample labels each epoch - sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] - - data_splitter = SemiSupervisedDataSplitter( - adata_manager=self.adata_manager, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, - n_samples_per_label=n_samples_per_label, - distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)), - batch_size=batch_size, - **datasplitter_kwargs, - ) + if datamodule is None: + # if we have labeled cells, we want to subsample labels each epoch + sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] + + datasplitter_kwargs = datasplitter_kwargs or {} + datamodule = SemiSupervisedDataSplitter( + adata_manager=self.adata_manager, + datamodule=datamodule, + train_size=train_size, + validation_size=validation_size, + shuffle_set_split=shuffle_set_split, + n_samples_per_label=n_samples_per_label, + distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)), + batch_size=batch_size, + **datasplitter_kwargs, + ) + else: + Warning("Warning: SCANVI sampler is not available with custom dataloader") + sampler_callback = [] + training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs) if "callbacks" in trainer_kwargs.keys(): @@ -408,7 +434,7 @@ def train( runner = TrainRunner( self, training_plan=training_plan, - data_splitter=data_splitter, + data_splitter=datamodule, max_epochs=max_epochs, accelerator=accelerator, devices=devices, diff --git a/src/scvi/model/base/_vaemixin.py b/src/scvi/model/base/_vaemixin.py index 1a2ea85bfa..87a2c150fe 100644 --- a/src/scvi/model/base/_vaemixin.py +++ b/src/scvi/model/base/_vaemixin.py @@ -5,6 +5,7 @@ import torch +from scvi.data._utils import _validate_adata_dataloader_input from scvi.utils import unsupported_if_adata_minified if TYPE_CHECKING: @@ -72,13 +73,20 @@ def get_elbo( """ from scvi.model.base._log_likelihood import compute_elbo - if adata is not None and dataloader is not None: - raise ValueError("Only one of `adata` or `dataloader` can be provided.") - elif dataloader is None: + _validate_adata_dataloader_input(self, adata, dataloader) + + if dataloader is None: adata = self._validate_anndata(adata) dataloader = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size ) + else: + for param in [indices, batch_size]: + if param is not None: + Warning( + f"Using {param} after custom Dataloader was initialize is redundant, " + f"please re-initialize with selected {param}", + ) return -compute_elbo(self.module, dataloader, return_mean=return_mean, **kwargs) @@ -139,14 +147,21 @@ def get_marginal_ll( "The model's module must implement `marginal_ll` to compute the marginal " "log-likelihood." ) - elif adata is not None and dataloader is not None: - raise ValueError("Only one of `adata` or `dataloader` can be provided.") + else: + _validate_adata_dataloader_input(self, adata, dataloader) if dataloader is None: adata = self._validate_anndata(adata) dataloader = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size ) + else: + for param in [indices, batch_size]: + if param is not None: + Warning( + f"Using {param} after custom Dataloader was initialize is redundant, " + f"please re-initialize with selected {param}", + ) log_likelihoods: list[float | Tensor] = [ self.module.marginal_ll( @@ -210,14 +225,20 @@ def get_reconstruction_error( """ from scvi.model.base._log_likelihood import compute_reconstruction_error - if adata is not None and dataloader is not None: - raise ValueError("Only one of `adata` or `dataloader` can be provided.") + _validate_adata_dataloader_input(self, adata, dataloader) if dataloader is None: adata = self._validate_anndata(adata) dataloader = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size ) + else: + for param in [indices, batch_size]: + if param is not None: + Warning( + f"Using {param} after custom Dataloader was initialize is redundant, " + f"please re-initialize with selected {param}", + ) return compute_reconstruction_error( self.module, dataloader, return_mean=return_mean, **kwargs @@ -277,14 +298,20 @@ def get_latent_representation( from scvi.module._constants import MODULE_KEYS self._check_if_trained(warn=False) - if adata is not None and dataloader is not None: - raise ValueError("Only one of `adata` or `dataloader` can be provided.") + _validate_adata_dataloader_input(self, adata, dataloader) if dataloader is None: adata = self._validate_anndata(adata) dataloader = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size ) + else: + for param in [indices, batch_size]: + if param is not None: + Warning( + f"Using {param} after custom Dataloader was initialize is redundant, " + f"please re-initialize with selected {param}", + ) zs: list[Tensor] = [] qz_means: list[Tensor] = [] diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index 95985605b5..7b00a276ea 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -905,8 +905,13 @@ def training_step(self, batch, batch_idx): full_dataset = batch[0] labelled_dataset = batch[1] else: - full_dataset = batch - labelled_dataset = None + if list(batch.keys()) == ["X", "batch", "labels"]: + # mean we are on batch loading from custom dataloader, TODO: KEEP IT? + full_dataset = batch + labelled_dataset = batch + else: + full_dataset = batch + labelled_dataset = None if "kl_weight" in self.loss_kwargs: self.loss_kwargs.update({"kl_weight": self.kl_weight}) diff --git a/tests/conftest.py b/tests/conftest.py index 78f3934b82..f8a6816c84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,12 @@ def pytest_addoption(parser): default=False, help="Run tests that are designed for Ray Autotune.", ) + parser.addoption( + "--custom-dataloader-tests", + action="store_true", + default=False, + help="Run tests that deals with custom dataloaders. This increases test time.", + ) parser.addoption( "--optional", action="store_true", @@ -78,6 +84,23 @@ def pytest_collection_modifyitems(config, items): elif run_internet and ("internet" not in item.keywords): item.add_marker(skip_non_internet) + run_custom_dataloader = config.getoption("--custom-dataloader-tests") + skip_custom_dataloader = pytest.mark.skip( + reason="need ---custom-dataloader-tests option to run" + ) + skip_non_custom_dataloader = pytest.mark.skip( + reason="test not having a pytest.mark.custom_dataloader decorator" + ) + for item in items: + # All tests marked with `pytest.mark.custom_dataloader` get skipped unless + # `--custom_dataloader-tests` passed + if not run_custom_dataloader and ("dataloader" in item.keywords): + item.add_marker(skip_custom_dataloader) + # Skip all tests not marked with `pytest.mark.custom_dataloader` + # if `--custom-dataloader-tests` passed + elif run_custom_dataloader and ("dataloader" not in item.keywords): + item.add_marker(skip_non_custom_dataloader) + run_optional = config.getoption("--optional") skip_optional = pytest.mark.skip(reason="need --optional option to run") skip_non_optional = pytest.mark.skip(reason="test not having a pytest.mark.optional decorator") diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py new file mode 100644 index 0000000000..55118de701 --- /dev/null +++ b/tests/dataloaders/test_custom_dataloader.py @@ -0,0 +1,721 @@ +from __future__ import annotations + +import os +from pprint import pprint + +import numpy as np +import pytest + +import scvi +from scvi.data import synthetic_iid +from scvi.dataloaders import MappedCollectionDataModule, TileDBDataModule +from scvi.utils import dependencies + + +@pytest.mark.dataloader +@dependencies("lamindb") +def test_lamindb_dataloader_scvi_small(save_path: str): + os.system("lamin init --storage ./lamindb_collection") # one time for github runner (comment) + import lamindb as ln + + ln.setup.init() # one time for github runner (comment out when runing localy) + + # prepare test data + adata1 = synthetic_iid() + adata2 = synthetic_iid() + + artifact1 = ln.Artifact.from_anndata(adata1, key="part_one.h5ad").save() + artifact2 = ln.Artifact.from_anndata(adata2, key="part_two.h5ad").save() + + collection = ln.Collection([artifact1, artifact2], key="gather") + collection.save() + + artifacts = collection.artifacts.all() + artifacts.df() + + # large data example + # ln.track("d1kl7wobCO1H0005") + # ln.setup.init(name="lamindb_instance_name", storage=save_path) # is this need in github test + # ln.setup.init() + # collection = ln.Collection.using("laminlabs/cellxgene").get(name="covid_normal_lung") + # artifacts = collection.artifacts.all() + # artifacts.df() + + datamodule = MappedCollectionDataModule( + collection, + batch_key="batch", + batch_size=1024, + join="inner", + ) + + print(datamodule.n_obs, datamodule.n_vars, datamodule.n_batch) + + pprint(datamodule.registry) + + model = scvi.model.SCVI(registry=datamodule.registry) + pprint(model.summary_stats) + pprint(model.module) + + model.train( + max_epochs=1, + batch_size=1024, + datamodule=datamodule, + ) + model.history.keys() + + # The way to extract the internal model analysis is by the inference_dataloader + # Datamodule will always require to pass it into all downstream functions. + inference_dataloader = datamodule.inference_dataloader() + _ = model.get_elbo(dataloader=inference_dataloader) + _ = model.get_marginal_ll(dataloader=inference_dataloader) + _ = model.get_reconstruction_error(dataloader=inference_dataloader) + _ = model.get_latent_representation(dataloader=inference_dataloader) + _ = model.posterior_predictive_sample(dataloader=inference_dataloader) + _ = model.get_normalized_expression(dataloader=inference_dataloader) + _ = model.get_likelihood_parameters(dataloader=inference_dataloader) + _ = model._get_denoised_samples(dataloader=inference_dataloader) + _ = model.get_latent_library_size(dataloader=inference_dataloader, give_mean=False) + + # repeat but with other data with fewer indices and smaller batch size + adata1_small = synthetic_iid(batch_size=10) + adata2_small = synthetic_iid(batch_size=10) + artifact1_small = ln.Artifact.from_anndata(adata1_small, key="part_one_small.h5ad").save() + artifact2_small = ln.Artifact.from_anndata(adata2_small, key="part_two_small.h5ad").save() + collection_small = ln.Collection([artifact1_small, artifact2_small], key="gather") + datamodule_small = MappedCollectionDataModule( + collection_small, + batch_key="batch", + batch_size=1024, + join="inner", + collection_val=collection, + ) + inference_dataloader_small = datamodule_small.inference_dataloader(batch_size=128) + _ = model.get_elbo(return_mean=False, dataloader=inference_dataloader_small) + _ = model.get_marginal_ll(n_mc_samples=3, dataloader=inference_dataloader_small) + _ = model.get_reconstruction_error(return_mean=False, dataloader=inference_dataloader_small) + _ = model.get_latent_representation(dataloader=inference_dataloader_small) + _ = model.posterior_predictive_sample( + indices=[1, 2, 3], gene_list=["gene_1", "gene_2"], dataloader=inference_dataloader_small + ) + _ = model.get_normalized_expression(n_samples=2, dataloader=inference_dataloader_small) + + # load and save and make query with the other data + model.save("lamin_model", save_anndata=False, overwrite=True, datamodule=datamodule) + model_query = model.load_query_data( + adata=False, reference_model="lamin_model", registry=datamodule.registry + ) + model_query.train( + max_epochs=1, datamodule=datamodule_small, check_val_every_n_epoch=1, train_size=0.9 + ) + model_query.history.keys() + + _ = model_query.get_elbo(dataloader=inference_dataloader_small) + _ = model_query.get_marginal_ll(dataloader=inference_dataloader_small) + _ = model_query.get_reconstruction_error(dataloader=inference_dataloader_small) + _ = model_query.get_latent_representation(dataloader=inference_dataloader_small) + _ = model_query.posterior_predictive_sample(dataloader=inference_dataloader_small) + _ = model_query.get_normalized_expression(dataloader=inference_dataloader_small) + _ = model_query.get_likelihood_parameters(dataloader=inference_dataloader_small) + _ = model_query._get_denoised_samples(dataloader=inference_dataloader_small) + _ = model_query.get_latent_library_size(dataloader=inference_dataloader_small, give_mean=False) + + # query again but with the adata of the model, which might bring more functionality + adata = collection.load(join="inner") + scvi.model.SCVI.setup_anndata(adata, batch_key="batch") + with pytest.raises(ValueError): + model.load_query_data(adata=adata) + model_query_adata = model.load_query_data(adata=adata, reference_model="lamin_model") + model_query_adata.train(max_epochs=1, check_val_every_n_epoch=1, train_size=0.9) + model_query_adata.history.keys() + _ = model_query_adata.get_elbo() + _ = model_query_adata.get_marginal_ll() + _ = model_query_adata.get_reconstruction_error() + _ = model_query_adata.get_latent_representation() + _ = model_query_adata.get_latent_representation(dataloader=inference_dataloader) + _ = model_query_adata.posterior_predictive_sample(indices=[1, 2, 3]) + model.save("lamin_model", save_anndata=False, overwrite=True, datamodule=datamodule) + model.load("lamin_model", adata=False) + model.load_query_data(adata=False, reference_model="lamin_model", registry=datamodule.registry) + + # cretae a regular model + model.load_query_data(adata=adata, reference_model="lamin_model") + model_adata = model.load("lamin_model", adata=adata) + scvi.model.SCVI.setup_anndata(adata, batch_key="batch") + model_adata.train(max_epochs=1, check_val_every_n_epoch=1, train_size=0.9) + model_adata.save( + "lamin_model_anndata", save_anndata=True, overwrite=True, datamodule=datamodule + ) + model_adata.load("lamin_model_anndata") + model_adata.load_query_data( + adata=adata, reference_model="lamin_model_anndata", registry=datamodule.registry + ) + model_adata.history.keys() + # test different gene_likelihoods + for gene_likelihood in ["zinb", "nb", "poisson"]: + model_adata = scvi.model.SCVI(adata, gene_likelihood=gene_likelihood) + model_adata.train(1, check_val_every_n_epoch=1, train_size=0.9) + model_adata.posterior_predictive_sample() + model_adata.get_latent_representation() + model_adata.get_normalized_expression() + + +@pytest.mark.dataloader +@dependencies("lamindb") +def test_lamindb_dataloader_scanvi_small(save_path: str): + # os.system("lamin init --storage ./lamindb_collection") + import lamindb as ln + + # ln.setup.init() + + # prepare test data + adata1 = synthetic_iid() + adata2 = synthetic_iid() + + artifact1 = ln.Artifact.from_anndata(adata1, key="part_one.h5ad").save() + artifact2 = ln.Artifact.from_anndata(adata2, key="part_two.h5ad").save() + + collection = ln.Collection([artifact1, artifact2], key="gather") + collection.save() + + artifacts = collection.artifacts.all() + artifacts.df() + + # large data example + # ln.track("d1kl7wobCO1H0005") + # ln.setup.init(name="lamindb_instance_name", storage=save_path) # is this need in github test + # ln.setup.init() + # collection = ln.Collection.using("laminlabs/cellxgene").get(name="covid_normal_lung") + # artifacts = collection.artifacts.all() + # artifacts.df() + + datamodule = MappedCollectionDataModule( + collection, + label_key="labels", + batch_key="batch", + batch_size=1024, + join="inner", + unlabeled_category="label_0", + ) + + # We can now create the scVI model object and train it: + model = scvi.model.SCANVI( + adata=None, + registry=datamodule.registry, + encode_covariates=False, + datamodule=datamodule, + ) + + model.train( + datamodule=datamodule, + max_epochs=1, + batch_size=1024, + train_size=1, + early_stopping=False, + ) + + user_attributes = model._get_user_attributes() + pprint(user_attributes) + model.history.keys() + + # save the model + # model.save("lamin_model_scanvi", save_anndata=False, overwrite=True, datamodule=datamodule) + # load it back and do downstream analysis (not working) + # scvi.model.SCANVI.load("lamin_model_scanvi", adata=False) + + inference_dataloader = datamodule.inference_dataloader() + + _ = model.get_elbo(dataloader=inference_dataloader) + _ = model.get_marginal_ll(dataloader=inference_dataloader) + _ = model.get_reconstruction_error(dataloader=inference_dataloader) + _ = model.get_latent_representation(dataloader=inference_dataloader) + _ = model.posterior_predictive_sample(dataloader=inference_dataloader) + _ = model.get_normalized_expression(dataloader=inference_dataloader) + _ = model.get_likelihood_parameters(dataloader=inference_dataloader) + _ = model._get_denoised_samples(dataloader=inference_dataloader) + _ = model.get_latent_library_size(dataloader=inference_dataloader, give_mean=False) + + logged_keys = model.history.keys() + # assert "elbo_validation" in logged_keys + # assert "reconstruction_loss_validation" in logged_keys + # assert "kl_local_validation" in logged_keys + assert "elbo_train" in logged_keys + assert "reconstruction_loss_train" in logged_keys + assert "kl_local_train" in logged_keys + # assert "validation_classification_loss" in logged_keys + # assert "validation_accuracy" in logged_keys + # assert "validation_f1_score" in logged_keys + # assert "validation_calibration_error" in logged_keys + + # repeat but with other data with fewer indices and smaller batch size + adata1_small = synthetic_iid(batch_size=10) + adata2_small = synthetic_iid(batch_size=10) + artifact1_small = ln.Artifact.from_anndata(adata1_small, key="part_one_small.h5ad").save() + artifact2_small = ln.Artifact.from_anndata(adata2_small, key="part_two_small.h5ad").save() + collection_small = ln.Collection([artifact1_small, artifact2_small], key="gather") + datamodule_small = MappedCollectionDataModule( + collection_small, + batch_key="batch", + batch_size=1024, + join="inner", + ) + inference_dataloader_small = datamodule_small.inference_dataloader(batch_size=128) + + model.predict(dataloader=inference_dataloader_small, soft=False) + + # train from scvi model + model_scvi = scvi.model.SCVI(registry=datamodule.registry) + + # with validation collection + datamodule = MappedCollectionDataModule( + collection, + label_key="labels", + batch_key="batch", + batch_size=1024, + join="inner", + unlabeled_category="label_0", + collection_val=collection_small, + ) + + model_scvi.train( + max_epochs=1, + batch_size=1024, + datamodule=datamodule, + check_val_every_n_epoch=1, + train_size=0.9, + ) + model_scvi.save("lamin_model_scvi", save_anndata=False, overwrite=True, datamodule=datamodule) + + logged_keys = model_scvi.history.keys() + assert "elbo_validation" in logged_keys + assert "reconstruction_loss_validation" in logged_keys + assert "kl_local_validation" in logged_keys + assert "elbo_train" in logged_keys + assert "reconstruction_loss_train" in logged_keys + assert "kl_local_train" in logged_keys + assert "validation_loss" in logged_keys + + # We can now create the scVI model object and train it: + model_scanvi_from_scvi = scvi.model.SCANVI.from_scvi_model( + scvi_model=model_scvi, + adata=None, + registry=datamodule.registry, + encode_covariates=False, + unlabeled_category="label_0", + labels_key="labels", + datamodule=datamodule, + ) + model_scanvi_from_scvi.train( + datamodule=datamodule, + max_epochs=1, + batch_size=1024, + train_size=0.9, + check_val_every_n_epoch=1, + early_stopping=False, + ) + # save the model + # model_scanvi_from_scvi.save( + # "lamin_model_scanvi_from_scvi", save_anndata=False, overwrite=True, datamodule=datamodule + # ) + # # load it back and do downstream analysis (not working) + # scvi.model.SCANVI.load("lamin_model_scanvi_from_scvi", adata=False) + + logged_keys = model_scanvi_from_scvi.history.keys() + assert "elbo_validation" in logged_keys + assert "reconstruction_loss_validation" in logged_keys + assert "kl_local_validation" in logged_keys + assert "elbo_train" in logged_keys + assert "reconstruction_loss_train" in logged_keys + assert "kl_local_train" in logged_keys + assert "validation_loss" in logged_keys + + inference_dataloader = datamodule.inference_dataloader() + + _ = model_scanvi_from_scvi.get_elbo(dataloader=inference_dataloader) + _ = model_scanvi_from_scvi.get_marginal_ll(dataloader=inference_dataloader) + _ = model_scanvi_from_scvi.get_reconstruction_error(dataloader=inference_dataloader) + _ = model_scanvi_from_scvi.get_latent_representation(dataloader=inference_dataloader) + _ = model_scanvi_from_scvi.posterior_predictive_sample(dataloader=inference_dataloader) + _ = model_scanvi_from_scvi.get_normalized_expression(dataloader=inference_dataloader) + _ = model_scanvi_from_scvi.get_likelihood_parameters(dataloader=inference_dataloader) + _ = model_scanvi_from_scvi._get_denoised_samples(dataloader=inference_dataloader) + _ = model_scanvi_from_scvi.get_latent_library_size( + dataloader=inference_dataloader, give_mean=False + ) + + # create scanvi from adata + adata = collection.load(join="inner") # we can continue to + + scvi.model.SCANVI.setup_anndata( + adata, batch_key="batch", labels_key="labels", unlabeled_category="label_0" + ) + model_query_adata = scvi.model.SCANVI(adata, encode_covariates=True) + model_query_adata.train(max_epochs=1, check_val_every_n_epoch=1, train_size=0.9) + model_query_adata.predict(adata=adata, soft=True) + model.history.keys() + + +@pytest.mark.dataloader +@dependencies("tiledbsoma") +@dependencies("cellxgene_census") +def test_census_custom_dataloader_scvi(save_path: str): + import cellxgene_census + import tiledbsoma as soma + + # load census + census = cellxgene_census.open_soma(census_version="2023-12-15") + + # do obs filtering (in this test we take a small dataset) + experiment_name = "mus_musculus" + obs_value_filter = ( + 'is_primary_data == True and tissue_general in ["liver","heart"] and nnz >= 5000' + ) + + # in order to save time in this test we manulay filter var + hv_idx = np.arange(10) # just to make it smaller and faster for debug + + # For HVG, we can use the highly_variable_genes function provided in cellxgene_census, + # which can compute HVGs in constant memory: + hvg_query = census["census_data"][experiment_name].axis_query( + measurement_name="RNA", + obs_query=soma.AxisQuery(value_filter=obs_value_filter), + var_query=soma.AxisQuery(coords=(list(hv_idx),)), + ) + + # We will now use class TileDBDataModule to connect TileDB-SOMA-ML with PyTorch Lightning. + batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"] + datamodule = TileDBDataModule( + hvg_query, + layer_name="raw", + batch_size=1024, + shuffle=True, + train_size=0.9, + seed=42, + batch_column_names=batch_keys, + dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, + ) + + print(datamodule.n_obs, datamodule.n_vars, datamodule.n_batch) + + n_layers = 1 + n_latent = 5 + + pprint(datamodule.registry) + + # creating the dataloader for trainset + datamodule.setup() + + # We can now create the scVI model object and train it: + model = scvi.model.SCVI( + adata=None, + registry=datamodule.registry, + n_layers=n_layers, + n_latent=n_latent, + gene_likelihood="nb", + encode_covariates=False, + ) + + model.train( + datamodule=datamodule, + max_epochs=1, + batch_size=1024, + train_size=0.9, + check_val_every_n_epoch=1, + early_stopping=False, + ) + + user_attributes = model._get_user_attributes() + pprint(user_attributes) + model.history.keys() + + # save the model + model.save("census_model", save_anndata=False, overwrite=True, datamodule=datamodule) + # load it back and do downstream analysis (not working) + scvi.model.SCVI.load("census_model", adata=False) + + # Generate cell embeddings + inference_datamodule = TileDBDataModule( + hvg_query, + layer_name="raw", + batch_labels=datamodule.batch_labels, + batch_size=1024, + shuffle=False, + batch_column_names=batch_keys, + dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, + ) + + inference_datamodule.setup() + + # Datamodule will always require to pass it into all downstream functions. + # need to init the inference_dataloader before each of those commands: + latent = model.get_latent_representation( + dataloader=inference_datamodule.inference_dataloader() + ) + print(latent.shape) + _ = model.get_elbo(dataloader=inference_datamodule.inference_dataloader()) + _ = model.get_marginal_ll(dataloader=inference_datamodule.inference_dataloader()) + _ = model.get_reconstruction_error(dataloader=inference_datamodule.inference_dataloader()) + _ = model.get_latent_representation(dataloader=inference_datamodule.inference_dataloader()) + _ = model.posterior_predictive_sample(dataloader=inference_datamodule.inference_dataloader()) + _ = model.get_normalized_expression(dataloader=inference_datamodule.inference_dataloader()) + _ = model.get_likelihood_parameters(dataloader=inference_datamodule.inference_dataloader()) + _ = model._get_denoised_samples(dataloader=inference_datamodule.inference_dataloader()) + _ = model.get_latent_library_size( + dataloader=inference_datamodule.inference_dataloader(), give_mean=False + ) + + # generating data from this census + adata = cellxgene_census.get_anndata( + census, + organism=experiment_name, + obs_value_filter=obs_value_filter, + var_coords=hv_idx, + ) + # verify cell order: + assert np.array_equal( + np.array(adata.obs["soma_joinid"]), + inference_datamodule.train_dataset.query_ids.obs_joinids, + ) + + adata.obsm["scvi"] = latent + + # Additional things we would like to check + # we make the batch name the same as in the model + adata.obs["batch"] = adata.obs[batch_keys].agg("//".join, axis=1).astype("category") + + # query data + model_query = model.load_query_data( + adata=False, reference_model="census_model", registry=datamodule.registry + ) + model_query.history.keys() + + scvi.model.SCVI.prepare_query_anndata(adata, "census_model", return_reference_var_names=True) + scvi.model.SCVI.load_query_data(registry=datamodule.registry, reference_model="census_model") + + scvi.model.SCVI.prepare_query_anndata(adata, model) + + model.save("census_model2", save_anndata=False, overwrite=True, datamodule=datamodule) + + scvi.model.SCVI.setup_anndata(adata, batch_key="batch") + model_census3 = scvi.model.SCVI.load("census_model2", adata=adata) + + model_census3.train( + datamodule=datamodule, + max_epochs=1, + check_val_every_n_epoch=1, + train_size=0.9, + early_stopping=False, + ) + + user_attributes_model_census3 = model_census3._get_user_attributes() + pprint(user_attributes_model_census3) + _ = model_census3.get_elbo() + _ = model_census3.get_marginal_ll() + _ = model_census3.get_reconstruction_error() + _ = model_census3.get_latent_representation() + _ = model_census3.posterior_predictive_sample() + _ = model_census3.get_normalized_expression() + _ = model_census3.get_likelihood_parameters() + _ = model_census3._get_denoised_samples() + _ = model_census3.get_latent_library_size(give_mean=False) + for gene_likelihood in ["zinb", "nb", "poisson"]: + model_adata = scvi.model.SCVI(adata, gene_likelihood=gene_likelihood) + model_adata.train(1, check_val_every_n_epoch=1, train_size=0.9) + + scvi.model.SCVI.prepare_query_anndata(adata, "census_model2", return_reference_var_names=True) + scvi.model.SCVI.load_query_data(adata, "census_model2") + + scvi.model.SCVI.prepare_query_anndata(adata, model_census3) + scvi.model.SCVI.load_query_data(adata, model_census3) + + +@pytest.mark.dataloader +@dependencies("tiledbsoma") +@dependencies("cellxgene_census") +def test_census_custom_dataloader_scanvi(save_path: str): + import cellxgene_census + import tiledbsoma as soma + + # load census + census = cellxgene_census.open_soma(census_version="2023-12-15") + + # do obs filtering (in this test we take a small dataset) + experiment_name = "mus_musculus" + obs_value_filter = ( + 'is_primary_data == True and tissue_general in ["liver","heart"] and nnz >= 5000' + ) + + # in order to save time in this test we manually filter var + hv_idx = np.arange(10) # just ot make it smaller and faster for debug + + # For HVG, we can use the highly_variable_genes function provided in cellxgene_census, + # which can compute HVGs in constant memory: + hvg_query = census["census_data"][experiment_name].axis_query( + measurement_name="RNA", + obs_query=soma.AxisQuery(value_filter=obs_value_filter), + var_query=soma.AxisQuery(coords=(list(hv_idx),)), + ) + + # We will now use class TileDBDataModule to connect TileDB-SOMA-ML with PyTorch Lightning. + batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"] + label_keys = ["tissue_general"] + datamodule = TileDBDataModule( + hvg_query, + layer_name="raw", + batch_size=1024, + shuffle=True, + seed=42, + batch_column_names=batch_keys, + label_keys=label_keys, + train_size=0.9, + unlabeled_category="label_0", + dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, + ) + + print(datamodule.n_obs, datamodule.n_vars, datamodule.n_batch) + + n_layers = 1 + n_latent = 5 + + pprint(datamodule.registry) + + # creating the dataloader for trainset + datamodule.setup() + + # We can now create the scVI model object and train it: + model = scvi.model.SCANVI( + adata=None, + registry=datamodule.registry, + n_layers=n_layers, + n_latent=n_latent, + gene_likelihood="nb", + encode_covariates=False, + datamodule=datamodule, + ) + + model.train( + datamodule=datamodule, + max_epochs=1, + batch_size=1024, + train_size=0.9, + check_val_every_n_epoch=1, + early_stopping=False, + ) + + user_attributes = model._get_user_attributes() + pprint(user_attributes) + model.history.keys() + + # save the model + # model.save("census_model_scanvi", save_anndata=False, overwrite=True, datamodule=datamodule) + # load it back and do downstream analysis (not working) + # scvi.model.SCANVI.load("census_model_scanvi", adata=False) + + # Generate cell embeddings + inference_datamodule = TileDBDataModule( + hvg_query, + layer_name="raw", + batch_labels=datamodule.batch_labels, + batch_size=1024, + shuffle=False, + batch_column_names=batch_keys, + label_keys=label_keys, + unlabeled_category="label_0", + dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, + ) + + inference_datamodule.setup() + + latent = model.get_latent_representation( + dataloader=inference_datamodule.inference_dataloader() + ) + print(latent.shape) + _ = model.get_elbo(dataloader=inference_datamodule.inference_dataloader()) + _ = model.get_marginal_ll(dataloader=inference_datamodule.inference_dataloader()) + _ = model.get_reconstruction_error(dataloader=inference_datamodule.inference_dataloader()) + _ = model.posterior_predictive_sample(dataloader=inference_datamodule.inference_dataloader()) + _ = model.get_normalized_expression(dataloader=inference_datamodule.inference_dataloader()) + _ = model.get_likelihood_parameters(dataloader=inference_datamodule.inference_dataloader()) + _ = model._get_denoised_samples(dataloader=inference_datamodule.inference_dataloader()) + _ = model.get_latent_library_size( + dataloader=inference_datamodule.inference_dataloader(), give_mean=False + ) + + logged_keys = model.history.keys() + assert "elbo_validation" in logged_keys + assert "reconstruction_loss_validation" in logged_keys + assert "kl_local_validation" in logged_keys + assert "elbo_train" in logged_keys + assert "reconstruction_loss_train" in logged_keys + assert "kl_local_train" in logged_keys + # assert "validation_classification_loss" in logged_keys + # assert "validation_accuracy" in logged_keys + # assert "validation_f1_score" in logged_keys + # assert "validation_calibration_error" in logged_keys + assert "kl_global_validation" in logged_keys + assert "kl_global_train" in logged_keys + + model.predict(dataloader=inference_datamodule.inference_dataloader(), soft=False) + + # train from scvi model + model_scvi = scvi.model.SCVI(registry=datamodule.registry) + + model_scvi.train( + max_epochs=1, + batch_size=1024, + datamodule=datamodule, + check_val_every_n_epoch=1, + train_size=0.9, + ) + model_scvi.save("census_model_scvi", save_anndata=False, overwrite=True, datamodule=datamodule) + # We can now create the scVI model object and train it: + model_scanvi_from_scvi = scvi.model.SCANVI.from_scvi_model( + scvi_model=model_scvi, + adata=None, + registry=datamodule.registry, + encode_covariates=False, + unlabeled_category="label_0", + labels_key="labels", + datamodule=datamodule, + ) + model_scanvi_from_scvi.train( + datamodule=datamodule, + max_epochs=1, + batch_size=1024, + train_size=0.9, + check_val_every_n_epoch=1, + early_stopping=False, + ) + # # save the model + # model_scanvi_from_scvi.save( + # "census_model_scanvi_from_scvi",save_anndata=False, overwrite=True, datamodule=datamodule + # ) + # # load it back and do downstream analysis (not working) + # model_scanvi_from_scvi_loaded = scvi.model.SCANVI.load( + # "census_model_scanvi_from_scvi", adata=False + # ) + + # generating adata from this census + adata = cellxgene_census.get_anndata( + census, + organism=experiment_name, + obs_value_filter=obs_value_filter, + var_coords=hv_idx, + ) + # verify cell order: + assert np.array_equal( + np.array(adata.obs["soma_joinid"]), + inference_datamodule.train_dataset.query_ids.obs_joinids, + ) + + adata.obsm["scvi"] = latent + + # Additional things we would like to check + # we make the batch name the same as in the model + adata.obs["batch"] = adata.obs[batch_keys].agg("//".join, axis=1).astype("category") + + # scvi.model.SCVI.setup_anndata(adata, batch_key="batch") + # model_query_adata = model.load_query_data( + # adata=adata, reference_model="census_model_scanvi_from_scvi" + # ) + # model_query_adata.train(max_epochs=1, check_val_every_n_epoch=1, train_size=0.9) + # model_query_adata.predict(adata=adata) diff --git a/tests/model/saved_model/model.pt b/tests/model/saved_model/model.pt index eabae350c7..40256772aa 100644 Binary files a/tests/model/saved_model/model.pt and b/tests/model/saved_model/model.pt differ diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index fd04ef99c3..9b9d06ad03 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -1168,7 +1168,7 @@ def test_scvi_no_anndata(n_batches: int = 3, n_latent: int = 5): model.train(datamodule=datamodule) # must pass in datamodule if not initialized with adata - with pytest.raises(ValueError): + with pytest.raises(AttributeError): model.train() model.train(max_epochs=1, datamodule=datamodule) @@ -1182,10 +1182,6 @@ def test_scvi_no_anndata(n_batches: int = 3, n_latent: int = 5): assert model.module is not None assert hasattr(model, "adata") - # initialized with adata, cannot pass in datamodule - with pytest.raises(ValueError): - model.train(datamodule=datamodule) - def test_scvi_no_anndata_with_external_indices(n_batches: int = 3, n_latent: int = 5): from scvi.dataloaders import DataSplitter @@ -1207,7 +1203,7 @@ def test_scvi_no_anndata_with_external_indices(n_batches: int = 3, n_latent: int datamodule.n_vars = adata.n_vars datamodule.n_batch = n_batches - model = SCVI(n_latent=n_latent) + model = SCVI(adata=None, n_latent=n_latent) # model with no adata assert model._module_init_on_train assert model.module is None @@ -1216,7 +1212,7 @@ def test_scvi_no_anndata_with_external_indices(n_batches: int = 3, n_latent: int model.train(datamodule=datamodule) # must pass in datamodule if not initialized with adata - with pytest.raises(ValueError): + with pytest.raises(AttributeError): model.train() model.train(max_epochs=1, datamodule=datamodule) @@ -1230,10 +1226,6 @@ def test_scvi_no_anndata_with_external_indices(n_batches: int = 3, n_latent: int assert model.module is not None assert hasattr(model, "adata") - # initialized with adata, cannot pass in datamodule - with pytest.raises(ValueError): - model.train(datamodule=datamodule) - @pytest.mark.parametrize("embedding_dim", [5, 10]) @pytest.mark.parametrize("encode_covariates", [True, False])