Skip to content

Commit

Permalink
fix(learner): LearnerTorchModel now works in parallel settings (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer authored Feb 7, 2025
1 parent 4148c9e commit a5e7403
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 20 deletions.
7 changes: 5 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# mlr3torch dev

* fix: `LearnerTorchModel` can now be parallelized and trained with
encapsulation activated.

# mlr3torch 0.2.0

## Breaking Changes
Expand All @@ -15,9 +18,9 @@

* Optimizers now use the faster ('ignite') version of the optimizers,
which leads to considerable speed improvements.
* The `jit_trace` parameter was added to `LearnerTorch`, which when set to
* The `jit_trace` parameter was added to `LearnerTorch`, which when set to
`TRUE` can lead to significant speedups.
This should only be enabled for 'static' models, see the
This should only be enabled for 'static' models, see the
[torch tutorial](https://torch.mlverse.org/docs/articles/torchscript)
for more information.
* Added parameter `num_interop_threads` to `LearnerTorch`.
Expand Down
20 changes: 8 additions & 12 deletions R/LearnerTorchModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ LearnerTorchModel = R6Class("LearnerTorchModel",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(network = NULL, ingress_tokens = NULL, task_type, properties = NULL, optimizer = NULL, loss = NULL,
callbacks = list(), packages = character(0), feature_types = NULL) {
# TODO: What about the learner properties?
if (!is.null(network)) self$network_stored = network
# we need to serialize here as otherwise encapsulation and parallelization fails
if (!is.null(network)) private$.network_stored = torch_serialize(assert_class(network, "nn_module"))
if (!is.null(ingress_tokens)) self$ingress_tokens = ingress_tokens
if (is.null(feature_types)) {
feature_types = unname(mlr_reflections$task_feature_types)
Expand Down Expand Up @@ -89,15 +89,6 @@ LearnerTorchModel = R6Class("LearnerTorchModel",
}
),
active = list(
#' @field network_stored (`nn_module` or `NULL`)\cr
#' The network that will be trained.
#' After calling `$train()`, this is `NULL`.
network_stored = function(rhs) {
if (!missing(rhs)) {
private$.network_stored = assert_class(rhs, "nn_module")
}
private$.network_stored
},
#' @field ingress_tokens (named `list()` with `TorchIngressToken` or `NULL`)\cr
#' The ingress tokens. Must be non-`NULL` when calling `$train()`.
ingress_tokens = function(rhs) {
Expand All @@ -121,7 +112,12 @@ LearnerTorchModel = R6Class("LearnerTorchModel",
if (is.null(private$.network_stored)) {
stopf("No network stored, did you already train learner '%s' or did not specify a model?", self$id)
}
network = private$.network_stored
network = if (test_class(private$.network_stored, "nn_module")) {
# optimization for PipeOpTorchModel, where we control the construction of LearnerTorchModel
private$.network_stored
} else {
torch_load(private$.network_stored)
}
private$.network_stored = NULL
network
},
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpTorchModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ PipeOpTorchModel = R6Class("PipeOpTorchModel",
output_pointers = list(md$pointer),
list_output = FALSE
)
private$.learner$network_stored = network
# Because we control the creation of the LearnerTorchModel, we know that it's fitted in the same
# process as the current .train function, hence, we can avoid the serialization round-trip
get_private(private$.learner, ".network_stored") = network
private$.learner$ingress_tokens = md$ingress

if (is.null(md$loss)) {
Expand Down
2 changes: 2 additions & 0 deletions man/mlr_learners.mlp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mlr_learners.tab_resnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/mlr_learners.torch_featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mlr_learners.torchvision.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mlr_learners_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mlr_learners_torch_image.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit a5e7403

Please sign in to comment.