You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I would like to ask a question regarding support for different devices in PyTorch. I encountered an issue where, if ( x ) and ( y ) are located on the "cuda" device, the SmoothTopkSVM cannot be computed because the "labels" variable is on the "cpu". I reused the explanatory code from a previous issue and added a device line.
importtorchfromtopk.svmimportSmoothTopkSVM# hyper-parameterstopk=10# top-10 classificationalpha=1# margin parameter of hinge loss, 1 is a default valuetau=0.1# smoothing temperature parameter (see paper for details)# define the deviceiftorch.cuda.is_available():
device=torch.device('cuda')
# random databatch_size=3n_dim=20n_classes=100x=torch.randn(batch_size, n_dim).to(device)
y=torch.randint(0, n_classes, size=(batch_size,)).long().to(device)
# modelmodel=torch.nn.Linear(n_dim, n_classes).to(device)
# loss functionloss_fn=SmoothTopkSVM(n_classes, alpha=alpha, tau=tau, k=topk).to(device)
# forward passloss=loss_fn(model(x), y)
# backward passloss.backward()
# print gradientprint(model.weight.grad)
Dear Authors,
I am using your repository for my re-implementation of the CLAM model. You can see the citation in the environment file here: https://github.com/mahmoodlab/CLAM/blob/master/env.yml.
I would like to ask a question regarding support for different devices in PyTorch. I encountered an issue where, if ( x ) and ( y ) are located on the "cuda" device, the SmoothTopkSVM cannot be computed because the "labels" variable is on the "cpu". I reused the explanatory code from a previous issue and added a device line.
If I run this, the error
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
comes from this line, where the "labels" variable is not on "cuda":https://github.com/oval-group/smooth-topk/blob/master/topk/utils.py#L23.
I could fix it with:
and also here: https://github.com/oval-group/smooth-topk/blob/master/topk/functional.py#L35 with:
Thank you for your time and help.
Best regards :)
The text was updated successfully, but these errors were encountered: