From 7d5c13ca24c0deb1903ecbffb3652b5e82222987 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 2 Feb 2024 18:19:58 +0000 Subject: [PATCH 01/14] update --- src/lightning/data/processing/__init__.py | 0 src/lightning/data/processing/readers.py | 109 ++++++++++++++++++ .../data/streaming/data_processor.py | 32 +++-- src/lightning/data/streaming/functions.py | 5 + 4 files changed, 136 insertions(+), 10 deletions(-) create mode 100644 src/lightning/data/processing/__init__.py create mode 100644 src/lightning/data/processing/readers.py diff --git a/src/lightning/data/processing/__init__.py b/src/lightning/data/processing/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py new file mode 100644 index 0000000000000..eec1587ee9fb1 --- /dev/null +++ b/src/lightning/data/processing/readers.py @@ -0,0 +1,109 @@ +import os +from abc import ABC, abstractmethod +from lightning_utilities.core.imports import RequirementCache +from typing import Any, List, Optional +from dataclasses import dataclass +from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks +from lightning.data.utilities.env import _DistributedEnv + + +_POLARS_AVAILABLE = RequirementCache("polars") +_PYARROW_AVAILABLE = RequirementCache("pyarrow") + + +class BaseReader(ABC): + + def get_num_nodes(self) -> int: + return int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 1)) + + @abstractmethod + def to_workers_user_items(self, items: List[Any], num_workers: int) -> List[List[Any]]: + pass + + @abstractmethod + def read(self, item: Any) -> Any: + pass + + +@dataclass +class ParquetSlice: + filepath: str + start: int + end: int + + +class ParquetReader(BaseReader): + + def __init__(self, num_rows: Optional[int] = 2048, to_pandas: bool = True) -> None: + self.num_rows = num_rows + self.to_pandas = to_pandas + + def _get_num_rows(self, path: str) -> int: + # TODO: There is a bug in polars. This leads to read_parquet to hang. + # if _POLARS_AVAILABLE: + # import polars as pol + # df = pol.scan_parquet(path) + # num_rows = df.select(pol.len()).collect().item() + # return num_rows + + if _PYARROW_AVAILABLE: + import pyarrow.dataset as ds + df = ds.dataset(path).scanner() + return df.count_rows() + + raise RuntimeError("Please, install either pyarrow or polars.") + + def read(self, item: ParquetSlice) -> Any: + if _POLARS_AVAILABLE: + import polars as pol + df = pol.read_parquet(item.filepath, row_index_offset=item.start, n_rows=item.end - item.start, parallel="row_groups") + + if self.to_pandas: + df = df.to_pandas() + + return df + + if _PYARROW_AVAILABLE: + import pyarrow.dataset as ds + + df = ds.dataset(item.filepath).scanner() + + df = df.take([item.start, item.end]) + + if self.to_pandas: + df.to_pandas() + + return df + + raise RuntimeError("Please, install either pyarrow or polars.") + + + def to_workers_user_items(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]: + intervals = [(0, self._get_num_rows(item)) for item in items] + + world_size = self.get_num_nodes() * num_workers + + fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes()) + parquet_indexes_per_worker, parquet_slices_per_worker = _associate_chunks_and_internals_to_ranks(fake_distributed_env, list(range(len(items))), intervals, False) + + workers_user_items = [[] for _ in range(world_size)] + + for worker_idx, (parquet_indexes, parquet_slices) in enumerate(zip(parquet_indexes_per_worker, parquet_slices_per_worker)): + if self.num_rows: + workers_user_items[worker_idx].extend([ + ParquetSlice( + items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows + if parquet_slice[1] > (parquet_slice_start + self.num_rows) else + parquet_slice[1] + ) + for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) + for parquet_slice_start in range(parquet_slice[0], parquet_slice[1] + self.num_rows, self.num_rows) + if parquet_slice_start < parquet_slice[1] + ]) + else: + workers_user_items[worker_idx].extend([ + ParquetSlice(items[parquet_index], *parquet_slice) + for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) + ]) + + return workers_user_items diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index a7a49710ad45a..196294ce7dd07 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -34,6 +34,7 @@ from lightning.data.streaming.resolver import _resolve_dir from lightning.data.utilities.broadcast import broadcast_object from lightning.data.utilities.packing import _pack_greedily +from lightning.data.processing.readers import BaseReader if _TORCH_GREATER_EQUAL_2_1_0: from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads @@ -158,8 +159,9 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) elif os.path.isfile(path): - os.makedirs(os.path.dirname(local_path), exist_ok=True) - shutil.copyfile(path, local_path) + if not path.startswith("/teamspace/studios/this_studio"): + os.makedirs(os.path.dirname(local_path), exist_ok=True) + shutil.copyfile(path, local_path) else: raise ValueError(f"The provided {input_dir.url} isn't supported.") @@ -340,6 +342,7 @@ def __init__( num_downloaders: int, num_uploaders: int, remove: bool, + reader: Optional[BaseReader] = None, ) -> None: """The BaseWorker is responsible to process the user data.""" self.worker_index = worker_index @@ -353,6 +356,7 @@ def __init__( self.num_downloaders = num_downloaders self.num_uploaders = num_uploaders self.remove = remove + self.reader = reader self.paths: List[List[str]] = [] self.remover: Optional[Process] = None self.downloaders: List[Process] = [] @@ -392,7 +396,7 @@ def _loop(self) -> None: num_downloader_finished = 0 while True: - index = self.ready_to_process_queue.get() + index = self.ready_to_process_queue.get(timeout=1) if index is None: num_downloader_finished += 1 @@ -433,7 +437,7 @@ def _loop(self) -> None: self.progress_queue.put((self.worker_index, self._counter)) self._last_time = time() - if self.remove and self.input_dir.path is not None: + if self.remove and self.input_dir.path is not None and self.reader is None: self.remove_queue.put(self.paths[index]) try: @@ -476,7 +480,7 @@ def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None: self.to_upload_queues[self._counter % self.num_uploaders].put(data) def _collect_paths(self) -> None: - if self.input_dir.path is None: + if self.input_dir.path is None or self.reader is not None: for index in range(len(self.items)): self.ready_to_process_queue.put(index) for _ in range(self.num_downloaders): @@ -513,7 +517,7 @@ def is_path(element: Any) -> bool: paths = [] for index, path in indexed_paths.items(): paths.append(path) - if self.input_dir: + if self.input_dir and not self.input_dir.path.startswith("/teamspace/studios/this_studio"): path = path.replace(self.input_dir.path, self.cache_data_dir) flattened_item[index] = path @@ -525,8 +529,9 @@ def is_path(element: Any) -> bool: self.items = items def _start_downloaders(self) -> None: - if self.input_dir.path is None: + if self.input_dir.path is None or self.reader is not None: return + for _ in range(self.num_downloaders): to_download_queue: Queue = Queue() p = Process( @@ -609,7 +614,8 @@ def _handle_data_chunk_recipe_end(self) -> None: def _handle_data_transform_recipe(self, index: int) -> None: # Don't use a context manager to avoid deleting files that are being uploaded. output_dir = tempfile.mkdtemp() - item_data = self.data_recipe.prepare_item(self.items[index], str(output_dir), len(self.items) - 1 == index) + item = self.items[index] if self.reader is None else self.reader.read(self.items[index]) + item_data = self.data_recipe.prepare_item(item, str(output_dir), len(self.items) - 1 == index) if item_data is not None: raise ValueError( "When using a `DataTransformRecipe`, the `prepare_item` shouldn't return anything." @@ -792,6 +798,7 @@ def __init__( random_seed: Optional[int] = 42, reorder_files: bool = True, weights: Optional[List[int]] = None, + reader: Optional[BaseReader] = None, ): """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make training faster. @@ -825,6 +832,7 @@ def __init__( self.stop_queues: List[Queue] = [] self.reorder_files = reorder_files self.weights = weights + self.reader = reader # Ensure the input dir is the same across all nodes self.input_dir = broadcast_object("input_dir", self.input_dir) @@ -853,7 +861,10 @@ def run(self, data_recipe: DataRecipe) -> None: if not isinstance(user_items, list): raise ValueError("The `prepare_structure` should return a list of item metadata.") - if self.weights is not None: + if self.reader: + workers_user_items = self.reader.to_workers_user_items(user_items, self.num_workers) + + elif self.weights is not None: if len(self.weights) != len(user_items): raise ValueError("The provided weights length should match the inputs' length.") workers_user_items = _map_items_to_workers_weighted( @@ -880,7 +891,7 @@ def run(self, data_recipe: DataRecipe) -> None: self._cleanup_cache() - print(f"Starting {self.num_workers} workers") + print(f"Starting {self.num_workers} workers with {num_items} items.") if self.input_dir is None and self.src_resolver is not None and self.input_dir: self.input_dir = self.src_resolver(self.input_dir) @@ -988,6 +999,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L self.num_downloaders, self.num_uploaders, self.delete_cached_files, + self.reader, ) worker.start() workers.append(worker) diff --git a/src/lightning/data/streaming/functions.py b/src/lightning/data/streaming/functions.py index bd6616cca6e1f..98ee3bc92c4e8 100644 --- a/src/lightning/data/streaming/functions.py +++ b/src/lightning/data/streaming/functions.py @@ -31,6 +31,7 @@ _execute, _resolve_dir, ) +from lightning.data.processing.readers import BaseReader if _TORCH_GREATER_EQUAL_2_1_0: from torch.utils._pytree import tree_flatten @@ -157,6 +158,7 @@ def map( num_downloaders: Optional[int] = None, reorder_files: bool = True, error_when_not_empty: bool = False, + reader: Optional[BaseReader] = None, ) -> None: """This function map a callbable over a collection of files possibly in a distributed way. @@ -203,6 +205,7 @@ def map( num_downloaders=num_downloaders, reorder_files=reorder_files, weights=weights, + reader=reader, ) return data_processor.run(LambdaDataTransformRecipe(fn, inputs)) return _execute( @@ -225,6 +228,7 @@ def optimize( machine: Optional[str] = None, num_downloaders: Optional[int] = None, reorder_files: bool = True, + reader: Optional[BaseReader] = None, ) -> None: """This function converts a dataset into chunks possibly in a distributed way. @@ -274,6 +278,7 @@ def optimize( fast_dev_run=fast_dev_run, num_downloaders=num_downloaders, reorder_files=reorder_files, + reader=reader, ) return data_processor.run( LambdaDataChunkRecipe( From 77ed46ad9075ed5b716910f709f5c5e2a27dc44c Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Sat, 3 Feb 2024 14:49:12 +0000 Subject: [PATCH 02/14] update --- _notebooks | 1 - src/lightning/data/processing/image.py | 50 +++++++++++++++++++ src/lightning/data/processing/utilities.py | 32 ++++++++++++ .../data/streaming/data_processor.py | 4 +- 4 files changed, 84 insertions(+), 3 deletions(-) delete mode 160000 _notebooks create mode 100644 src/lightning/data/processing/image.py create mode 100644 src/lightning/data/processing/utilities.py diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 543a8d8200662..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 543a8d82006620906dc9eb669eab18d06ebe6863 diff --git a/src/lightning/data/processing/image.py b/src/lightning/data/processing/image.py new file mode 100644 index 0000000000000..5dcecf2da2561 --- /dev/null +++ b/src/lightning/data/processing/image.py @@ -0,0 +1,50 @@ +import urllib +import io + +def is_disallowed(headers, user_agent_token, disallowed_header_directives): + """Check if HTTP headers contain an X-Robots-Tag directive disallowing usage""" + for values in headers.get_all("X-Robots-Tag", []): + try: + uatoken_directives = values.split(":", 1) + directives = [x.strip().lower() for x in uatoken_directives[-1].split(",")] + ua_token = uatoken_directives[0].lower() if len(uatoken_directives) == 2 else None + if (ua_token is None or ua_token == user_agent_token) and any( + x in disallowed_header_directives for x in directives + ): + return True + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"Failed to parse X-Robots-Tag: {values}: {err}") + return False + + +def download_image(url, timeout = 10, user_agent_token="img2dataset", disallowed_header_directives = ["noai", "noimageai", "noindex", "noimageindex"]): + """Download an image with urllib""" + url + img_stream = None + user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0" + if user_agent_token: + user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/rom1504/img2dataset)" + try: + request = urllib.request.Request(url, data=None, headers={"User-Agent": user_agent_string}) + with urllib.request.urlopen(request, timeout=timeout) as r: + if disallowed_header_directives and is_disallowed( + r.headers, + user_agent_token, + disallowed_header_directives, + ): + return key, None, "Use of image disallowed by X-Robots-Tag directive" + img_stream = io.BytesIO(r.read()) + return img_stream, None + except Exception as err: # pylint: disable=broad-except + if img_stream is not None: + img_stream.close() + return None, str(err) + + +def download_image_with_retry(retries, url, timeout = 10, user_agent_token="img2dataset", disallowed_header_directives = []): + for _ in range(retries + 1): + img_stream, err = download_image(url, timeout, user_agent_token, disallowed_header_directives) + if img_stream is not None: + return img_stream, err + return None, err \ No newline at end of file diff --git a/src/lightning/data/processing/utilities.py b/src/lightning/data/processing/utilities.py new file mode 100644 index 0000000000000..e0e984f0e8160 --- /dev/null +++ b/src/lightning/data/processing/utilities.py @@ -0,0 +1,32 @@ +import os + + +class SuppressStdoutStderr: + """ + A context manager for doing a "deep suppression" of stdout and stderr in + Python, i.e. will suppress all print, even if the print originates in a + compiled C/Fortran sub-function. + This will not suppress raised exceptions, since exceptions are printed + to stderr just before a script exits, and after the context manager has + exited (at least, I think that is why it lets exceptions through). + + """ + + def __init__(self): + # Open a pair of null files + self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)] + # Save the actual stdout (1) and stderr (2) file descriptors. + self.save_fds = [os.dup(1), os.dup(2)] + + def __enter__(self): + # Assign the null pointers to stdout and stderr. + os.dup2(self.null_fds[0], 1) + os.dup2(self.null_fds[1], 2) + + def __exit__(self, *_): + # Re-assign the real stdout/stderr back to (1) and (2) + os.dup2(self.save_fds[0], 1) + os.dup2(self.save_fds[1], 2) + # Close all file descriptors + for fd in self.null_fds + self.save_fds: + os.close(fd) \ No newline at end of file diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index 196294ce7dd07..8ce43f6dcc6e1 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -588,7 +588,7 @@ def _start_uploaders(self) -> None: def _handle_data_chunk_recipe(self, index: int) -> None: try: - self._current_item = self.items[index] + self._current_item = self.items[index] if self.reader is None else self.reader.read(self.items[index]) item_data_or_generator = self.data_recipe.prepare_item(self._current_item) if isinstance(item_data_or_generator, types.GeneratorType): for item_data in item_data_or_generator: @@ -601,7 +601,7 @@ def _handle_data_chunk_recipe(self, index: int) -> None: self._try_upload(chunk_filepath) self._index_counter += 1 except Exception as e: - raise RuntimeError(f"Failed processing {self._current_item}") from e + raise RuntimeError(f"Failed processing {self.items[index]}") from e def _handle_data_chunk_recipe_end(self) -> None: chunks_filepaths = self.cache.done() From c7517cf904c7b340d39c9a537d5b327f278e52e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 3 Feb 2024 14:50:43 +0000 Subject: [PATCH 03/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/processing/image.py | 13 +++++++------ src/lightning/data/processing/readers.py | 15 ++++++++------- src/lightning/data/processing/utilities.py | 14 ++++++-------- src/lightning/data/streaming/data_processor.py | 2 +- src/lightning/data/streaming/functions.py | 2 +- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/lightning/data/processing/image.py b/src/lightning/data/processing/image.py index 5dcecf2da2561..d197cea17e585 100644 --- a/src/lightning/data/processing/image.py +++ b/src/lightning/data/processing/image.py @@ -1,8 +1,9 @@ -import urllib import io +import urllib + def is_disallowed(headers, user_agent_token, disallowed_header_directives): - """Check if HTTP headers contain an X-Robots-Tag directive disallowing usage""" + """Check if HTTP headers contain an X-Robots-Tag directive disallowing usage.""" for values in headers.get_all("X-Robots-Tag", []): try: uatoken_directives = values.split(":", 1) @@ -18,8 +19,8 @@ def is_disallowed(headers, user_agent_token, disallowed_header_directives): return False -def download_image(url, timeout = 10, user_agent_token="img2dataset", disallowed_header_directives = ["noai", "noimageai", "noindex", "noimageindex"]): - """Download an image with urllib""" +def download_image(url, timeout=10, user_agent_token="img2dataset", disallowed_header_directives=["noai", "noimageai", "noindex", "noimageindex"]): + """Download an image with urllib.""" url img_stream = None user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0" @@ -42,9 +43,9 @@ def download_image(url, timeout = 10, user_agent_token="img2dataset", disallowed return None, str(err) -def download_image_with_retry(retries, url, timeout = 10, user_agent_token="img2dataset", disallowed_header_directives = []): +def download_image_with_retry(retries, url, timeout=10, user_agent_token="img2dataset", disallowed_header_directives=[]): for _ in range(retries + 1): img_stream, err = download_image(url, timeout, user_agent_token, disallowed_header_directives) if img_stream is not None: return img_stream, err - return None, err \ No newline at end of file + return None, err diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py index eec1587ee9fb1..d275c98825b33 100644 --- a/src/lightning/data/processing/readers.py +++ b/src/lightning/data/processing/readers.py @@ -1,12 +1,13 @@ import os from abc import ABC, abstractmethod -from lightning_utilities.core.imports import RequirementCache -from typing import Any, List, Optional from dataclasses import dataclass +from typing import Any, List, Optional + +from lightning_utilities.core.imports import RequirementCache + from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks from lightning.data.utilities.env import _DistributedEnv - _POLARS_AVAILABLE = RequirementCache("polars") _PYARROW_AVAILABLE = RequirementCache("pyarrow") @@ -76,11 +77,11 @@ def read(self, item: ParquetSlice) -> Any: return df raise RuntimeError("Please, install either pyarrow or polars.") - + def to_workers_user_items(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]: intervals = [(0, self._get_num_rows(item)) for item in items] - + world_size = self.get_num_nodes() * num_workers fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes()) @@ -92,8 +93,8 @@ def to_workers_user_items(self, items: Any, num_workers: int) -> List[List[Parqu if self.num_rows: workers_user_items[worker_idx].extend([ ParquetSlice( - items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows - if parquet_slice[1] > (parquet_slice_start + self.num_rows) else + items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows + if parquet_slice[1] > (parquet_slice_start + self.num_rows) else parquet_slice[1] ) for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) diff --git a/src/lightning/data/processing/utilities.py b/src/lightning/data/processing/utilities.py index e0e984f0e8160..6649c3f058501 100644 --- a/src/lightning/data/processing/utilities.py +++ b/src/lightning/data/processing/utilities.py @@ -2,13 +2,11 @@ class SuppressStdoutStderr: - """ - A context manager for doing a "deep suppression" of stdout and stderr in - Python, i.e. will suppress all print, even if the print originates in a - compiled C/Fortran sub-function. - This will not suppress raised exceptions, since exceptions are printed - to stderr just before a script exits, and after the context manager has - exited (at least, I think that is why it lets exceptions through). + """A context manager for doing a "deep suppression" of stdout and stderr in Python, i.e. will suppress all print, + even if the print originates in a compiled C/Fortran sub-function. + + This will not suppress raised exceptions, since exceptions are printed to stderr just before a script exits, and + after the context manager has exited (at least, I think that is why it lets exceptions through). """ @@ -29,4 +27,4 @@ def __exit__(self, *_): os.dup2(self.save_fds[1], 2) # Close all file descriptors for fd in self.null_fds + self.save_fds: - os.close(fd) \ No newline at end of file + os.close(fd) diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index 8ce43f6dcc6e1..6653d16328490 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -20,6 +20,7 @@ from tqdm.auto import tqdm as _tqdm from lightning import seed_everything +from lightning.data.processing.readers import BaseReader from lightning.data.streaming import Cache from lightning.data.streaming.cache import Dir from lightning.data.streaming.client import S3Client @@ -34,7 +35,6 @@ from lightning.data.streaming.resolver import _resolve_dir from lightning.data.utilities.broadcast import broadcast_object from lightning.data.utilities.packing import _pack_greedily -from lightning.data.processing.readers import BaseReader if _TORCH_GREATER_EQUAL_2_1_0: from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads diff --git a/src/lightning/data/streaming/functions.py b/src/lightning/data/streaming/functions.py index 98ee3bc92c4e8..bbdcdbc2ca764 100644 --- a/src/lightning/data/streaming/functions.py +++ b/src/lightning/data/streaming/functions.py @@ -22,6 +22,7 @@ import torch +from lightning.data.processing.readers import BaseReader from lightning.data.streaming.constants import _TORCH_GREATER_EQUAL_2_1_0 from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from lightning.data.streaming.resolver import ( @@ -31,7 +32,6 @@ _execute, _resolve_dir, ) -from lightning.data.processing.readers import BaseReader if _TORCH_GREATER_EQUAL_2_1_0: from torch.utils._pytree import tree_flatten From 90ff250efe0eb8d158ff953fa56b0bbe503d27f4 Mon Sep 17 00:00:00 2001 From: thomas Date: Sat, 3 Feb 2024 18:24:51 +0000 Subject: [PATCH 04/14] update --- _notebooks | 1 + src/lightning/data/processing/image.py | 37 +++++++++++----- src/lightning/data/processing/readers.py | 50 ++++++++++++++-------- src/lightning/data/processing/utilities.py | 32 -------------- 4 files changed, 60 insertions(+), 60 deletions(-) create mode 160000 _notebooks delete mode 100644 src/lightning/data/processing/utilities.py diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..543a8d8200662 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 543a8d82006620906dc9eb669eab18d06ebe6863 diff --git a/src/lightning/data/processing/image.py b/src/lightning/data/processing/image.py index 5dcecf2da2561..ae35dba0be74a 100644 --- a/src/lightning/data/processing/image.py +++ b/src/lightning/data/processing/image.py @@ -1,8 +1,12 @@ -import urllib import io +import traceback +import urllib + +# Credit to the https://github.com/rom1504/pytorch Github repo +# The code was taken from there. def is_disallowed(headers, user_agent_token, disallowed_header_directives): - """Check if HTTP headers contain an X-Robots-Tag directive disallowing usage""" + """Check if HTTP headers contain an X-Robots-Tag directive disallowing usage.""" for values in headers.get_all("X-Robots-Tag", []): try: uatoken_directives = values.split(":", 1) @@ -18,22 +22,27 @@ def is_disallowed(headers, user_agent_token, disallowed_header_directives): return False -def download_image(url, timeout = 10, user_agent_token="img2dataset", disallowed_header_directives = ["noai", "noimageai", "noindex", "noimageindex"]): - """Download an image with urllib""" +def _download_image( + url, + timeout=10, + user_agent_token="pytorch-lightning", + disallowed_header_directives=["noai", "noimageai", "noindex", "noimageindex"] +): + """Download an image with urllib.""" url img_stream = None user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0" if user_agent_token: - user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/rom1504/img2dataset)" + user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/pytorch-lightning)" try: - request = urllib.request.Request(url, data=None, headers={"User-Agent": user_agent_string}) - with urllib.request.urlopen(request, timeout=timeout) as r: + request = urllib.request.Request(url, data=None, headers={"User-Agent": user_agent_string}) # noqa S310 + with urllib.request.urlopen(request, timeout=timeout) as r: # noqa S310 if disallowed_header_directives and is_disallowed( r.headers, user_agent_token, disallowed_header_directives, ): - return key, None, "Use of image disallowed by X-Robots-Tag directive" + return None, "Use of image disallowed by X-Robots-Tag directive" img_stream = io.BytesIO(r.read()) return img_stream, None except Exception as err: # pylint: disable=broad-except @@ -42,9 +51,15 @@ def download_image(url, timeout = 10, user_agent_token="img2dataset", disallowed return None, str(err) -def download_image_with_retry(retries, url, timeout = 10, user_agent_token="img2dataset", disallowed_header_directives = []): +def download_image( + url, + retries=0, + timeout=10, + user_agent_token="pytorch-lightning", + disallowed_header_directives=[] +): for _ in range(retries + 1): - img_stream, err = download_image(url, timeout, user_agent_token, disallowed_header_directives) + img_stream, err = _download_image(url, timeout, user_agent_token, disallowed_header_directives) if img_stream is not None: return img_stream, err - return None, err \ No newline at end of file + return None, err diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py index eec1587ee9fb1..d1e0913c39fef 100644 --- a/src/lightning/data/processing/readers.py +++ b/src/lightning/data/processing/readers.py @@ -1,12 +1,13 @@ import os from abc import ABC, abstractmethod -from lightning_utilities.core.imports import RequirementCache -from typing import Any, List, Optional from dataclasses import dataclass +from typing import Any, List, Optional + +from lightning_utilities.core.imports import RequirementCache + from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks from lightning.data.utilities.env import _DistributedEnv - _POLARS_AVAILABLE = RequirementCache("polars") _PYARROW_AVAILABLE = RequirementCache("pyarrow") @@ -18,15 +19,19 @@ def get_num_nodes(self) -> int: @abstractmethod def to_workers_user_items(self, items: List[Any], num_workers: int) -> List[List[Any]]: + """This method is meant to convert the items provided by the users into items to be processed by the + workers.""" pass @abstractmethod def read(self, item: Any) -> Any: + """Read the data associated to an item.""" pass @dataclass class ParquetSlice: + """Keep track of a parquet file slice with its filepath, start and end.""" filepath: str start: int end: int @@ -38,25 +43,33 @@ def __init__(self, num_rows: Optional[int] = 2048, to_pandas: bool = True) -> No self.num_rows = num_rows self.to_pandas = to_pandas - def _get_num_rows(self, path: str) -> int: - # TODO: There is a bug in polars. This leads to read_parquet to hang. - # if _POLARS_AVAILABLE: - # import polars as pol - # df = pol.scan_parquet(path) - # num_rows = df.select(pol.len()).collect().item() - # return num_rows + if not _PYARROW_AVAILABLE or not _POLARS_AVAILABLE: + raise ModuleNotFoundError("Please, run: `pip install pyarrow polars`") + def _get_num_rows(self, path: str) -> int: if _PYARROW_AVAILABLE: import pyarrow.dataset as ds df = ds.dataset(path).scanner() return df.count_rows() + # FIXED: There is a bug in polars. This leads to read_parquet to hang. + if _POLARS_AVAILABLE: + import polars as pol + df = pol.scan_parquet(path) + num_rows = df.select(pol.len()).collect().item() + return num_rows + raise RuntimeError("Please, install either pyarrow or polars.") def read(self, item: ParquetSlice) -> Any: if _POLARS_AVAILABLE: import polars as pol - df = pol.read_parquet(item.filepath, row_index_offset=item.start, n_rows=item.end - item.start, parallel="row_groups") + df = pol.read_parquet( + item.filepath, + row_index_offset=item.start, + n_rows=item.end - item.start, + parallel="row_groups" + ) if self.to_pandas: df = df.to_pandas() @@ -76,24 +89,27 @@ def read(self, item: ParquetSlice) -> Any: return df raise RuntimeError("Please, install either pyarrow or polars.") - + def to_workers_user_items(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]: intervals = [(0, self._get_num_rows(item)) for item in items] - + world_size = self.get_num_nodes() * num_workers fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes()) - parquet_indexes_per_worker, parquet_slices_per_worker = _associate_chunks_and_internals_to_ranks(fake_distributed_env, list(range(len(items))), intervals, False) + parquet_indexes_per_worker, parquet_slices_per_worker = _associate_chunks_and_internals_to_ranks( + fake_distributed_env, list(range(len(items))), intervals, False) workers_user_items = [[] for _ in range(world_size)] - for worker_idx, (parquet_indexes, parquet_slices) in enumerate(zip(parquet_indexes_per_worker, parquet_slices_per_worker)): + iterator = enumerate(zip(parquet_indexes_per_worker, parquet_slices_per_worker)) + + for worker_idx, (parquet_indexes, parquet_slices) in iterator: if self.num_rows: workers_user_items[worker_idx].extend([ ParquetSlice( - items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows - if parquet_slice[1] > (parquet_slice_start + self.num_rows) else + items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows + if parquet_slice[1] > (parquet_slice_start + self.num_rows) else parquet_slice[1] ) for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) diff --git a/src/lightning/data/processing/utilities.py b/src/lightning/data/processing/utilities.py deleted file mode 100644 index e0e984f0e8160..0000000000000 --- a/src/lightning/data/processing/utilities.py +++ /dev/null @@ -1,32 +0,0 @@ -import os - - -class SuppressStdoutStderr: - """ - A context manager for doing a "deep suppression" of stdout and stderr in - Python, i.e. will suppress all print, even if the print originates in a - compiled C/Fortran sub-function. - This will not suppress raised exceptions, since exceptions are printed - to stderr just before a script exits, and after the context manager has - exited (at least, I think that is why it lets exceptions through). - - """ - - def __init__(self): - # Open a pair of null files - self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)] - # Save the actual stdout (1) and stderr (2) file descriptors. - self.save_fds = [os.dup(1), os.dup(2)] - - def __enter__(self): - # Assign the null pointers to stdout and stderr. - os.dup2(self.null_fds[0], 1) - os.dup2(self.null_fds[1], 2) - - def __exit__(self, *_): - # Re-assign the real stdout/stderr back to (1) and (2) - os.dup2(self.save_fds[0], 1) - os.dup2(self.save_fds[1], 2) - # Close all file descriptors - for fd in self.null_fds + self.save_fds: - os.close(fd) \ No newline at end of file From 895c6c36f2f9edbc5c2d365b42e7b42e0c9d746b Mon Sep 17 00:00:00 2001 From: thomas Date: Sat, 3 Feb 2024 18:26:06 +0000 Subject: [PATCH 05/14] update --- src/lightning/data/processing/utilities.py | 30 ---------------------- 1 file changed, 30 deletions(-) delete mode 100644 src/lightning/data/processing/utilities.py diff --git a/src/lightning/data/processing/utilities.py b/src/lightning/data/processing/utilities.py deleted file mode 100644 index 6649c3f058501..0000000000000 --- a/src/lightning/data/processing/utilities.py +++ /dev/null @@ -1,30 +0,0 @@ -import os - - -class SuppressStdoutStderr: - """A context manager for doing a "deep suppression" of stdout and stderr in Python, i.e. will suppress all print, - even if the print originates in a compiled C/Fortran sub-function. - - This will not suppress raised exceptions, since exceptions are printed to stderr just before a script exits, and - after the context manager has exited (at least, I think that is why it lets exceptions through). - - """ - - def __init__(self): - # Open a pair of null files - self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)] - # Save the actual stdout (1) and stderr (2) file descriptors. - self.save_fds = [os.dup(1), os.dup(2)] - - def __enter__(self): - # Assign the null pointers to stdout and stderr. - os.dup2(self.null_fds[0], 1) - os.dup2(self.null_fds[1], 2) - - def __exit__(self, *_): - # Re-assign the real stdout/stderr back to (1) and (2) - os.dup2(self.save_fds[0], 1) - os.dup2(self.save_fds[1], 2) - # Close all file descriptors - for fd in self.null_fds + self.save_fds: - os.close(fd) From f7ef40e4dbe1ac8a5ebbb263be912d008f2cf342 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 5 Feb 2024 14:01:43 +0000 Subject: [PATCH 06/14] update --- src/lightning/data/processing/image.py | 62 +++++++++--------------- src/lightning/data/processing/readers.py | 2 +- 2 files changed, 23 insertions(+), 41 deletions(-) diff --git a/src/lightning/data/processing/image.py b/src/lightning/data/processing/image.py index ae35dba0be74a..904c9568cc602 100644 --- a/src/lightning/data/processing/image.py +++ b/src/lightning/data/processing/image.py @@ -1,65 +1,47 @@ import io -import traceback -import urllib +from typing import Optional, Tuple -# Credit to the https://github.com/rom1504/pytorch Github repo -# The code was taken from there. +from lightning_utilities.core.imports import RequirementCache -def is_disallowed(headers, user_agent_token, disallowed_header_directives): - """Check if HTTP headers contain an X-Robots-Tag directive disallowing usage.""" - for values in headers.get_all("X-Robots-Tag", []): - try: - uatoken_directives = values.split(":", 1) - directives = [x.strip().lower() for x in uatoken_directives[-1].split(",")] - ua_token = uatoken_directives[0].lower() if len(uatoken_directives) == 2 else None - if (ua_token is None or ua_token == user_agent_token) and any( - x in disallowed_header_directives for x in directives - ): - return True - except Exception as err: # pylint: disable=broad-except - traceback.print_exc() - print(f"Failed to parse X-Robots-Tag: {values}: {err}") - return False +_HTTPX_AVAILABLE = RequirementCache("httpx") +# Credit to the https://github.com/rom1504/pytorch Github repo +# The code was taken from there. def _download_image( - url, - timeout=10, - user_agent_token="pytorch-lightning", - disallowed_header_directives=["noai", "noimageai", "noindex", "noimageindex"] -): + url: str, + timeout: int = 10, + user_agent_token: str = "pytorch-lightning", +) -> Tuple[Optional[io.BytesIO], Optional[Exception]]: """Download an image with urllib.""" url img_stream = None user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0" if user_agent_token: user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/pytorch-lightning)" + import httpx + try: - request = urllib.request.Request(url, data=None, headers={"User-Agent": user_agent_string}) # noqa S310 - with urllib.request.urlopen(request, timeout=timeout) as r: # noqa S310 - if disallowed_header_directives and is_disallowed( - r.headers, - user_agent_token, - disallowed_header_directives, - ): - return None, "Use of image disallowed by X-Robots-Tag directive" + with httpx.Client(http2=True) as client: + r = client.get(url, headers={"User-Agent": user_agent_string}, timeout=timeout) img_stream = io.BytesIO(r.read()) return img_stream, None except Exception as err: # pylint: disable=broad-except if img_stream is not None: img_stream.close() - return None, str(err) + return None, err def download_image( - url, - retries=0, - timeout=10, - user_agent_token="pytorch-lightning", - disallowed_header_directives=[] -): + url: str, + retries: int = 0, + timeout: int = 10, + user_agent_token: str = "pytorch-lightning", +) -> Tuple[Optional[io.BytesIO], Optional[Exception]]: + if not _HTTPX_AVAILABLE: + raise ModuleNotFoundError("Please, run: `pip install httpx`.") for _ in range(retries + 1): - img_stream, err = _download_image(url, timeout, user_agent_token, disallowed_header_directives) + img_stream, err = _download_image(url, timeout, user_agent_token) if img_stream is not None: return img_stream, err return None, err diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py index d1e0913c39fef..4e9954acbfd25 100644 --- a/src/lightning/data/processing/readers.py +++ b/src/lightning/data/processing/readers.py @@ -100,7 +100,7 @@ def to_workers_user_items(self, items: Any, num_workers: int) -> List[List[Parqu parquet_indexes_per_worker, parquet_slices_per_worker = _associate_chunks_and_internals_to_ranks( fake_distributed_env, list(range(len(items))), intervals, False) - workers_user_items = [[] for _ in range(world_size)] + workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(world_size)] iterator = enumerate(zip(parquet_indexes_per_worker, parquet_slices_per_worker)) From 7c0829c52f640efd2a0f5e45fa3a92ec951bc6b7 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 5 Feb 2024 17:50:46 +0000 Subject: [PATCH 07/14] update --- requirements/data/test.txt | 2 + src/lightning/data/processing/readers.py | 7 +-- .../data/streaming/data_processor.py | 4 ++ tests/tests_data/processing/__init__.py | 0 tests/tests_data/processing/test_readers.py | 55 +++++++++++++++++++ 5 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 tests/tests_data/processing/__init__.py create mode 100644 tests/tests_data/processing/test_readers.py diff --git a/requirements/data/test.txt b/requirements/data/test.txt index 38439e2d6705a..901e2b5d8c161 100644 --- a/requirements/data/test.txt +++ b/requirements/data/test.txt @@ -5,3 +5,5 @@ pytest-timeout ==2.1.0 pytest-rerunfailures ==12.0 pytest-random-order ==1.1.0 viztracer +pyarrow +polars diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py index 4e9954acbfd25..7708c63e0907b 100644 --- a/src/lightning/data/processing/readers.py +++ b/src/lightning/data/processing/readers.py @@ -64,12 +64,7 @@ def _get_num_rows(self, path: str) -> int: def read(self, item: ParquetSlice) -> Any: if _POLARS_AVAILABLE: import polars as pol - df = pol.read_parquet( - item.filepath, - row_index_offset=item.start, - n_rows=item.end - item.start, - parallel="row_groups" - ) + df = pol.scan_parquet(item.filepath).slice(item.start, item.end).collect() if self.to_pandas: df = df.to_pandas() diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index 6653d16328490..6179aec69e70c 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -816,6 +816,7 @@ def __init__( Set this to ``False`` if the order in which samples are processed should be preserved. weights: Provide a list of weights associated to the inputs. This is used to evenly split the work among the workers. + reader: Map the inputs to worker inputs and provides a read method to read a slice of the data. """ self.input_dir = _resolve_dir(input_dir) @@ -834,6 +835,9 @@ def __init__( self.weights = weights self.reader = reader + if self.reader is not None and self.weights is not None: + raise ValueError("Either the reader or the weights needs to be defined.") + # Ensure the input dir is the same across all nodes self.input_dir = broadcast_object("input_dir", self.input_dir) diff --git a/tests/tests_data/processing/__init__.py b/tests/tests_data/processing/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_data/processing/test_readers.py b/tests/tests_data/processing/test_readers.py new file mode 100644 index 0000000000000..32f7b2ad0c3c0 --- /dev/null +++ b/tests/tests_data/processing/test_readers.py @@ -0,0 +1,55 @@ +import os + +import pytest +from lightning.data import map +from lightning.data.processing.readers import _POLARS_AVAILABLE, _PYARROW_AVAILABLE, BaseReader, ParquetReader + + +class DummyReader(BaseReader): + + def to_workers_user_items(self, items, num_workers: int): + return [[(worker_idx, idx, item) for idx, item in enumerate(items)] for worker_idx in range(num_workers)] + + def read(self, item): + return item + + +def fn(data: str, output_dir): + worker_idx, idx, _ = data + + with open(os.path.join(output_dir, f"{worker_idx}_{idx}"), "w") as f: + f.write("hello world") + + +def test_reader(tmpdir): + map(fn, list(range(3)), output_dir=str(tmpdir), reader=DummyReader(), num_workers=2) + assert sorted(os.listdir(tmpdir)) == ['0_0', '0_1', '0_2', '1_0', '1_1', '1_2'] + + +def map_parquet(df, output_dir): + filename = f"{df.row(0)[0]}_{len(df)}" + + with open(os.path.join(output_dir, filename), "w") as f: + f.write("hello world") + +@pytest.mark.skipif(not _POLARS_AVAILABLE and not _PYARROW_AVAILABLE, reason="polars and pyarrow are required") +def test_parquet_reader(tmpdir): + import polars as pol + + inputs = [] + + for i in range(3): + parquet_path = os.path.join(tmpdir, f"{i}.parquet") + df = pol.DataFrame(list(range(i * 10, (i + 1) * 10))) + df.write_parquet(parquet_path) + inputs.append(parquet_path) + + map( + map_parquet, + inputs=inputs, + output_dir=os.path.join(tmpdir, "output_dir"), + reader=ParquetReader(num_rows=10, to_pandas=False), + num_workers=2 + ) + + assert sorted(os.listdir(os.path.join(tmpdir, "output_dir"))) == ['0_10', '10_5', '15_5', '20_10'] From 25ad0c9a369320688095c4e2e51c5c2bf4d46016 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 5 Feb 2024 19:02:55 +0000 Subject: [PATCH 08/14] update --- src/lightning/data/streaming/data_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index 6179aec69e70c..abf9356805ec4 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -396,7 +396,7 @@ def _loop(self) -> None: num_downloader_finished = 0 while True: - index = self.ready_to_process_queue.get(timeout=1) + index = self.ready_to_process_queue.get() if index is None: num_downloader_finished += 1 From 9851e3d20003df1e1aa34a32e7e5b124d509ccef Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 5 Feb 2024 19:38:38 +0000 Subject: [PATCH 09/14] update --- tests/tests_data/processing/test_readers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/tests_data/processing/test_readers.py b/tests/tests_data/processing/test_readers.py index 32f7b2ad0c3c0..6c611aef8d480 100644 --- a/tests/tests_data/processing/test_readers.py +++ b/tests/tests_data/processing/test_readers.py @@ -1,4 +1,5 @@ import os +import sys import pytest from lightning.data import map @@ -32,7 +33,10 @@ def map_parquet(df, output_dir): with open(os.path.join(output_dir, filename), "w") as f: f.write("hello world") -@pytest.mark.skipif(not _POLARS_AVAILABLE and not _PYARROW_AVAILABLE, reason="polars and pyarrow are required") +@pytest.mark.skipif( + not _POLARS_AVAILABLE and not _PYARROW_AVAILABLE and sys.platform == "linux", + reason="polars and pyarrow are required" +) def test_parquet_reader(tmpdir): import polars as pol From 9b0eaed1f61728e9b81f45b69f6d03633a10d4f6 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 5 Feb 2024 20:19:02 +0000 Subject: [PATCH 10/14] update --- tests/tests_data/processing/test_readers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_data/processing/test_readers.py b/tests/tests_data/processing/test_readers.py index 6c611aef8d480..3ef3fd13c2863 100644 --- a/tests/tests_data/processing/test_readers.py +++ b/tests/tests_data/processing/test_readers.py @@ -34,7 +34,7 @@ def map_parquet(df, output_dir): f.write("hello world") @pytest.mark.skipif( - not _POLARS_AVAILABLE and not _PYARROW_AVAILABLE and sys.platform == "linux", + (not _POLARS_AVAILABLE and not _PYARROW_AVAILABLE) or sys.platform == "linux", reason="polars and pyarrow are required" ) def test_parquet_reader(tmpdir): From 39dae841ddcc84600318016b5c27a6d5adcd7c89 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 5 Feb 2024 21:01:58 +0000 Subject: [PATCH 11/14] update --- src/lightning/data/processing/image.py | 4 ++-- src/lightning/data/processing/readers.py | 4 ++-- src/lightning/data/streaming/data_processor.py | 2 +- tests/tests_data/processing/test_readers.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lightning/data/processing/image.py b/src/lightning/data/processing/image.py index 904c9568cc602..0bf110b40a37e 100644 --- a/src/lightning/data/processing/image.py +++ b/src/lightning/data/processing/image.py @@ -5,8 +5,8 @@ _HTTPX_AVAILABLE = RequirementCache("httpx") -# Credit to the https://github.com/rom1504/pytorch Github repo -# The code was taken from there. +# Credit to the https://github.com/rom1504/img2dataset Github repo +# The code was taken from there. It has a MIT License. def _download_image( url: str, diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py index 7708c63e0907b..c979d39d3839d 100644 --- a/src/lightning/data/processing/readers.py +++ b/src/lightning/data/processing/readers.py @@ -18,7 +18,7 @@ def get_num_nodes(self) -> int: return int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 1)) @abstractmethod - def to_workers_user_items(self, items: List[Any], num_workers: int) -> List[List[Any]]: + def items_to_workers(self, items: List[Any], num_workers: int) -> List[List[Any]]: """This method is meant to convert the items provided by the users into items to be processed by the workers.""" pass @@ -86,7 +86,7 @@ def read(self, item: ParquetSlice) -> Any: raise RuntimeError("Please, install either pyarrow or polars.") - def to_workers_user_items(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]: + def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]: intervals = [(0, self._get_num_rows(item)) for item in items] world_size = self.get_num_nodes() * num_workers diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index abf9356805ec4..3fdebf29b5448 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -866,7 +866,7 @@ def run(self, data_recipe: DataRecipe) -> None: raise ValueError("The `prepare_structure` should return a list of item metadata.") if self.reader: - workers_user_items = self.reader.to_workers_user_items(user_items, self.num_workers) + workers_user_items = self.reader.items_to_workers(user_items, self.num_workers) elif self.weights is not None: if len(self.weights) != len(user_items): diff --git a/tests/tests_data/processing/test_readers.py b/tests/tests_data/processing/test_readers.py index 3ef3fd13c2863..31e0b9ec62109 100644 --- a/tests/tests_data/processing/test_readers.py +++ b/tests/tests_data/processing/test_readers.py @@ -8,7 +8,7 @@ class DummyReader(BaseReader): - def to_workers_user_items(self, items, num_workers: int): + def items_to_workers(self, items, num_workers: int): return [[(worker_idx, idx, item) for idx, item in enumerate(items)] for worker_idx in range(num_workers)] def read(self, item): From ae015d20bb8533f2719f888ae640816c9260cb11 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 5 Feb 2024 21:20:13 +0000 Subject: [PATCH 12/14] update --- src/lightning/data/processing/readers.py | 45 +++++++++++++++--------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py index c979d39d3839d..a7d1f75e2e1f3 100644 --- a/src/lightning/data/processing/readers.py +++ b/src/lightning/data/processing/readers.py @@ -17,6 +17,9 @@ class BaseReader(ABC): def get_num_nodes(self) -> int: return int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 1)) + def get_node_rank(self) -> int: + return int(os.getenv("DATA_OPTIMIZER_NODE_RANK", 0)) + @abstractmethod def items_to_workers(self, items: List[Any], num_workers: int) -> List[List[Any]]: """This method is meant to convert the items provided by the users into items to be processed by the @@ -90,31 +93,39 @@ def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSli intervals = [(0, self._get_num_rows(item)) for item in items] world_size = self.get_num_nodes() * num_workers + node_rank = self.get_node_rank() fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes()) parquet_indexes_per_worker, parquet_slices_per_worker = _associate_chunks_and_internals_to_ranks( fake_distributed_env, list(range(len(items))), intervals, False) - workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(world_size)] + workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(num_workers)] iterator = enumerate(zip(parquet_indexes_per_worker, parquet_slices_per_worker)) + node_start = node_rank * num_workers + node_end = (node_rank + 1) * num_workers + for worker_idx, (parquet_indexes, parquet_slices) in iterator: - if self.num_rows: - workers_user_items[worker_idx].extend([ - ParquetSlice( - items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows - if parquet_slice[1] > (parquet_slice_start + self.num_rows) else - parquet_slice[1] - ) - for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) - for parquet_slice_start in range(parquet_slice[0], parquet_slice[1] + self.num_rows, self.num_rows) - if parquet_slice_start < parquet_slice[1] - ]) - else: - workers_user_items[worker_idx].extend([ - ParquetSlice(items[parquet_index], *parquet_slice) - for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) - ]) + if node_start <= worker_idx < node_end: + if self.num_rows: + workers_user_items[worker_idx % num_workers].extend([ + ParquetSlice( + items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows + if parquet_slice[1] > (parquet_slice_start + self.num_rows) else + parquet_slice[1] + ) + for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) + for parquet_slice_start in range(parquet_slice[0], parquet_slice[1] + self.num_rows, self.num_rows) + if parquet_slice_start < parquet_slice[1] + ]) + else: + workers_user_items[worker_idx % num_workers].extend([ + ParquetSlice(items[parquet_index], *parquet_slice) + for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) + ]) + + assert len(workers_user_items) == num_workers + assert all(len(w) for w in workers_user_items) return workers_user_items From 8fc93da36d151a00074717a91e621204eeee956b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 21:21:06 +0000 Subject: [PATCH 13/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/processing/readers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py index a7d1f75e2e1f3..da6c04be90d59 100644 --- a/src/lightning/data/processing/readers.py +++ b/src/lightning/data/processing/readers.py @@ -125,7 +125,7 @@ def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSli for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) ]) - assert len(workers_user_items) == num_workers + assert len(workers_user_items) == num_workers assert all(len(w) for w in workers_user_items) return workers_user_items From 0b3c09cc0e48b400e046fb641e3f8831fef4c49b Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 5 Feb 2024 21:25:58 +0000 Subject: [PATCH 14/14] update --- src/lightning/data/processing/readers.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py index da6c04be90d59..d9a5b293bdba0 100644 --- a/src/lightning/data/processing/readers.py +++ b/src/lightning/data/processing/readers.py @@ -96,33 +96,33 @@ def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSli node_rank = self.get_node_rank() fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes()) - parquet_indexes_per_worker, parquet_slices_per_worker = _associate_chunks_and_internals_to_ranks( + parquet_indexes_per_worker, p_slices_per_worker = _associate_chunks_and_internals_to_ranks( fake_distributed_env, list(range(len(items))), intervals, False) workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(num_workers)] - iterator = enumerate(zip(parquet_indexes_per_worker, parquet_slices_per_worker)) + iterator = enumerate(zip(parquet_indexes_per_worker, p_slices_per_worker)) node_start = node_rank * num_workers node_end = (node_rank + 1) * num_workers - for worker_idx, (parquet_indexes, parquet_slices) in iterator: + for worker_idx, (parquet_indexes, p_slices) in iterator: if node_start <= worker_idx < node_end: if self.num_rows: workers_user_items[worker_idx % num_workers].extend([ ParquetSlice( - items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows - if parquet_slice[1] > (parquet_slice_start + self.num_rows) else - parquet_slice[1] + items[parquet_index], p_slice_start, p_slice_start + self.num_rows + if p_slice[1] > (p_slice_start + self.num_rows) else + p_slice[1] ) - for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) - for parquet_slice_start in range(parquet_slice[0], parquet_slice[1] + self.num_rows, self.num_rows) - if parquet_slice_start < parquet_slice[1] + for parquet_index, p_slice in zip(parquet_indexes, p_slices) + for p_slice_start in range(p_slice[0], p_slice[1] + self.num_rows, self.num_rows) + if p_slice_start < p_slice[1] ]) else: workers_user_items[worker_idx % num_workers].extend([ - ParquetSlice(items[parquet_index], *parquet_slice) - for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices) + ParquetSlice(items[parquet_index], *p_slice) + for parquet_index, p_slice in zip(parquet_indexes, p_slices) ]) assert len(workers_user_items) == num_workers