Skip to content

Commit

Permalink
correct, but 10+ times slower than manual construction in data-raw/ci…
Browse files Browse the repository at this point in the history
…far.R
  • Loading branch information
cxzhang4 committed Jan 6, 2025
1 parent 792a48d commit 958f779
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 111 deletions.
136 changes: 40 additions & 96 deletions R/TaskClassif_cifar.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,108 +29,55 @@
#' print(task_cifar100)
NULL

# for a specific batch file
read_cifar_labels_batch = function(file_path, type = 10) {
con = file(file_path, "rb")
on.exit({close(con)}, add = TRUE)

if (type == 10) {
batch_size <- 10000
} else if (type == 100 && grepl("test", file_path)) {
batch_size <- 10000
} else {
batch_size <- 50000
}

labels = integer(length = batch_size)
if (type == 100) {
fine_labels = integer(length = batch_size)
}

for (i in 1:batch_size) {
labels[i] = readBin(con, integer(), n = 1, size = 1, endian = "big")
if (type == 100) {
fine_labels[i] = readBin(con, integer(), n = 1, size = 1, endian = "big")
}
seek(con, 32 * 32 * 3, origin = "current")
}

if (type == 100) fine_labels else labels
}

# for a specific batch file
read_cifar_image = function(file_path, i, type = 10) {
fine_label = as.integer(type == 100)
record_size = 1 + fine_label + (32 * 32 * 3)

con = file(file_path, "rb")
on.exit({close(con)}, add = TRUE)

seek(con, (i - 1) * record_size, origin = "start") # previous labels and images
seek(con, 1 + fine_label, origin = "current") # seek past the current label(s)

r = as.integer(readBin(con, raw(), size = 1, n = 1024, endian = "big"))
g = as.integer(readBin(con, raw(), size = 1, n = 1024, endian = "big"))
b = as.integer(readBin(con, raw(), size = 1, n = 1024, endian = "big"))

img = array(dim = c(32, 32, 3))
img[,,1] = matrix(r, ncol = 32, byrow = TRUE)
img[,,2] = matrix(g, ncol = 32, byrow = TRUE)
img[,,3] = matrix(b, ncol = 32, byrow = TRUE)

img
}

constructor_cifar10 = function(path) {
require_namespaces("torchvision")

torchvision::cifar10_dataset(root = path, download = TRUE)
tv_ds_train = torchvision::cifar10_dataset(root = path, download = TRUE)
tv_data_train = tv_ds_train$.getitem(1:50000)

train_files = file.path(path, "cifar-10-batches-bin", sprintf("data_batch_%d.bin", 1:5))
test_file = file.path(path, "cifar-10-batches-bin", "test_batch.bin")
tv_ds_test = torchvision::cifar10_dataset(root = path, train = FALSE, download = FALSE)
tv_data_test = tv_ds_test$.getitem(1:10000)

labels = c(tv_data_train$y, tv_data_test$y)
images = array(c(tv_data_train$x, tv_data_test$x), dim = c(60000, 32, 32, 3))

# TODO: convert these to the meaningful names
train_labels_ints = unlist(map(train_files, read_cifar_labels_batch, type = 10))
class_names = readLines(file.path(path, "cifar-10-batches-bin", "batches.meta.txt"))
class_names = class_names[class_names != ""]

data.table(
class = factor(c(train_labels_ints, rep(NA, times = 10000)), labels = class_names),
file = c(rep(train_files, each = 10000),
rep(test_file, 10000)),
idx_in_file = c(rep(1:10000, 5),
1:10000),
split = factor(rep(c("train", "test"), c(50000, 10000))),
..row_id = seq_len(60000)
)
return(list(labels = labels, images = images, class_names = class_names))
}

