From 04cd731ba429bce83206facc683709fe25e15c52 Mon Sep 17 00:00:00 2001 From: Henrik Date: Fri, 2 Aug 2024 09:40:12 +0200 Subject: [PATCH] Add interface for model initialization (#124) * Add interface for model initialization * Seperate finetuner from estimator * Add missing seed to test case --- R/Estimator.R | 10 ---------- R/TransferLearning.R | 14 ++++++++++++-- inst/python/Estimator.py | 17 ++++++----------- inst/python/InitStrategy.py | 24 ++++++++++++++++++++++++ tests/testthat/test-Finetuner.R | 7 +++++-- tests/testthat/test-TrainingCache.R | 4 ++-- 6 files changed, 49 insertions(+), 27 deletions(-) create mode 100644 inst/python/InitStrategy.py diff --git a/R/Estimator.R b/R/Estimator.R index ece685c..19f414e 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -468,16 +468,6 @@ evalEstimatorSettings <- function(estimatorSettings) { createEstimator <- function(modelParameters, estimatorSettings) { path <- system.file("python", package = "DeepPatientLevelPrediction") - - if (modelParameters$modelType == "Finetuner") { - estimatorSettings$finetune <- TRUE - plpModel <- PatientLevelPrediction::loadPlpModel(modelParameters$modelPath) - estimatorSettings$finetuneModelPath <- - normalizePath(file.path(plpModel$model, "DeepEstimatorModel.pt")) - modelParameters$modelType <- - plpModel$modelDesign$modelSettings$modelType - } - model <- reticulate::import_from_path(modelParameters$modelType, path = path)[[modelParameters$modelType]] diff --git a/R/TransferLearning.R b/R/TransferLearning.R index 95d1c04..e47a942 100644 --- a/R/TransferLearning.R +++ b/R/TransferLearning.R @@ -42,6 +42,16 @@ setFinetuner <- function(modelPath, modelPath)) } + plpModel <- PatientLevelPrediction::loadPlpModel(modelPath) + estimatorSettings$finetuneModelPath <- + normalizePath(file.path(plpModel$model, "DeepEstimatorModel.pt")) + modelType <- + plpModel$modelDesign$modelSettings$modelType + + path <- system.file("python", package = "DeepPatientLevelPrediction") + estimatorSettings$initStrategy <- + reticulate::import_from_path("InitStrategy", + path = path)$FinetuneInitStrategy() param <- list() param[[1]] <- list(modelPath = modelPath) @@ -52,9 +62,9 @@ setFinetuner <- function(modelPath, estimatorSettings = estimatorSettings, saveType = "file", modelParamNames = c("modelPath"), - modelType = "Finetuner" + modelType = modelType ) - attr(results$param, "settings")$modelType <- results$modelType + attr(results$param, "settings")$modelType <- "Finetuner" class(results) <- "modelSettings" diff --git a/inst/python/Estimator.py b/inst/python/Estimator.py index f0e020d..1b6ac18 100644 --- a/inst/python/Estimator.py +++ b/inst/python/Estimator.py @@ -6,7 +6,7 @@ from tqdm import tqdm from gpu_memory_cleanup import memory_cleanup - +from InitStrategy import InitStrategy, DefaultInitStrategy class Estimator: """ @@ -20,17 +20,12 @@ def __init__(self, model, model_parameters, estimator_settings): else: self.device = estimator_settings["device"] torch.manual_seed(seed=self.seed) - if "finetune" in estimator_settings.keys() and estimator_settings["finetune"]: - path = estimator_settings["finetune_model_path"] - fitted_estimator = torch.load(path, map_location="cpu") - fitted_parameters = fitted_estimator["model_parameters"] - self.model = model(**fitted_parameters) - self.model.load_state_dict(fitted_estimator["model_state_dict"]) - for param in self.model.parameters(): - param.requires_grad = False - self.model.reset_head() + + if "init_strategy" in estimator_settings: + self.model = estimator_settings["init_strategy"].initialize(model, model_parameters, estimator_settings) else: - self.model = model(**model_parameters) + self.model = DefaultInitStrategy().initialize(model, model_parameters, estimator_settings) + self.model_parameters = model_parameters self.estimator_settings = estimator_settings diff --git a/inst/python/InitStrategy.py b/inst/python/InitStrategy.py new file mode 100644 index 0000000..d723d4e --- /dev/null +++ b/inst/python/InitStrategy.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +import torch + +class InitStrategy(ABC): + @abstractmethod + def initialize(self, model, model_parameters, estimator_settings): + pass + +class DefaultInitStrategy(InitStrategy): + def initialize(self, model, model_parameters, estimator_settings): + return model(**model_parameters) + +class FinetuneInitStrategy(InitStrategy): + def initialize(self, model, model_parameters, estimator_settings): + path = estimator_settings["finetune_model_path"] + fitted_estimator = torch.load(path, map_location="cpu") + fitted_parameters = fitted_estimator["model_parameters"] + model_instance = model(**fitted_parameters) + model_instance.load_state_dict(fitted_estimator["model_state_dict"]) + for param in model_instance.parameters(): + param.requires_grad = False + model_instance.reset_head() + return model_instance + diff --git a/tests/testthat/test-Finetuner.R b/tests/testthat/test-Finetuner.R index a8f7cb3..e6489e2 100644 --- a/tests/testthat/test-Finetuner.R +++ b/tests/testthat/test-Finetuner.R @@ -6,6 +6,10 @@ fineTunerSettings <- setFinetuner( epochs = 1) ) +plpModel <- PatientLevelPrediction::loadPlpModel(file.path(fitEstimatorPath, + "plpModel")) +modelType <- plpModel$modelDesign$modelSettings$modelType + test_that("Finetuner settings work", { expect_equal(fineTunerSettings$param[[1]]$modelPath, file.path(fitEstimatorPath, "plpModel")) @@ -14,7 +18,7 @@ test_that("Finetuner settings work", { expect_equal(fineTunerSettings$estimatorSettings$epochs, 1) expect_equal(fineTunerSettings$fitFunction, "fitEstimator") expect_equal(fineTunerSettings$saveType, "file") - expect_equal(fineTunerSettings$modelType, "Finetuner") + expect_equal(fineTunerSettings$modelType, modelType) expect_equal(fineTunerSettings$modelParamNames, "modelPath") expect_equal(class(fineTunerSettings), "modelSettings") expect_equal(attr(fineTunerSettings$param, "settings")$modelType, "Finetuner") @@ -44,7 +48,6 @@ test_that("Finetuner fitEstimator works", { fineTunedModel <- torch$load(file.path(fineTunerResults$model, "DeepEstimatorModel.pt")) - expect_true(fineTunedModel$estimator_settings$finetune) expect_equal(fineTunedModel$estimator_settings$finetune_model_path, normalizePath(file.path(fitEstimatorPath, "plpModel", "model", "DeepEstimatorModel.pt"))) diff --git a/tests/testthat/test-TrainingCache.R b/tests/testthat/test-TrainingCache.R index ec5b063..5663beb 100644 --- a/tests/testthat/test-TrainingCache.R +++ b/tests/testthat/test-TrainingCache.R @@ -10,10 +10,10 @@ resNetSettings <- setResNet(numLayers = c(1, 2, 4), device = "cpu", batchSize = 64, epochs = 1, - seed = NULL), + seed = 42), hyperParamSearch = "random", randomSample = 3, - randomSampleSeed = NULL) + randomSampleSeed = 42) trainCache <- trainingCache$new(testLoc) paramSearch <- resNetSettings$param