Skip to content

RESOLVI RuntimeError: dtype mismatch when dispersion set to "gene-batch" #3289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
tsvvas opened this issue Apr 10, 2025 · 1 comment
Closed
Labels

Comments

@tsvvas
Copy link

tsvvas commented Apr 10, 2025

Hi team,

I'm running into a PyTorch dtype mismatch error when training the RESOLVI model. The traceback suggests a matrix multiplication is being attempted between a LongTensor and a FloatTensor.

params = {
        "n_hidden": 32,
        "dropout_rate": 0.05,
        "dispersion": "gene-batch",
        "downsample_counts": False,
    }

scvi.external.RESOLVI.setup_anndata(
    adata,
    layer=LAYER_KEY,
    batch_key=BATCH_KEY,
    prepare_data_kwargs={"spatial_rep": "spatial"},
)
model = scvi.external.RESOLVI(adata, **params)
model.to_device(device)

model.train(**train_params)

The final and most important part of the error is:

File /opt/conda/envs/spatial/lib/python3.11/site-packages/scvi/external/resolvi/_module.py:870, in RESOLVAEGuide.forward(self, x, ind_x, library, y, batch_index, cat_covs, x_n, distances_n, n_obs, kl_weight)
    862 px_r_mle = pyro.param(
    863     "px_r_mle",
    864     self.px_r,
    865     constraint=constraints.greater_than(self.eps),
    866     event_dim=len(self.px_r.shape),
    867 )
    869 if self.dispersion == "gene-batch":
--> 870     px_r_inv = F.linear(
    871         torch.nn.functional.one_hot(batch_index.flatten(), self.n_batch), px_r_mle
    872     )
    873 elif self.dispersion == "gene":
    874     px_r_inv = px_r_mle

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: long int != float

Proposed solution:

The issue seems to be solved by casting the dtype of the result of one_hot() to float:

if self.dispersion == "gene-batch":
    one_hot = torch.nn.functional.one_hot(batch_index.flatten(), self.n_batch).to(px_r_mle.dtype)
    px_r_inv = F.linear(one_hot, px_r_mle)

Versions:

scvi-tools: 1.3.0
PyTorch: 2.6.0+cu124
CUDA: NVIDIA A100

Best regards,
Vasily

@tsvvas tsvvas added the bug label Apr 10, 2025
@canergen
Copy link
Member

canergen commented May 8, 2025

Yes, this is correct and fix will be pushed latter.

@canergen canergen closed this as completed May 8, 2025
canergen added a commit that referenced this issue May 8, 2025
Several bug fixes for reported issues.

#3283
#3208
#3289
#3267

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants