Skip to content

Commit

Permalink
in manual file 100 hould be ok, still need to modify 10
Browse files Browse the repository at this point in the history
  • Loading branch information
cxzhang4 committed Jan 6, 2025
1 parent 0c7fbd3 commit 4b756d5
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions data-raw/cifar.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,40 +92,43 @@ constructor_cifar100 = function(path) {
# )
tv_ds = torchvision::cifar100_dataset(root = path, download = TRUE)
class_names = readLines(file.path(path, "cifar-100-binary", "fine_label_names.txt"))

# TODO: use $.getitem() instead of $x
# data.table(
# class = factor(as.integer(tv_ds$y), labels = class_names),
# image = self$x,
# split = factor(rep(c("train", "test"), c(50000, 10000))),
# ..row_id = seq_len(60000)
# )

tv_data = tv_ds$.getitem(1:60000)

tv_data$class_names = class_names

tv_data
}

data = constructor_cifar100(path)

cifar100_ds_generator = torch::dataset(
initialize = function() {
self$.data = data
initialize = function(data) {
self$.img_arr = data$x
},
.getitem = function(idx) {
force(idx)

x = torch_tensor(read_cifar_image(self$.data$file[idx], self$.data$idx_in_file[idx], type = 100))
x = torch_tensor(self$.img_arr[i, , , ])

return(list(x = x))
},
.length = function() {
nrow(self$.data)
dim(self$.img_arr)[1]
}
)

cifar100_ds = cifar100_ds_generator()
cifar100_ds = cifar100_ds_generator(data)

dd = as_data_descriptor(cifar100_ds, list(x = c(NA, 32, 32, 3)))
lt = lazy_tensor(dd)

dt = cbind(data, data.table(image = lt))
dt = data.table(
class = factor(data$y, labels = data$class_names),
image = lt,
split = factor(rep(c("train", "test"), c(50000, 10000))),
..row_id = seq_len(60000)
)

task = as_task_classif(dt, target = "class")

Expand Down

0 comments on commit 4b756d5

Please sign in to comment.