Skip to content

Commit b097a4d

Browse files
authored
Improve data processing to enable downloading LAOIN 400M (#19452)
1 parent 3c5a465 commit b097a4d

File tree

15 files changed

+275
-255
lines changed

15 files changed

+275
-255
lines changed

requirements/data/test.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ pytest-timeout ==2.1.0
55
pytest-rerunfailures ==12.0
66
pytest-random-order ==1.1.0
77
viztracer
8+
pandas
89
pyarrow
9-
polars

src/lightning/data/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from lightning_utilities.core.imports import RequirementCache
2+
13
from lightning.data.processing.functions import map, optimize, walk
24
from lightning.data.streaming.combined import CombinedStreamingDataset
35
from lightning.data.streaming.dataloader import StreamingDataLoader
@@ -13,3 +15,8 @@
1315
"optimize",
1416
"walk",
1517
]
18+
19+
if RequirementCache('lightning_sdk'):
20+
from lightning_sdk import Machine # noqa: F401
21+
22+
__all__.append("Machine")

src/lightning/data/processing/data_processor.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,6 @@ def __init__(
372372
self._counter = 0
373373
self._last_time = time()
374374
self._index_counter = 0
375-
self._current_item: Any = None
376375

377376
def run(self) -> None:
378377
try:
@@ -477,6 +476,7 @@ def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None:
477476
assert os.path.exists(data), data
478477
else:
479478
assert os.path.exists(data[-1]), data
479+
480480
self.to_upload_queues[self._counter % self.num_uploaders].put(data)
481481

482482
def _collect_paths(self) -> None:
@@ -588,8 +588,8 @@ def _start_uploaders(self) -> None:
588588

589589
def _handle_data_chunk_recipe(self, index: int) -> None:
590590
try:
591-
self._current_item = self.items[index] if self.reader is None else self.reader.read(self.items[index])
592-
item_data_or_generator = self.data_recipe.prepare_item(self._current_item)
591+
current_item = self.items[index] if self.reader is None else self.reader.read(self.items[index])
592+
item_data_or_generator = self.data_recipe.prepare_item(current_item)
593593
if isinstance(item_data_or_generator, types.GeneratorType):
594594
for item_data in item_data_or_generator:
595595
if item_data is not None:
@@ -713,14 +713,19 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul
713713
size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]])
714714
num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]])
715715
data_format = tree_unflatten(config["config"]["data_format"], treespec_loads(config["config"]["data_spec"]))
716+
num_chunks = len(config["chunks"])
717+
718+
# The platform can't store more than 1024 entries.
719+
# Note: This isn't really used right now, so it is fine to skip if too big.
720+
num_bytes_per_chunk = [c["chunk_size"] for c in config["chunks"]] if num_chunks < 1024 else []
716721

