From da449b4285ac6bb51ac7d64813a9cde26e3284b9 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 8 Feb 2025 11:51:55 +0100 Subject: [PATCH] fix(jit): jit_trace now works with batch-norm (#355) Otherwise, the variance is an nan tensor during tracing which -- for some reason -- does not get updated. Resolves Issue #354 --- NEWS.md | 1 + R/utils.R | 10 ++++++++-- tests/testthat/test_PipeOpTorchBatchNorm.R | 13 +++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/NEWS.md b/NEWS.md index c287c99b..6f0a2968 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ * fix: `LearnerTorchModel` can now be parallelized and trained with encapsulation activated. +* fix: `jit_trace` now works in combination with batch normalization. # mlr3torch 0.2.0 diff --git a/R/utils.R b/R/utils.R index 4b443759..f20bad18 100644 --- a/R/utils.R +++ b/R/utils.R @@ -238,10 +238,16 @@ CallbacksNone = function() { get_example_batch = function(dl) { ds = dl$dataset + if (length(ds) < 2) { + stopf("Dataset needs to contain at least 2 observations") + } if (!is.null(ds$.getbatch)) { - ds$.getbatch(1) + ds$.getbatch(1:2) } else { - ds$.getitem(1)$unsqueeze(1) + torch_stack(list( + ds$.getitem(1), + ds$.getitem(2) + )) } } diff --git a/tests/testthat/test_PipeOpTorchBatchNorm.R b/tests/testthat/test_PipeOpTorchBatchNorm.R index f2f5b534..038f3216 100644 --- a/tests/testthat/test_PipeOpTorchBatchNorm.R +++ b/tests/testthat/test_PipeOpTorchBatchNorm.R @@ -40,3 +40,16 @@ test_that("PipeOpTorchBatchNorm3D paramtest", { res = expect_paramset(po("nn_batch_norm3d"), nn_batch_norm3d, exclude = "num_features") expect_paramtest(res) }) + +test_that("jit_trace works (#354)", { + graph = po("torch_ingress_num") %>>% + nn("batch_norm1d") %>>% + nn("head") %>>% + po("torch_loss", t_loss("cross_entropy")) %>>% + po("torch_optimizer", t_opt("adamw")) %>>% + po("torch_model_classif", epochs = 1, batch_size = 50) + lrn = as_learner(graph) + task = tsk("iris") + lrn$train(task) + expect_prediction(lrn$predict(task)) +})