Skip to content

Commit

Permalink
Remove dataloader v2 usage in data/benchmarks (#1351)
Browse files Browse the repository at this point in the history
* lint

* run pre-commit again
  • Loading branch information
divyanshk authored Nov 1, 2024
1 parent 7b0de83 commit b255542
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 123 deletions.
60 changes: 0 additions & 60 deletions benchmarks/torchvision_classification/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@
# LICENSE file in the root directory of this source tree.

import itertools
import os
import random
from functools import partial
from pathlib import Path

import torch
import torch.distributed as dist
import torchvision
from PIL import Image
from torchdata.datapipes.iter import FileLister, IterDataPipe


# TODO: maybe infinite buffer can / is already natively supported by torchdata?
Expand All @@ -24,26 +21,6 @@
IMAGENET_TEST_LEN = 50_000


class _LenSetter(IterDataPipe):
# TODO: Ideally, we woudn't need this extra class
def __init__(self, dp, root):
self.dp = dp

if "train" in str(root):
self.size = IMAGENET_TRAIN_LEN
elif "val" in str(root):
self.size = IMAGENET_TEST_LEN
else:
raise ValueError("oops?")

def __iter__(self):
yield from self.dp

def __len__(self):
# TODO The // world_size part shouldn't be needed. See https://github.com/pytorch/data/issues/533
return self.size // dist.get_world_size()


def _decode(path, root, category_to_int):
category = Path(path).relative_to(root).parts[0]

Expand All @@ -58,22 +35,6 @@ def _apply_tranforms(img_and_label, transforms):
return transforms(img), label


def make_dp(root, transforms):

root = Path(root).expanduser().resolve()
categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
category_to_int = {category: i for (i, category) in enumerate(categories)}

dp = FileLister(str(root), recursive=True, masks=["*.JPEG"])

dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False).sharding_filter()
dp = dp.map(partial(_decode, root=root, category_to_int=category_to_int))
dp = dp.map(partial(_apply_tranforms, transforms=transforms))

dp = _LenSetter(dp, root=root)
return dp


class PreLoadedMapStyle:
# All the data is pre-loaded and transformed in __init__, so the DataLoader should be crazy fast.
# This is just to assess how fast a model could theoretically be trained if there was no data bottleneck at all.
Expand All @@ -89,27 +50,6 @@ def __getitem__(self, idx):
return self.samples[idx % len(self.samples)]


class _PreLoadedDP(IterDataPipe):
# Same as above, but this is a DataPipe
def __init__(self, root, transforms, buffer_size=100):
dataset = torchvision.datasets.ImageFolder(root, transform=transforms)
self.size = len(dataset)
self.samples = [dataset[torch.randint(0, len(dataset), size=(1,)).item()] for i in range(buffer_size)]
# Note: the rng might be different across DDP workers so they'll all have different samples.
# But we don't care about accuracy here so whatever.

def __iter__(self):
for idx in range(self.size):
yield self.samples[idx % len(self.samples)]


def make_pre_loaded_dp(root, transforms):
dp = _PreLoadedDP(root=root, transforms=transforms)
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False).sharding_filter()
dp = _LenSetter(dp, root=root)
return dp


class MapStyleToIterable(torch.utils.data.IterableDataset):
# This converts a MapStyle dataset into an iterable one.
# Not sure this kind of Iterable dataset is actually useful to benchmark. It
Expand Down
144 changes: 81 additions & 63 deletions benchmarks/torchvision_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import torchvision
import utils
from torch import nn
from torchdata.dataloader2 import adapter, DataLoader2, MultiProcessingReadingService


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args):
Expand Down Expand Up @@ -111,23 +110,19 @@ def create_data_loaders(args):
train_dir = os.path.join(dataset_dir, "train")
val_dir = os.path.join(dataset_dir, "val")

val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
val_resize_size, val_crop_size, train_crop_size = (
args.val_resize_size,
args.val_crop_size,
args.train_crop_size,
)

if args.no_transforms:
train_preset = val_preset = helpers.no_transforms
else:
train_preset = presets.ClassificationPresetTrain(crop_size=train_crop_size)
val_preset = presets.ClassificationPresetEval(crop_size=val_crop_size, resize_size=val_resize_size)

if args.ds_type == "dp":
builder = helpers.make_pre_loaded_dp if args.preload_ds else helpers.make_dp
train_dataset = builder(train_dir, transforms=train_preset)
val_dataset = builder(val_dir, transforms=val_preset)

