Skip to content

Commit 6c9ab82

Browse files
authored
Merge branch 'master' into 19433_typing_optional_str_in_csv_logger
2 parents 4528025 + a41528c commit 6c9ab82

File tree

20 files changed

+185
-106
lines changed

20 files changed

+185
-106
lines changed

docs/source-pytorch/cli/lightning_cli_advanced_3.rst

+10
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ Since the init parameters of the model have as a type hint a class, in the confi
197197
decoder: Instance of a module for decoding
198198
"""
199199
super().__init__()
200+
self.save_hyperparameters()
200201
self.encoder = encoder
201202
self.decoder = decoder
202203

@@ -216,6 +217,13 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou
216217
217218
It is also possible to combine ``subclass_mode_model=True`` and submodules, thereby having two levels of ``class_path``.
218219

220+
.. tip::
221+
222+
By having ``self.save_hyperparameters()`` it becomes possible to load the model from a checkpoint. Simply do
223+
``ModelClass.load_from_checkpoint("path/to/checkpoint.ckpt")``. In the case of using ``subclass_mode_model=True``,
224+
then load it like ``LightningModule.load_from_checkpoint("path/to/checkpoint.ckpt")``. ``save_hyperparameters`` is
225+
optional and can be safely removed if there is no need to load from a checkpoint.
226+
219227

220228
Fixed optimizer and scheduler
221229
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -279,6 +287,7 @@ An example of a model that uses two optimizers is the following:
279287
class MyModel(LightningModule):
280288
def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable):
281289
super().__init__()
290+
self.save_hyperparameters()
282291
self.optimizer1 = optimizer1
283292
self.optimizer2 = optimizer2
284293
@@ -318,6 +327,7 @@ that uses dependency injection for an optimizer and a learning scheduler is:
318327
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
319328
):
320329
super().__init__()
330+
self.save_hyperparameters()
321331
self.optimizer = optimizer
322332
self.scheduler = scheduler
323333

requirements/pytorch/extra.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
matplotlib>3.1, <3.9.0
66
omegaconf >=2.0.5, <2.4.0
77
hydra-core >=1.0.5, <1.4.0
8-
jsonargparse[signatures] >=4.26.1, <4.28.0
8+
jsonargparse[signatures] >=4.27.5, <4.28.0
99
rich >=12.3.0, <13.6.0
1010
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
1111
bitsandbytes ==0.41.0 # strict

src/lightning/pytorch/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- The `ModelSummary` and `RichModelSummary` callbacks now display the training mode of each layer in the column "Mode" ([#19468](https://github.com/Lightning-AI/lightning/pull/19468))
1212

13+
- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105))
14+
1315
-
1416

1517
-
@@ -64,6 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6466
- Added shortcut name `strategy='deepspeed_stage_1_offload'` to the strategy registry ([#19075](https://github.com/Lightning-AI/lightning/pull/19075))
6567
- Added support for non-strict state-dict loading in Trainer via the new `LightningModule.strict_loading = True | False` attribute ([#19404](https://github.com/Lightning-AI/lightning/pull/19404))
6668

69+
6770
### Changed
6871

6972
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))

src/lightning/pytorch/cli.py

+54-2
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
import os
1516
import sys
1617
from functools import partial, update_wrapper
1718
from types import MethodType
18-
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
19+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union
1920

2021
import torch
22+
import yaml
2123
from lightning_utilities.core.imports import RequirementCache
2224
from lightning_utilities.core.rank_zero import _warn
2325
from torch.optim import Optimizer
@@ -27,11 +29,12 @@
2729
from lightning.fabric.utilities.cloud_io import get_filesystem
2830
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
2931
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, seed_everything
32+
from lightning.pytorch.core.mixins.hparams_mixin import _given_hyperparameters_context
3033
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3134
from lightning.pytorch.utilities.model_helpers import is_overridden
3235
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
3336

34-
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.26.1")
37+
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.5")
3538

3639
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
3740
import docstring_parser
@@ -50,6 +53,8 @@
5053
locals()["ArgumentParser"] = object
5154
locals()["Namespace"] = object
5255

56+
ModuleType = TypeVar("ModuleType")
57+
5358

5459
class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
5560
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
@@ -381,6 +386,7 @@ def __init__(
381386

382387
self._set_seed()
383388

389+
self._add_instantiators()
384390
self.before_instantiate_classes()
385391
self.instantiate_classes()
386392

@@ -527,6 +533,22 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
527533
else:
528534
self.config = parser.parse_args(args)
529535

536+
def _add_instantiators(self) -> None:
537+
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False))
538+
if "subcommand" in self.config:
539+
self.config_dump = self.config_dump[self.config.subcommand]
540+
541+
self.parser.add_instantiator(
542+
_InstantiatorFn(cli=self, key="model"),
543+
_get_module_type(self._model_class),
544+
subclasses=self.subclass_mode_model,
545+
)
546+
self.parser.add_instantiator(
547+
_InstantiatorFn(cli=self, key="data"),
548+
_get_module_type(self._datamodule_class),
549+
subclasses=self.subclass_mode_data,
550+
)
551+
530552
def before_instantiate_classes(self) -> None:
531553
"""Implement to run some code before instantiating the classes."""
532554

@@ -755,3 +777,33 @@ def _get_short_description(component: object) -> Optional[str]:
755777
return docstring.short_description
756778
except (ValueError, docstring_parser.ParseError) as ex:
757779
rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")
780+
781+
782+
def _get_module_type(value: Union[Callable, type]) -> type:
783+
if callable(value) and not isinstance(value, type):
784+
return inspect.signature(value).return_annotation
785+
return value
786+
787+
788+
class _InstantiatorFn:
789+
def __init__(self, cli: LightningCLI, key: str) -> None:
790+
self.cli = cli
791+
self.key = key
792+
793+
def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
794+
with _given_hyperparameters_context(
795+
hparams=self.cli.config_dump.get(self.key, {}),
796+
instantiator="lightning.pytorch.cli.instantiate_module",
797+
):
798+
return class_type(*args, **kwargs)
799+
800+
801+
def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType:
802+
parser = ArgumentParser(exit_on_error=False)
803+
if "class_path" in config:
804+
parser.add_subclass_arguments(class_type, "module")
805+
else:
806+
parser.add_class_arguments(class_type, "module")
807+
cfg = parser.parse_object({"module": config})
808+
init = parser.instantiate_classes(cfg)
809+
return init.module

src/lightning/pytorch/core/mixins/hparams_mixin.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import inspect
1616
import types
1717
from argparse import Namespace
18-
from typing import Any, List, MutableMapping, Optional, Sequence, Union
18+
from contextlib import contextmanager
19+
from contextvars import ContextVar
20+
from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union
1921

2022
from lightning.fabric.utilities.data import AttributeDict
2123
from lightning.pytorch.utilities.parsing import save_hyperparameters
@@ -24,6 +26,20 @@
2426
_ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
2527

2628

29+
_given_hyperparameters: ContextVar = ContextVar("_given_hyperparameters", default=None)
30+
31+
32+
@contextmanager
33+
def _given_hyperparameters_context(hparams: dict, instantiator: str) -> Iterator[None]:
34+
hparams = hparams.copy()
35+
hparams["_instantiator"] = instantiator
36+
token = _given_hyperparameters.set(hparams)
37+
try:
38+
yield
39+
finally:
40+
_given_hyperparameters.reset(token)
41+
42+
2743
class HyperparametersMixin:
2844
__jit_unused_properties__: List[str] = ["hparams", "hparams_initial"]
2945

@@ -105,12 +121,13 @@ class ``__init__`` to be ignored
105121
106122
"""
107123
self._log_hyperparams = logger
124+
given_hparams = _given_hyperparameters.get()
108125
# the frame needs to be created in this file.
109-
if not frame:
126+
if given_hparams is None and not frame:
110127
current_frame = inspect.currentframe()
111128
if current_frame:
112129
frame = current_frame.f_back
113-
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
130+
save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams)
114131

