From 1c2a7490e7c85e00ed7f3fb48dd75dedfa99dcc3 Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Thu, 2 Jan 2025 02:44:12 -0600 Subject: [PATCH] docs(callbacks): add examples --- R/CallbackSetCheckpoint.R | 12 ++++++++++ R/CallbackSetHistory.R | 14 +++++++++++ R/CallbackSetProgress.R | 11 +++++++++ R/CallbackSetUnfreeze.R | 15 ++++++++++++ attic/try-CallbackSetUnfreeze.R | 37 ------------------------------ attic/try-Select.R | 27 ---------------------- man/mlr_callback_set.checkpoint.Rd | 15 ++++++++++++ man/mlr_callback_set.history.Rd | 18 +++++++++++++++ man/mlr_callback_set.progress.Rd | 14 +++++++++++ man/mlr_callback_set.unfreeze.Rd | 18 +++++++++++++++ 10 files changed, 117 insertions(+), 64 deletions(-) delete mode 100644 attic/try-CallbackSetUnfreeze.R delete mode 100644 attic/try-Select.R diff --git a/R/CallbackSetCheckpoint.R b/R/CallbackSetCheckpoint.R index e9584d60..d11ae88d 100644 --- a/R/CallbackSetCheckpoint.R +++ b/R/CallbackSetCheckpoint.R @@ -19,6 +19,18 @@ #' @family Callback #' @export #' @include CallbackSet.R +#' +#' @examplesIf torch::torch_is_installed() +#' cb = t_clbk("checkpoint", freq = 1) +#' task = tsk("iris") +#' +#' pth = tempfile() +#' learner = lrn("classif.mlp", epochs = 3, batch_size = 1, callbacks = cb) +#' learner$param_set$set_values(cb.checkpoint.path = pth) +#' +#' learner$train(task) +#' +#' list.files(pth) CallbackSetCheckpoint = R6Class("CallbackSetCheckpoint", inherit = CallbackSet, lock_objects = FALSE, diff --git a/R/CallbackSetHistory.R b/R/CallbackSetHistory.R index f9f8176a..43c81c50 100644 --- a/R/CallbackSetHistory.R +++ b/R/CallbackSetHistory.R @@ -9,6 +9,20 @@ #' #' @export #' @include CallbackSet.R +#' @examplesIf torch::torch_is_installed() +#' +#' cb = t_clbk("history") +#' task = tsk("iris") +#' +#' learner = lrn("classif.mlp", epochs = 3, batch_size = 1, +#' callbacks = t_clbk("history"), validate = 0.3) +#' learner$param_set$set_values( +#' measures_train = msrs(c("classif.acc", "classif.ce")), +#' measures_valid = msr("classif.ce") +#' ) +#' learner$train(task) +#' +#' print(learner$model$callbacks$history) CallbackSetHistory = R6Class("CallbackSetHistory", inherit = CallbackSet, lock_objects = FALSE, diff --git a/R/CallbackSetProgress.R b/R/CallbackSetProgress.R index 4e85ec5d..d6b34341 100644 --- a/R/CallbackSetProgress.R +++ b/R/CallbackSetProgress.R @@ -8,6 +8,17 @@ #' @family Callback #' @include CallbackSet.R #' @export +#' @examplesIf torch::torch_is_installed() +#' task = tsk("iris") +#' +#' learner = lrn("classif.mlp", epochs = 10, batch_size = 1, +#' callbacks = t_clbk("progress"), validate = 0.3) +#' learner$param_set$set_values( +#' measures_train = msrs(c("classif.acc", "classif.ce")), +#' measures_valid = msr("classif.ce") +#' ) +#' +#' learner$train(task) CallbackSetProgress = R6Class("CallbackSetProgress", inherit = CallbackSet, lock_objects = FALSE, diff --git a/R/CallbackSetUnfreeze.R b/R/CallbackSetUnfreeze.R index 4bb43991..d6953de3 100644 --- a/R/CallbackSetUnfreeze.R +++ b/R/CallbackSetUnfreeze.R @@ -14,6 +14,21 @@ #' @family Callback #' @export #' @include CallbackSet.R +#' @examplesIf torch::torch_is_installed() +#' task = tsk("iris") +#' cb = t_clbk("unfreeze") +#' mlp = lrn("classif.mlp", callbacks = cb, +#' cb.unfreeze.starting_weights = select_invert( +#' select_name(c("0.weight", "3.weight", "6.weight", "6.bias")) +#' ), +#' cb.unfreeze.unfreeze = data.table( +#' epoch = c(2, 5), +#' weights = list(select_name("0.weight"), select_name(c("3.weight", "6.weight"))) +#' ), +#' epochs = 6, batch_size = 150, neurons = c(1, 1, 1) +#' ) +#' +#' mlp$train(task) CallbackSetUnfreeze = R6Class("CallbackSetUnfreeze", inherit = CallbackSet, lock_objects = FALSE, diff --git a/attic/try-CallbackSetUnfreeze.R b/attic/try-CallbackSetUnfreeze.R deleted file mode 100644 index 8fef25ab..00000000 --- a/attic/try-CallbackSetUnfreeze.R +++ /dev/null @@ -1,37 +0,0 @@ -devtools::load_all() - -task = tsk("iris") - -mlp = lrn("classif.mlp", - epochs = 10, batch_size = 150, neurons = c(100, 200, 300) -) - -# sela = selector_all() -# sela(mlp$network$modules) - -mlp$train(task) - -# do this for each element in the parameters list -mlp$model$network$modules[["9"]]$parameters[[1]]$requires_grad_(TRUE) -mlp$model$network$modules[["9"]]$parameters[[2]]$requires_grad_(TRUE) - - -# construct a NN as a graph -module_1 = nn_linear(in_features = 3, out_features = 4, bias = TRUE) -activation = nn_sigmoid() -module_2 = nn_linear(4, 3, bias = TRUE) -softmax = nn_softmax(2) - -po_module_1 = po("module_1", module = module_1) -po_activation = po("module", id = "activation", activation) -po_module_2 = po("module_2", module = module_2) -po_softmax = po("module", id = "softmax", module = softmax) - -module_graph = po_module_1 %>>% - po_activation %>>% - po_module_2 %>>% - po_softmax - -module_graph$plot(html = TRUE) - -module_graph \ No newline at end of file diff --git a/attic/try-Select.R b/attic/try-Select.R deleted file mode 100644 index 81ba380b..00000000 --- a/attic/try-Select.R +++ /dev/null @@ -1,27 +0,0 @@ - -n_epochs = 10 - -task = tsk("iris") - -mlp = lrn("classif.mlp", - epochs = 10, batch_size = 150, neurons = c(100, 200, 300) -) -mlp$train(task) - -names(mlp$network$parameters) - -sela = select_all() -sela(names(mlp$network$parameters)) - -selg = select_grep("weight") -selg(names(mlp$network$parameters)) - -seln = select_name("0.weight") -seln(names(mlp$network$parameters)) - -seli = select_invert(select_name("0.weight")) -seli(names(mlp$network$parameters)) - -seln = select_none() -seln(names(mlp$network$parameters)) - diff --git a/man/mlr_callback_set.checkpoint.Rd b/man/mlr_callback_set.checkpoint.Rd index b953213f..da294675 100644 --- a/man/mlr_callback_set.checkpoint.Rd +++ b/man/mlr_callback_set.checkpoint.Rd @@ -12,6 +12,21 @@ The final network and optimizer are always stored. Saving the learner itself in the callback with a trained model is impossible, as the model slot is set \emph{after} the last callback step is executed. } +\examples{ +\dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +cb = t_clbk("checkpoint", freq = 1) +task = tsk("iris") +task$row_roles$use = 1 + +pth = tempfile() +learner = lrn("classif.mlp", epochs = 3, batch_size = 1, callbacks = cb) +learner$param_set$set_values(cb.checkpoint.path = pth) + +learner$train(task) + +list.files(pth) +\dontshow{\}) # examplesIf} +} \seealso{ Other Callback: \code{\link{TorchCallback}}, diff --git a/man/mlr_callback_set.history.Rd b/man/mlr_callback_set.history.Rd index ac9b71b6..c517bd41 100644 --- a/man/mlr_callback_set.history.Rd +++ b/man/mlr_callback_set.history.Rd @@ -9,6 +9,24 @@ Saves the training and validation history during training. The history is saved as a data.table in the \verb{$train} and \verb{$valid} slots. The first column is always \code{epoch}. } +\examples{ +\dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} + +cb = t_clbk("history") +task = tsk("iris") +task$filter(1:10) + +learner = lrn("classif.mlp", epochs = 3, batch_size = 1, + callbacks = t_clbk("history"), validate = 0.1) +learner$param_set$set_values( + measures_train = msrs(c("classif.acc", "classif.ce")), + measures_valid = msr("classif.ce") +) +learner$train(task) + +print(learner$model$callbacks$history) +\dontshow{\}) # examplesIf} +} \section{Super class}{ \code{\link[mlr3torch:CallbackSet]{mlr3torch::CallbackSet}} -> \code{CallbackSetHistory} } diff --git a/man/mlr_callback_set.progress.Rd b/man/mlr_callback_set.progress.Rd index b1552df8..51ec9b48 100644 --- a/man/mlr_callback_set.progress.Rd +++ b/man/mlr_callback_set.progress.Rd @@ -7,6 +7,20 @@ \description{ Prints a progress bar and the metrics for training and validation. } +\examples{ +\dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +task = tsk("iris") + +learner = lrn("classif.mlp", epochs = 10, batch_size = 1, + callbacks = t_clbk("progress"), validate = 0.1) +learner$param_set$set_values( +measures_train = msrs(c("classif.acc", "classif.ce")), + measures_valid = msr("classif.ce") +) + +learner$train(task) +\dontshow{\}) # examplesIf} +} \seealso{ Other Callback: \code{\link{TorchCallback}}, diff --git a/man/mlr_callback_set.unfreeze.Rd b/man/mlr_callback_set.unfreeze.Rd index 11681c00..83ffd7f8 100644 --- a/man/mlr_callback_set.unfreeze.Rd +++ b/man/mlr_callback_set.unfreeze.Rd @@ -7,6 +7,24 @@ \description{ Unfreeze some weights (parameters of the network) after some number of steps or epochs. } +\examples{ +\dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +task = tsk("iris") +cb = t_clbk("unfreeze") +mlp = lrn("classif.mlp", callbacks = cb, + cb.unfreeze.starting_weights = select_invert( + select_name(c("0.weight", "3.weight", "6.weight", "6.bias")) + ), + cb.unfreeze.unfreeze = data.table( + epoch = c(2, 5), + weights = list(select_name("0.weight"), select_name(c("3.weight", "6.weight"))) + ), + epochs = 6, batch_size = 150, neurons = c(1, 1, 1) +) + +mlp$train(task) +\dontshow{\}) # examplesIf} +} \seealso{ Other Callback: \code{\link{TorchCallback}},