Skip to content

Commit 9709c64

Browse files
MrWhatZitToYaalantigapre-commit-ci[bot]
authored
Add str method to datamodule (#20301)
* Add feature implementation to datamodule for str method First implementation scetch * Removed list / tuple case for datamodule str method * Added test cases for DataModule string function Added alternative Boring Data Module implementations Added test cases for all possible options Added additional check for NotImplementedError in string function of DataModule * Reverted accidental changes in DataModule * Updated dataloader str method Made changes to comply with requested suggestions Switched from hardcoded \n to more general os.linesep * Improvements to implementation of str method for datamodule Corrected the annotation for the internal function and the list that is suppsoed to store the information on the datasets * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Implementing str method for datamodule Fixed type annotation issue Reduced code size by using Sized object from abc library * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add string method to datamodule Switched from Dataset based implementation to Dataloader based implementation * Implementing str mehtod for dataloader Added missing size value to tuple in the error case instead of returning only a string * Implementing str fucntion for datamodule Adjusted test to match the new implementation requirenemnts Added necessary BoringModules for tests Fixed bugs and annotation issues in the str method * Implementing str method for datamodule Refactored code and made it more readable by implementing more abstarct fucntion methods Adjusted tests Removed debug statements Removed TODO comments * Finilized required adjustments for dataloader string proposal method Renamed varaibles to more sensible names to increase readability * Implementing str method Switched name from dataset to dataloader Switched name Prediction to Predict removed available keyword and instead write None if not available Switched from unknown to NA * Update src/lightning/pytorch/core/datamodule.py * Update src/lightning/pytorch/core/datamodule.py --------- Co-authored-by: Luca Antiga <luca.antiga@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Antiga <luca@lightning.ai>
1 parent 1e32ebf commit 9709c64

File tree

3 files changed

+267
-3
lines changed

3 files changed

+267
-3
lines changed

src/lightning/pytorch/core/datamodule.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
"""LightningDataModule for loading DataLoaders with ease."""
1515

1616
import inspect
17-
from collections.abc import Iterable
17+
import os
18+
from collections.abc import Iterable, Sized
1819
from typing import IO, Any, Optional, Union, cast
1920

2021
from lightning_utilities import apply_to_collection
@@ -244,3 +245,75 @@ def load_from_checkpoint(
244245
**kwargs,
245246
)
246247
return cast(Self, loaded)
248+
249+
def __str__(self) -> str:
250+
"""Return a string representation of the datasets that are set up.
251+
252+
Returns:
253+
A string representation of the datasets that are setup.
254+
255+
"""
256+
257+
class dataset_info:
258+
def __init__(self, available: bool, length: str) -> None:
259+
self.available = available
260+
self.length = length
261+
262+
def retrieve_dataset_info(loader: DataLoader) -> dataset_info:
263+
"""Helper function to compute dataset information."""
264+
dataset = loader.dataset
265+
size: str = str(len(dataset)) if isinstance(dataset, Sized) else "NA"
266+
267+
return dataset_info(True, size)
268+
269+
def loader_info(
270+
loader: Union[DataLoader, Iterable[DataLoader]],
271+
) -> Union[dataset_info, Iterable[dataset_info]]:
272+
"""Helper function to compute dataset information."""
273+
return apply_to_collection(loader, DataLoader, retrieve_dataset_info)
274+
275+
def extract_loader_info(methods: list[tuple[str, str]]) -> dict:
276+
"""Helper function to extract information for each dataloader method."""
277+
info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {}
278+
for loader_name, func_name in methods:
279+
loader_method = getattr(self, func_name, None)
280+
281+
try:
282+
loader = loader_method() # type: ignore
283+
info[loader_name] = loader_info(loader)
284+
except Exception:
285+
info[loader_name] = dataset_info(False, "")
286+
287+
return info
288+
289+
def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info]]]) -> str:
290+
"""Helper function to format loader information."""
291+
output = []
292+
for loader_name, loader_info in info.items():
293+
# Single dataset
294+
if isinstance(loader_info, dataset_info):
295+
loader_info_formatted = "None" if not loader_info.available else f"size={loader_info.length}"
296+
# Iterable of datasets
297+
else:
298+
loader_info_formatted = " ; ".join(
299+
"None" if not loader_info_i.available else f"{i}. size={loader_info_i.length}"
300+
for i, loader_info_i in enumerate(loader_info, start=1)
301+
)
302+
303+
output.append(f"{{{loader_name}: {loader_info_formatted}}}")
304+
305+
return os.linesep.join(output)
306+
307+
# Available dataloader methods
308+
datamodule_loader_methods: list[tuple[str, str]] = [
309+
("Train dataloader", "train_dataloader"),
310+
("Validation dataloader", "val_dataloader"),
311+
("Test dataloader", "test_dataloader"),
312+
("Predict dataloader", "predict_dataloader"),
313+
]
314+
315+
# Retrieve information for each dataloader method
316+
dataloader_info = extract_loader_info(datamodule_loader_methods)
317+
# Format the information
318+
dataloader_str = format_loader_info(dataloader_info)
319+
return dataloader_str

