Skip to content

Commit c10fd22

Browse files
tchatonthomas
and
thomas
authored
BC: Switch map operator arguments order (#19345)
update Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
1 parent 012f68d commit c10fd22

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

src/lightning/data/streaming/data_processor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def _handle_data_chunk_recipe_end(self) -> None:
563563
def _handle_data_transform_recipe(self, index: int) -> None:
564564
# Don't use a context manager to avoid deleting files that are being uploaded.
565565
output_dir = tempfile.mkdtemp()
566-
item_data = self.data_recipe.prepare_item(str(output_dir), self.items[index])
566+
item_data = self.data_recipe.prepare_item(self.items[index], str(output_dir))
567567
if item_data is not None:
568568
raise ValueError(
569569
"When using a `DataTransformRecipe`, the `prepare_item` shouldn't return anything."
@@ -753,7 +753,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]:
753753
"""
754754

755755
@abstractmethod
756-
def prepare_item(self, output_dir: str, item_metadata: T) -> None: # type: ignore
756+
def prepare_item(self, item_metadata: T, output_dir: str) -> None: # type: ignore
757757
"""Use your item metadata to process your files and save the file outputs into `output_dir`."""
758758

759759

src/lightning/data/streaming/functions.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,24 @@ def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
8282
params = inspect.signature(_fn).parameters
8383
self._contains_device = "device" in params
8484

85-
def prepare_structure(self, input_dir: Optional[str]) -> Any:
85+
def prepare_structure(self, _: Optional[str]) -> Any:
8686
return self._inputs
8787

88-
def prepare_item(self, output_dir: str, item_metadata: Any) -> None: # type: ignore
88+
def prepare_item(self, item_metadata: Any, output_dir: str) -> None: # type: ignore
8989
if self._contains_device and self._device is None:
9090
self._find_device()
9191

9292
if isinstance(self._fn, (FunctionType, partial)):
9393
if self._contains_device:
94-
self._fn(output_dir, item_metadata, self._device)
94+
self._fn(item_metadata, output_dir, self._device)
9595
else:
96-
self._fn(output_dir, item_metadata)
96+
self._fn(item_metadata, output_dir)
9797

9898
elif callable(self._fn):
9999
if self._contains_device:
100-
self._fn.__call__(output_dir, item_metadata, self._device) # type: ignore
100+
self._fn.__call__(item_metadata, output_dir, self._device) # type: ignore
101101
else:
102-
self._fn.__call__(output_dir, item_metadata) # type: ignore
102+
self._fn.__call__(item_metadata, output_dir) # type: ignore
103103
else:
104104
raise ValueError(f"The provided {self._fn} isn't supported.")
105105

tests/tests_data/streaming/test_data_processor.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def prepare_structure(self, input_dir: str):
581581
filepaths = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)]
582582
return [filepath for filepath in filepaths if os.path.isfile(filepath)]
583583

584-
def prepare_item(self, output_dir: str, filepath: Any) -> None:
584+
def prepare_item(self, filepath: Any, output_dir: str) -> None:
585585
from PIL import Image
586586

587587
img = Image.open(filepath)
@@ -628,7 +628,7 @@ def test_data_process_transform(monkeypatch, tmpdir):
628628
assert img.size == (12, 12)
629629

630630

631-
def map_fn(output_dir, filepath):
631+
def map_fn(filepath, output_dir):
632632
from PIL import Image
633633

634634
img = Image.open(filepath)
@@ -833,7 +833,7 @@ def fn(output_dir, item, device):
833833

834834
data_recipe = LambdaDataTransformRecipe(fn, range(1))
835835

836-
data_recipe.prepare_item("", 1)
836+
data_recipe.prepare_item(1, "")
837837
assert called
838838

839839

@@ -847,13 +847,13 @@ def test_lambda_transform_recipe_class(monkeypatch):
847847
called = False
848848

849849
class Transform:
850-
def __call__(self, output_dir, item, device):
850+
def __call__(self, item, output_dir, device):
851851
nonlocal called
852852
assert device == "cuda:2"
853853
called = True
854854

855855
data_recipe = LambdaDataTransformRecipe(Transform(), range(1))
856-
data_recipe.prepare_item("", 1)
856+
data_recipe.prepare_item(1, "")
857857
assert called
858858

859859

@@ -894,7 +894,7 @@ def test_get_item_filesizes(tmp_path):
894894
_get_item_filesizes([str(tmp_path / "empty_file")])
895895

896896

897-
def map_fn_index(output_dir, index):
897+
def map_fn_index(index, output_dir):
898898
with open(os.path.join(output_dir, f"{index}.JPEG"), "w") as f:
899899
f.write("Hello")
900900

0 commit comments

Comments
 (0)