Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightning Data: Refactor files #19424

Merged
merged 3 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from lightning.data.processing.functions import map, optimize, walk
from lightning.data.streaming.combined import CombinedStreamingDataset
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.functions import map, optimize, walk

__all__ = [
"LightningDataset",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

import torch

from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.processing.readers import BaseReader
from lightning.data.streaming.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.streaming.resolver import (
Dir,
_assert_dir_has_index_file,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/processing/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from lightning_utilities.core.imports import RequirementCache

from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks
from lightning.data.utilities.env import _DistributedEnv
from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks

_POLARS_AVAILABLE = RequirementCache("polars")
_PYARROW_AVAILABLE = RequirementCache("pyarrow")
Expand Down
4 changes: 0 additions & 4 deletions src/lightning/data/streaming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@

from lightning.data.streaming.cache import Cache
from lightning.data.streaming.combined import CombinedStreamingDataset
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.item_loader import TokensLoader

__all__ = [
"Cache",
"DataProcessor",
"StreamingDataset",
"CombinedStreamingDataset",
"StreamingDataLoader",
"DataTransformRecipe",
"DataChunkRecipe",
"TokensLoader",
]
74 changes: 1 addition & 73 deletions src/lightning/data/streaming/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from lightning.data.streaming import Cache
from lightning.data.utilities.env import _DistributedEnv
from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle


class Shuffle(ABC):
Expand Down Expand Up @@ -129,76 +130,3 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c

def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
return np.random.RandomState([self.seed, num_chunks * current_epoch, chunk_index]).permutation(array).tolist()


def _intra_node_chunk_shuffle(
distributed_env: _DistributedEnv,
chunks_per_ranks: List[List[int]],
seed: int,
current_epoch: int,
) -> List[int]:
chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)]
for rank, chunks_per_rank in enumerate(chunks_per_ranks):
chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend(
chunks_per_rank
)

# shuffle the chunks associated to the node
for i in range(len(chunk_indexes_per_nodes)):
# permute the indexes within the node
chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation(
chunk_indexes_per_nodes[i]
)

return [index for chunks in chunk_indexes_per_nodes for index in chunks]


def _associate_chunks_and_internals_to_ranks(
distributed_env: _DistributedEnv,
indexes: Any,
chunk_intervals: Any,
drop_last: bool,
) -> Tuple[List[List[int]], List[Any]]:
num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
num_items_per_ranks: List[int] = [
num_items // distributed_env.world_size + num_items % distributed_env.world_size
if rank == distributed_env.world_size - 1 and not drop_last
else num_items // distributed_env.world_size
for rank in range(distributed_env.world_size)
]
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]

# 4. Assign the chunk & intervals to each rank
for chunk_index, chunk_interval in zip(indexes, chunk_intervals):
rank = 0

while True:
if rank == len(num_items_per_ranks):
break

items_left_to_assign = num_items_per_ranks[rank]

if items_left_to_assign == 0:
rank += 1
continue

items_in_chunk = chunk_interval[-1] - chunk_interval[0]

if items_in_chunk == 0:
break

if items_in_chunk > items_left_to_assign:
chunks_per_ranks[rank].append(chunk_index)
begin, end = chunk_interval
intervals_per_ranks[rank].append([begin, begin + items_left_to_assign])
chunk_interval = (begin + items_left_to_assign, end)
num_items_per_ranks[rank] = 0
rank += 1
else:
chunks_per_ranks[rank].append(chunk_index)
intervals_per_ranks[rank].append(chunk_interval)
num_items_per_ranks[rank] -= items_in_chunk
break

return chunks_per_ranks, intervals_per_ranks
78 changes: 78 additions & 0 deletions src/lightning/data/utilities/shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any, List, Tuple

import numpy as np

from lightning.data.utilities.env import _DistributedEnv


def _intra_node_chunk_shuffle(
distributed_env: _DistributedEnv,
chunks_per_ranks: List[List[int]],
seed: int,
current_epoch: int,
) -> List[int]:
chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)]
for rank, chunks_per_rank in enumerate(chunks_per_ranks):
chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend(
chunks_per_rank
)

# shuffle the chunks associated to the node
for i in range(len(chunk_indexes_per_nodes)):
# permute the indexes within the node
chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation(
chunk_indexes_per_nodes[i]
)

