1
+ import contextlib
1
2
import os
2
3
from abc import ABC , abstractmethod
3
- from dataclasses import dataclass
4
- from typing import Any , List , Optional
4
+ from typing import Any , List
5
5
6
6
from lightning_utilities .core .imports import RequirementCache
7
+ from tqdm import tqdm
7
8
8
- from lightning .data .utilities .env import _DistributedEnv
9
- from lightning .data .utilities .shuffle import _associate_chunks_and_internals_to_ranks
10
-
11
- _POLARS_AVAILABLE = RequirementCache ("polars" )
12
9
_PYARROW_AVAILABLE = RequirementCache ("pyarrow" )
13
10
14
-
15
11
class BaseReader (ABC ):
16
12
17
13
def get_num_nodes (self ) -> int :
@@ -21,9 +17,8 @@ def get_node_rank(self) -> int:
21
17
return int (os .getenv ("DATA_OPTIMIZER_NODE_RANK" , 0 ))
22
18
23
19
@abstractmethod
24
- def items_to_workers (self , items : List [Any ], num_workers : int ) -> List [List [Any ]]:
25
- """This method is meant to convert the items provided by the users into items to be processed by the
26
- workers."""
20
+ def remap_items (self , items : List [Any ], num_workers : int ) -> List [Any ]:
21
+ """This method is meant to remap the items provided by the users into items more adapted to be distributed."""
27
22
pass
28
23
29
24
@abstractmethod
@@ -32,100 +27,73 @@ def read(self, item: Any) -> Any:
32
27
pass
33
28
34
29
35
- @dataclass
36
- class ParquetSlice :
37
- """Keep track of a parquet file slice with its filepath, start and end."""
38
- filepath : str
39
- start : int
40
- end : int
41
-
42
-
43
30
class ParquetReader (BaseReader ):
44
31
45
- def __init__ (self , num_rows : Optional [int ] = 2048 , to_pandas : bool = True ) -> None :
32
+ def __init__ (self , cache_folder : str , num_rows : int = 65536 , to_pandas : bool = True ) -> None :
33
+ super ().__init__ ()
34
+ self .cache_folder = cache_folder
46
35
self .num_rows = num_rows
47
36
self .to_pandas = to_pandas
48
37
49
- if not _PYARROW_AVAILABLE or not _POLARS_AVAILABLE :
50
- raise ModuleNotFoundError ("Please, run: `pip install pyarrow polars`" )
51
38
52
- def _get_num_rows (self , path : str ) -> int :
53
- if _PYARROW_AVAILABLE :
54
- import pyarrow .dataset as ds
55
- df = ds .dataset (path ).scanner ()
56
- return df .count_rows ()
57
39
58
- # FIXED: There is a bug in polars. This leads to read_parquet to hang.
59
- if _POLARS_AVAILABLE :
60
- import polars as pol
61
- df = pol .scan_parquet (path )
62
- num_rows = df .select (pol .len ()).collect ().item ()
63
- return num_rows
40
+ if not _PYARROW_AVAILABLE :
41
+ raise ModuleNotFoundError ("Please, run: `pip install pyarrow`" )
64
42
65
- raise RuntimeError ("Please, install either pyarrow or polars." )
66
43
67
- def read (self , item : ParquetSlice ) -> Any :
68
- if _POLARS_AVAILABLE :
69
- import polars as pol
70
- df = pol .scan_parquet (item .filepath ).slice (item .start , item .end ).collect ()
44
+ self .parquet_file = None
71
45
72
- if self . to_pandas :
73
- df = df . to_pandas ()
46
+ def _get_num_rows ( self , path : str ) -> int :
47
+ import pyarrow . dataset as ds
74
48
75
- return df
49
+ df = ds .dataset (path ).scanner ()
50
+ return df .count_rows ()
76
51
77
- if _PYARROW_AVAILABLE :
78
- import pyarrow .dataset as ds
52
+ def read (self , filepath : str ) -> Any :
53
+ import pyarrow as pa
54
+ import pyarrow .parquet as pq
79
55
80
- df = ds .dataset (item .filepath ).scanner ()
56
+ # Try to force dellocation to avoid memory leak
57
+ with contextlib .suppress (Exception ):
58
+ pa .jemalloc_set_decay_ms (0 )
81
59
82
- df = df .take ([item .start , item .end ])
60
+ # close the previous parquet file to release the memory
61
+ if self .parquet_file is not None :
62
+ self .parquet_file .close ()
63
+ self .parquet_file = None
83
64
84
- if self . to_pandas :
85
- df . to_pandas ()
65
+ self . parquet_file = pq . ParquetFile ( filepath , memory_map = True )
66
+ return self . parquet_file
86
67
87
- return df
68
+ def remap_items (self , filepaths : List [str ], _ : int ) -> List [str ]:
69
+ import pyarrow .parquet as pq
88
70
89
- raise RuntimeError ( "Please, install either pyarrow or polars ." )
71
+ print ( "Starting resharding the parquet files for optimized processing ." )
90
72
73
+ new_items = []
91
74
92
- def items_to_workers ( self , items : Any , num_workers : int ) -> List [ List [ ParquetSlice ]]:
93
- intervals = [( 0 , self . _get_num_rows ( item )) for item in items ]
75
+ cache_folder = os . path . join ( self . cache_folder , f" { self . num_rows } " )
76
+ os . makedirs ( cache_folder , exist_ok = True )
94
77
95
- world_size = self . get_num_nodes () * num_workers
96
- node_rank = self .get_node_rank ( )
78
+ for filepath in filepaths :
79
+ num_rows = self ._get_num_rows ( filepath )
97
80
98
- fake_distributed_env = _DistributedEnv (world_size , 0 , self .get_num_nodes ())
99
- parquet_indexes_per_worker , p_slices_per_worker = _associate_chunks_and_internals_to_ranks (
100
- fake_distributed_env , list (range (len (items ))), intervals , False )
81
+ table = None
82
+ parquet_filename = os .path .basename (filepath )
101
83
102
- workers_user_items : List [List [ParquetSlice ]] = [[] for _ in range (num_workers )]
84
+ for start in tqdm (range (0 , num_rows , self .num_rows )):
85
+ end = min (start + self .num_rows , num_rows )
86
+ chunk_filepath = os .path .join (cache_folder , f"{ start } _{ end } _{ parquet_filename } " )
87
+ new_items .append (chunk_filepath )
103
88
104
- iterator = enumerate (zip (parquet_indexes_per_worker , p_slices_per_worker ))
89
+ if os .path .exists (chunk_filepath ):
90
+ continue
105
91
106
- node_start = node_rank * num_workers
107
- node_end = ( node_rank + 1 ) * num_workers
92
+ if table is None :
93
+ table = pq . read_table ( filepath , memory_map = True )
108
94
109
- for worker_idx , (parquet_indexes , p_slices ) in iterator :
110
- if node_start <= worker_idx < node_end :
111
- if self .num_rows :
112
- workers_user_items [worker_idx % num_workers ].extend ([
113
- ParquetSlice (
114
- items [parquet_index ], p_slice_start , p_slice_start + self .num_rows
115
- if p_slice [1 ] > (p_slice_start + self .num_rows ) else
116
- p_slice [1 ]
117
- )
118
- for parquet_index , p_slice in zip (parquet_indexes , p_slices )
119
- for p_slice_start in range (p_slice [0 ], p_slice [1 ] + self .num_rows , self .num_rows )
120
- if p_slice_start < p_slice [1 ]
121
- ])
122
- else :
123
- workers_user_items [worker_idx % num_workers ].extend ([
124
- ParquetSlice (items [parquet_index ], * p_slice )
125
- for parquet_index , p_slice in zip (parquet_indexes , p_slices )
126
- ])
95
+ pq .write_table (table [start : end ], chunk_filepath )
127
96
128
- assert len (workers_user_items ) == num_workers
129
- assert all (len (w ) for w in workers_user_items )
97
+ print ("Finished resharding the parquet files for optimized processing." )
130
98
131
- return workers_user_items
99
+ return new_items
0 commit comments