diff --git a/vamb/cluster.py b/vamb/cluster.py index b06f5fb0..9ea8e560 100644 --- a/vamb/cluster.py +++ b/vamb/cluster.py @@ -349,7 +349,7 @@ def _smaller_indices(tensor, kept_mask, threshold, cuda): # If it's on GPU, we remove the already clustered points at this step. if cuda: - return _torch.nonzero((tensor <= threshold) & kept_mask).flatten() + return _torch.nonzero((tensor <= threshold) & kept_mask).flatten().cpu() else: arr = tensor.numpy() indices = (arr <= threshold).nonzero()[0]