717722
return _Result(
718723
size=size,
719724
num_bytes=num_bytes,
720725
data_format=data_format,
721726
compression=config["config"]["compression"],
722727
num_chunks=len(config["chunks"]),
723-
num_bytes_per_chunk=[c["chunk_size"] for c in config["chunks"]],
728+
num_bytes_per_chunk=num_bytes_per_chunk,
724729
)
725730
return _Result(
726731
size=size,
@@ -866,9 +871,9 @@ def run(self, data_recipe: DataRecipe) -> None:
866871
raise ValueError("The `prepare_structure` should return a list of item metadata.")
867872

868873
if self.reader:
869-
workers_user_items = self.reader.items_to_workers(user_items, self.num_workers)
874+
user_items = self.reader.remap_items(user_items, self.num_workers)
870875

871-
elif self.weights is not None:
876+
if self.weights is not None:
872877
if len(self.weights) != len(user_items):
873878
raise ValueError("The provided weights length should match the inputs' length.")
874879
workers_user_items = _map_items_to_workers_weighted(

src/lightning/data/processing/dns.py

-47
This file was deleted.

src/lightning/data/processing/functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525
from lightning.data.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0
2626
from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
27-
from lightning.data.processing.dns import optimize_dns_context
2827
from lightning.data.processing.readers import BaseReader
28+
from lightning.data.processing.utilities import optimize_dns_context
2929
from lightning.data.streaming.resolver import (
3030
Dir,
3131
_assert_dir_has_index_file,

src/lightning/data/processing/image.py

-47
This file was deleted.
+48-80
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
1+
import contextlib
12
import os
23
from abc import ABC, abstractmethod
3-
from dataclasses import dataclass
4-
from typing import Any, List, Optional
4+
from typing import Any, List
55

66
from lightning_utilities.core.imports import RequirementCache
7+
from tqdm import tqdm
78

8-
from lightning.data.utilities.env import _DistributedEnv
9-
from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks
10-
11-
_POLARS_AVAILABLE = RequirementCache("polars")
129
_PYARROW_AVAILABLE = RequirementCache("pyarrow")
1310

14-
1511
class BaseReader(ABC):
1612

1713
def get_num_nodes(self) -> int:
@@ -21,9 +17,8 @@ def get_node_rank(self) -> int:
2117
return int(os.getenv("DATA_OPTIMIZER_NODE_RANK", 0))
2218

2319
@abstractmethod
24-
def items_to_workers(self, items: List[Any], num_workers: int) -> List[List[Any]]:
25-
"""This method is meant to convert the items provided by the users into items to be processed by the
26-
workers."""
20+
def remap_items(self, items: List[Any], num_workers: int) -> List[Any]:
21+
"""This method is meant to remap the items provided by the users into items more adapted to be distributed."""
2722
pass
2823

2924
@abstractmethod
@@ -32,100 +27,73 @@ def read(self, item: Any) -> Any:
3227
pass
3328

3429

35-
@dataclass
36-
class ParquetSlice:
37-
"""Keep track of a parquet file slice with its filepath, start and end."""
38-
filepath: str
39-
start: int
40-
end: int
41-
42-
4330
class ParquetReader(BaseReader):
4431

45-
def __init__(self, num_rows: Optional[int] = 2048, to_pandas: bool = True) -> None:
32+
def __init__(self, cache_folder: str, num_rows: int = 65536, to_pandas: bool = True) -> None:
33+
super().__init__()
34+
self.cache_folder = cache_folder
4635
self.num_rows = num_rows
4736
self.to_pandas = to_pandas
4837

49-
if not _PYARROW_AVAILABLE or not _POLARS_AVAILABLE:
50-
raise ModuleNotFoundError("Please, run: `pip install pyarrow polars`")
5138

52-
def _get_num_rows(self, path: str) -> int:
53-
if _PYARROW_AVAILABLE:
54-
import pyarrow.dataset as ds
55-
df = ds.dataset(path).scanner()
56-
return df.count_rows()
5739

58-
# FIXED: There is a bug in polars. This leads to read_parquet to hang.
59-
if _POLARS_AVAILABLE:
60-
import polars as pol
61-
df = pol.scan_parquet(path)
62-
num_rows = df.select(pol.len()).collect().item()
63-
return num_rows
40+
if not _PYARROW_AVAILABLE:
41+
raise ModuleNotFoundError("Please, run: `pip install pyarrow`")
6442

65-
raise RuntimeError("Please, install either pyarrow or polars.")
6643

67-
def read(self, item: ParquetSlice) -> Any:
68-
if _POLARS_AVAILABLE:
69-
import polars as pol
70-
df = pol.scan_parquet(item.filepath).slice(item.start, item.end).collect()
44+
self.parquet_file = None
7145

72-
if self.to_pandas:
73-
df = df.to_pandas()
46+
def _get_num_rows(self, path: str) -> int:
47+
import pyarrow.dataset as ds
7448

75-
return df
49+
df = ds.dataset(path).scanner()
50+
return df.count_rows()
7651

77-
if _PYARROW_AVAILABLE:
78-
import pyarrow.dataset as ds
52+
def read(self, filepath: str) -> Any:
53+
import pyarrow as pa
54+
import pyarrow.parquet as pq
7955

80-
df = ds.dataset(item.filepath).scanner()
56+
# Try to force dellocation to avoid memory leak
57+
with contextlib.suppress(Exception):
58+
pa.jemalloc_set_decay_ms(0)
8159

82-
df = df.take([item.start, item.end])
60+
# close the previous parquet file to release the memory
61+
if self.parquet_file is not None:
62+
self.parquet_file.close()
63+
self.parquet_file = None
8364

84-
if self.to_pandas:
85-
df.to_pandas()
65+
self.parquet_file = pq.ParquetFile(filepath, memory_map=True)
66+
return self.parquet_file
8667

87-
return df
68+
def remap_items(self, filepaths: List[str], _: int) -> List[str]:
69+
import pyarrow.parquet as pq
8870

89-
raise RuntimeError("Please, install either pyarrow or polars.")
71+
print("Starting resharding the parquet files for optimized processing.")
9072

73+
new_items = []
9174

92-
def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]:
93-
intervals = [(0, self._get_num_rows(item)) for item in items]
75+
cache_folder = os.path.join(self.cache_folder, f"{self.num_rows}")
76+
os.makedirs(cache_folder, exist_ok=True)
9477

95-
world_size = self.get_num_nodes() * num_workers
96-
node_rank = self.get_node_rank()
78+
for filepath in filepaths:
79+
num_rows = self._get_num_rows(filepath)
9780

98-
fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes())
99-
parquet_indexes_per_worker, p_slices_per_worker = _associate_chunks_and_internals_to_ranks(
100-
fake_distributed_env, list(range(len(items))), intervals, False)
81+
table = None
82+
parquet_filename = os.path.basename(filepath)
10183

102-
workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(num_workers)]
84+
for start in tqdm(range(0, num_rows, self.num_rows)):
85+
end = min(start + self.num_rows, num_rows)
86+
chunk_filepath = os.path.join(cache_folder, f"{start}_{end}_{parquet_filename}")
87+
new_items.append(chunk_filepath)
10388