115132
def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
116133
hp = self._to_hparams_dict(hp)

src/lightning/pytorch/core/saving.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,18 @@ def _load_state(
151151
_cls_kwargs.update(cls_kwargs_loaded)
152152
_cls_kwargs.update(cls_kwargs_new)
153153

154+
instantiator = None
155+
instantiator_path = _cls_kwargs.pop("_instantiator", None)
156+
if instantiator_path is not None:
157+
# import custom instantiator
158+
module_path, name = instantiator_path.rsplit(".", 1)
159+
instantiator = getattr(__import__(module_path, fromlist=[name]), name)
160+
154161
if not cls_spec.varkw:
155162
# filter kwargs according to class init unless it allows any argument via kwargs
156163
_cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
157164

158-
obj = cls(**_cls_kwargs)
165+
obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs)
159166

160167
if isinstance(obj, pl.LightningDataModule):
161168
if obj.__class__.__qualname__ in checkpoint:

src/lightning/pytorch/utilities/parsing.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,11 @@ def collect_init_args(
140140

141141

142142
def save_hyperparameters(
143-
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
143+
obj: Any,
144+
*args: Any,
145+
ignore: Optional[Union[Sequence[str], str]] = None,
146+
frame: Optional[types.FrameType] = None,
147+
given_hparams: Optional[Dict[str, Any]] = None,
144148
) -> None:
145149
"""See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`"""
146150

@@ -156,7 +160,9 @@ def save_hyperparameters(
156160
if not isinstance(frame, types.FrameType):
157161
raise AttributeError("There is no `frame` available while being required.")
158162

159-
if is_dataclass(obj):
163+
if given_hparams is not None:
164+
init_args = given_hparams
165+
elif is_dataclass(obj):
160166
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
161167
else:
162168
init_args = {}

tests/tests_fabric/plugins/precision/test_amp_integration.py

-8
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,10 @@
1313
# limitations under the License.
1414
"""Integration tests for Automatic Mixed Precision (AMP) training."""
1515

16-
import sys
17-
1816
import pytest
1917
import torch
2018
import torch.nn as nn
2119
from lightning.fabric import Fabric, seed_everything
22-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2320

2421
from tests_fabric.helpers.runif import RunIf
2522

@@ -41,11 +38,6 @@ def forward(self, x):
4138
return output
4239

4340

44-
@pytest.mark.xfail(
45-
# https://github.com/pytorch/pytorch/issues/116056
46-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
47-
reason="Windows + DDP issue in PyTorch 2.2",
48-
)
4941
@pytest.mark.parametrize(
5042
("accelerator", "precision", "expected_dtype"),
5143
[

tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py

-7
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import sys
1514

1615
import pytest
1716
import torch
1817
import torch.nn as nn
1918
from lightning.fabric import Fabric
20-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2119

2220
from tests_fabric.helpers.runif import RunIf
2321

@@ -31,11 +29,6 @@ def __init__(self):
3129
self.register_buffer("buffer", torch.ones(3))
3230

3331

34-
@pytest.mark.xfail(
35-
# https://github.com/pytorch/pytorch/issues/116056
36-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
37-
reason="Windows + DDP issue in PyTorch 2.2",
38-
)
3932
@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))])
4033
def test_memory_sharing_disabled(strategy):
4134
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race

