Skip to content

Commit

Permalink
fix: place tensor on cpu before converting to R
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jan 17, 2024
1 parent 4d5d123 commit 7291b50
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 8 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ ByteCompile: no
VignetteBuilder: knitr
Encoding: UTF-8
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.2.3.9000
Collate:
'CallbackSet.R'
'zzz.R'
Expand Down
7 changes: 4 additions & 3 deletions R/learner_torch_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ train_loop = function(ctx, cbs) {

ctx$last_loss = loss$item()
predictions[[length(predictions) + 1]] = y_hat$detach()
indices[[length(indices) + 1]] = as.numeric(batch$.index)
indices[[length(indices) + 1]] = as.integer(batch$.index$to(device = "cpu"))
ctx$optimizer$step()

call("on_batch_end")
Expand Down Expand Up @@ -245,11 +245,12 @@ encode_prediction_default = function(predict_tensor, predict_type, task) {
response = prob = NULL
if (task$task_type == "classif") {
if (predict_type == "prob") {
predict_tensor = nnf_softmax(predict_tensor, dim = 2L)
predict_tensor = with_no_grad(nnf_softmax(predict_tensor, dim = 2L))
}
# We still execute the argmax on the device before converting to R
response = as.integer(predict_tensor$argmax(dim = 2L))
response = as.integer(with_no_grad(predict_tensor$argmax(dim = 2L))$to(device = "cpu"))

predict_tensor = predict_tensor$to(device = "cpu")
if (predict_type == "prob") {
prob = as.matrix(predict_tensor)
colnames(prob) = task$class_names
Expand Down
8 changes: 7 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
auto_device = function(device = NULL) {
if (device == "auto") {
device = if (cuda_is_available()) "cuda" else "cpu"
device = if (cuda_is_available()) {
"cuda"
} else if (backends_mps_is_available()) {
"mps"
} else {
"cpu"
}
lg$debug("Auto-detected device '%s'.", device)
}
return(device)
Expand Down
3 changes: 2 additions & 1 deletion man-roxygen/paramset_torchlearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#' The number of epochs.
#' * `device` :: `character(1)`\cr
#' The device. One of `"auto"`, `"cpu"`, or `"cuda"` or other values defined in `mlr_reflections$torch$devices`.
#' The value is initialized to `"auto"`, which will select `"cuda"` if possible and `"cpu"` otherwise.
#' The value is initialized to `"auto"`, which will select `"cuda"` if possible, then try `"mps"` and otherwise
#' fall back to `"cpu"`.
#' * `measures_train` :: [`Measure`] or `list()` of [`Measure`]s.
#' Measures to be evaluated during training.
#' * `measures_valid` :: [`Measure`] or `list()` of [`Measure`]s.
Expand Down
3 changes: 2 additions & 1 deletion man/mlr_learners_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/mlr_pipeops_torch_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7291b50

Please sign in to comment.