From f1d00da0c0847fe83e8c6607db8bd19fe9255238 Mon Sep 17 00:00:00 2001 From: andrewkho Date: Mon, 14 Oct 2024 17:44:05 -0700 Subject: [PATCH] update exception handling --- torchdata/nodes/_apply_udf.py | 2 -- torchdata/nodes/base_node.py | 9 +------ torchdata/nodes/batch.py | 2 +- torchdata/nodes/map.py | 46 ++++++++++++++++++++--------------- torchdata/nodes/pin_memory.py | 23 ++++++++++++------ torchdata/nodes/prefetch.py | 21 ++++++++++------ 6 files changed, 57 insertions(+), 46 deletions(-) diff --git a/torchdata/nodes/_apply_udf.py b/torchdata/nodes/_apply_udf.py index 1e5546d07..0a661ce31 100644 --- a/torchdata/nodes/_apply_udf.py +++ b/torchdata/nodes/_apply_udf.py @@ -10,8 +10,6 @@ def _apply_udf(worker_id, in_q, out_q, in_order, udf, stop_event): while True: if stop_event.is_set() and in_q.empty(): break - else: - print(worker_id, stop_event, in_q) try: # TODO: implement in-order execution x = in_q.get(block=True, timeout=1.0) diff --git a/torchdata/nodes/base_node.py b/torchdata/nodes/base_node.py index d1b210341..277b5e0e2 100644 --- a/torchdata/nodes/base_node.py +++ b/torchdata/nodes/base_node.py @@ -1,8 +1,6 @@ -import threading from typing import Generic, Iterator, TypeVar import torch.utils.data -from torch._utils import ExceptionWrapper T = TypeVar("T") @@ -18,9 +16,4 @@ def iterator(self) -> Iterator[T]: raise NotImplementedError() def __iter__(self) -> Iterator[T]: - # Do not override this method, override iterator() instead. - for x in self.iterator(): - if isinstance(x, ExceptionWrapper) and threading.main_thread() is threading.current_thread(): - # We re-raise exceptions as early as possible once we're in the main thread - x.reraise() - yield x + yield from self.iterator() diff --git a/torchdata/nodes/batch.py b/torchdata/nodes/batch.py index bb260dfd1..780033915 100644 --- a/torchdata/nodes/batch.py +++ b/torchdata/nodes/batch.py @@ -4,7 +4,7 @@ class Batcher(BaseNode[List[T]]): - def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True) -> None: + def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True): self.source = source self.batch_size = batch_size self.drop_last = drop_last diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index 3a3078877..132059550 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -2,7 +2,7 @@ # import multiprocessing as mp import queue import threading -from typing import Callable, Iterator, List, Literal, TypeVar, Union +from typing import Callable, Iterator, List, Literal, Optional, TypeVar, Union import torch.multiprocessing as mp @@ -29,10 +29,7 @@ def __init__( def iterator(self) -> Iterator[T]: for item in self.source: - if isinstance(item, ExceptionWrapper): - yield item - else: - yield self.map_fn(item) + yield self.map_fn(item) class ParallelMapper(BaseNode[T]): @@ -59,29 +56,32 @@ def __init__( self._read_thread = threading.Thread( target=_populate_queue, args=(self.source, self.in_q, self._stop, self.sem) ) + self._method = method self._map_threads: List[Union[threading.Thread, mp.Process]] = [] - for worker_id in range(self.num_workers): - args = ( - worker_id, - self.in_q, - self.out_q, - self.in_order, - self.udf, - self._stop if method == "thread" else self._mp_stop, - ) - self._map_threads.append( - threading.Thread(target=_apply_udf, args=args) - if method == "thread" - else mp.Process(target=_apply_udf, args=args) - ) + + def iterator(self) -> Iterator[T]: if not self._started: + for worker_id in range(self.num_workers): + args = ( + worker_id, + self.in_q, + self.out_q, + self.in_order, + self.udf, + self._stop if self._method == "thread" else self._mp_stop, + ) + self._map_threads.append( + threading.Thread(target=_apply_udf, args=args) + if self._method == "thread" + else mp.Process(target=_apply_udf, args=args) + ) self._read_thread.start() for t in self._map_threads: t.start() self._started = True - def iterator(self) -> Iterator[T]: + exception: Optional[ExceptionWrapper] = None while True: if self._stop.is_set(): yield from self._flush_queues() @@ -95,10 +95,16 @@ def iterator(self) -> Iterator[T]: if isinstance(item, StopIteration): yield from self._flush_queues() break + elif isinstance(item, ExceptionWrapper): + exception = item + break yield item self._stop.set() self._mp_stop.set() + if exception is not None: + exception.reraise() + self._shutdown() def _flush_queues(self): while self.sem._value < 2 * self.num_workers: diff --git a/torchdata/nodes/pin_memory.py b/torchdata/nodes/pin_memory.py index e9e747ebc..98d32a02f 100644 --- a/torchdata/nodes/pin_memory.py +++ b/torchdata/nodes/pin_memory.py @@ -125,17 +125,26 @@ def iterator(self) -> Iterator[T]: self.read_thread.start() self.pin_memory_thread.start() self._started = True + + exception: Optional[ExceptionWrapper] = None while True: try: - value = self.out_q.get(block=True, timeout=0.1) - self.sem.release() - if isinstance(value, StopIteration): - break - yield value - if isinstance(value, ExceptionWrapper): - break + item = self.out_q.get(block=True, timeout=0.1) except queue.Empty: continue + + self.sem.release() + if isinstance(item, StopIteration): + break + elif isinstance(item, ExceptionWrapper): + exception = item + break + yield item + + self._populate_queue_stop_event.set() + self._pin_memory_stop_event.set() + if exception is not None: + exception.reraise() self._shutdown() def __del__(self): diff --git a/torchdata/nodes/prefetch.py b/torchdata/nodes/prefetch.py index 456f6dec7..de35d33ef 100644 --- a/torchdata/nodes/prefetch.py +++ b/torchdata/nodes/prefetch.py @@ -27,18 +27,23 @@ def iterator(self) -> Iterator[T]: args=(self.source, self.q, self._stop_event, self.sem), ) self.thread.start() + + exception: Optional[ExceptionWrapper] = None while True: try: - value = self.q.get(block=True, timeout=0.1) - self.sem.release() - if isinstance(value, StopIteration): - break - yield value - if isinstance(value, ExceptionWrapper): - self._stop_event.set() - break + item = self.q.get(block=True, timeout=0.1) except queue.Empty: continue + self.sem.release() + if isinstance(item, StopIteration): + break + elif isinstance(item, ExceptionWrapper): + exception = item + break + yield item + self._stop_event.set() + if exception is not None: + exception.reraise() self._shutdown() def __del__(self):