Skip to content

Commit b0e1ee2

Browse files
authored
map operator: Add support for nested folders (#19366)
1 parent 37a521c commit b0e1ee2

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

src/lightning/data/streaming/data_processor.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,14 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
190190
s3 = S3Client()
191191

192192
while True:
193-
local_filepath: Optional[str] = upload_queue.get()
193+
data: Optional[Union[str, Tuple[str, str]]] = upload_queue.get()
194+
195+
tmpdir = None
196+
197+
if isinstance(data, str) or data is None:
198+
local_filepath = data
199+
else:
200+
tmpdir, local_filepath = data
194201

195202
# Terminate the process if we received a termination signal
196203
if local_filepath is None:
@@ -202,15 +209,25 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
202209

203210
if obj.scheme == "s3":
204211
try:
212+
if tmpdir is None:
213+
output_filepath = os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath))
214+
else:
215+
output_filepath = os.path.join(str(obj.path).lstrip("/"), local_filepath.replace(tmpdir, "")[1:])
216+
205217
s3.client.upload_file(
206218
local_filepath,
207219
obj.netloc,
208-
os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)),
220+
output_filepath,
209221
)
210222
except Exception as e:
211223
print(e)
212224
elif output_dir.path and os.path.isdir(output_dir.path):
213-
shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
225+
if tmpdir is None:
226+
shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
227+
else:
228+
output_filepath = os.path.join(output_dir.path, local_filepath.replace(tmpdir, "")[1:])
229+
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
230+
shutil.copyfile(local_filepath, output_filepath)
214231
else:
215232
raise ValueError(f"The provided {output_dir.path} isn't supported.")
216233

@@ -435,12 +452,15 @@ def _create_cache(self) -> None:
435452
)
436453
self.cache._reader._rank = _get_node_rank() * self.num_workers + self.worker_index
437454

438-
def _try_upload(self, filepath: Optional[str]) -> None:
439-
if not filepath or (self.output_dir.url if self.output_dir.url else self.output_dir.path) is None:
455+
def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None:
456+
if not data or (self.output_dir.url if self.output_dir.url else self.output_dir.path) is None:
440457
return
441458

442-
assert os.path.exists(filepath), filepath
443-
self.to_upload_queues[self._counter % self.num_uploaders].put(filepath)
459+
if isinstance(data, str):
460+
assert os.path.exists(data), data
461+
else:
462+
assert os.path.exists(data[-1]), data
463+
self.to_upload_queues[self._counter % self.num_uploaders].put(data)
444464

445465
def _collect_paths(self) -> None:
446466
if self.input_dir.path is None:
@@ -582,7 +602,7 @@ def _handle_data_transform_recipe(self, index: int) -> None:
582602
filepaths.append(os.path.join(directory, filename))
583603

584604
for filepath in filepaths:
585-
self._try_upload(filepath)
605+
self._try_upload((output_dir, filepath))
586606

587607

588608
class DataWorkerProcess(BaseWorker, Process):

tests/tests_data/streaming/test_data_processor.py

+20
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,26 @@ def test_data_processing_map_weights_mismatch(monkeypatch, tmpdir):
931931
map(map_fn_index, list(range(5)), output_dir=output_dir, num_workers=1, reorder_files=True, weights=[1])
932932

933933

934+
def map_fn_index_folder(index, output_dir):
935+
os.makedirs(os.path.join(output_dir, str(index)))
936+
with open(os.path.join(output_dir, str(index), f"{index}.JPEG"), "w") as f:
937+
f.write("Hello")
938+
939+
940+
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']")
941+
def test_data_processing_map_without_input_dir_and_folder(monkeypatch, tmpdir):
942+
cache_dir = os.path.join(tmpdir, "cache")
943+
output_dir = os.path.join(tmpdir, "target_dir")
944+
os.makedirs(output_dir, exist_ok=True)
945+
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
946+
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir)
947+
948+
map(map_fn_index_folder, list(range(5)), output_dir=output_dir, num_workers=1, reorder_files=True)
949+
950+
assert sorted(os.listdir(output_dir)) == ["0", "1", "2", "3", "4"]
951+
assert os.path.exists(os.path.join(output_dir, "0", "0.JPEG"))
952+
953+
934954
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
935955
def test_map_error_when_not_empty(monkeypatch, tmpdir):
936956
boto3 = mock.MagicMock()

0 commit comments

Comments
 (0)