tests/tests_fabric/strategies/test_ddp_integration.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
import sys
1615
from copy import deepcopy
1716
from unittest import mock
1817
from unittest.mock import Mock
1918

2019
import pytest
2120
import torch
2221
from lightning.fabric import Fabric
23-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_2
22+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
2423
from torch.nn.parallel.distributed import DistributedDataParallel
2524

2625
from tests_fabric.helpers.runif import RunIf
2726
from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
2827
from tests_fabric.test_fabric import BoringModel
2928

3029

31-
@pytest.mark.xfail(
32-
# https://github.com/pytorch/pytorch/issues/116056
33-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
34-
reason="Windows + DDP issue in PyTorch 2.2",
35-
)
3630
@pytest.mark.parametrize(
3731
"accelerator",
3832
[

tests/tests_fabric/utilities/test_distributed.py

-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import os
3-
import sys
43
from functools import partial
54
from pathlib import Path
65
from unittest import mock
@@ -19,7 +18,6 @@
1918
_sync_ddp,
2019
is_shared_filesystem,
2120
)
22-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2321

2422
from tests_fabric.helpers.runif import RunIf
2523

@@ -121,11 +119,6 @@ def test_collective_operations(devices, process):
121119
spawn_launch(process, devices)
122120

123121

124-
@pytest.mark.xfail(
125-
# https://github.com/pytorch/pytorch/issues/116056
126-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
127-
reason="Windows + DDP issue in PyTorch 2.2",
128-
)
129122
@pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO)
130123
def test_is_shared_filesystem(tmp_path, monkeypatch):
131124
# In the non-distributed case, every location is interpreted as 'shared'

tests/tests_fabric/utilities/test_spike.py

-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pytest
55
import torch
66
from lightning.fabric import Fabric
7-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
87
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException
98

109

@@ -29,11 +28,6 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
2928
)
3029

3130

32-
@pytest.mark.xfail(
33-
# https://github.com/pytorch/pytorch/issues/116056
34-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
35-
reason="Windows + DDP issue in PyTorch 2.2",
36-
)
3731
@pytest.mark.flaky(max_runs=3)
3832
@pytest.mark.parametrize(
3933
("global_rank_spike", "num_devices", "spike_value", "finite_only"),

0 commit comments

Comments
 (0)