Skip to content

Commit

Permalink
docs(callbacks): add examples
Browse files Browse the repository at this point in the history
  • Loading branch information
cxzhang4 authored Jan 2, 2025
1 parent e62d188 commit 1c2a749
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 64 deletions.
12 changes: 12 additions & 0 deletions R/CallbackSetCheckpoint.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@
#' @family Callback
#' @export
#' @include CallbackSet.R
#'
#' @examplesIf torch::torch_is_installed()
#' cb = t_clbk("checkpoint", freq = 1)
#' task = tsk("iris")
#'
#' pth = tempfile()
#' learner = lrn("classif.mlp", epochs = 3, batch_size = 1, callbacks = cb)
#' learner$param_set$set_values(cb.checkpoint.path = pth)
#'
#' learner$train(task)
#'
#' list.files(pth)
CallbackSetCheckpoint = R6Class("CallbackSetCheckpoint",
inherit = CallbackSet,
lock_objects = FALSE,
Expand Down
14 changes: 14 additions & 0 deletions R/CallbackSetHistory.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@
#'
#' @export
#' @include CallbackSet.R
#' @examplesIf torch::torch_is_installed()
#'
#' cb = t_clbk("history")
#' task = tsk("iris")
#'
#' learner = lrn("classif.mlp", epochs = 3, batch_size = 1,
#' callbacks = t_clbk("history"), validate = 0.3)
#' learner$param_set$set_values(
#' measures_train = msrs(c("classif.acc", "classif.ce")),
#' measures_valid = msr("classif.ce")
#' )
#' learner$train(task)
#'
#' print(learner$model$callbacks$history)
CallbackSetHistory = R6Class("CallbackSetHistory",
inherit = CallbackSet,
lock_objects = FALSE,
Expand Down
11 changes: 11 additions & 0 deletions R/CallbackSetProgress.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
#' @family Callback
#' @include CallbackSet.R
#' @export
#' @examplesIf torch::torch_is_installed()
#' task = tsk("iris")
#'
#' learner = lrn("classif.mlp", epochs = 10, batch_size = 1,
#' callbacks = t_clbk("progress"), validate = 0.3)
#' learner$param_set$set_values(
#' measures_train = msrs(c("classif.acc", "classif.ce")),
#' measures_valid = msr("classif.ce")
#' )
#'
#' learner$train(task)
CallbackSetProgress = R6Class("CallbackSetProgress",
inherit = CallbackSet,
lock_objects = FALSE,
Expand Down
15 changes: 15 additions & 0 deletions R/CallbackSetUnfreeze.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@
#' @family Callback
#' @export
#' @include CallbackSet.R
#' @examplesIf torch::torch_is_installed()
#' task = tsk("iris")
#' cb = t_clbk("unfreeze")
#' mlp = lrn("classif.mlp", callbacks = cb,
#' cb.unfreeze.starting_weights = select_invert(
#' select_name(c("0.weight", "3.weight", "6.weight", "6.bias"))
#' ),
#' cb.unfreeze.unfreeze = data.table(
#' epoch = c(2, 5),
#' weights = list(select_name("0.weight"), select_name(c("3.weight", "6.weight")))
#' ),
#' epochs = 6, batch_size = 150, neurons = c(1, 1, 1)
#' )
#'
#' mlp$train(task)
CallbackSetUnfreeze = R6Class("CallbackSetUnfreeze",
inherit = CallbackSet,
lock_objects = FALSE,
Expand Down
37 changes: 0 additions & 37 deletions attic/try-CallbackSetUnfreeze.R

This file was deleted.

27 changes: 0 additions & 27 deletions attic/try-Select.R

This file was deleted.

15 changes: 15 additions & 0 deletions man/mlr_callback_set.checkpoint.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/mlr_callback_set.history.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions man/mlr_callback_set.progress.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/mlr_callback_set.unfreeze.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1c2a749

Please sign in to comment.