From 7291b50c880c8bf89f7380e7d64656b016d61e21 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 17 Jan 2024 15:42:50 +0100 Subject: [PATCH] fix: place tensor on cpu before converting to R --- DESCRIPTION | 2 +- R/learner_torch_methods.R | 7 ++++--- R/utils.R | 8 +++++++- man-roxygen/paramset_torchlearner.R | 3 ++- man/mlr_learners_torch.Rd | 3 ++- man/mlr_pipeops_torch_model.Rd | 3 ++- 6 files changed, 18 insertions(+), 8 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index c30cccbf..6d8baf19 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -71,7 +71,7 @@ ByteCompile: no VignetteBuilder: knitr Encoding: UTF-8 Roxygen: list(markdown = TRUE, r6 = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.2.3.9000 Collate: 'CallbackSet.R' 'zzz.R' diff --git a/R/learner_torch_methods.R b/R/learner_torch_methods.R index 96a516da..4a5cc780 100644 --- a/R/learner_torch_methods.R +++ b/R/learner_torch_methods.R @@ -153,7 +153,7 @@ train_loop = function(ctx, cbs) { ctx$last_loss = loss$item() predictions[[length(predictions) + 1]] = y_hat$detach() - indices[[length(indices) + 1]] = as.numeric(batch$.index) + indices[[length(indices) + 1]] = as.integer(batch$.index$to(device = "cpu")) ctx$optimizer$step() call("on_batch_end") @@ -245,11 +245,12 @@ encode_prediction_default = function(predict_tensor, predict_type, task) { response = prob = NULL if (task$task_type == "classif") { if (predict_type == "prob") { - predict_tensor = nnf_softmax(predict_tensor, dim = 2L) + predict_tensor = with_no_grad(nnf_softmax(predict_tensor, dim = 2L)) } # We still execute the argmax on the device before converting to R - response = as.integer(predict_tensor$argmax(dim = 2L)) + response = as.integer(with_no_grad(predict_tensor$argmax(dim = 2L))$to(device = "cpu")) + predict_tensor = predict_tensor$to(device = "cpu") if (predict_type == "prob") { prob = as.matrix(predict_tensor) colnames(prob) = task$class_names diff --git a/R/utils.R b/R/utils.R index ba53fe54..5ab34061 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1,6 +1,12 @@ auto_device = function(device = NULL) { if (device == "auto") { - device = if (cuda_is_available()) "cuda" else "cpu" + device = if (cuda_is_available()) { + "cuda" + } else if (backends_mps_is_available()) { + "mps" + } else { + "cpu" + } lg$debug("Auto-detected device '%s'.", device) } return(device) diff --git a/man-roxygen/paramset_torchlearner.R b/man-roxygen/paramset_torchlearner.R index 89f5bb76..0d8b803f 100644 --- a/man-roxygen/paramset_torchlearner.R +++ b/man-roxygen/paramset_torchlearner.R @@ -6,7 +6,8 @@ #' The number of epochs. #' * `device` :: `character(1)`\cr #' The device. One of `"auto"`, `"cpu"`, or `"cuda"` or other values defined in `mlr_reflections$torch$devices`. -#' The value is initialized to `"auto"`, which will select `"cuda"` if possible and `"cpu"` otherwise. +#' The value is initialized to `"auto"`, which will select `"cuda"` if possible, then try `"mps"` and otherwise +#' fall back to `"cpu"`. #' * `measures_train` :: [`Measure`] or `list()` of [`Measure`]s. #' Measures to be evaluated during training. #' * `measures_valid` :: [`Measure`] or `list()` of [`Measure`]s. diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd index 55d72e08..6975304b 100644 --- a/man/mlr_learners_torch.Rd +++ b/man/mlr_learners_torch.Rd @@ -24,7 +24,8 @@ The batch size. The number of epochs. \item \code{device} :: \code{character(1)}\cr The device. One of \code{"auto"}, \code{"cpu"}, or \code{"cuda"} or other values defined in \code{mlr_reflections$torch$devices}. -The value is initialized to \code{"auto"}, which will select \code{"cuda"} if possible and \code{"cpu"} otherwise. +The value is initialized to \code{"auto"}, which will select \code{"cuda"} if possible, then try \code{"mps"} and otherwise +fall back to \code{"cpu"}. \item \code{measures_train} :: \code{\link{Measure}} or \code{list()} of \code{\link{Measure}}s. Measures to be evaluated during training. \item \code{measures_valid} :: \code{\link{Measure}} or \code{list()} of \code{\link{Measure}}s. diff --git a/man/mlr_pipeops_torch_model.Rd b/man/mlr_pipeops_torch_model.Rd index db642859..b72653cf 100644 --- a/man/mlr_pipeops_torch_model.Rd +++ b/man/mlr_pipeops_torch_model.Rd @@ -29,7 +29,8 @@ The batch size. The number of epochs. \item \code{device} :: \code{character(1)}\cr The device. One of \code{"auto"}, \code{"cpu"}, or \code{"cuda"} or other values defined in \code{mlr_reflections$torch$devices}. -The value is initialized to \code{"auto"}, which will select \code{"cuda"} if possible and \code{"cpu"} otherwise. +The value is initialized to \code{"auto"}, which will select \code{"cuda"} if possible, then try \code{"mps"} and otherwise +fall back to \code{"cpu"}. \item \code{measures_train} :: \code{\link{Measure}} or \code{list()} of \code{\link{Measure}}s. Measures to be evaluated during training. \item \code{measures_valid} :: \code{\link{Measure}} or \code{list()} of \code{\link{Measure}}s.