diff --git a/scar/main/_scar.py b/scar/main/_scar.py index 94a302f..5d8a94d 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -713,7 +713,7 @@ def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None): self.raw_count = torch.from_numpy(raw_count).int().to(device) self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device) self.batch_id = batch_id.to(torch.int64).to(device) - self.batch_onehot = self._onehot(batch_id.to(torch.int64)).to(device) + self.batch_onehot = self._onehot() if list_ids: self.list_ids = list_ids @@ -733,10 +733,9 @@ def __getitem__(self, index): sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :] return sc_count, sc_ambient, sc_batch_id_onehot - def _onehot(self, batch_id): + def _onehot(self): """One-hot encoding""" - batch_id = batch_id.to(self.device) - n_batch = batch_id.unique().size()[0] - x_onehot = torch.zeros(n_batch, n_batch) - x_onehot.scatter_(1, batch_id.unique().unsqueeze(1), 1) + n_batch = self.batch_id.unique().size()[0] + x_onehot = torch.zeros(n_batch, n_batch).to(self.device) + x_onehot.scatter_(1, self.batch_id.unique().unsqueeze(1), 1) return x_onehot \ No newline at end of file