From 61695cd3390f361baa8fe26565ff537a57b9b7e5 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 6 Feb 2024 21:07:25 +0000 Subject: [PATCH 1/2] update --- src/lightning/data/__init__.py | 2 +- .../{streaming => processing}/data_processor.py | 0 .../data/{streaming => processing}/functions.py | 2 +- src/lightning/data/streaming/__init__.py | 2 +- .../test_data_processor.py | 15 ++++++++------- .../{streaming => processing}/test_functions.py | 2 +- tests/tests_data/streaming/test_dataset.py | 3 ++- 7 files changed, 14 insertions(+), 12 deletions(-) rename src/lightning/data/{streaming => processing}/data_processor.py (100%) rename src/lightning/data/{streaming => processing}/functions.py (99%) rename tests/tests_data/{streaming => processing}/test_data_processor.py (98%) rename tests/tests_data/{streaming => processing}/test_functions.py (94%) diff --git a/src/lightning/data/__init__.py b/src/lightning/data/__init__.py index 88b384e8a227b..21cd09bbe8d98 100644 --- a/src/lightning/data/__init__.py +++ b/src/lightning/data/__init__.py @@ -1,7 +1,7 @@ +from lightning.data.processing.functions import map, optimize, walk from lightning.data.streaming.combined import CombinedStreamingDataset from lightning.data.streaming.dataloader import StreamingDataLoader from lightning.data.streaming.dataset import StreamingDataset -from lightning.data.streaming.functions import map, optimize, walk __all__ = [ "LightningDataset", diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/processing/data_processor.py similarity index 100% rename from src/lightning/data/streaming/data_processor.py rename to src/lightning/data/processing/data_processor.py diff --git a/src/lightning/data/streaming/functions.py b/src/lightning/data/processing/functions.py similarity index 99% rename from src/lightning/data/streaming/functions.py rename to src/lightning/data/processing/functions.py index 7bc865ccb8d85..6939418ad617c 100644 --- a/src/lightning/data/streaming/functions.py +++ b/src/lightning/data/processing/functions.py @@ -22,9 +22,9 @@ import torch +from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from lightning.data.processing.readers import BaseReader from lightning.data.streaming.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0 -from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from lightning.data.streaming.resolver import ( Dir, _assert_dir_has_index_file, diff --git a/src/lightning/data/streaming/__init__.py b/src/lightning/data/streaming/__init__.py index 55b88dc2758c1..03ccd7a10cdc8 100644 --- a/src/lightning/data/streaming/__init__.py +++ b/src/lightning/data/streaming/__init__.py @@ -11,9 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from lightning.data.streaming.cache import Cache from lightning.data.streaming.combined import CombinedStreamingDataset -from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from lightning.data.streaming.dataloader import StreamingDataLoader from lightning.data.streaming.dataset import StreamingDataset from lightning.data.streaming.item_loader import TokensLoader diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/processing/test_data_processor.py similarity index 98% rename from tests/tests_data/streaming/test_data_processor.py rename to tests/tests_data/processing/test_data_processor.py index 671917abe18b1..06420130321a9 100644 --- a/tests/tests_data/streaming/test_data_processor.py +++ b/tests/tests_data/processing/test_data_processor.py @@ -10,10 +10,9 @@ import pytest import torch from lightning import seed_everything -from lightning.data.streaming import data_processor as data_processor_module -from lightning.data.streaming import functions, resolver -from lightning.data.streaming.cache import Cache, Dir -from lightning.data.streaming.data_processor import ( +from lightning.data.processing import data_processor as data_processor_module +from lightning.data.processing import functions +from lightning.data.processing.data_processor import ( DataChunkRecipe, DataProcessor, DataTransformRecipe, @@ -26,7 +25,9 @@ _wait_for_disk_usage_higher_than_threshold, _wait_for_file_to_exist, ) -from lightning.data.streaming.functions import LambdaDataTransformRecipe, map, optimize +from lightning.data.processing.functions import LambdaDataTransformRecipe, map, optimize +from lightning.data.streaming import resolver +from lightning.data.streaming.cache import Cache, Dir from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") @@ -162,7 +163,7 @@ def fn(*_, **__): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") -@mock.patch("lightning.data.streaming.data_processor._wait_for_disk_usage_higher_than_threshold") +@mock.patch("lightning.data.processing.data_processor._wait_for_disk_usage_higher_than_threshold") def test_download_data_target(wait_for_disk_usage_higher_than_threshold_mock, tmpdir): input_dir = os.path.join(tmpdir, "input_dir") os.makedirs(input_dir, exist_ok=True) @@ -201,7 +202,7 @@ def fn(*_, **__): def test_wait_for_disk_usage_higher_than_threshold(): disk_usage_mock = mock.Mock(side_effect=[mock.Mock(free=10e9), mock.Mock(free=10e9), mock.Mock(free=10e11)]) - with mock.patch("lightning.data.streaming.data_processor.shutil.disk_usage", disk_usage_mock): + with mock.patch("lightning.data.processing.data_processor.shutil.disk_usage", disk_usage_mock): _wait_for_disk_usage_higher_than_threshold("/", 10, sleep_time=0) assert disk_usage_mock.call_count == 3 diff --git a/tests/tests_data/streaming/test_functions.py b/tests/tests_data/processing/test_functions.py similarity index 94% rename from tests/tests_data/streaming/test_functions.py rename to tests/tests_data/processing/test_functions.py index 10bf40caf7c2f..d0b581130d928 100644 --- a/tests/tests_data/streaming/test_functions.py +++ b/tests/tests_data/processing/test_functions.py @@ -4,7 +4,7 @@ import pytest from lightning.data import walk -from lightning.data.streaming.functions import _get_input_dir +from lightning.data.processing.functions import _get_input_dir @pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.") diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index db294f6ea564f..e48db3fab9f30 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -20,7 +20,8 @@ import pytest import torch from lightning import seed_everything -from lightning.data.streaming import Cache, functions +from lightning.data.processing import functions +from lightning.data.streaming import Cache from lightning.data.streaming.dataloader import StreamingDataLoader from lightning.data.streaming.dataset import ( _INDEX_FILENAME, From 48958ccadcef0aee4a963958b89177e22b5c60bb Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 7 Feb 2024 23:39:48 +0000 Subject: [PATCH 2/2] update --- src/lightning/data/processing/readers.py | 2 +- src/lightning/data/streaming/__init__.py | 4 -- src/lightning/data/streaming/shuffle.py | 74 +------------------- src/lightning/data/utilities/shuffle.py | 78 ++++++++++++++++++++++ tests/tests_data/streaming/test_shuffle.py | 2 +- 5 files changed, 81 insertions(+), 79 deletions(-) create mode 100644 src/lightning/data/utilities/shuffle.py diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py index d9a5b293bdba0..8519795b6207c 100644 --- a/src/lightning/data/processing/readers.py +++ b/src/lightning/data/processing/readers.py @@ -5,8 +5,8 @@ from lightning_utilities.core.imports import RequirementCache -from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks from lightning.data.utilities.env import _DistributedEnv +from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks _POLARS_AVAILABLE = RequirementCache("polars") _PYARROW_AVAILABLE = RequirementCache("pyarrow") diff --git a/src/lightning/data/streaming/__init__.py b/src/lightning/data/streaming/__init__.py index 03ccd7a10cdc8..2e6c49cfe6e59 100644 --- a/src/lightning/data/streaming/__init__.py +++ b/src/lightning/data/streaming/__init__.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from lightning.data.streaming.cache import Cache from lightning.data.streaming.combined import CombinedStreamingDataset from lightning.data.streaming.dataloader import StreamingDataLoader @@ -20,11 +19,8 @@ __all__ = [ "Cache", - "DataProcessor", "StreamingDataset", "CombinedStreamingDataset", "StreamingDataLoader", - "DataTransformRecipe", - "DataChunkRecipe", "TokensLoader", ] diff --git a/src/lightning/data/streaming/shuffle.py b/src/lightning/data/streaming/shuffle.py index 3c6f91a97e3db..b0a48bd728eb4 100644 --- a/src/lightning/data/streaming/shuffle.py +++ b/src/lightning/data/streaming/shuffle.py @@ -19,6 +19,7 @@ from lightning.data.streaming import Cache from lightning.data.utilities.env import _DistributedEnv +from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle class Shuffle(ABC): @@ -129,76 +130,3 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]: return np.random.RandomState([self.seed, num_chunks * current_epoch, chunk_index]).permutation(array).tolist() - - -def _intra_node_chunk_shuffle( - distributed_env: _DistributedEnv, - chunks_per_ranks: List[List[int]], - seed: int, - current_epoch: int, -) -> List[int]: - chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)] - for rank, chunks_per_rank in enumerate(chunks_per_ranks): - chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend( - chunks_per_rank - ) - - # shuffle the chunks associated to the node - for i in range(len(chunk_indexes_per_nodes)): - # permute the indexes within the node - chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation( - chunk_indexes_per_nodes[i] - ) - - return [index for chunks in chunk_indexes_per_nodes for index in chunks] - - -def _associate_chunks_and_internals_to_ranks( - distributed_env: _DistributedEnv, - indexes: Any, - chunk_intervals: Any, - drop_last: bool, -) -> Tuple[List[List[int]], List[Any]]: - num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals]) - num_items_per_ranks: List[int] = [ - num_items // distributed_env.world_size + num_items % distributed_env.world_size - if rank == distributed_env.world_size - 1 and not drop_last - else num_items // distributed_env.world_size - for rank in range(distributed_env.world_size) - ] - chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)] - intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)] - - # 4. Assign the chunk & intervals to each rank - for chunk_index, chunk_interval in zip(indexes, chunk_intervals): - rank = 0 - - while True: - if rank == len(num_items_per_ranks): - break - - items_left_to_assign = num_items_per_ranks[rank] - - if items_left_to_assign == 0: - rank += 1 - continue - - items_in_chunk = chunk_interval[-1] - chunk_interval[0] - - if items_in_chunk == 0: - break - - if items_in_chunk > items_left_to_assign: - chunks_per_ranks[rank].append(chunk_index) - begin, end = chunk_interval - intervals_per_ranks[rank].append([begin, begin + items_left_to_assign]) - chunk_interval = (begin + items_left_to_assign, end) - num_items_per_ranks[rank] = 0 - rank += 1 - else: - chunks_per_ranks[rank].append(chunk_index) - intervals_per_ranks[rank].append(chunk_interval) - num_items_per_ranks[rank] -= items_in_chunk - break - - return chunks_per_ranks, intervals_per_ranks diff --git a/src/lightning/data/utilities/shuffle.py b/src/lightning/data/utilities/shuffle.py new file mode 100644 index 0000000000000..8be6563096ed6 --- /dev/null +++ b/src/lightning/data/utilities/shuffle.py @@ -0,0 +1,78 @@ +from typing import Any, List, Tuple + +import numpy as np + +from lightning.data.utilities.env import _DistributedEnv + + +def _intra_node_chunk_shuffle( + distributed_env: _DistributedEnv, + chunks_per_ranks: List[List[int]], + seed: int, + current_epoch: int, +) -> List[int]: + chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)] + for rank, chunks_per_rank in enumerate(chunks_per_ranks): + chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend( + chunks_per_rank + ) + + # shuffle the chunks associated to the node + for i in range(len(chunk_indexes_per_nodes)): + # permute the indexes within the node + chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation( + chunk_indexes_per_nodes[i] + ) + + return [index for chunks in chunk_indexes_per_nodes for index in chunks] + + +def _associate_chunks_and_internals_to_ranks( + distributed_env: _DistributedEnv, + indexes: Any, + chunk_intervals: Any, + drop_last: bool, +) -> Tuple[List[List[int]], List[Any]]: + num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals]) + num_items_per_ranks: List[int] = [ + num_items // distributed_env.world_size + num_items % distributed_env.world_size + if rank == distributed_env.world_size - 1 and not drop_last + else num_items // distributed_env.world_size + for rank in range(distributed_env.world_size) + ] + chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)] + intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)] + + # 4. Assign the chunk & intervals to each rank + for chunk_index, chunk_interval in zip(indexes, chunk_intervals): + rank = 0 + + while True: + if rank == len(num_items_per_ranks): + break + + items_left_to_assign = num_items_per_ranks[rank] + + if items_left_to_assign == 0: + rank += 1 + continue + + items_in_chunk = chunk_interval[-1] - chunk_interval[0] + + if items_in_chunk == 0: + break + + if items_in_chunk > items_left_to_assign: + chunks_per_ranks[rank].append(chunk_index) + begin, end = chunk_interval + intervals_per_ranks[rank].append([begin, begin + items_left_to_assign]) + chunk_interval = (begin + items_left_to_assign, end) + num_items_per_ranks[rank] = 0 + rank += 1 + else: + chunks_per_ranks[rank].append(chunk_index) + intervals_per_ranks[rank].append(chunk_interval) + num_items_per_ranks[rank] -= items_in_chunk + break + + return chunks_per_ranks, intervals_per_ranks diff --git a/tests/tests_data/streaming/test_shuffle.py b/tests/tests_data/streaming/test_shuffle.py index 519c61f8239a0..cb451ce73a1ec 100644 --- a/tests/tests_data/streaming/test_shuffle.py +++ b/tests/tests_data/streaming/test_shuffle.py @@ -1,5 +1,5 @@ -from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle from lightning.data.utilities.env import _DistributedEnv +from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle def test_intra_node_chunk_shuffle():