Skip to content

Commit

Permalink
Transfer Learning (#109)
Browse files Browse the repository at this point in the history
* Transfer learning first prototype working

* linting

* add unit tests for transfer learning
---------

Co-authored-by: egillax <egillax@gmail.com>
  • Loading branch information
lhjohn and egillax authored Feb 22, 2024
1 parent d2469f6 commit dcfc7fa
Show file tree
Hide file tree
Showing 24 changed files with 357 additions and 186 deletions.
18 changes: 9 additions & 9 deletions .github/workflows/R_CDM_check_hades.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,24 +95,24 @@ jobs:
error-on: '"warning"'
check-dir: '"check"'

- name: Upload source package
if: success() && runner.os == 'macOS' && github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
uses: actions/upload-artifact@v2
with:
name: package_tarball
path: check/*.tar.gz

- name: Install covr
if: runner.os == 'Windows'
if: runner.os == 'ubuntu-22.04'
run: |
remotes::install_cran("covr")
shell: Rscript {0}

- name: Test coverage
if: runner.os == 'Windows'
if: runner.os == 'ubuntu-22.04'
run: covr::codecov()
shell: Rscript {0}

- name: Upload source package
if: success() && runner.os == 'macOS' && github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
uses: actions/upload-artifact@v2
with:
name: package_tarball
path: check/*.tar.gz

Release:
needs: R-CMD-Check

Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export(predictDeepEstimator)
export(setDefaultResNet)
export(setDefaultTransformer)
export(setEstimator)
export(setFinetuner)
export(setMultiLayerPerceptron)
export(setResNet)
export(setTransformer)
Expand Down
14 changes: 9 additions & 5 deletions R/Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ createDataset <- function(data, labels, plpModel = NULL) {
# sqlite object
attributes(data)$path <- attributes(data)$dbname
}
if (is.null(plpModel)) {
data <- dataset(
r_to_py(normalizePath(attributes(data)$path)),
r_to_py(labels$outcomeCount)
)
if (is.null(plpModel) && is.null(data$numericalIndex)) {
data <- dataset(r_to_py(normalizePath(attributes(data)$path)),
r_to_py(labels$outcomeCount))
} else if (!is.null(data$numericalIndex)) {
numericalIndex <-
r_to_py(as.array(data$numericalIndex %>% dplyr::pull()))
data <- dataset(r_to_py(normalizePath(attributes(data)$path)),
r_to_py(labels$outcomeCount),
numericalIndex)
} else {
numericalFeatures <-
r_to_py(as.array(which(plpModel$covariateImportance$isNumeric)))
Expand Down
111 changes: 59 additions & 52 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,26 @@ fitEstimator <- function(trainData,
if (!is.null(trainData$folds)) {
trainData$labels <- merge(trainData$labels, trainData$fold, by = "rowId")
}
mappedCovariateData <- PatientLevelPrediction::MapIds(
covariateData = trainData$covariateData,
cohort = trainData$labels
)

if (modelSettings$modelType == "Finetuner") {
# make sure to use same mapping from covariateIds to columns if finetuning
path <- modelSettings$param[[1]]$modelPath
oldCovImportance <- utils::read.csv(file.path(path,
"covariateImportance.csv"))
mapping <- oldCovImportance %>% dplyr::select("columnId", "covariateId")
numericalIndex <- which(oldCovImportance %>% dplyr::pull("isNumeric"))
mappedCovariateData <- PatientLevelPrediction::MapIds(
covariateData = trainData$covariateData,
cohort = trainData$labels,
mapping = mapping
)
mappedCovariateData$numericalIndex <- as.data.frame(numericalIndex)
} else {
mappedCovariateData <- PatientLevelPrediction::MapIds(
covariateData = trainData$covariateData,
cohort = trainData$labels
)
}

covariateRef <- mappedCovariateData$covariateRef

Expand Down Expand Up @@ -281,19 +297,13 @@ predictDeepEstimator <- function(plpModel,
# get predictions
prediction <- cohort
if (is.character(plpModel$model)) {
modelSettings <- plpModel$modelDesign$modelSettings
model <- torch$load(
file.path(
plpModel$model,
"DeepEstimatorModel.pt"
),
map_location = "cpu"
)
estimator <- createEstimator(
modelType = modelSettings$modelType,
modelParameters = model$model_parameters,
estimatorSettings = model$estimator_settings
)
model <- torch$load(file.path(plpModel$model,
"DeepEstimatorModel.pt"), map_location = "cpu")
estimator <-
createEstimator(modelParameters =
snakeCaseToCamelCaseNames(model$model_parameters),
estimatorSettings =
snakeCaseToCamelCaseNames(model$estimator_settings))
estimator$model$load_state_dict(model$model_state_dict)
prediction$value <- estimator$predict_proba(data)
} else {
Expand Down Expand Up @@ -323,11 +333,9 @@ gridCvDeep <- function(mappedData,
modelSettings,
modelLocation,
analysisPath) {
ParallelLogger::logInfo(paste0(
"Running hyperparameter search for ",
modelSettings$modelType, " model"
))

ParallelLogger::logInfo(paste0("Running hyperparameter search for ",
modelSettings$modelType,
" model"))
###########################################################################

paramSearch <- modelSettings$param
Expand Down Expand Up @@ -362,23 +370,19 @@ gridCvDeep <- function(mappedData,
collapse = " | "
))
currentModelParams <- paramSearch[[gridId]][modelSettings$modelParamNames]

attr(currentModelParams, "metaData")$names <-
modelSettings$modelParamNames
currentModelParams$modelType <- modelSettings$modelType
currentEstimatorSettings <-
fillEstimatorSettings(
modelSettings$estimatorSettings,
fitParams,
paramSearch[[gridId]]
)
currentEstimatorSettings$modelType <- modelSettings$modelType
fillEstimatorSettings(modelSettings$estimatorSettings,
fitParams,
paramSearch[[gridId]])
currentModelParams$catFeatures <- dataset$get_cat_features()$max()
currentModelParams$numFeatures <-
dataset$get_numerical_features()$max()
dataset$get_numerical_features()$len()
if (findLR) {
lrFinder <- createLRFinder(
modelType = modelSettings$modelType,
modelParameters = currentModelParams,
estimatorSettings = currentEstimatorSettings
)
lrFinder <- createLRFinder(modelParameters = currentModelParams,
estimatorSettings = currentEstimatorSettings)
lr <- lrFinder$get_lr(dataset)
ParallelLogger::logInfo(paste0("Auto learning rate selected as: ", lr))
currentEstimatorSettings$learningRate <- lr
Expand Down Expand Up @@ -457,19 +461,16 @@ gridCvDeep <- function(mappedData,

modelParams$catFeatures <- dataset$get_cat_features()$max()
modelParams$numFeatures <- dataset$get_numerical_features()$len()
modelParams$modelType <- modelSettings$modelType

estimatorSettings <- fillEstimatorSettings(
modelSettings$estimatorSettings,
fitParams,
finalParam
)
estimatorSettings$learningRate <- finalParam$learnSchedule$LRs[[1]]
estimator <- createEstimator(
modelType = modelSettings$modelType,
modelParameters = modelParams,
estimatorSettings = estimatorSettings
)

estimator <- createEstimator(modelParameters = modelParams,
estimatorSettings = estimatorSettings)
numericalIndex <- dataset$get_numerical_features()
estimator$fit_whole_training_set(dataset, finalParam$learnSchedule$LRs)

Expand Down Expand Up @@ -534,12 +535,22 @@ evalEstimatorSettings <- function(estimatorSettings) {
estimatorSettings
}

createEstimator <- function(modelType,
modelParameters,
createEstimator <- function(modelParameters,
estimatorSettings) {
path <- system.file("python", package = "DeepPatientLevelPrediction")

model <- reticulate::import_from_path(modelType, path = path)[[modelType]]
if (modelParameters$modelType == "Finetuner") {
estimatorSettings$finetune <- TRUE
plpModel <- PatientLevelPrediction::loadPlpModel(modelParameters$modelPath)
estimatorSettings$finetuneModelPath <-
file.path(normalizePath(plpModel$model), "DeepEstimatorModel.pt")
modelParameters$modelType <-
plpModel$modelDesign$modelSettings$modelType
}

model <-
reticulate::import_from_path(modelParameters$modelType,
path = path)[[modelParameters$modelType]]
estimator <- reticulate::import_from_path("Estimator", path = path)$Estimator

modelParameters <- camelCaseToSnakeCaseNames(modelParameters)
Expand Down Expand Up @@ -573,14 +584,10 @@ doCrossvalidation <- function(dataset,

# -1 for python 0-based indexing
testDataset <- torch$utils$data$Subset(dataset,
indices =
as.integer(which(fold == i) - 1)
)
estimator <- createEstimator(
modelType = estimatorSettings$modelType,
modelParameters = modelSettings,
estimatorSettings = estimatorSettings
)
indices =
as.integer(which(fold == i) - 1))
estimator <- createEstimator(modelParameters = modelSettings,
estimatorSettings = estimatorSettings)
estimator$fit(trainDataset, testDataset)

ParallelLogger::logInfo("Calculating predictions on left out fold set...")
Expand Down
27 changes: 27 additions & 0 deletions R/HelperFunctions.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,33 @@ camelCaseToSnakeCase <- function(string) {
return(string)
}

#' Convert a camel case string to snake case
#'
#' @param string The string to be converted
#'
#' @return
#' A string
#'
snakeCaseToCamelCase <- function(string) {
string <- tolower(string)
for (letter in letters) {
string <- gsub(paste("_", letter, sep = ""), toupper(letter), string)
}
string <- gsub("_([0-9])", "\\1", string)
return(string)
}

#' Convert the names of an object from snake case to camel case
#'
#' @param object The object of which the names should be converted
#'
#' @return
#' The same object, but with converted names.
snakeCaseToCamelCaseNames <- function(object) {
names(object) <- snakeCaseToCamelCase(names(object))
return(object)
}

#' Convert the names of an object from camel case to snake case
#'
#' @param object The object of which the names should be converted
Expand Down
23 changes: 5 additions & 18 deletions R/LRFinder.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
createLRFinder <- function(modelType,
modelParameters,
createLRFinder <- function(modelParameters,
estimatorSettings,
lrSettings = NULL) {
path <- system.file("python", package = "DeepPatientLevelPrediction")
lrFinderClass <-
reticulate::import_from_path("LrFinder", path = path)$LrFinder



model <- reticulate::import_from_path(modelType, path = path)[[modelType]]
modelParameters <- camelCaseToSnakeCaseNames(modelParameters)
estimatorSettings <- camelCaseToSnakeCaseNames(estimatorSettings)
estimatorSettings <- evalEstimatorSettings(estimatorSettings)

estimator <- createEstimator(modelParameters = modelParameters,
estimatorSettings = estimatorSettings)
if (!is.null(lrSettings)) {
lrSettings <- camelCaseToSnakeCaseNames(lrSettings)
}

lrFinder <- lrFinderClass(
model = model,
model_parameters = modelParameters,
estimator_settings = estimatorSettings,
lr_settings = lrSettings
)

lrFinder <- lrFinderClass(estimator = estimator,
lr_settings = lrSettings)
return(lrFinder)
}
8 changes: 4 additions & 4 deletions R/MLP.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,19 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8),
)]
}))
}
attr(param, "settings")$modelType <- "MLP"

results <- list(
fitFunction = "DeepPatientLevelPrediction::fitEstimator",
param = param,
estimatorSettings = estimatorSettings,
modelType = "MLP",
saveType = "file",
modelParamNames = c(
"numLayers", "sizeHidden",
"dropout", "sizeEmbedding"
)
),
modelType = "MLP"
)
attr(results$param, "settings")$modelType <- results$modelType


class(results) <- "modelSettings"

Expand Down
10 changes: 4 additions & 6 deletions R/ResNet.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,16 @@ setResNet <- function(numLayers = c(1:8),
)]
}))
}
attr(param, "settings")$modelType <- "ResNet"
results <- list(
fitFunction = "DeepPatientLevelPrediction::fitEstimator",
param = param,
estimatorSettings = estimatorSettings,
modelType = "ResNet",
saveType = "file",
modelParamNames = c(
"numLayers", "sizeHidden", "hiddenFactor",
"residualDropout", "hiddenDropout", "sizeEmbedding"
)
modelParamNames = c("numLayers", "sizeHidden", "hiddenFactor",
"residualDropout", "hiddenDropout", "sizeEmbedding"),
modelType = "ResNet"
)
attr(results$param, "settings")$modelType <- results$modelType

class(results) <- "modelSettings"

Expand Down
Loading

0 comments on commit dcfc7fa

Please sign in to comment.