Skip to content

Commit

Permalink
accumulation in estimator and LRFinder (#113)
Browse files Browse the repository at this point in the history
* gradient accumulation
  • Loading branch information
egillax authored Apr 29, 2024
1 parent a438ca7 commit f343e72
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 64 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/R_CDM_check_hades.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
while read -r cmd
do
eval sudo $cmd
done < <(Rscript -e 'writeLines(remotes::system_requirements("ubuntu", "20.04"))')
done < <(Rscript -e 'writeLines(remotes::system_requirements("ubuntu", "22.04"))')
- uses: r-lib/actions/setup-r-dependencies@v2
with:
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Remotes:
ohdsi/FeatureExtraction,
ohdsi/Eunomia,
ohdsi/ResultModelManager
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Encoding: UTF-8
Config/testthat/edition: 3
Config/reticulate:
Expand Down
3 changes: 1 addition & 2 deletions R/DeepPatientLevelPrediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
#' @description A package containing deep learning extensions for developing
#' prediction models using data in the OMOP CDM
#'
#' @docType package
#' @name DeepPatientLevelPrediction
#' @importFrom dplyr %>%
#' @importFrom reticulate r_to_py py_to_r
#' @importFrom rlang .data
NULL
"_PACKAGE"

.onLoad <- function(libname, pkgname) {
# use superassignment to update global reference
Expand Down
12 changes: 12 additions & 0 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
#' `name`,
#' `fun` needs to be a function that takes in prediction and labels and
#' outputs a score.
#' @param accumulationSteps how many steps to accumulate gradients before
#' updating weights
#' @param seed seed to initialize weights of model with
#' @export
setEstimator <- function(
Expand All @@ -59,6 +61,7 @@ setEstimator <- function(
params = list(patience = 4)
),
metric = "auc",
accumulationSteps = NULL,
seed = NULL) {
checkIsClass(learningRate, c("numeric", "character"))
if (inherits(learningRate, "character") && learningRate != "auto") {
Expand All @@ -74,6 +77,14 @@ setEstimator <- function(
checkIsClass(earlyStopping, c("list", "NULL"))
checkIsClass(metric, c("character", "list"))
checkIsClass(seed, c("numeric", "integer", "NULL"))

if (!is.null(accumulationSteps)) {
checkHigher(accumulationSteps, 0)
checkIsClass(accumulationSteps, c("numeric", "integer"))
if (batchSize %% accumulationSteps != 0) {
stop("Batch size should be divisible by accumulation steps")
}
}


if (length(learningRate) == 1 && learningRate == "auto") {
Expand All @@ -93,6 +104,7 @@ setEstimator <- function(
earlyStopping = earlyStopping,
findLR = findLR,
metric = metric,
accumulationSteps = accumulationSteps,
seed = seed[1]
)

Expand Down
2 changes: 1 addition & 1 deletion inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __getitem__(self, item):
batch = {"cat": self.cat[item, :], "num": self.num[item, :]}
else:
batch = {"cat": self.cat[item, :].squeeze(), "num": None}
if batch["cat"].dim() == 1 and not isinstance(item, list):
if batch["cat"].dim() == 1:
batch["cat"] = batch["cat"].unsqueeze(0)
if (
batch["num"] is not None
Expand Down
72 changes: 54 additions & 18 deletions inst/python/Estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def __init__(self, model, model_parameters, estimator_settings):
self.device = estimator_settings["device"]()
else:
self.device = estimator_settings["device"]

torch.manual_seed(seed=self.seed)
if "finetune" in estimator_settings.keys() and estimator_settings["finetune"]:
path = estimator_settings["finetune_model_path"]
Expand All @@ -41,7 +40,14 @@ def __init__(self, model, model_parameters, estimator_settings):
self.weight_decay = estimator_settings.get("weight_decay", 1e-5)
self.batch_size = int(estimator_settings.get("batch_size", 1024))
self.prefix = estimator_settings.get("prefix", self.model.name)


if estimator_settings["accumulation_steps"]:
self.accumulation_steps = int(estimator_settings["accumulation_steps"])
self.sub_batch_size = self.batch_size // self.accumulation_steps
else:
self.accumulation_steps = 1
self.sub_batch_size = self.batch_size

self.previous_epochs = int(estimator_settings.get("previous_epochs", 0))
self.model.to(device=self.device)

Expand All @@ -50,7 +56,7 @@ def __init__(self, model, model_parameters, estimator_settings):
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
self.criterion = estimator_settings["criterion"]()
self.criterion = estimator_settings["criterion"](reduction="sum")

if (
"metric" in estimator_settings.keys()
Expand Down Expand Up @@ -163,15 +169,22 @@ def fit_epoch(self, dataloader):
training_losses = torch.empty(len(dataloader))
self.model.train()
index = 0
self.optimizer.zero_grad()
for batch in tqdm(dataloader):
self.optimizer.zero_grad()
batch = batch_to_device(batch, device=self.device)
out = self.model(batch[0])
loss = self.criterion(out, batch[1])
loss.backward()

split_batch = self.split_batch(batch)
accumulated_loss = 0
all_out = []
for sub_batch in split_batch:
sub_batch = batch_to_device(sub_batch, device=self.device)
out = self.model(sub_batch[0])
all_out.append(out.detach())
loss = self.criterion(out.squeeze(), sub_batch[1])
loss.backward()
accumulated_loss += loss.detach()

self.optimizer.step()
training_losses[index] = loss.detach()
self.optimizer.zero_grad()
training_losses[index] = accumulated_loss / self.batch_size
index += 1
return training_losses.mean().item()

Expand All @@ -183,11 +196,16 @@ def score(self, dataloader):
self.model.eval()
index = 0
for batch in tqdm(dataloader):
batch = batch_to_device(batch, device=self.device)
pred = self.model(batch[0])
predictions.append(pred)
targets.append(batch[1])
loss[index] = self.criterion(pred, batch[1])
split_batch = self.split_batch(batch)
accumulated_loss = 0
for sub_batch in split_batch:
sub_batch = batch_to_device(sub_batch, device=self.device)
pred = self.model(sub_batch[0])
predictions.append(pred)
targets.append(sub_batch[1])
accumulated_loss += self.criterion(pred.squeeze(), sub_batch[1]).detach()
loss[index] = accumulated_loss / self.batch_size

index += 1
mean_loss = loss.mean().item()
predictions = torch.concat(predictions)
Expand Down Expand Up @@ -260,6 +278,22 @@ def print_progress(self, scores, training_loss, delta_time, current_epoch):
)
return

def split_batch(self, batch):
if self.accumulation_steps > 1 and len(batch[0]["cat"]) > self.sub_batch_size:
data, labels = batch
split_data = {key: list(torch.split(value, self.sub_batch_size))
for key, value in data.items() if value is not None}
split_labels = list(torch.split(labels, self.sub_batch_size))

sub_batches = []
for i in range(self.accumulation_steps):
sub_batch = {key: value[i] for key, value in split_data.items()}
sub_batch = [sub_batch, split_labels[i]]
sub_batches.append(sub_batch)
else:
sub_batches = [batch]
return sub_batches

def fit_whole_training_set(self, dataset, learning_rates=None):
dataloader = DataLoader(
dataset=dataset,
Expand Down Expand Up @@ -308,9 +342,11 @@ def predict_proba(self, dataset):
predictions = list()
self.model.eval()
for batch in tqdm(dataloader):
batch = batch_to_device(batch, device=self.device)
pred = self.model(batch[0])
predictions.append(torch.sigmoid(pred))
split_batch = self.split_batch(batch)
for sub_batch in split_batch:
sub_batch = batch_to_device(sub_batch, device=self.device)
pred = self.model(sub_batch[0])
predictions.append(torch.sigmoid(pred))
predictions = torch.concat(predictions).cpu().numpy()
return predictions

Expand Down
62 changes: 33 additions & 29 deletions inst/python/LrFinder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os import walk
import random

import torch
Expand All @@ -20,26 +21,27 @@ def get_lr(self):

class LrFinder:
def __init__(self, estimator, lr_settings=None):
self.first_batch = None
if lr_settings is None:
lr_settings = {}
min_lr = lr_settings.get("min_lr", 1e-7)
max_lr = lr_settings.get("max_lr", 1)
num_lr = lr_settings.get("num_lr", 100)
smooth = lr_settings.get("smooth", 0.05)
divergence_threshold = lr_settings.get("divergence_threshold", 4)
self.min_lr = lr_settings.get("min_lr", 1e-7)
self.max_lr = lr_settings.get("max_lr", 1)
self.num_lr = lr_settings.get("num_lr", 100)
self.smooth = lr_settings.get("smooth", 0.05)
self.divergence_threshold = lr_settings.get("divergence_threshold", 4)
torch.manual_seed(seed=estimator.seed)
self.seed = estimator.seed

self.min_lr = min_lr
self.max_lr = max_lr
self.num_lr = num_lr
self.smooth = smooth
self.divergence_threshold = divergence_threshold
for group in estimator.optimizer.param_groups:
group['lr'] = self.min_lr

estimator.scheduler = ExponentialSchedulerPerBatch(
estimator.optimizer, self.max_lr, self.num_lr
)

if estimator.accumulation_steps > 1:
self.accumulation_steps = estimator.accumulation_steps
else:
self.accumulation_steps = 1
self.estimator = estimator
self.losses = None
self.loss_index = None
Expand All @@ -51,31 +53,33 @@ def get_lr(self, dataset):
random.seed(self.seed)
losses = torch.empty(size=(self.num_lr,), dtype=torch.float)
lrs = torch.empty(size=(self.num_lr,), dtype=torch.float)
self.estimator.optimizer.zero_grad()
best_loss = float("inf")
for i in tqdm(range(self.num_lr)):
self.estimator.optimizer.zero_grad()
loss_value = 0
random_batch = random.sample(batch_index, self.estimator.batch_size)
batch = dataset[random_batch]
batch = batch_to_device(batch, self.estimator.device)

out = self.estimator.model(batch[0])
loss = self.estimator.criterion(out, batch[1])
for j in range(self.accumulation_steps):
batch = dataset[random_batch[j * self.estimator.sub_batch_size : (j + 1) * self.estimator.sub_batch_size]]
batch = batch_to_device(batch, self.estimator.device)

out = self.estimator.model(batch[0])
loss = self.estimator.criterion(out, batch[1])
loss.backward()
loss_value += loss.item()
loss_value = loss_value / self.accumulation_steps
if self.smooth is not None and i != 0:
losses[i] = (
self.smooth * loss.item() + (1 - self.smooth) * losses[i - 1]
self.smooth * loss_value + (1 - self.smooth) * losses[i - 1]
)
else:
losses[i] = loss.item()
losses[i] = loss_value
lrs[i] = self.estimator.optimizer.param_groups[0]["lr"]

loss.backward()
self.estimator.optimizer.step()
self.estimator.scheduler.step()

if i == 0:
if losses[i] < best_loss:
best_loss = losses[i]
else:
if losses[i] < best_loss:
best_loss = losses[i]

if losses[i] > (self.divergence_threshold * best_loss):
print(
Expand All @@ -85,12 +89,12 @@ def get_lr(self, dataset):

# find LR where gradient is highest but before global minimum is reached
# I added -5 to make sure it is not still in the minimum
global_minimum = torch.argmin(losses)
global_minimum = torch.argmin(losses[:i])
gradient = torch.diff(losses[: (global_minimum - 5) + 1])
smallest_gradient = torch.argmin(gradient)
biggest_gradient = torch.argmax(gradient)

suggested_lr = lrs[smallest_gradient]
self.losses = losses
self.loss_index = smallest_gradient
suggested_lr = lrs[biggest_gradient]
self.losses = losses[:i]
self.loss_index = biggest_gradient
self.lrs = lrs
return suggested_lr.item()
21 changes: 21 additions & 0 deletions man/DeepPatientLevelPrediction.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions man/setEstimator.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit f343e72

Please sign in to comment.