Skip to content

Commit

Permalink
passes local check
Browse files Browse the repository at this point in the history
  • Loading branch information
cxzhang4 committed Jan 7, 2025
1 parent b1b31a6 commit a932955
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 135 deletions.
8 changes: 5 additions & 3 deletions R/TaskClassif_cifar.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#' @include aaa.R
#'
#' @description
#' The CIFAR-10 and CIFAR-100 subsets of the 80 million tiny images dataset.
#' The data is obtained from [`torchvision::cifar10_dataset()`] or [`torchvision::cifar100:dataset()`].
#' The CIFAR-10 and CIFAR-100 subsets of the 80 million tiny images dataset. TODO: explain the subsets. explain the difference.
#' The data is obtained from [`torchvision::cifar10_dataset()`] (or `torchvision::cifar100_dataset()`).
#'
#' @section Construction:
#' ```
Expand All @@ -22,7 +22,7 @@
#'
#' @references
#' `r format_bib("cifar2009")`
#' @examplesIf torch::torch_is_installed()
#' @examples
#' task_cifar10 = tsk("cifar10")
#' task_cifar100 = tsk("cifar100")
#' print(task_cifar10)
Expand Down Expand Up @@ -117,6 +117,8 @@ constructor_cifar100 = function(path) {

labels = c(d_train$y, d_test$y)
images = array(NA, dim = c(60000, 3, 32, 32))
# original data has channel dimension at the end
perm_idx = c(1, 4, 2, 3)
images[1:50000, , , ] = aperm(d_train$x, perm_idx, resize = TRUE)
images[50001:60000, , , ] = aperm(d_test$x, perm_idx, resize = TRUE)

Expand Down
1 change: 1 addition & 0 deletions R/TaskClassif_melanoma.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#'
#' @references
#' `r format_bib("melanoma2021")`
#' @examples
#' task = tsk("melanoma")
#' task
NULL
Expand Down
52 changes: 0 additions & 52 deletions attic/task_manual_construct-cifar.R

This file was deleted.

37 changes: 0 additions & 37 deletions attic/try-CallbackSetUnfreeze.R

This file was deleted.

27 changes: 0 additions & 27 deletions attic/try-Select.R

This file was deleted.

32 changes: 18 additions & 14 deletions data-raw/cifar.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ library(torchvision)
constructor_cifar10 = function(path) {
require_namespaces("torchvision")

tv_ds_train = torchvision::cifar10_dataset(root = path, train = TRUE, download = TRUE)
tv_data_train = tv_ds_train$.getitem(1:50000)
d_train = torchvision::cifar10_dataset(root = path, train = TRUE, download = TRUE)

tv_ds_test = torchvision::cifar10_dataset(root = path, train = FALSE, download = FALSE)
tv_data_test = tv_ds_test$.getitem(1:10000)
d_test = torchvision::cifar10_dataset(root = path, train = FALSE, download = FALSE)

labels = c(tv_data_train$y, tv_data_test$y)
images = array(c(tv_data_train$x, tv_data_test$x), dim = c(60000, 32, 32, 3))
labels = c(d_train$y, d_test$y)
images = array(NA, dim = c(60000, 3, 32, 32))
# original data has channel dimension at the end
perm_idx = c(1, 4, 2, 3)
images[1:50000, , , ] = aperm(d_train$x, perm_idx, resize = TRUE)
images[50001:60000, , , ] = aperm(d_test$x, perm_idx, resize = TRUE)

class_names = readLines(file.path(path, "cifar-10-batches-bin", "batches.meta.txt"))
class_names = class_names[class_names != ""]
Expand Down Expand Up @@ -47,7 +49,7 @@ cifar10_ds_generator = torch::dataset(

cifar10_ds = cifar10_ds_generator(data$images)

dd = as_data_descriptor(cifar10_ds, list(x = c(NA, 32, 32, 3)))
dd = as_data_descriptor(cifar10_ds, list(x = c(NA, 3, 32, 32)))
lt = lazy_tensor(dd)

tsk_dt = data.table(
Expand All @@ -73,14 +75,16 @@ path = file.path(get_cache_dir(), "datasets", "cifar100", "raw")
constructor_cifar100 = function(path) {
require_namespaces("torchvision")

tv_ds_train = torchvision::cifar100_dataset(root = path, train = TRUE, download = TRUE)
tv_data_train = tv_ds_train$.getitem(1:50000)
d_train = torchvision::cifar100_dataset(root = path, train = TRUE, download = TRUE)

tv_ds_test = torchvision::cifar100_dataset(root = path, train = FALSE, download = FALSE)
tv_data_test = tv_ds_test$.getitem(1:10000)
d_test = torchvision::cifar100_dataset(root = path, train = FALSE, download = FALSE)

labels = c(tv_data_train$y, tv_data_test$y)
images = array(c(tv_data_train$x, tv_data_test$x), dim = c(60000, 32, 32, 3))
labels = c(d_train$y, d_test$y)
images = array(NA, dim = c(60000, 3, 32, 32))
# original data has channel dimension at the end
perm_idx = c(1, 4, 2, 3)
images[1:50000, , , ] = aperm(d_train$x, perm_idx, resize = TRUE)
images[50001:60000, , , ] = aperm(d_test$x, perm_idx, resize = TRUE)

class_names = readLines(file.path(path, "cifar-100-binary", "fine_label_names.txt"))

Expand All @@ -107,7 +111,7 @@ cifar100_ds_generator = torch::dataset(

cifar100_ds = cifar100_ds_generator(data$images)

dd = as_data_descriptor(cifar100_ds, list(x = c(NA, 32, 32, 3)))
dd = as_data_descriptor(cifar100_ds, list(x = c(NA, 3, 32, 32)))
lt = lazy_tensor(dd)

dt = data.table(
Expand Down
4 changes: 2 additions & 2 deletions man/mlr_tasks_cifar.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a932955

Please sign in to comment.