Skip to content

Commit b3648e0

Browse files
Bug fixes resolVI. (#3308)
Several bug fixes for reported issues. #3283 #3208 #3289 #3267 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 17da5f7 commit b3648e0

File tree

5 files changed

+46
-24
lines changed

5 files changed

+46
-24
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@ to [Semantic Versioning]. Full commit history is available in the
2323
- Add consideration for missing monitor set during early stopping. {pr}`3226`.
2424
- Fix bug in SysVI get_normalized_expression function. {pr}`3255`.
2525
- Add support for IntegratedGradients for multimodal models. {pr}`3264`.
26+
- Fix bug in resolVI get_normalized expression function. {pr}`3308`.
27+
- Fix bug in resolVI gene-assay dispersion. {pr}`3308`.
2628

2729
#### Changed
2830

2931
- Updated Scvi-Tools AWS hub to Weizmann instead of Berkeley. {pr}`3246`.
32+
- Updated resolVI to use rapids-singlecell. {pr}`3308`.
3033

3134
#### Removed
3235

src/scvi/external/resolvi/_model.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import importlib.util
43
import logging
54
from functools import partial
65
from typing import TYPE_CHECKING
@@ -343,7 +342,11 @@ def setup_anndata(
343342
cls.register_manager(adata_manager)
344343

345344
@staticmethod
346-
def _prepare_data(adata, n_neighbors=10, spatial_rep="X_spatial", batch_key=None, **kwargs):
345+
def _prepare_data(
346+
adata, n_neighbors=10, spatial_rep="X_spatial", batch_key=None, slice_key=None, **kwargs
347+
):
348+
if slice_key is not None:
349+
batch_key = slice_key
347350
try:
348351
import scanpy
349352
from sklearn.neighbors._base import _kneighbors_from_graph
@@ -365,13 +368,15 @@ def _prepare_data(adata, n_neighbors=10, spatial_rep="X_spatial", batch_key=None
365368

366369
for index in indices:
367370
sub_data = adata[index].copy()
368-
if importlib.util.find_spec("cuml") is not None:
369-
method = "rapids"
370-
else:
371-
method = "umap"
372-
scanpy.pp.neighbors(
373-
sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep, method=method
374-
)
371+
try:
372+
import rapids_singlecell
373+
374+
print("RAPIDS SingleCell is installed and can be imported")
375+
rapids_singlecell.pp.neighbors(
376+
sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep
377+
)
378+
except ImportError:
379+
scanpy.pp.neighbors(sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep)
375380
distances = sub_data.obsp["distances"] ** 2
376381

377382
distance_neighbor[index, :], index_neighbor_batch = _kneighbors_from_graph(

src/scvi/external/resolvi/_module.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ def __init__(
163163
init_px_r = torch.full([n_input, n_batch], 0.01)
164164
else:
165165
raise ValueError(
166-
"dispersion must be one of ['gene', 'gene-batch', 'gene-label'], but input was "
167-
"{}.format(self.dispersion)"
166+
f"dispersion must be one of ['gene', 'gene-batch'], but input was {dispersion}."
168167
)
169168
self.register_buffer("px_r", init_px_r)
170169

@@ -751,8 +750,7 @@ def __init__(
751750
init_px_r = torch.full([n_input, n_batch], 0.01)
752751
else:
753752
raise ValueError(
754-
"dispersion must be one of ['gene', 'gene-batch', 'gene-label'], but input was "
755-
"{}.format(dispersion)"
753+
f"dispersion must be one of ['gene', 'gene-batch'], but input was {dispersion}."
756754
)
757755
self.register_buffer("px_r", init_px_r)
758756
self.register_buffer("per_neighbor_diffusion_init", torch.zeros([n_obs, n_neighbors]))
@@ -868,7 +866,10 @@ def forward( # not used arguments to have same set of arguments in model and gu
868866

869867
if self.dispersion == "gene-batch":
870868
px_r_inv = F.linear(
871-
torch.nn.functional.one_hot(batch_index.flatten(), self.n_batch), px_r_mle
869+
torch.nn.functional.one_hot(batch_index.flatten(), self.n_batch).to(
870+
px_r_mle.dtype
871+
),
872+
px_r_mle,
872873
)
873874
elif self.dispersion == "gene":
874875
px_r_inv = px_r_mle

src/scvi/external/resolvi/_utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def get_normalized_expression(
229229
library_size
230230
Scale the expression frequencies to a common library size.
231231
This allows gene expression levels to be interpreted on a common scale of relevant
232-
magnitude. If set to `"latent"`, use the latent library size.
232+
magnitude.
233233
n_samples
234234
Number of posterior samples to use for estimation.
235235
n_samples_overall
@@ -301,32 +301,28 @@ def get_normalized_expression(
301301
kwargs["batch_index"],
302302
*categorical_input,
303303
)
304-
z = torch.distributions.Normal(qz_m, qz_v.sqrt()).sample(
305-
[
306-
n_samples,
307-
]
308-
)
304+
z = torch.distributions.Normal(qz_m, qz_v.sqrt()).sample([n_samples])
309305

310306
if kwargs["cat_covs"] is not None:
311307
categorical_input = list(torch.split(kwargs["cat_covs"], 1, dim=1))
312308
else:
313309
categorical_input = ()
314310
if batch is not None:
315-
batch = torch.full_like(kwargs["batch"], batch)
311+
batch = torch.full_like(kwargs["batch_index"], batch)
316312
else:
317313
batch = kwargs["batch_index"]
318314

319315
px_scale, _, px_rate, _ = self.module.model.decoder(
320316
self.module.model.dispersion, z, kwargs["library"], batch, *categorical_input
321317
)
322318
if library_size is not None:
323-
exp_ = library_size * px_scale.reshape(-1, px_scale.shape[-1])
319+
exp_ = library_size * px_scale
324320
else:
325-
exp_ = px_rate.reshape(-1, px_scale.shape[-1])
321+
exp_ = px_rate
326322

327323
exp_ = exp_[..., gene_mask]
328324
per_batch_exprs.append(exp_[None].cpu())
329-
per_batch_exprs = torch.cat(per_batch_exprs, dim=0).numpy()
325+
per_batch_exprs = torch.cat(per_batch_exprs, dim=0).mean(0).numpy()
330326
exprs.append(per_batch_exprs)
331327

332328
exprs = np.concatenate(exprs, axis=1)

tests/external/resolvi/test_resolvi.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def test_resolvi_train(adata):
2323
model.train(
2424
max_epochs=2,
2525
)
26+
model = RESOLVI(adata, dispersion="gene-batch")
27+
model.train(
28+
max_epochs=2,
29+
)
2630

2731

2832
def test_resolvi_save_load(adata):
@@ -52,8 +56,21 @@ def test_resolvi_downstream(adata):
5256
)
5357
latent = model.get_latent_representation()
5458
assert latent.shape == (adata.n_obs, model.module.n_latent)
59+
counts = model.get_normalized_expression(n_samples=31, library_size=10000)
60+
counts = model.get_normalized_expression_importance(n_samples=30, library_size=10000)
61+
print("FFFFFF", counts.shape)
5562
model.differential_expression(groupby="labels")
5663
model.differential_expression(groupby="labels", weights="importance")
64+
model.sample_posterior(
65+
model=model.module.model_residuals,
66+
num_samples=30,
67+
return_samples=False,
68+
return_sites=None,
69+
batch_size=1000,
70+
)
71+
model.sample_posterior(
72+
model=model.module.model_residuals, num_samples=30, return_samples=False, batch_size=1000
73+
)
5774
model_query = model.load_query_data(reference_model=model, adata=adata)
5875
model_query = model.load_query_data(reference_model="test_resolvi", adata=adata)
5976
model_query.train(

0 commit comments

Comments
 (0)