Skip to content

Commit d08e6cd

Browse files
authored
Add walk operator (#19333)
1 parent d02009a commit d08e6cd

File tree

3 files changed

+69
-3
lines changed

3 files changed

+69
-3
lines changed

src/lightning/data/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from lightning.data.streaming.combined import CombinedStreamingDataset
22
from lightning.data.streaming.dataloader import StreamingDataLoader
33
from lightning.data.streaming.dataset import StreamingDataset
4-
from lightning.data.streaming.functions import map, optimize
4+
from lightning.data.streaming.functions import map, optimize, walk
55

66
__all__ = [
77
"LightningDataset",
@@ -11,4 +11,5 @@
1111
"LightningIterableDataset",
1212
"map",
1313
"optimize",
14+
"walk",
1415
]

src/lightning/data/streaming/functions.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import concurrent.futures
1415
import inspect
1516
import os
1617
from datetime import datetime
1718
from functools import partial
1819
from pathlib import Path
1920
from types import FunctionType
20-
from typing import Any, Callable, Dict, Optional, Sequence, Union
21+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
2122

2223
import torch
2324

@@ -286,3 +287,51 @@ def optimize(
286287
num_nodes,
287288
machine,
288289
)
290+
291+
292+
def _listdir(folder: str) -> Tuple[str, List[str]]:
293+
return folder, os.listdir(folder)
294+
295+
296+
class walk:
297+
"""This class is an optimized version of os.walk for listing files and folders from cloud filesystem.
298+
299+
Note: The order of files and folders yielded aren't depth-first anymore due to the asynchronous listing call.
300+
301+
"""
302+
303+
def __init__(self, folder: str, max_workers: Optional[int] = os.cpu_count()) -> None:
304+
self.folders = [folder]
305+
self.max_workers = max_workers or 1
306+
self.futures: List[concurrent.futures.Future] = []
307+
308+
def __iter__(self) -> Any:
309+
"""This function queues the folders to perform listdir across multiple workers."""
310+
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
311+
while len(self.folders):
312+
folder = self.folders.pop(0)
313+
future = executor.submit(_listdir, folder)
314+
self.futures.append(future)
315+
316+
while self.futures:
317+
for future in concurrent.futures.as_completed(self.futures):
318+
filenames = []
319+
folders = []
320+
321+
folder, files_or_folders = future.result()
322+
self.futures = [f for f in self.futures if f != future]
323+
324+
for file_or_folder in files_or_folders:
325+
if os.path.isfile(os.path.join(folder, file_or_folder)):
326+
filenames.append(file_or_folder)
327+
else:
328+
folders.append(file_or_folder)
329+
self.folders.append(os.path.join(folder, file_or_folder))
330+
331+
yield folder, folders, filenames
332+
333+
while len(self.folders) and len(self.futures) <= self.max_workers * 2:
334+
folder = self.folders.pop(0)
335+
future = executor.submit(_listdir, folder)
336+
self.futures.append(future)
337+
return

tests/tests_data/streaming/test_functions.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import os
12
import sys
23
from unittest import mock
34

45
import pytest
5-
from lightning.data.streaming.functions import _get_input_dir, os
6+
from lightning.data import walk
7+
from lightning.data.streaming.functions import _get_input_dir
68

79

810
@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
@@ -19,3 +21,17 @@ def fn(*_, **__):
1921

2022
with pytest.raises(ValueError, match="The provided item didn't contain any filepaths."):
2123
assert _get_input_dir(["", "/teamspace/studios/asd/b"])
24+
25+
26+
def test_walk(tmpdir):
27+
for i in range(5):
28+
folder_path = os.path.join(tmpdir, str(i))
29+
os.makedirs(folder_path, exist_ok=True)
30+
for j in range(5):
31+
filepath = os.path.join(folder_path, f"{j}.txt")
32+
with open(filepath, "w") as f:
33+
f.write("hello world !")
34+
35+
walks_os = sorted(os.walk(tmpdir))
36+
walks_function = sorted(walk(tmpdir))
37+
assert walks_os == walks_function

0 commit comments

Comments
 (0)