train_sampler = val_sampler = None
train_shuffle = True

elif args.ds_type == "iterable":
if args.ds_type == "iterable":
train_dataset = torchvision.datasets.ImageFolder(train_dir, transform=train_preset)
train_dataset = helpers.MapStyleToIterable(train_dataset, shuffle=True)

Expand All @@ -149,45 +144,22 @@ def create_data_loaders(args):
else:
raise ValueError(f"Invalid value for args.ds_type ({args.ds_type})")

data_loader_arg = args.data_loader.lower()
if data_loader_arg == "v1":
train_data_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=train_shuffle,
sampler=train_sampler,
num_workers=args.workers,
pin_memory=True,
drop_last=True,
)
val_data_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
sampler=val_sampler,
num_workers=args.workers,
pin_memory=True,
)
elif data_loader_arg == "v2":
if args.ds_type != "dp":
raise ValueError("DataLoader2 only works with datapipes.")

# Note: we are batching and collating here *after the transforms*, which is consistent with DLV1.
# But maybe it would be more efficient to do that before, so that the transforms can work on batches??

train_dataset = train_dataset.batch(args.batch_size, drop_last=True).collate()
train_data_loader = DataLoader2(
train_dataset,
datapipe_adapter_fn=adapter.Shuffle(),
reading_service=MultiProcessingReadingService(num_workers=args.workers),
)

val_dataset = val_dataset.batch(args.batch_size, drop_last=True).collate() # TODO: Do we need drop_last here?
val_data_loader = DataLoader2(
val_dataset,
reading_service=MultiProcessingReadingService(num_workers=args.workers),
)
else:
raise ValueError(f"invalid data-loader param. Got {args.data_loader}")
train_data_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=train_shuffle,
sampler=train_sampler,
num_workers=args.workers,
pin_memory=True,
drop_last=True,
)
val_data_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
sampler=val_sampler,
num_workers=args.workers,
pin_memory=True,
)

return train_data_loader, val_data_loader, train_sampler

Expand Down Expand Up @@ -266,17 +238,47 @@ def get_args_parser(add_help=True):

parser.add_argument("--fs", default="fsx", type=str)
parser.add_argument("--model", default="resnet18", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
"--device",
default="cuda",
type=str,
help="device (Use cuda or cpu Default: cuda)",
)
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument(
"-j", "--workers", default=12, type=int, metavar="N", help="number of data loading workers (default: 16)"
"-b",
"--batch-size",
default=32,
type=int,
help="images per gpu, the total batch size is $NGPU x batch_size",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"-j",
"--workers",
default=12,
type=int,
metavar="N",
help="number of data loading workers (default: 16)",
)
parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
parser.add_argument(
"--lr-step-size",
default=30,
type=int,
help="decrease lr every step-size epochs",
)
parser.add_argument(
"--lr-gamma",
default=0.1,
type=float,
help="decrease lr by a factor of lr-gamma",
)
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")

parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
Expand All @@ -291,27 +293,43 @@ def get_args_parser(add_help=True):

# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
parser.add_argument(
"--dist-url",
default="env://",
type=str,
help="url used to set up distributed training",
)

parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
"--use-deterministic-algorithms",
action="store_true",
help="Forces the use of deterministic algorithms only.",
)
parser.add_argument(
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
"--val-resize-size",
default=256,
type=int,
help="the resize size used for validation (default: 256)",
)
parser.add_argument(
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
"--val-crop-size",
default=224,
type=int,
help="the central crop size used for validation (default: 224)",
)
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
"--train-crop-size",
default=224,
type=int,
help="the random crop size used for training (default: 224)",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

parser.add_argument(
"--ds-type",
default="mapstyle",
type=str,
help="'dp' or 'iterable' or 'mapstyle' (for regular indexable datasets)",
help="'iterable' or 'mapstyle' (for regular indexable datasets)",
)

parser.add_argument(
Expand Down Expand Up @@ -341,7 +359,7 @@ def get_args_parser(add_help=True):
"--data-loader",
default="V1",
type=str,
help="'V1' or 'V2'. V2 only works for datapipes",
help="'V1' or 'V2'. Last stable release with DataloaderV2 is 0.9.0.",
)

return parser
Expand Down

0 comments on commit b255542

Please sign in to comment.