@@ -347,6 +347,99 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":
347
347
return _MultiProcessingDataLoaderIterPatch (self )
348
348
349
349
350
+ def _wrapper (fetcher : Any , func : Callable , tracer : Any , profile : int , profile_dir : str ) -> Callable :
351
+ counter = 0
352
+
353
+ def wrap (* args : Any , ** kwargs : Any ) -> Any :
354
+ nonlocal counter
355
+ result = func (* args , ** kwargs )
356
+
357
+ if tracer .enable and counter == profile :
358
+ tracer .stop ()
359
+ tracer .save ()
360
+ print (
361
+ f"Saved { os .path .join (profile_dir , 'result.json' )} file after { profile } batches."
362
+ "Use chrome://tracing/ to view it."
363
+ )
364
+ fetcher .fetch = func
365
+
366
+ counter += 1
367
+ return result
368
+
369
+ return wrap
370
+
371
+
372
+ class _ProfileWorkerLoop :
373
+ """Wrap the PyTorch DataLoader WorkerLoop to add profiling."""
374
+
375
+ def __init__ (self , profile : Union [int , bool ], profile_dir : Optional [str ] = None ):
376
+ self ._profile = profile
377
+ self ._profile_dir = profile_dir if profile_dir else os .getcwd ()
378
+
379
+ def __call__ (
380
+ self ,
381
+ dataset_kind : Any ,
382
+ dataset : Any ,
383
+ index_queue : Any ,
384
+ data_queue : Any ,
385
+ done_event : Any ,
386
+ auto_collation : Any ,
387
+ collate_fn : Any ,
388
+ drop_last : Any ,
389
+ base_seed : Any ,
390
+ init_fn : Any ,
391
+ worker_id : Any ,
392
+ * args : Any ,
393
+ ** kwargs : Any ,
394
+ ) -> None :
395
+ from torch .utils .data ._utils import worker
396
+ from viztracer import VizTracer
397
+
398
+ if worker_id == 0 :
399
+ output_file = os .path .join (self ._profile_dir , "result.json" )
400
+
401
+ if os .path .exists (output_file ):
402
+ os .remove (output_file )
403
+
404
+ tracer = VizTracer (output_file = output_file , verbose = 0 )
405
+ tracer .start ()
406
+
407
+ # Reload to remove the patching
408
+ reloaded_worker = reload (worker )
409
+ create_fetcher = _DatasetKind .create_fetcher
410
+ fetcher = None
411
+
412
+ def create_fetcher_fn (* args : Any , ** kwargs : Any ) -> "_BaseDatasetFetcher" :
413
+ nonlocal fetcher
414
+ fetcher = create_fetcher (* args , ** kwargs )
415
+
416
+ if worker_id == 0 and isinstance (self ._profile , int ):
417
+ fetcher .fetch = _wrapper (fetcher , fetcher .fetch , tracer , self ._profile , self ._profile_dir )
418
+ return fetcher
419
+
420
+ _DatasetKind .create_fetcher = create_fetcher_fn # type: ignore
421
+
422
+ reloaded_worker ._worker_loop (
423
+ dataset_kind ,
424
+ dataset ,
425
+ index_queue ,
426
+ data_queue ,
427
+ done_event ,
428
+ auto_collation ,
429
+ collate_fn ,
430
+ drop_last ,
431
+ base_seed ,
432
+ init_fn ,
433
+ worker_id ,
434
+ * args ,
435
+ ** kwargs ,
436
+ )
437
+
438
+ if worker_id == 0 and isinstance (self ._profile , bool ):
439
+ tracer .stop ()
440
+ tracer .save ()
441
+
442
+
350
443
class _StreamingMultiProcessingDataLoaderIter (_MultiProcessingDataLoaderIter ):
351
444
def __init__ (self , loader : DataLoader ) -> None :
352
445
self ._loader = loader
@@ -355,6 +448,15 @@ def __init__(self, loader: DataLoader) -> None:
355
448
if self ._loader ._latest_worker_idx > 0
356
449
else []
357
450
)
451
+ self ._num_workers = loader .num_workers
452
+
453
+ distributed_env = _DistributedEnv .detect ()
454
+
455
+ if self ._loader ._profile_bactches and distributed_env .global_rank == 0 and _VIZ_TRACKER_AVAILABLE :
456
+ from torch .utils .data ._utils import worker
457
+
458
+ worker ._worker_loop = _ProfileWorkerLoop (self ._loader ._profile_bactches , self ._loader ._profile_dir )
459
+
358
460
super ().__init__ (loader )
359
461
360
462
def _try_put_index (self ) -> None :
@@ -388,6 +490,9 @@ def __init__(
388
490
* args : Any ,
389
491
batch_size : int = 1 ,
390
492
num_workers : int = 0 ,
493
+ profile_bactches : Union [bool , int ] = False ,
494
+ profile_dir : Optional [str ] = None ,
495
+ prefetch_factor : Optional [int ] = None ,
391
496
** kwargs : Any ,
392
497
) -> None : # pyright: ignore
393
498
if not isinstance (dataset , (StreamingDataset , CombinedStreamingDataset )):
@@ -396,17 +501,32 @@ def __init__(
396
501
f" Found { dataset } ."
397
502
)
398
503
504
+ if profile_bactches and not _VIZ_TRACKER_AVAILABLE :
505
+ raise ModuleNotFoundError ("To use profile_bactches, viztracer is required. Run `pip install viztracer`" )
506
+
507
+ if profile_bactches and num_workers == 0 :
508
+ raise ValueError ("Profiling is supported only with num_workers >= 1." )
509
+
399
510
self .current_epoch = 0
400
511
self .batch_size = batch_size
401
512
self .num_workers = num_workers
513
+ self ._profile_bactches = profile_bactches
514
+ self ._profile_dir = profile_dir
402
515
self ._num_samples_yielded_streaming = 0
403
516
self ._num_samples_yielded_combined : Dict [int , List [Any ]] = {}
404
517
self .rng_state : Optional [Any ] = None
405
518
self ._worker_idx = cycle (list (range (self .num_workers if self .num_workers > 0 else 1 )))
406
519
self ._worker_idx_iter : Optional [Any ] = None
407
520
self ._latest_worker_idx = 0
408
521
self .restore = False
409
- super ().__init__ (dataset , * args , batch_size = batch_size , num_workers = num_workers , ** kwargs ) # type: ignore
522
+ super ().__init__ (
523
+ dataset ,
524
+ * args ,
525
+ batch_size = batch_size ,
526
+ num_workers = num_workers ,
527
+ prefetch_factor = (10 if num_workers > 0 else None ) if prefetch_factor is None else prefetch_factor ,
528
+ ** kwargs ,
529
+ ) # type: ignore
410
530
411
531
def __iter__ (self ) -> Any :
412
532
if not self .restore :
0 commit comments