104-
iterator = enumerate(zip(parquet_indexes_per_worker, p_slices_per_worker))
89+
if os.path.exists(chunk_filepath):
90+
continue
10591

106-
node_start = node_rank * num_workers
107-
node_end = (node_rank + 1) * num_workers
92+
if table is None:
93+
table = pq.read_table(filepath, memory_map=True)
10894

109-
for worker_idx, (parquet_indexes, p_slices) in iterator:
110-
if node_start <= worker_idx < node_end:
111-
if self.num_rows:
112-
workers_user_items[worker_idx % num_workers].extend([
113-
ParquetSlice(
114-
items[parquet_index], p_slice_start, p_slice_start + self.num_rows
115-
if p_slice[1] > (p_slice_start + self.num_rows) else
116-
p_slice[1]
117-
)
118-
for parquet_index, p_slice in zip(parquet_indexes, p_slices)
119-
for p_slice_start in range(p_slice[0], p_slice[1] + self.num_rows, self.num_rows)
120-
if p_slice_start < p_slice[1]
121-
])
122-
else:
123-
workers_user_items[worker_idx % num_workers].extend([
124-
ParquetSlice(items[parquet_index], *p_slice)
125-
for parquet_index, p_slice in zip(parquet_indexes, p_slices)
126-
])
95+
pq.write_table(table[start: end], chunk_filepath)
12796

128-
assert len(workers_user_items) == num_workers
129-
assert all(len(w) for w in workers_user_items)
97+
print("Finished resharding the parquet files for optimized processing.")
13098

131-
return workers_user_items
99+
return new_items

0 commit comments

Comments
 (0)