7
7
8
8
from scvi import REGISTRY_KEYS
9
9
from scvi .data import AnnDataManager
10
+ from scvi .data ._constants import ADATA_MINIFY_TYPE
11
+ from scvi .data ._utils import _get_adata_minify_type
10
12
from scvi .data .fields import CategoricalObsField , LayerField
11
13
from scvi .model ._utils import _init_library_size
12
14
from scvi .model .base import UnsupervisedTrainingMixin
13
15
from scvi .module import LDVAE
14
16
from scvi .utils import setup_anndata_dsp
15
17
16
- from .base import BaseModelClass , RNASeqMixin , VAEMixin
18
+ from .base import BaseMinifiedModeModelClass , RNASeqMixin , VAEMixin
17
19
18
20
if TYPE_CHECKING :
19
21
from typing import Literal
23
25
logger = logging .getLogger (__name__ )
24
26
25
27
26
- class LinearSCVI (RNASeqMixin , VAEMixin , UnsupervisedTrainingMixin , BaseModelClass ):
28
+ class LinearSCVI (
29
+ RNASeqMixin ,
30
+ VAEMixin ,
31
+ UnsupervisedTrainingMixin ,
32
+ BaseMinifiedModeModelClass ,
33
+ ):
27
34
"""Linearly-decoded VAE :cite:p:`Svensson20`.
28
35
29
36
Parameters
@@ -51,6 +58,9 @@ class LinearSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClas
51
58
* ``'nb'`` - Negative binomial distribution
52
59
* ``'zinb'`` - Zero-inflated negative binomial distribution
53
60
* ``'poisson'`` - Poisson distribution
61
+ use_observed_lib_size
62
+ If ``True``, use the observed library size for RNA as the scaling factor in the mean of the
63
+ conditional distribution.
54
64
latent_distribution
55
65
One of:
56
66
@@ -75,6 +85,8 @@ class LinearSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClas
75
85
"""
76
86
77
87
_module_cls = LDVAE
88
+ _LATENT_QZM_KEY = "ldvae_latent_qzm"
89
+ _LATENT_QZV_KEY = "ldvae_latent_qzv"
78
90
79
91
def __init__ (
80
92
self ,
@@ -85,13 +97,19 @@ def __init__(
85
97
dropout_rate : float = 0.1 ,
86
98
dispersion : Literal ["gene" , "gene-batch" , "gene-label" , "gene-cell" ] = "gene" ,
87
99
gene_likelihood : Literal ["zinb" , "nb" , "poisson" ] = "nb" ,
100
+ use_observed_lib_size : bool = False ,
88
101
latent_distribution : Literal ["normal" , "ln" ] = "normal" ,
89
102
** model_kwargs ,
90
103
):
91
104
super ().__init__ (adata )
92
105
93
106
n_batch = self .summary_stats .n_batch
94
- library_log_means , library_log_vars = _init_library_size (self .adata_manager , n_batch )
107
+ library_log_means , library_log_vars = None , None
108
+ if (
109
+ self .minified_data_type != ADATA_MINIFY_TYPE .LATENT_POSTERIOR
110
+ and not use_observed_lib_size
111
+ ):
112
+ library_log_means , library_log_vars = _init_library_size (self .adata_manager , n_batch )
95
113
96
114
self .module = self ._module_cls (
97
115
n_input = self .summary_stats .n_vars ,
@@ -105,6 +123,7 @@ def __init__(
105
123
latent_distribution = latent_distribution ,
106
124
library_log_means = library_log_means ,
107
125
library_log_vars = library_log_vars ,
126
+ use_observed_lib_size = use_observed_lib_size ,
108
127
** model_kwargs ,
109
128
)
110
129
self ._model_summary_string = (
@@ -153,6 +172,10 @@ def setup_anndata(
153
172
CategoricalObsField (REGISTRY_KEYS .BATCH_KEY , batch_key ),
154
173
CategoricalObsField (REGISTRY_KEYS .LABELS_KEY , labels_key ),
155
174
]
175
+ # register new fields if the adata is minified
176
+ adata_minify_type = _get_adata_minify_type (adata )
177
+ if adata_minify_type is not None :
178
+ anndata_fields += cls ._get_fields_for_adata_minification (adata_minify_type )
156
179
adata_manager = AnnDataManager (fields = anndata_fields , setup_method_args = setup_method_args )
157
180
adata_manager .register_fields (adata , ** kwargs )
158
181
cls .register_manager (adata_manager )
0 commit comments