Skip to content

Commit

Permalink
update exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Oct 15, 2024
1 parent f3aef79 commit f1d00da
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 46 deletions.
2 changes: 0 additions & 2 deletions torchdata/nodes/_apply_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ def _apply_udf(worker_id, in_q, out_q, in_order, udf, stop_event):
while True:
if stop_event.is_set() and in_q.empty():
break
else:
print(worker_id, stop_event, in_q)

try: # TODO: implement in-order execution
x = in_q.get(block=True, timeout=1.0)
Expand Down
9 changes: 1 addition & 8 deletions torchdata/nodes/base_node.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import threading
from typing import Generic, Iterator, TypeVar

import torch.utils.data
from torch._utils import ExceptionWrapper


T = TypeVar("T")
Expand All @@ -18,9 +16,4 @@ def iterator(self) -> Iterator[T]:
raise NotImplementedError()

def __iter__(self) -> Iterator[T]:
# Do not override this method, override iterator() instead.
for x in self.iterator():
if isinstance(x, ExceptionWrapper) and threading.main_thread() is threading.current_thread():
# We re-raise exceptions as early as possible once we're in the main thread
x.reraise()
yield x
yield from self.iterator()
2 changes: 1 addition & 1 deletion torchdata/nodes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Batcher(BaseNode[List[T]]):
def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True) -> None:
def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True):
self.source = source
self.batch_size = batch_size
self.drop_last = drop_last
Expand Down
46 changes: 26 additions & 20 deletions torchdata/nodes/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# import multiprocessing as mp
import queue
import threading
from typing import Callable, Iterator, List, Literal, TypeVar, Union
from typing import Callable, Iterator, List, Literal, Optional, TypeVar, Union

import torch.multiprocessing as mp

Expand All @@ -29,10 +29,7 @@ def __init__(

def iterator(self) -> Iterator[T]:
for item in self.source:
if isinstance(item, ExceptionWrapper):
yield item
else:
yield self.map_fn(item)
yield self.map_fn(item)


class ParallelMapper(BaseNode[T]):
Expand All @@ -59,29 +56,32 @@ def __init__(
self._read_thread = threading.Thread(
target=_populate_queue, args=(self.source, self.in_q, self._stop, self.sem)
)
self._method = method

self._map_threads: List[Union[threading.Thread, mp.Process]] = []
for worker_id in range(self.num_workers):
args = (
worker_id,
self.in_q,
self.out_q,
self.in_order,
self.udf,
self._stop if method == "thread" else self._mp_stop,
)
self._map_threads.append(
threading.Thread(target=_apply_udf, args=args)
if method == "thread"
else mp.Process(target=_apply_udf, args=args)
)

def iterator(self) -> Iterator[T]:
if not self._started:
for worker_id in range(self.num_workers):
args = (
worker_id,
self.in_q,
self.out_q,
self.in_order,
self.udf,
self._stop if self._method == "thread" else self._mp_stop,
)
self._map_threads.append(
threading.Thread(target=_apply_udf, args=args)
if self._method == "thread"
else mp.Process(target=_apply_udf, args=args)
)
self._read_thread.start()
for t in self._map_threads:
t.start()
self._started = True

def iterator(self) -> Iterator[T]:
exception: Optional[ExceptionWrapper] = None
while True:
if self._stop.is_set():
yield from self._flush_queues()
Expand All @@ -95,10 +95,16 @@ def iterator(self) -> Iterator[T]:
if isinstance(item, StopIteration):
yield from self._flush_queues()
break
elif isinstance(item, ExceptionWrapper):
exception = item
break
yield item

self._stop.set()
self._mp_stop.set()
if exception is not None:
exception.reraise()
self._shutdown()

def _flush_queues(self):
while self.sem._value < 2 * self.num_workers:
Expand Down
23 changes: 16 additions & 7 deletions torchdata/nodes/pin_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,26 @@ def iterator(self) -> Iterator[T]:
self.read_thread.start()
self.pin_memory_thread.start()
self._started = True

exception: Optional[ExceptionWrapper] = None
while True:
try:
value = self.out_q.get(block=True, timeout=0.1)
self.sem.release()
if isinstance(value, StopIteration):
break
yield value
if isinstance(value, ExceptionWrapper):
break
item = self.out_q.get(block=True, timeout=0.1)
except queue.Empty:
continue

self.sem.release()
if isinstance(item, StopIteration):
break
elif isinstance(item, ExceptionWrapper):
exception = item
break
yield item

self._populate_queue_stop_event.set()
self._pin_memory_stop_event.set()
if exception is not None:
exception.reraise()
self._shutdown()

def __del__(self):
Expand Down
21 changes: 13 additions & 8 deletions torchdata/nodes/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,23 @@ def iterator(self) -> Iterator[T]:
args=(self.source, self.q, self._stop_event, self.sem),
)
self.thread.start()

exception: Optional[ExceptionWrapper] = None
while True:
try:
value = self.q.get(block=True, timeout=0.1)
self.sem.release()
if isinstance(value, StopIteration):
break
yield value
if isinstance(value, ExceptionWrapper):
self._stop_event.set()
break
item = self.q.get(block=True, timeout=0.1)
except queue.Empty:
continue
self.sem.release()
if isinstance(item, StopIteration):
break
elif isinstance(item, ExceptionWrapper):
exception = item
break
yield item
self._stop_event.set()
if exception is not None:
exception.reraise()
self._shutdown()

def __del__(self):
Expand Down

0 comments on commit f1d00da

Please sign in to comment.