load_task_cifar10 = function(id = "cifar10") {
cached_constructor = function(backend) {
data = cached(constructor_cifar10, "datasets", "cifar10")$data
data <- cached(constructor_cifar10, "datasets", "cifar10")$data

cifar10_ds_generator = torch::dataset(
initialize = function() {
self$.data = data
initialize = function(images) {
self$images = images
},
.getitem = function(idx) {
force(idx)

x = torch_tensor(read_cifar_image(self$.data$file[idx], self$.data$idx_in_file[idx]))
x = torch_tensor(self$images[idx, , , ])

return(list(x = x))
},
.length = function() {
nrow(self$.data)
dim(self$images)[1L]
}
)

cifar10_ds = cifar10_ds_generator()
cifar10_ds = cifar10_ds_generator(data$images)

dd = as_data_descriptor(cifar10_ds, list(x = c(NA, 32, 32, 3)))
lt = lazy_tensor(dd)

dt = cbind(data, data.table(image = lt))
dt = data.table(
class = factor(data$labels, labels = data$class_names),
image = lt,
split = factor(rep(c("train", "test"), c(50000, 10000))),
..row_id = seq_len(60000)
)

DataBackendDataTable$new(data = dt, primary_key = "..row_id")
}
Expand All @@ -154,8 +101,6 @@ load_task_cifar10 = function(id = "cifar10") {
backend$hash = "mlr3torch::mlr_tasks_cifar10"
task$man = "mlr3torch::mlr_tasks_cifar"

task$filter(1:50000)

return(task)
}

Expand All @@ -164,50 +109,51 @@ register_task("cifar10", load_task_cifar10)
constructor_cifar100 = function(path) {
require_namespaces("torchvision")

torchvision::cifar100_dataset(root = path, download = TRUE)
tv_ds_train = torchvision::cifar100_dataset(root = path, download = TRUE)
tv_data_train = tv_ds_train$.getitem(1:50000)

tv_ds_test = torchvision::cifar100_dataset(root = path, train = FALSE, download = FALSE)
tv_data_test = tv_ds_test$.getitem(1:10000)

train_file = file.path(path, "cifar-100-binary", "train.bin")
test_file = file.path(path, "cifar-100-binary", "test.bin")
labels = c(tv_data_train$y, tv_data_test$y)
images = array(c(tv_data_train$x, tv_data_test$x), dim = c(60000, 32, 32, 3))

train_labels_ints = read_cifar_labels_batch(train_file, type = 100)
class_names = readLines(file.path(path, "cifar-100-binary", "fine_label_names.txt"))

data.table(
class = factor(c(train_labels_ints, rep(NA, times = 10000)), labels = class_names),
file = c(rep(train_file, 50000),
rep(test_file, 10000)),
idx_in_file = c(1:50000, 1:10000),
split = factor(rep(c("train", "test"), c(50000, 10000))),
..row_id = seq_len(60000)
)

return(list(labels = labels, images = images, class_names = class_names))
}

load_task_cifar100 = function(id = "cifar100") {
cached_constructor = function(backend) {
data = cached(constructor_cifar100, "datasets", "cifar100")$data

cifar100_ds_generator = torch::dataset(
initialize = function() {
self$.data = data
initialize = function(images) {
self$images = images
},
.getitem = function(idx) {
force(idx)

x = torch_tensor(read_cifar_image(self$.data$file[idx], self$.data$idx_in_file[idx], type = 100))
x = torch_tensor(self$images[idx, , , ])

return(list(x = x))
},
.length = function() {
nrow(self$.data)
dim(self$images)[1L]
}
)

cifar100_ds = cifar100_ds_generator()
cifar100_ds = cifar100_ds_generator(data$images)

dd = as_data_descriptor(cifar100_ds, list(x = c(NA, 32, 32, 3)))
lt = lazy_tensor(dd)

dt = cbind(data, data.table(image = lt))
dt = data.table(
class = factor(data$labels, labels = data$class_names),
image = lt,
split = factor(rep(c("train", "test"), c(50000, 10000))),
..row_id = seq_len(60000)
)

DataBackendDataTable$new(data = dt, primary_key = "..row_id")
}
Expand All @@ -231,8 +177,6 @@ load_task_cifar100 = function(id = "cifar100") {
backend$hash = "mlr3torch::mlr_tasks_cifar100"
task$man = "mlr3torch::mlr_tasks_cifar"

task$filter(1:50000)

return(task)
}

Expand Down
20 changes: 9 additions & 11 deletions data-raw/cifar.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ library(torchvision)
constructor_cifar10 = function(path) {
require_namespaces("torchvision")

class_names = readLines(file.path(path, "cifar-10-batches-bin", "batches.meta.txt"))
class_names = class_names[class_names != ""]

tv_ds_train = torchvision::cifar10_dataset(root = path, download = TRUE)
tv_ds_train = torchvision::cifar10_dataset(root = path, train = TRUE, download = TRUE)
tv_data_train = tv_ds_train$.getitem(1:50000)

tv_ds_test = torchvision::cifar10_dataset(root = path, train = fALSE, download = FALSE)
tv_ds_test = torchvision::cifar10_dataset(root = path, train = FALSE, download = FALSE)
tv_data_test = tv_ds_test$.getitem(1:10000)

labels = c(tv_data_train$y, tv_data_test$y)
images = array(c(tv_data_train$x, tv_data_test$x), dim = c(60000, 32, 32, 3))

class_names = readLines(file.path(path, "cifar-10-batches-bin", "batches.meta.txt"))
class_names = class_names[class_names != ""]

return(list(labels = labels, images = images, class_names = class_names))
}

Expand Down Expand Up @@ -73,19 +73,17 @@ path = file.path(get_cache_dir(), "datasets", "cifar100", "raw")
constructor_cifar100 = function(path) {
require_namespaces("torchvision")

class_names = readLines(file.path(path, "cifar-100-binary", "fine_label_names.txt"))

class_names = class_names[class_names != ""]

tv_ds_train = torchvision::cifar100_dataset(root = path, download = TRUE)
tv_ds_train = torchvision::cifar100_dataset(root = path, train = TRUE, download = TRUE)
tv_data_train = tv_ds_train$.getitem(1:50000)

tv_ds_test = torchvision::cifar100_dataset(root = path, train = fALSE, download = FALSE)
tv_ds_test = torchvision::cifar100_dataset(root = path, train = FALSE, download = FALSE)
tv_data_test = tv_ds_test$.getitem(1:10000)

labels = c(tv_data_train$y, tv_data_test$y)
images = array(c(tv_data_train$x, tv_data_test$x), dim = c(60000, 32, 32, 3))

class_names = readLines(file.path(path, "cifar-100-binary", "fine_label_names.txt"))

return(list(labels = labels, images = images, class_names = class_names))
}

Expand Down
Binary file modified inst/col_info/cifar10.rds
Binary file not shown.
Binary file modified inst/col_info/cifar100.rds
Binary file not shown.
8 changes: 4 additions & 4 deletions tests/testthat/test_TaskClassif_cifar.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ test_that("CIFAR-10 works", {
withr::local_options(mlr3torch.cache = TRUE)
task = tsk("cifar10")

expect_equal(task$nrow, 50000)
expect_equal(task$nrow, 60000)

task$filter(1:10)
expect_equal(task$id, "cifar10")
Expand All @@ -17,14 +17,14 @@ test_that("CIFAR-10 works", {
expect_true("cifar-10-batches-bin" %in% list.files(file.path(get_cache_dir(), "datasets", "cifar10", "raw")))
expect_true("data.rds" %in% list.files(file.path(get_cache_dir(), "datasets", "cifar10")))
expect_equal(task$backend$nrow, 60000)
expect_equal(task$backend$ncol, 6)
expect_equal(task$backend$ncol, 4)
})

test_that("CIFAR-100 works", {
withr::local_options(mlr3torch.cache = TRUE)
task = tsk("cifar100")

expect_equal(task$nrow, 50000)
expect_equal(task$nrow, 60000)

task$filter(1:10)
expect_equal(task$id, "cifar100")
Expand All @@ -37,5 +37,5 @@ test_that("CIFAR-100 works", {
expect_true("cifar-100-binary" %in% list.files(file.path(get_cache_dir(), "datasets", "cifar100", "raw")))
expect_true("data.rds" %in% list.files(file.path(get_cache_dir(), "datasets", "cifar100")))
expect_equal(task$backend$nrow, 60000)
expect_equal(task$backend$ncol, 6)
expect_equal(task$backend$ncol, 4)
})

0 comments on commit 958f779

Please sign in to comment.