return [index for chunks in chunk_indexes_per_nodes for index in chunks]


def _associate_chunks_and_internals_to_ranks(
distributed_env: _DistributedEnv,
indexes: Any,
chunk_intervals: Any,
drop_last: bool,
) -> Tuple[List[List[int]], List[Any]]:
num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
num_items_per_ranks: List[int] = [
num_items // distributed_env.world_size + num_items % distributed_env.world_size
if rank == distributed_env.world_size - 1 and not drop_last
else num_items // distributed_env.world_size
for rank in range(distributed_env.world_size)
]
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]

# 4. Assign the chunk & intervals to each rank
for chunk_index, chunk_interval in zip(indexes, chunk_intervals):
rank = 0

while True:
if rank == len(num_items_per_ranks):
break

items_left_to_assign = num_items_per_ranks[rank]

if items_left_to_assign == 0:
rank += 1
continue

items_in_chunk = chunk_interval[-1] - chunk_interval[0]

if items_in_chunk == 0:
break

if items_in_chunk > items_left_to_assign:
chunks_per_ranks[rank].append(chunk_index)
begin, end = chunk_interval
intervals_per_ranks[rank].append([begin, begin + items_left_to_assign])
chunk_interval = (begin + items_left_to_assign, end)
num_items_per_ranks[rank] = 0
rank += 1
else:
chunks_per_ranks[rank].append(chunk_index)
intervals_per_ranks[rank].append(chunk_interval)
num_items_per_ranks[rank] -= items_in_chunk
break

return chunks_per_ranks, intervals_per_ranks
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
import pytest
import torch
from lightning import seed_everything
from lightning.data.streaming import data_processor as data_processor_module
from lightning.data.streaming import functions, resolver
from lightning.data.streaming.cache import Cache, Dir
from lightning.data.streaming.data_processor import (
from lightning.data.processing import data_processor as data_processor_module
from lightning.data.processing import functions
from lightning.data.processing.data_processor import (
DataChunkRecipe,
DataProcessor,
DataTransformRecipe,
Expand All @@ -26,7 +25,9 @@
_wait_for_disk_usage_higher_than_threshold,
_wait_for_file_to_exist,
)
from lightning.data.streaming.functions import LambdaDataTransformRecipe, map, optimize
from lightning.data.processing.functions import LambdaDataTransformRecipe, map, optimize
from lightning.data.streaming import resolver
from lightning.data.streaming.cache import Cache, Dir
from lightning_utilities.core.imports import RequirementCache

_PIL_AVAILABLE = RequirementCache("PIL")
Expand Down Expand Up @@ -162,7 +163,7 @@ def fn(*_, **__):


@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
@mock.patch("lightning.data.streaming.data_processor._wait_for_disk_usage_higher_than_threshold")
@mock.patch("lightning.data.processing.data_processor._wait_for_disk_usage_higher_than_threshold")
def test_download_data_target(wait_for_disk_usage_higher_than_threshold_mock, tmpdir):
input_dir = os.path.join(tmpdir, "input_dir")
os.makedirs(input_dir, exist_ok=True)
Expand Down Expand Up @@ -201,7 +202,7 @@ def fn(*_, **__):

def test_wait_for_disk_usage_higher_than_threshold():
disk_usage_mock = mock.Mock(side_effect=[mock.Mock(free=10e9), mock.Mock(free=10e9), mock.Mock(free=10e11)])
with mock.patch("lightning.data.streaming.data_processor.shutil.disk_usage", disk_usage_mock):
with mock.patch("lightning.data.processing.data_processor.shutil.disk_usage", disk_usage_mock):
_wait_for_disk_usage_higher_than_threshold("/", 10, sleep_time=0)
assert disk_usage_mock.call_count == 3

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
from lightning.data import walk
from lightning.data.streaming.functions import _get_input_dir
from lightning.data.processing.functions import _get_input_dir


@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_data/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import pytest
import torch
from lightning import seed_everything
from lightning.data.streaming import Cache, functions
from lightning.data.processing import functions
from lightning.data.streaming import Cache
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import (
_INDEX_FILENAME,
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_data/streaming/test_shuffle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle
from lightning.data.utilities.env import _DistributedEnv
from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle


def test_intra_node_chunk_shuffle():
Expand Down
Loading