Skip to content

Commit

Permalink
fix: fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed May 26, 2024
1 parent e5ebaf8 commit d6c1663
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ class UMIDataset(torch.utils.data.Dataset):

def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None):
"""Initialization"""
self.device = device
self.raw_count = torch.from_numpy(raw_count).int().to(device)
self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device)
self.batch_id = batch_id.to(torch.int64).to(device)
Expand All @@ -734,7 +735,7 @@ def __getitem__(self, index):

def _onehot(self, batch_id):
"""One-hot encoding"""
n_batch = batch_id.unique().size()[0]
n_batch = batch_id.to(self.device).unique().size()[0]
x_onehot = torch.zeros(n_batch, n_batch)
x_onehot.scatter_(1, batch_id.unique().unsqueeze(1), 1)
return x_onehot

0 comments on commit d6c1663

Please sign in to comment.