diff --git a/requirements/data/test.txt b/requirements/data/test.txt index 38439e2d6705a..901e2b5d8c161 100644 --- a/requirements/data/test.txt +++ b/requirements/data/test.txt @@ -5,3 +5,5 @@ pytest-timeout ==2.1.0 pytest-rerunfailures ==12.0 pytest-random-order ==1.1.0 viztracer +pyarrow +polars diff --git a/src/lightning/data/processing/__init__.py b/src/lightning/data/processing/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/lightning/data/processing/image.py b/src/lightning/data/processing/image.py new file mode 100644 index 0000000000000..0bf110b40a37e --- /dev/null +++ b/src/lightning/data/processing/image.py @@ -0,0 +1,47 @@ +import io +from typing import Optional, Tuple + +from lightning_utilities.core.imports import RequirementCache + +_HTTPX_AVAILABLE = RequirementCache("httpx") + +# Credit to the https://github.com/rom1504/img2dataset Github repo +# The code was taken from there. It has a MIT License. + +def _download_image( + url: str, + timeout: int = 10, + user_agent_token: str = "pytorch-lightning", +) -> Tuple[Optional[io.BytesIO], Optional[Exception]]: + """Download an image with urllib.""" + url + img_stream = None + user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0" + if user_agent_token: + user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/pytorch-lightning)" + import httpx + + try: + with httpx.Client(http2=True) as client: + r = client.get(url, headers={"User-Agent": user_agent_string}, timeout=timeout) + img_stream = io.BytesIO(r.read()) + return img_stream, None + except Exception as err: # pylint: disable=broad-except + if img_stream is not None: + img_stream.close() + return None, err + + +def download_image( + url: str, + retries: int = 0, + timeout: int = 10, + user_agent_token: str = "pytorch-lightning", +) -> Tuple[Optional[io.BytesIO], Optional[Exception]]: + if not _HTTPX_AVAILABLE: + raise ModuleNotFoundError("Please, run: `pip install httpx`.") + for _ in range(retries + 1): + img_stream, err = _download_image(url, timeout, user_agent_token) + if img_stream is not None: + return img_stream, err + return None, err diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py new file mode 100644 index 0000000000000..d9a5b293bdba0 --- /dev/null +++ b/src/lightning/data/processing/readers.py @@ -0,0 +1,131 @@ +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, List, Optional + +from lightning_utilities.core.imports import RequirementCache + +from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks +from lightning.data.utilities.env import _DistributedEnv + +_POLARS_AVAILABLE = RequirementCache("polars") +_PYARROW_AVAILABLE = RequirementCache("pyarrow") + + +class BaseReader(ABC): + + def get_num_nodes(self) -> int: + return int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 1)) + + def get_node_rank(self) -> int: + return int(os.getenv("DATA_OPTIMIZER_NODE_RANK", 0)) + + @abstractmethod + def items_to_workers(self, items: List[Any], num_workers: int) -> List[List[Any]]: + """This method is meant to convert the items provided by the users into items to be processed by the + workers.""" + pass + + @abstractmethod + def read(self, item: Any) -> Any: + """Read the data associated to an item.""" + pass + + +@dataclass +class ParquetSlice: + """Keep track of a parquet file slice with its filepath, start and end.""" + filepath: str + start: int + end: int + + +class ParquetReader(BaseReader): + + def __init__(self, num_rows: Optional[int] = 2048, to_pandas: bool = True) -> None: + self.num_rows = num_rows + self.to_pandas = to_pandas + + if not _PYARROW_AVAILABLE or not _POLARS_AVAILABLE: + raise ModuleNotFoundError("Please, run: `pip install pyarrow polars`") + + def _get_num_rows(self, path: str) -> int: + if _PYARROW_AVAILABLE: + import pyarrow.dataset as ds + df = ds.dataset(path).scanner() + return df.count_rows() + + # FIXED: There is a bug in polars. This leads to read_parquet to hang. + if _POLARS_AVAILABLE: + import polars as pol + df = pol.scan_parquet(path) + num_rows = df.select(pol.len()).collect().item() + return num_rows + + raise RuntimeError("Please, install either pyarrow or polars.") + + def read(self, item: ParquetSlice) -> Any: + if _POLARS_AVAILABLE: + import polars as pol + df = pol.scan_parquet(item.filepath).slice(item.start, item.end).collect() + + if self.to_pandas: + df = df.to_pandas() + + return df + + if _PYARROW_AVAILABLE: + import pyarrow.dataset as ds + + df = ds.dataset(item.filepath).scanner() + + df = df.take([item.start, item.end]) + + if self.to_pandas: + df.to_pandas() + + return df + + raise RuntimeError("Please, install either pyarrow or polars.") + + + def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]: + intervals = [(0, self._get_num_rows(item)) for item in items] + + world_size = self.get_num_nodes() * num_workers + node_rank = self.get_node_rank() + + fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes()) + parquet_indexes_per_worker, p_slices_per_worker = _associate_chunks_and_internals_to_ranks( + fake_distributed_env, list(range(len(items))), intervals, False) + + workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(num_workers)] + + iterator = enumerate(zip(parquet_indexes_per_worker, p_slices_per_worker)) + + node_start = node_rank * num_workers + node_end = (node_rank + 1) * num_workers + + for worker_idx, (parquet_indexes, p_slices) in iterator: + if node_start <= worker_idx < node_end: + if self.num_rows: + workers_user_items[worker_idx % num_workers].extend([ + ParquetSlice( + items[parquet_index], p_slice_start, p_slice_start + self.num_rows + if p_slice[1] > (p_slice_start + self.num_rows) else + p_slice[1] + ) + for parquet_index, p_slice in zip(parquet_indexes, p_slices) + for p_slice_start in range(p_slice[0], p_slice[1] + self.num_rows, self.num_rows) + if p_slice_start < p_slice[1] + ]) + else: + workers_user_items[worker_idx % num_workers].extend([ + ParquetSlice(items[parquet_index], *p_slice) + for parquet_index, p_slice in zip(parquet_indexes, p_slices) + ]) + + assert len(workers_user_items) == num_workers + assert all(len(w) for w in workers_user_items) + + return workers_user_items diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index a7a49710ad45a..3fdebf29b5448 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -20,6 +20,7 @@ from tqdm.auto import tqdm as _tqdm from lightning import seed_everything +from lightning.data.processing.readers import BaseReader from lightning.data.streaming import Cache from lightning.data.streaming.cache import Dir from lightning.data.streaming.client import S3Client @@ -158,8 +159,9 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) elif os.path.isfile(path): - os.makedirs(os.path.dirname(local_path), exist_ok=True) - shutil.copyfile(path, local_path) + if not path.startswith("/teamspace/studios/this_studio"): + os.makedirs(os.path.dirname(local_path), exist_ok=True) + shutil.copyfile(path, local_path) else: raise ValueError(f"The provided {input_dir.url} isn't supported.") @@ -340,6 +342,7 @@ def __init__( num_downloaders: int, num_uploaders: int, remove: bool, + reader: Optional[BaseReader] = None, ) -> None: """The BaseWorker is responsible to process the user data.""" self.worker_index = worker_index @@ -353,6 +356,7 @@ def __init__( self.num_downloaders = num_downloaders self.num_uploaders = num_uploaders self.remove = remove + self.reader = reader self.paths: List[List[str]] = [] self.remover: Optional[Process] = None self.downloaders: List[Process] = [] @@ -433,7 +437,7 @@ def _loop(self) -> None: self.progress_queue.put((self.worker_index, self._counter)) self._last_time = time() - if self.remove and self.input_dir.path is not None: + if self.remove and self.input_dir.path is not None and self.reader is None: self.remove_queue.put(self.paths[index]) try: @@ -476,7 +480,7 @@ def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None: self.to_upload_queues[self._counter % self.num_uploaders].put(data) def _collect_paths(self) -> None: - if self.input_dir.path is None: + if self.input_dir.path is None or self.reader is not None: for index in range(len(self.items)): self.ready_to_process_queue.put(index) for _ in range(self.num_downloaders): @@ -513,7 +517,7 @@ def is_path(element: Any) -> bool: paths = [] for index, path in indexed_paths.items(): paths.append(path) - if self.input_dir: + if self.input_dir and not self.input_dir.path.startswith("/teamspace/studios/this_studio"): path = path.replace(self.input_dir.path, self.cache_data_dir) flattened_item[index] = path @@ -525,8 +529,9 @@ def is_path(element: Any) -> bool: self.items = items def _start_downloaders(self) -> None: - if self.input_dir.path is None: + if self.input_dir.path is None or self.reader is not None: return + for _ in range(self.num_downloaders): to_download_queue: Queue = Queue() p = Process( @@ -583,7 +588,7 @@ def _start_uploaders(self) -> None: def _handle_data_chunk_recipe(self, index: int) -> None: try: - self._current_item = self.items[index] + self._current_item = self.items[index] if self.reader is None else self.reader.read(self.items[index]) item_data_or_generator = self.data_recipe.prepare_item(self._current_item) if isinstance(item_data_or_generator, types.GeneratorType): for item_data in item_data_or_generator: @@ -596,7 +601,7 @@ def _handle_data_chunk_recipe(self, index: int) -> None: self._try_upload(chunk_filepath) self._index_counter += 1 except Exception as e: - raise RuntimeError(f"Failed processing {self._current_item}") from e + raise RuntimeError(f"Failed processing {self.items[index]}") from e def _handle_data_chunk_recipe_end(self) -> None: chunks_filepaths = self.cache.done() @@ -609,7 +614,8 @@ def _handle_data_chunk_recipe_end(self) -> None: def _handle_data_transform_recipe(self, index: int) -> None: # Don't use a context manager to avoid deleting files that are being uploaded. output_dir = tempfile.mkdtemp() - item_data = self.data_recipe.prepare_item(self.items[index], str(output_dir), len(self.items) - 1 == index) + item = self.items[index] if self.reader is None else self.reader.read(self.items[index]) + item_data = self.data_recipe.prepare_item(item, str(output_dir), len(self.items) - 1 == index) if item_data is not None: raise ValueError( "When using a `DataTransformRecipe`, the `prepare_item` shouldn't return anything." @@ -792,6 +798,7 @@ def __init__( random_seed: Optional[int] = 42, reorder_files: bool = True, weights: Optional[List[int]] = None, + reader: Optional[BaseReader] = None, ): """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make training faster. @@ -809,6 +816,7 @@ def __init__( Set this to ``False`` if the order in which samples are processed should be preserved. weights: Provide a list of weights associated to the inputs. This is used to evenly split the work among the workers. + reader: Map the inputs to worker inputs and provides a read method to read a slice of the data. """ self.input_dir = _resolve_dir(input_dir) @@ -825,6 +833,10 @@ def __init__( self.stop_queues: List[Queue] = [] self.reorder_files = reorder_files self.weights = weights + self.reader = reader + + if self.reader is not None and self.weights is not None: + raise ValueError("Either the reader or the weights needs to be defined.") # Ensure the input dir is the same across all nodes self.input_dir = broadcast_object("input_dir", self.input_dir) @@ -853,7 +865,10 @@ def run(self, data_recipe: DataRecipe) -> None: if not isinstance(user_items, list): raise ValueError("The `prepare_structure` should return a list of item metadata.") - if self.weights is not None: + if self.reader: + workers_user_items = self.reader.items_to_workers(user_items, self.num_workers) + + elif self.weights is not None: if len(self.weights) != len(user_items): raise ValueError("The provided weights length should match the inputs' length.") workers_user_items = _map_items_to_workers_weighted( @@ -880,7 +895,7 @@ def run(self, data_recipe: DataRecipe) -> None: self._cleanup_cache() - print(f"Starting {self.num_workers} workers") + print(f"Starting {self.num_workers} workers with {num_items} items.") if self.input_dir is None and self.src_resolver is not None and self.input_dir: self.input_dir = self.src_resolver(self.input_dir) @@ -988,6 +1003,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L self.num_downloaders, self.num_uploaders, self.delete_cached_files, + self.reader, ) worker.start() workers.append(worker) diff --git a/src/lightning/data/streaming/functions.py b/src/lightning/data/streaming/functions.py index bd6616cca6e1f..bbdcdbc2ca764 100644 --- a/src/lightning/data/streaming/functions.py +++ b/src/lightning/data/streaming/functions.py @@ -22,6 +22,7 @@ import torch +from lightning.data.processing.readers import BaseReader from lightning.data.streaming.constants import _TORCH_GREATER_EQUAL_2_1_0 from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from lightning.data.streaming.resolver import ( @@ -157,6 +158,7 @@ def map( num_downloaders: Optional[int] = None, reorder_files: bool = True, error_when_not_empty: bool = False, + reader: Optional[BaseReader] = None, ) -> None: """This function map a callbable over a collection of files possibly in a distributed way. @@ -203,6 +205,7 @@ def map( num_downloaders=num_downloaders, reorder_files=reorder_files, weights=weights, + reader=reader, ) return data_processor.run(LambdaDataTransformRecipe(fn, inputs)) return _execute( @@ -225,6 +228,7 @@ def optimize( machine: Optional[str] = None, num_downloaders: Optional[int] = None, reorder_files: bool = True, + reader: Optional[BaseReader] = None, ) -> None: """This function converts a dataset into chunks possibly in a distributed way. @@ -274,6 +278,7 @@ def optimize( fast_dev_run=fast_dev_run, num_downloaders=num_downloaders, reorder_files=reorder_files, + reader=reader, ) return data_processor.run( LambdaDataChunkRecipe( diff --git a/tests/tests_data/processing/__init__.py b/tests/tests_data/processing/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_data/processing/test_readers.py b/tests/tests_data/processing/test_readers.py new file mode 100644 index 0000000000000..31e0b9ec62109 --- /dev/null +++ b/tests/tests_data/processing/test_readers.py @@ -0,0 +1,59 @@ +import os +import sys + +import pytest +from lightning.data import map +from lightning.data.processing.readers import _POLARS_AVAILABLE, _PYARROW_AVAILABLE, BaseReader, ParquetReader + + +class DummyReader(BaseReader): + + def items_to_workers(self, items, num_workers: int): + return [[(worker_idx, idx, item) for idx, item in enumerate(items)] for worker_idx in range(num_workers)] + + def read(self, item): + return item + + +def fn(data: str, output_dir): + worker_idx, idx, _ = data + + with open(os.path.join(output_dir, f"{worker_idx}_{idx}"), "w") as f: + f.write("hello world") + + +def test_reader(tmpdir): + map(fn, list(range(3)), output_dir=str(tmpdir), reader=DummyReader(), num_workers=2) + assert sorted(os.listdir(tmpdir)) == ['0_0', '0_1', '0_2', '1_0', '1_1', '1_2'] + + +def map_parquet(df, output_dir): + filename = f"{df.row(0)[0]}_{len(df)}" + + with open(os.path.join(output_dir, filename), "w") as f: + f.write("hello world") + +@pytest.mark.skipif( + (not _POLARS_AVAILABLE and not _PYARROW_AVAILABLE) or sys.platform == "linux", + reason="polars and pyarrow are required" +) +def test_parquet_reader(tmpdir): + import polars as pol + + inputs = [] + + for i in range(3): + parquet_path = os.path.join(tmpdir, f"{i}.parquet") + df = pol.DataFrame(list(range(i * 10, (i + 1) * 10))) + df.write_parquet(parquet_path) + inputs.append(parquet_path) + + map( + map_parquet, + inputs=inputs, + output_dir=os.path.join(tmpdir, "output_dir"), + reader=ParquetReader(num_rows=10, to_pandas=False), + num_workers=2 + ) + + assert sorted(os.listdir(os.path.join(tmpdir, "output_dir"))) == ['0_10', '10_5', '15_5', '20_10']