diff --git a/NEWS.md b/NEWS.md index c17af676..c287c99b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # mlr3torch dev +* fix: `LearnerTorchModel` can now be parallelized and trained with + encapsulation activated. + # mlr3torch 0.2.0 ## Breaking Changes @@ -15,9 +18,9 @@ * Optimizers now use the faster ('ignite') version of the optimizers, which leads to considerable speed improvements. -* The `jit_trace` parameter was added to `LearnerTorch`, which when set to +* The `jit_trace` parameter was added to `LearnerTorch`, which when set to `TRUE` can lead to significant speedups. - This should only be enabled for 'static' models, see the + This should only be enabled for 'static' models, see the [torch tutorial](https://torch.mlverse.org/docs/articles/torchscript) for more information. * Added parameter `num_interop_threads` to `LearnerTorch`. diff --git a/R/LearnerTorchModel.R b/R/LearnerTorchModel.R index 6dd43032..b64b0ac4 100644 --- a/R/LearnerTorchModel.R +++ b/R/LearnerTorchModel.R @@ -60,8 +60,8 @@ LearnerTorchModel = R6Class("LearnerTorchModel", #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(network = NULL, ingress_tokens = NULL, task_type, properties = NULL, optimizer = NULL, loss = NULL, callbacks = list(), packages = character(0), feature_types = NULL) { - # TODO: What about the learner properties? - if (!is.null(network)) self$network_stored = network + # we need to serialize here as otherwise encapsulation and parallelization fails + if (!is.null(network)) private$.network_stored = torch_serialize(assert_class(network, "nn_module")) if (!is.null(ingress_tokens)) self$ingress_tokens = ingress_tokens if (is.null(feature_types)) { feature_types = unname(mlr_reflections$task_feature_types) @@ -89,15 +89,6 @@ LearnerTorchModel = R6Class("LearnerTorchModel", } ), active = list( - #' @field network_stored (`nn_module` or `NULL`)\cr - #' The network that will be trained. - #' After calling `$train()`, this is `NULL`. - network_stored = function(rhs) { - if (!missing(rhs)) { - private$.network_stored = assert_class(rhs, "nn_module") - } - private$.network_stored - }, #' @field ingress_tokens (named `list()` with `TorchIngressToken` or `NULL`)\cr #' The ingress tokens. Must be non-`NULL` when calling `$train()`. ingress_tokens = function(rhs) { @@ -121,7 +112,12 @@ LearnerTorchModel = R6Class("LearnerTorchModel", if (is.null(private$.network_stored)) { stopf("No network stored, did you already train learner '%s' or did not specify a model?", self$id) } - network = private$.network_stored + network = if (test_class(private$.network_stored, "nn_module")) { + # optimization for PipeOpTorchModel, where we control the construction of LearnerTorchModel + private$.network_stored + } else { + torch_load(private$.network_stored) + } private$.network_stored = NULL network }, diff --git a/R/PipeOpTorchModel.R b/R/PipeOpTorchModel.R index 0c5828e9..5042cc12 100644 --- a/R/PipeOpTorchModel.R +++ b/R/PipeOpTorchModel.R @@ -69,7 +69,9 @@ PipeOpTorchModel = R6Class("PipeOpTorchModel", output_pointers = list(md$pointer), list_output = FALSE ) - private$.learner$network_stored = network + # Because we control the creation of the LearnerTorchModel, we know that it's fitted in the same + # process as the current .train function, hence, we can avoid the serialization round-trip + get_private(private$.learner, ".network_stored") = network private$.learner$ingress_tokens = md$ingress if (is.null(md$loss)) { diff --git a/man/mlr_learners.mlp.Rd b/man/mlr_learners.mlp.Rd index 7f03f5c7..362d1a0d 100644 --- a/man/mlr_learners.mlp.Rd +++ b/man/mlr_learners.mlp.Rd @@ -107,11 +107,13 @@ Other Learner:
Inherited methods