Skip to content

Commit 012f68d

Browse files
authored
StreamingDataloader: Add profiling support (#19338)
1 parent 925357d commit 012f68d

File tree

3 files changed

+144
-1
lines changed

3 files changed

+144
-1
lines changed

requirements/data/test.txt

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ pytest-cov ==4.1.0
44
pytest-timeout ==2.1.0
55
pytest-rerunfailures ==12.0
66
pytest-random-order ==1.1.0
7+
viztracer

src/lightning/data/streaming/dataloader.py

+121-1
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,99 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":
347347
return _MultiProcessingDataLoaderIterPatch(self)
348348

349349

350+
def _wrapper(fetcher: Any, func: Callable, tracer: Any, profile: int, profile_dir: str) -> Callable:
351+
counter = 0
352+
353+
def wrap(*args: Any, **kwargs: Any) -> Any:
354+
nonlocal counter
355+
result = func(*args, **kwargs)
356+
357+
if tracer.enable and counter == profile:
358+
tracer.stop()
359+
tracer.save()
360+
print(
361+
f"Saved {os.path.join(profile_dir, 'result.json')} file after {profile} batches."
362+
"Use chrome://tracing/ to view it."
363+
)
364+
fetcher.fetch = func
365+
366+
counter += 1
367+
return result
368+
369+
return wrap
370+
371+
372+
class _ProfileWorkerLoop:
373+
"""Wrap the PyTorch DataLoader WorkerLoop to add profiling."""
374+
375+
def __init__(self, profile: Union[int, bool], profile_dir: Optional[str] = None):
376+
self._profile = profile
377+
self._profile_dir = profile_dir if profile_dir else os.getcwd()
378+
379+
def __call__(
380+
self,
381+
dataset_kind: Any,
382+
dataset: Any,
383+
index_queue: Any,
384+
data_queue: Any,
385+
done_event: Any,
386+
auto_collation: Any,
387+
collate_fn: Any,
388+
drop_last: Any,
389+
base_seed: Any,
390+
init_fn: Any,
391+
worker_id: Any,
392+
*args: Any,
393+
**kwargs: Any,
394+
) -> None:
395+
from torch.utils.data._utils import worker
396+
from viztracer import VizTracer
397+
398+
if worker_id == 0:
399+
output_file = os.path.join(self._profile_dir, "result.json")
400+
401+
if os.path.exists(output_file):
402+
os.remove(output_file)
403+
404+
tracer = VizTracer(output_file=output_file, verbose=0)
405+
tracer.start()
406+
407+
# Reload to remove the patching
408+
reloaded_worker = reload(worker)
409+
create_fetcher = _DatasetKind.create_fetcher
410+
fetcher = None
411+
412+
def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher":
413+
nonlocal fetcher
414+
fetcher = create_fetcher(*args, **kwargs)
415+
416+
if worker_id == 0 and isinstance(self._profile, int):
417+
fetcher.fetch = _wrapper(fetcher, fetcher.fetch, tracer, self._profile, self._profile_dir)
418+
return fetcher
419+
420+
_DatasetKind.create_fetcher = create_fetcher_fn # type: ignore
421+
422+
reloaded_worker._worker_loop(
423+
dataset_kind,
424+
dataset,
425+
index_queue,
426+
data_queue,
427+
done_event,
428+
auto_collation,
429+
collate_fn,
430+
drop_last,
431+
base_seed,
432+
init_fn,
433+
worker_id,
434+
*args,
435+
**kwargs,
436+
)
437+
438+
if worker_id == 0 and isinstance(self._profile, bool):
439+
tracer.stop()
440+
tracer.save()
441+
442+
350443
class _StreamingMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
351444
def __init__(self, loader: DataLoader) -> None:
352445
self._loader = loader
@@ -355,6 +448,15 @@ def __init__(self, loader: DataLoader) -> None:
355448
if self._loader._latest_worker_idx > 0
356449
else []
357450
)
451+
self._num_workers = loader.num_workers
452+
453+
distributed_env = _DistributedEnv.detect()
454+
455+
if self._loader._profile_bactches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE:
456+
from torch.utils.data._utils import worker
457+
458+
worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_bactches, self._loader._profile_dir)
459+
358460
super().__init__(loader)
359461

360462
def _try_put_index(self) -> None:
@@ -388,6 +490,9 @@ def __init__(
388490
*args: Any,
389491
batch_size: int = 1,
390492
num_workers: int = 0,
493+
profile_bactches: Union[bool, int] = False,
494+
profile_dir: Optional[str] = None,
495+
prefetch_factor: Optional[int] = None,
391496
**kwargs: Any,
392497
) -> None: # pyright: ignore
393498
if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)):
@@ -396,17 +501,32 @@ def __init__(
396501
f" Found {dataset}."
397502
)
398503

504+
if profile_bactches and not _VIZ_TRACKER_AVAILABLE:
505+
raise ModuleNotFoundError("To use profile_bactches, viztracer is required. Run `pip install viztracer`")
506+
507+
if profile_bactches and num_workers == 0:
508+
raise ValueError("Profiling is supported only with num_workers >= 1.")
509+
399510
self.current_epoch = 0
400511
self.batch_size = batch_size
401512
self.num_workers = num_workers
513+
self._profile_bactches = profile_bactches
514+
self._profile_dir = profile_dir
402515
self._num_samples_yielded_streaming = 0
403516
self._num_samples_yielded_combined: Dict[int, List[Any]] = {}
404517
self.rng_state: Optional[Any] = None
405518
self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1)))
406519
self._worker_idx_iter: Optional[Any] = None
407520
self._latest_worker_idx = 0
408521
self.restore = False
409-
super().__init__(dataset, *args, batch_size=batch_size, num_workers=num_workers, **kwargs) # type: ignore
522+
super().__init__(
523+
dataset,
524+
*args,
525+
batch_size=batch_size,
526+
num_workers=num_workers,
527+
prefetch_factor=(10 if num_workers > 0 else None) if prefetch_factor is None else prefetch_factor,
528+
**kwargs,
529+
) # type: ignore
410530

411531
def __iter__(self) -> Any:
412532
if not self.restore:

tests/tests_data/streaming/test_dataloader.py

+22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import os
2+
3+
import pytest
14
import torch
25
from lightning.data.streaming import CombinedStreamingDataset, StreamingDataLoader
6+
from lightning.data.streaming import dataloader as streaming_dataloader_module
37
from torch import tensor
48

59

@@ -70,3 +74,21 @@ def test_streaming_dataloader():
7074
"latest_worker_idx": 0,
7175
"num_samples_yielded": {0: [11, 9]},
7276
}
77+
78+
79+
@pytest.mark.parametrize("profile", [2, True])
80+
def test_dataloader_profiling(profile, tmpdir, monkeypatch):
81+
monkeypatch.setattr(streaming_dataloader_module, "_VIZ_TRACKER_AVAILABLE", True)
82+
83+
dataset = TestCombinedStreamingDataset(
84+
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
85+
)
86+
dataloader = StreamingDataLoader(
87+
dataset, batch_size=2, profile_bactches=profile, profile_dir=str(tmpdir), num_workers=1
88+
)
89+
dataloader_iter = iter(dataloader)
90+
batches = []
91+
for batch in dataloader_iter:
92+
batches.append(batch)
93+
94+
assert os.path.exists(os.path.join(tmpdir, "result.json"))

0 commit comments

Comments
 (0)