From d6c16630692538d4cd0dd24cd3901301f891b28f Mon Sep 17 00:00:00 2001 From: Caibin Sheng Date: Sun, 26 May 2024 23:52:47 +0200 Subject: [PATCH] fix: fix a bug --- scar/main/_scar.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scar/main/_scar.py b/scar/main/_scar.py index 9491f88..7be5f07 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -709,6 +709,7 @@ class UMIDataset(torch.utils.data.Dataset): def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None): """Initialization""" + self.device = device self.raw_count = torch.from_numpy(raw_count).int().to(device) self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device) self.batch_id = batch_id.to(torch.int64).to(device) @@ -734,7 +735,7 @@ def __getitem__(self, index): def _onehot(self, batch_id): """One-hot encoding""" - n_batch = batch_id.unique().size()[0] + n_batch = batch_id.to(self.device).unique().size()[0] x_onehot = torch.zeros(n_batch, n_batch) x_onehot.scatter_(1, batch_id.unique().unsqueeze(1), 1) return x_onehot \ No newline at end of file