Skip to content

Commit c83b890

Browse files
authored
feat: Linear SCVI adata minifiction (#3294)
added adata minification for linear scvi
1 parent c97655d commit c83b890

File tree

3 files changed

+51
-3
lines changed

3 files changed

+51
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ to [Semantic Versioning]. Full commit history is available in the
1616
- Add get normalized function model property for any generative model {pr}`3238` and changed
1717
get_accessibility_estimates to get_normalized_accessibility, where needed.
1818
- Add Early stopping KL warmup steps. {pr}`3262`.
19+
- Add Minification option to {class}`~scvi.model.LinearSCVI` {pr}`3294`.
1920

2021
#### Fixed
2122

src/scvi/model/_linear_scvi.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77

88
from scvi import REGISTRY_KEYS
99
from scvi.data import AnnDataManager
10+
from scvi.data._constants import ADATA_MINIFY_TYPE
11+
from scvi.data._utils import _get_adata_minify_type
1012
from scvi.data.fields import CategoricalObsField, LayerField
1113
from scvi.model._utils import _init_library_size
1214
from scvi.model.base import UnsupervisedTrainingMixin
1315
from scvi.module import LDVAE
1416
from scvi.utils import setup_anndata_dsp
1517

16-
from .base import BaseModelClass, RNASeqMixin, VAEMixin
18+
from .base import BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin
1719

1820
if TYPE_CHECKING:
1921
from typing import Literal
@@ -23,7 +25,12 @@
2325
logger = logging.getLogger(__name__)
2426

2527

26-
class LinearSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
28+
class LinearSCVI(
29+
RNASeqMixin,
30+
VAEMixin,
31+
UnsupervisedTrainingMixin,
32+
BaseMinifiedModeModelClass,
33+
):
2734
"""Linearly-decoded VAE :cite:p:`Svensson20`.
2835
2936
Parameters
@@ -51,6 +58,9 @@ class LinearSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClas
5158
* ``'nb'`` - Negative binomial distribution
5259
* ``'zinb'`` - Zero-inflated negative binomial distribution
5360
* ``'poisson'`` - Poisson distribution
61+
use_observed_lib_size
62+
If ``True``, use the observed library size for RNA as the scaling factor in the mean of the
63+
conditional distribution.
5464
latent_distribution
5565
One of:
5666
@@ -75,6 +85,8 @@ class LinearSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClas
7585
"""
7686

7787
_module_cls = LDVAE
88+
_LATENT_QZM_KEY = "ldvae_latent_qzm"
89+
_LATENT_QZV_KEY = "ldvae_latent_qzv"
7890

7991
def __init__(
8092
self,
@@ -85,13 +97,19 @@ def __init__(
8597
dropout_rate: float = 0.1,
8698
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
8799
gene_likelihood: Literal["zinb", "nb", "poisson"] = "nb",
100+
use_observed_lib_size: bool = False,
88101
latent_distribution: Literal["normal", "ln"] = "normal",
89102
**model_kwargs,
90103
):
91104
super().__init__(adata)
92105

93106
n_batch = self.summary_stats.n_batch
94-
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)
107+
library_log_means, library_log_vars = None, None
108+
if (
109+
self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
110+
and not use_observed_lib_size
111+
):
112+
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)
95113

96114
self.module = self._module_cls(
97115
n_input=self.summary_stats.n_vars,
@@ -105,6 +123,7 @@ def __init__(
105123
latent_distribution=latent_distribution,
106124
library_log_means=library_log_means,
107125
library_log_vars=library_log_vars,
126+
use_observed_lib_size=use_observed_lib_size,
108127
**model_kwargs,
109128
)
110129
self._model_summary_string = (
@@ -153,6 +172,10 @@ def setup_anndata(
153172
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
154173
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
155174
]
175+
# register new fields if the adata is minified
176+
adata_minify_type = _get_adata_minify_type(adata)
177+
if adata_minify_type is not None:
178+
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
156179
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
157180
adata_manager.register_fields(adata, **kwargs)
158181
cls.register_manager(adata_manager)

tests/model/test_linear_scvi.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch
77

88
from scvi.data import synthetic_iid
9+
from scvi.data._constants import ADATA_MINIFY_TYPE
10+
from scvi.data._utils import _is_minified
911
from scvi.model import LinearSCVI
1012
from scvi.utils import attrdict
1113

@@ -131,3 +133,25 @@ def test_linear_scvi_use_observed_lib_size():
131133
model.get_loadings()
132134
model.differential_expression(groupby="labels", group1="label_1")
133135
model.differential_expression(groupby="labels", group1="label_1", group2="label_2")
136+
137+
138+
def test_linear_scvi_with_minification(save_path):
139+
adata = synthetic_iid()
140+
adata = adata[:, :10].copy()
141+
LinearSCVI.setup_anndata(adata)
142+
model = LinearSCVI(adata, n_latent=10, use_observed_lib_size=True)
143+
model.train(1, check_val_every_n_epoch=1, train_size=0.5)
144+
assert len(model.history["elbo_train"]) == 1
145+
assert len(model.history["elbo_validation"]) == 1
146+
qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True)
147+
model.adata.obsm["X_latent_qzm"] = qzm
148+
model.adata.obsm["X_latent_qzv"] = qzv
149+
model.minify_adata()
150+
assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR
151+
assert model.adata_manager.registry is model.registry_
152+
assert not _is_minified(adata)
153+
assert adata is not model.adata
154+
assert len(model.adata.X.data) == 0
155+
assert model.adata.raw is None
156+
minified_model_path = os.path.join(save_path, "linear_scvi_minified")
157+
model.save(minified_model_path, save_anndata=True, overwrite=True)

0 commit comments

Comments
 (0)