From 6d3e65fe4c70c19ddf3a03bb48d9bfede4795170 Mon Sep 17 00:00:00 2001 From: egillax Date: Fri, 22 Dec 2023 16:46:57 +0100 Subject: [PATCH] fitEstimator can be called from PLP --- R/MLP.R | 2 +- R/ResNet.R | 2 +- R/Transformer.R | 2 +- tests/testthat/test-MLP.R | 2 +- tests/testthat/test-ResNet.R | 2 +- tests/testthat/test-Transformer.R | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/R/MLP.R b/R/MLP.R index 771e244..cc309c8 100644 --- a/R/MLP.R +++ b/R/MLP.R @@ -96,7 +96,7 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8), attr(param, "settings")$modelType <- "MLP" results <- list( - fitFunction = "fitEstimator", + fitFunction = "DeepPatientLevelPrediction::fitEstimator", param = param, estimatorSettings = estimatorSettings, modelType = "MLP", diff --git a/R/ResNet.R b/R/ResNet.R index 32e7f3a..c12c6bb 100644 --- a/R/ResNet.R +++ b/R/ResNet.R @@ -139,7 +139,7 @@ setResNet <- function(numLayers = c(1:8), } attr(param, "settings")$modelType <- "ResNet" results <- list( - fitFunction = "fitEstimator", + fitFunction = "DeepPatientLevelPrediction::fitEstimator", param = param, estimatorSettings = estimatorSettings, modelType = "ResNet", diff --git a/R/Transformer.R b/R/Transformer.R index cd4a968..6992d47 100644 --- a/R/Transformer.R +++ b/R/Transformer.R @@ -182,7 +182,7 @@ setTransformer <- function(numBlocks = 3, } attr(param, "settings")$modelType <- "Transformer" results <- list( - fitFunction = "fitEstimator", + fitFunction = "DeepPatientLevelPrediction::fitEstimator", param = param, estimatorSettings = estimatorSettings, modelType = "Transformer", diff --git a/tests/testthat/test-MLP.R b/tests/testthat/test-MLP.R index a470664..bdcf67a 100644 --- a/tests/testthat/test-MLP.R +++ b/tests/testthat/test-MLP.R @@ -18,7 +18,7 @@ modelSettings <- setMultiLayerPerceptron( test_that("setMultiLayerPerceptron works", { testthat::expect_s3_class(object = modelSettings, class = "modelSettings") - testthat::expect_equal(modelSettings$fitFunction, "fitEstimator") + testthat::expect_equal(modelSettings$fitFunction, "DeepPatientLevelPrediction::fitEstimator") testthat::expect_true(length(modelSettings$param) > 0) diff --git a/tests/testthat/test-ResNet.R b/tests/testthat/test-ResNet.R index 76e3d9f..aa88973 100644 --- a/tests/testthat/test-ResNet.R +++ b/tests/testthat/test-ResNet.R @@ -18,7 +18,7 @@ resSet <- setResNet( test_that("setResNet works", { testthat::expect_s3_class(object = resSet, class = "modelSettings") - testthat::expect_equal(resSet$fitFunction, "fitEstimator") + testthat::expect_equal(resSet$fitFunction, "DeepPatientLevelPrediction::fitEstimator") testthat::expect_true(length(resSet$param) > 0) diff --git a/tests/testthat/test-Transformer.R b/tests/testthat/test-Transformer.R index 62fdf12..8526e61 100644 --- a/tests/testthat/test-Transformer.R +++ b/tests/testthat/test-Transformer.R @@ -15,7 +15,7 @@ settings <- setTransformer( test_that("Transformer settings work", { testthat::expect_s3_class(object = settings, class = "modelSettings") - testthat::expect_equal(settings$fitFunction, "fitEstimator") + testthat::expect_equal(settings$fitFunction, "DeepPatientLevelPrediction::fitEstimator") testthat::expect_true(length(settings$param) > 0) testthat::expect_error(setTransformer( numBlocks = 1, dimToken = 50,