Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for parallelizing processing parquet files across workers and nodes. #19400

Merged
merged 17 commits into from
Feb 5, 2024
Prev Previous commit
Next Next commit
update
  • Loading branch information
tchaton committed Feb 5, 2024
commit ae015d20bb8533f2719f888ae640816c9260cb11
45 changes: 28 additions & 17 deletions src/lightning/data/processing/readers.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,9 @@ class BaseReader(ABC):
def get_num_nodes(self) -> int:
return int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 1))

def get_node_rank(self) -> int:
return int(os.getenv("DATA_OPTIMIZER_NODE_RANK", 0))

@abstractmethod
def items_to_workers(self, items: List[Any], num_workers: int) -> List[List[Any]]:
"""This method is meant to convert the items provided by the users into items to be processed by the
@@ -90,31 +93,39 @@ def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSli
intervals = [(0, self._get_num_rows(item)) for item in items]

world_size = self.get_num_nodes() * num_workers
node_rank = self.get_node_rank()

fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes())
parquet_indexes_per_worker, parquet_slices_per_worker = _associate_chunks_and_internals_to_ranks(
fake_distributed_env, list(range(len(items))), intervals, False)

workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(world_size)]
workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(num_workers)]

iterator = enumerate(zip(parquet_indexes_per_worker, parquet_slices_per_worker))

node_start = node_rank * num_workers
node_end = (node_rank + 1) * num_workers

for worker_idx, (parquet_indexes, parquet_slices) in iterator:
if self.num_rows:
workers_user_items[worker_idx].extend([
ParquetSlice(
items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows
if parquet_slice[1] > (parquet_slice_start + self.num_rows) else
parquet_slice[1]
)
for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices)
for parquet_slice_start in range(parquet_slice[0], parquet_slice[1] + self.num_rows, self.num_rows)
if parquet_slice_start < parquet_slice[1]
])
else:
workers_user_items[worker_idx].extend([
ParquetSlice(items[parquet_index], *parquet_slice)
for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices)
])
if node_start <= worker_idx < node_end:
if self.num_rows:
workers_user_items[worker_idx % num_workers].extend([
ParquetSlice(
items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows
if parquet_slice[1] > (parquet_slice_start + self.num_rows) else
parquet_slice[1]
)
for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices)
for parquet_slice_start in range(parquet_slice[0], parquet_slice[1] + self.num_rows, self.num_rows)
if parquet_slice_start < parquet_slice[1]
])
else:
workers_user_items[worker_idx % num_workers].extend([
ParquetSlice(items[parquet_index], *parquet_slice)
for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices)
])

assert len(workers_user_items) == num_workers
assert all(len(w) for w in workers_user_items)

return workers_user_items
Loading