Skip to content

Commit 4175e1a

Browse files
authored
Hot fix: Fix path resolution (#19508)
1 parent 39a86f8 commit 4175e1a

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

src/lightning/data/processing/data_processor.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,25 @@ def _get_item_filesizes(items: List[Any], base_path: str = "") -> List[int]:
338338
return item_sizes
339339

340340

341+
def _to_path(element: str) -> str:
342+
return element if _IS_IN_STUDIO and element.startswith("/teamspace") else str(Path(element).resolve())
343+
344+
345+
def _is_path(input_dir: Optional[str], element: Any) -> bool:
346+
if not isinstance(element, str):
347+
return False
348+
349+
if _IS_IN_STUDIO and input_dir is not None:
350+
if element.startswith(input_dir):
351+
return True
352+
353+
element = str(Path(element).absolute())
354+
if element.startswith(input_dir):
355+
return True
356+
357+
return os.path.exists(element)
358+
359+
341360
class BaseWorker:
342361
def __init__(
343362
self,
@@ -380,7 +399,6 @@ def __init__(
380399
self.remove_queue: Queue = Queue()
381400
self.progress_queue: Queue = progress_queue
382401
self.error_queue: Queue = error_queue
383-
self._collected_items = 0
384402
self._counter = 0
385403
self._last_time = time()
386404
self._index_counter = 0
@@ -503,22 +521,13 @@ def _collect_paths(self) -> None:
503521
for item in self.items:
504522
flattened_item, spec = tree_flatten(item)
505523

506-
def is_path(element: Any) -> bool:
507-
if not isinstance(element, str):
508-
return False
509-
510-
element: str = str(Path(element).resolve())
511-
if _IS_IN_STUDIO and self.input_dir.path is not None:
512-
if self.input_dir.path.startswith("/teamspace/studios/this_studio"):
513-
return os.path.exists(element)
514-
return element.startswith(self.input_dir.path)
515-
return os.path.exists(element)
516-
517524
# For speed reasons, we assume starting with `self.input_dir` is enough to be a real file.
518525
# Other alternative would be too slow.
519526
# TODO: Try using dictionary for higher accurary.
520527
indexed_paths = {
521-
index: str(Path(element).resolve()) for index, element in enumerate(flattened_item) if is_path(element)
528+
index: _to_path(element)
529+
for index, element in enumerate(flattened_item)
530+
if _is_path(self.input_dir.path, element)
522531
}
523532

524533
if len(indexed_paths) == 0:
@@ -536,7 +545,6 @@ def is_path(element: Any) -> bool:
536545
self.paths.append(paths)
537546

538547
items.append(tree_unflatten(flattened_item, spec))
539-
self._collected_items += 1
540548

541549
self.items = items
542550

tests/tests_data/processing/test_data_processor.py

+23
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
DataTransformRecipe,
2020
_download_data_target,
2121
_get_item_filesizes,
22+
_is_path,
2223
_map_items_to_workers_sequentially,
2324
_map_items_to_workers_weighted,
2425
_remove_target,
26+
_to_path,
2527
_upload_fn,
2628
_wait_for_disk_usage_higher_than_threshold,
2729
_wait_for_file_to_exist,
@@ -1135,3 +1137,24 @@ def test_load_torch_audio_from_wav_file(tmpdir, compression):
11351137
tensor = torchaudio.load(sample)
11361138
assert tensor[0].shape == torch.Size([1, 16000])
11371139
assert tensor[1] == 16000
1140+
1141+
1142+
def test_is_path_valid_in_studio(monkeypatch, tmpdir):
1143+
filepath = os.path.join(tmpdir, "a.png")
1144+
with open(filepath, "w") as f:
1145+
f.write("Hello World")
1146+
1147+
monkeypatch.setattr(data_processor_module, "_IS_IN_STUDIO", True)
1148+
1149+
assert _is_path("/teamspace/studios/this_studio", "/teamspace/studios/this_studio/a.png")
1150+
assert _is_path("/teamspace/studios/this_studio", filepath)
1151+
1152+
1153+
@pytest.mark.skipif(sys.platform == "win32", reason="skip windows")
1154+
def test_to_path(tmpdir):
1155+
filepath = os.path.join(tmpdir, "a.png")
1156+
with open(filepath, "w") as f:
1157+
f.write("Hello World")
1158+
1159+
assert _to_path("/teamspace/studios/this_studio/a.png") == "/teamspace/studios/this_studio/a.png"
1160+
assert _to_path(filepath) == filepath

0 commit comments

Comments
 (0)