Skip to content

Commit

Permalink
fix(jit): jit_trace now works with batch-norm (#355)
Browse files Browse the repository at this point in the history
Otherwise, the variance is an nan tensor during tracing which
-- for some reason -- does not get updated.

Resolves Issue #354
  • Loading branch information
sebffischer authored Feb 8, 2025
1 parent b45600d commit da449b4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

* fix: `LearnerTorchModel` can now be parallelized and trained with
encapsulation activated.
* fix: `jit_trace` now works in combination with batch normalization.

# mlr3torch 0.2.0

Expand Down
10 changes: 8 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,16 @@ CallbacksNone = function() {

get_example_batch = function(dl) {
ds = dl$dataset
if (length(ds) < 2) {
stopf("Dataset needs to contain at least 2 observations")
}
if (!is.null(ds$.getbatch)) {
ds$.getbatch(1)
ds$.getbatch(1:2)
} else {
ds$.getitem(1)$unsqueeze(1)
torch_stack(list(
ds$.getitem(1),
ds$.getitem(2)
))
}
}

Expand Down
13 changes: 13 additions & 0 deletions tests/testthat/test_PipeOpTorchBatchNorm.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,16 @@ test_that("PipeOpTorchBatchNorm3D paramtest", {
res = expect_paramset(po("nn_batch_norm3d"), nn_batch_norm3d, exclude = "num_features")
expect_paramtest(res)
})

test_that("jit_trace works (#354)", {
graph = po("torch_ingress_num") %>>%
nn("batch_norm1d") %>>%
nn("head") %>>%
po("torch_loss", t_loss("cross_entropy")) %>>%
po("torch_optimizer", t_opt("adamw")) %>>%
po("torch_model_classif", epochs = 1, batch_size = 50)
lrn = as_learner(graph)
task = tsk("iris")
lrn$train(task)
expect_prediction(lrn$predict(task))
})

0 comments on commit da449b4

Please sign in to comment.