src/lightning/pytorch/demos/boring_classes.py

+82-1
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from collections.abc import Iterator
14+
from collections.abc import Iterable, Iterator
1515
from typing import Any, Optional
1616

1717
import torch
1818
import torch.nn as nn
1919
import torch.nn.functional as F
20+
from lightning_utilities import apply_to_collection
2021
from torch import Tensor
2122
from torch.optim import Optimizer
2223
from torch.optim.lr_scheduler import LRScheduler
@@ -188,6 +189,86 @@ def predict_dataloader(self) -> DataLoader:
188189
return DataLoader(self.random_predict)
189190

190191

192+
class BoringDataModuleNoLen(LightningDataModule):
193+
"""
194+
.. warning:: This is meant for testing/debugging and is experimental.
195+
"""
196+
197+
def __init__(self) -> None:
198+
super().__init__()
199+
200+
def setup(self, stage: str) -> None:
201+
if stage == "fit":
202+
self.random_train = RandomIterableDataset(32, 512)
203+
204+
if stage in ("fit", "validate"):
205+
self.random_val = RandomIterableDataset(32, 128)
206+
207+
if stage == "test":
208+
self.random_test = RandomIterableDataset(32, 256)
209+
210+
if stage == "predict":
211+
self.random_predict = RandomIterableDataset(32, 64)
212+
213+
def train_dataloader(self) -> DataLoader:
214+
return DataLoader(self.random_train)
215+
216+
def val_dataloader(self) -> DataLoader:
217+
return DataLoader(self.random_val)
218+
219+
def test_dataloader(self) -> DataLoader:
220+
return DataLoader(self.random_test)
221+
222+
def predict_dataloader(self) -> DataLoader:
223+
return DataLoader(self.random_predict)
224+
225+
226+
class IterableBoringDataModule(LightningDataModule):
227+
def __init__(self) -> None:
228+
super().__init__()
229+
230+
def setup(self, stage: str) -> None:
231+
if stage == "fit":
232+
self.train_datasets = [
233+
RandomDataset(4, 16),
234+
RandomIterableDataset(4, 16),
235+
]
236+
237+
if stage in ("fit", "validate"):
238+
self.val_datasets = [
239+
RandomDataset(4, 32),
240+
RandomIterableDataset(4, 32),
241+
]
242+
243+
if stage == "test":
244+
self.test_datasets = [
245+
RandomDataset(4, 64),
246+
RandomIterableDataset(4, 64),
247+
]
248+
249+
if stage == "predict":
250+
self.predict_datasets = [
251+
RandomDataset(4, 128),
252+
RandomIterableDataset(4, 128),
253+
]
254+
255+
def train_dataloader(self) -> Iterable[DataLoader]:
256+
combined_train = apply_to_collection(self.train_datasets, Dataset, lambda x: DataLoader(x))
257+
return combined_train
258+
259+
def val_dataloader(self) -> DataLoader:
260+
combined_val = apply_to_collection(self.val_datasets, Dataset, lambda x: DataLoader(x))
261+
return combined_val
262+
263+
def test_dataloader(self) -> DataLoader:
264+
combined_test = apply_to_collection(self.test_datasets, Dataset, lambda x: DataLoader(x))
265+
return combined_test
266+
267+
def predict_dataloader(self) -> DataLoader:
268+
combined_predict = apply_to_collection(self.predict_datasets, Dataset, lambda x: DataLoader(x))
269+
return combined_predict
270+
271+
191272
class ManualOptimBoringModel(BoringModel):
192273
"""
193274
.. warning:: This is meant for testing/debugging and is experimental.

tests/tests_pytorch/core/test_datamodules.py

+111-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
import pickle
1516
from argparse import Namespace
1617
from dataclasses import dataclass
@@ -22,7 +23,12 @@
2223
import torch
2324
from lightning.pytorch import LightningDataModule, Trainer, seed_everything
2425
from lightning.pytorch.callbacks import ModelCheckpoint
25-
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
26+
from lightning.pytorch.demos.boring_classes import (
27+
BoringDataModule,
28+
BoringDataModuleNoLen,
29+
BoringModel,
30+
IterableBoringDataModule,
31+
)
2632
from lightning.pytorch.profilers.simple import SimpleProfiler
2733
from lightning.pytorch.trainer.states import TrainerFn
2834
from lightning.pytorch.utilities import AttributeDict
@@ -510,3 +516,107 @@ def prepare_data(self):
510516
durations = profiler.recorded_durations[key]
511517
assert len(durations) == 1
512518
assert durations[0] > 0
519+
520+
521+
def test_datamodule_string_not_available():
522+
dm = BoringDataModule()
523+
524+
expected_output = (
525+
f"{{Train dataloader: None}}{os.linesep}"
526+
f"{{Validation dataloader: None}}{os.linesep}"
527+
f"{{Test dataloader: None}}{os.linesep}"
528+
f"{{Predict dataloader: None}}"
529+
)
530+
out = str(dm)
531+
532+
assert out == expected_output
533+
534+
535+
def test_datamodule_string_fit_setup():
536+
dm = BoringDataModule()
537+
dm.setup(stage="fit")
538+
539+
expected_output = (
540+
f"{{Train dataloader: size=64}}{os.linesep}"
541+
f"{{Validation dataloader: size=64}}{os.linesep}"
542+
f"{{Test dataloader: None}}{os.linesep}"
543+
f"{{Predict dataloader: None}}"
544+
)
545+
output = str(dm)
546+
547+
assert expected_output == output
548+
549+
550+
def test_datamodule_string_validation_setup():
551+
dm = BoringDataModule()
552+
dm.setup(stage="validate")
553+
554+
expected_output = (
555+
f"{{Train dataloader: None}}{os.linesep}"
556+
f"{{Validation dataloader: size=64}}{os.linesep}"
557+
f"{{Test dataloader: None}}{os.linesep}"
558+
f"{{Predict dataloader: None}}"
559+
)
560+
output = str(dm)
561+
562+
assert expected_output == output
563+
564+
565+
def test_datamodule_string_test_setup():
566+
dm = BoringDataModule()
567+
dm.setup(stage="test")
568+
569+
expected_output = (
570+
f"{{Train dataloader: None}}{os.linesep}"
571+
f"{{Validation dataloader: None}}{os.linesep}"
572+
f"{{Test dataloader: size=64}}{os.linesep}"
573+
f"{{Predict dataloader: None}}"
574+
)
575+
output = str(dm)
576+
577+
assert expected_output == output
578+
579+
580+
def test_datamodule_string_predict_setup():
581+
dm = BoringDataModule()
582+
dm.setup(stage="predict")
583+
584+
expected_output = (
585+
f"{{Train dataloader: None}}{os.linesep}"
586+
f"{{Validation dataloader: None}}{os.linesep}"
587+
f"{{Test dataloader: None}}{os.linesep}"
588+
f"{{Predict dataloader: size=64}}"
589+
)
590+
output = str(dm)
591+
592+
assert expected_output == output
593+
594+
595+
def test_datamodule_string_no_len():
596+
dm = BoringDataModuleNoLen()
597+
dm.setup("fit")
598+
599+
expected_output = (
600+
f"{{Train dataloader: size=NA}}{os.linesep}"
601+
f"{{Validation dataloader: size=NA}}{os.linesep}"
602+
f"{{Test dataloader: None}}{os.linesep}"
603+
f"{{Predict dataloader: None}}"
604+
)
605+
output = str(dm)
606+
607+
assert output == expected_output
608+
609+
610+
def test_datamodule_string_iterable():
611+
dm = IterableBoringDataModule()
612+
dm.setup("fit")
613+
614+
expected_output = (
615+
f"{{Train dataloader: 1. size=16 ; 2. size=NA}}{os.linesep}"
616+
f"{{Validation dataloader: 1. size=32 ; 2. size=NA}}{os.linesep}"
617+
f"{{Test dataloader: None}}{os.linesep}"
618+
f"{{Predict dataloader: None}}"
619+
)
620+
output = str(dm)
621+
622+
assert output == expected_output

0 commit comments

Comments
 (0)