Skip to content

Commit 623ec58

Browse files
authored
load_from_checkpoint support for LightningCLI when using dependency injection (#18105)
1 parent a6273d1 commit 623ec58

File tree

8 files changed

+182
-9
lines changed

8 files changed

+182
-9
lines changed

docs/source-pytorch/cli/lightning_cli_advanced_3.rst

+10
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ Since the init parameters of the model have as a type hint a class, in the confi
197197
decoder: Instance of a module for decoding
198198
"""
199199
super().__init__()
200+
self.save_hyperparameters()
200201
self.encoder = encoder
201202
self.decoder = decoder
202203

@@ -216,6 +217,13 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou
216217
217218
It is also possible to combine ``subclass_mode_model=True`` and submodules, thereby having two levels of ``class_path``.
218219

220+
.. tip::
221+
222+
By having ``self.save_hyperparameters()`` it becomes possible to load the model from a checkpoint. Simply do
223+
``ModelClass.load_from_checkpoint("path/to/checkpoint.ckpt")``. In the case of using ``subclass_mode_model=True``,
224+
then load it like ``LightningModule.load_from_checkpoint("path/to/checkpoint.ckpt")``. ``save_hyperparameters`` is
225+
optional and can be safely removed if there is no need to load from a checkpoint.
226+
219227

220228
Fixed optimizer and scheduler
221229
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -279,6 +287,7 @@ An example of a model that uses two optimizers is the following:
279287
class MyModel(LightningModule):
280288
def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable):
281289
super().__init__()
290+
self.save_hyperparameters()
282291
self.optimizer1 = optimizer1
283292
self.optimizer2 = optimizer2
284293
@@ -318,6 +327,7 @@ that uses dependency injection for an optimizer and a learning scheduler is:
318327
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
319328
):
320329
super().__init__()
330+
self.save_hyperparameters()
321331
self.optimizer = optimizer
322332
self.scheduler = scheduler
323333

requirements/pytorch/extra.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
matplotlib>3.1, <3.9.0
66
omegaconf >=2.0.5, <2.4.0
77
hydra-core >=1.0.5, <1.4.0
8-
jsonargparse[signatures] >=4.26.1, <4.28.0
8+
jsonargparse[signatures] >=4.27.5, <4.28.0
99
rich >=12.3.0, <13.6.0
1010
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
1111
bitsandbytes ==0.41.0 # strict

src/lightning/pytorch/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- The `ModelSummary` and `RichModelSummary` callbacks now display the training mode of each layer in the column "Mode" ([#19468](https://github.com/Lightning-AI/lightning/pull/19468))
1212

13+
- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105))
14+
1315
-
1416

1517
-
@@ -64,6 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6466
- Added shortcut name `strategy='deepspeed_stage_1_offload'` to the strategy registry ([#19075](https://github.com/Lightning-AI/lightning/pull/19075))
6567
- Added support for non-strict state-dict loading in Trainer via the new `LightningModule.strict_loading = True | False` attribute ([#19404](https://github.com/Lightning-AI/lightning/pull/19404))
6668

69+
6770
### Changed
6871

6972
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))

src/lightning/pytorch/cli.py

+54-2
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
import os
1516
import sys
1617
from functools import partial, update_wrapper
1718
from types import MethodType
18-
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
19+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union
1920

2021
import torch
22+
import yaml
2123
from lightning_utilities.core.imports import RequirementCache
2224
from lightning_utilities.core.rank_zero import _warn
2325
from torch.optim import Optimizer
@@ -27,11 +29,12 @@
2729
from lightning.fabric.utilities.cloud_io import get_filesystem
2830
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
2931
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, seed_everything
32+
from lightning.pytorch.core.mixins.hparams_mixin import _given_hyperparameters_context
3033
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3134
from lightning.pytorch.utilities.model_helpers import is_overridden
3235
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
3336

34-
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.26.1")
37+
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.5")
3538

3639
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
3740
import docstring_parser
@@ -50,6 +53,8 @@
5053
locals()["ArgumentParser"] = object
5154
locals()["Namespace"] = object
5255

56+
ModuleType = TypeVar("ModuleType")
57+
5358

5459
class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
5560
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
@@ -381,6 +386,7 @@ def __init__(
381386

382387
self._set_seed()
383388

389+
self._add_instantiators()
384390
self.before_instantiate_classes()
385391
self.instantiate_classes()
386392

@@ -527,6 +533,22 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
527533
else:
528534
self.config = parser.parse_args(args)
529535

536+
def _add_instantiators(self) -> None:
537+
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False))
538+
if "subcommand" in self.config:
539+
self.config_dump = self.config_dump[self.config.subcommand]
540+
541+
self.parser.add_instantiator(
542+
_InstantiatorFn(cli=self, key="model"),
543+
_get_module_type(self._model_class),
544+
subclasses=self.subclass_mode_model,
545+
)
546+
self.parser.add_instantiator(
547+
_InstantiatorFn(cli=self, key="data"),
548+
_get_module_type(self._datamodule_class),
549+
subclasses=self.subclass_mode_data,
550+
)
551+
530552
def before_instantiate_classes(self) -> None:
531553
"""Implement to run some code before instantiating the classes."""
532554

@@ -755,3 +777,33 @@ def _get_short_description(component: object) -> Optional[str]:
755777
return docstring.short_description
756778
except (ValueError, docstring_parser.ParseError) as ex:
757779
rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")
780+
781+
782+
def _get_module_type(value: Union[Callable, type]) -> type:
783+
if callable(value) and not isinstance(value, type):
784+
return inspect.signature(value).return_annotation
785+
return value
786+
787+
788+
class _InstantiatorFn:
789+
def __init__(self, cli: LightningCLI, key: str) -> None:
790+
self.cli = cli
791+
self.key = key
792+
793+
def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
794+
with _given_hyperparameters_context(
795+
hparams=self.cli.config_dump.get(self.key, {}),
796+
instantiator="lightning.pytorch.cli.instantiate_module",
797+
):
798+
return class_type(*args, **kwargs)
799+
800+
801+
def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType:
802+
parser = ArgumentParser(exit_on_error=False)
803+
if "class_path" in config:
804+
parser.add_subclass_arguments(class_type, "module")
805+
else:
806+
parser.add_class_arguments(class_type, "module")
807+
cfg = parser.parse_object({"module": config})
808+
init = parser.instantiate_classes(cfg)
809+
return init.module

src/lightning/pytorch/core/mixins/hparams_mixin.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import inspect
1616
import types
1717
from argparse import Namespace
18-
from typing import Any, List, MutableMapping, Optional, Sequence, Union
18+
from contextlib import contextmanager
19+
from contextvars import ContextVar
20+
from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union
1921

2022
from lightning.fabric.utilities.data import AttributeDict
2123
from lightning.pytorch.utilities.parsing import save_hyperparameters
@@ -24,6 +26,20 @@
2426
_ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
2527

2628

29+
_given_hyperparameters: ContextVar = ContextVar("_given_hyperparameters", default=None)
30+
31+
32+
@contextmanager
33+
def _given_hyperparameters_context(hparams: dict, instantiator: str) -> Iterator[None]:
34+
hparams = hparams.copy()
35+
hparams["_instantiator"] = instantiator
36+
token = _given_hyperparameters.set(hparams)
37+
try:
38+
yield
39+
finally:
40+
_given_hyperparameters.reset(token)
41+
42+
2743
class HyperparametersMixin:
2844
__jit_unused_properties__: List[str] = ["hparams", "hparams_initial"]
2945

@@ -105,12 +121,13 @@ class ``__init__`` to be ignored
105121
106122
"""
107123
self._log_hyperparams = logger
124+
given_hparams = _given_hyperparameters.get()
108125
# the frame needs to be created in this file.
109-
if not frame:
126+
if given_hparams is None and not frame:
110127
current_frame = inspect.currentframe()
111128
if current_frame:
112129
frame = current_frame.f_back
113-
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
130+
save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams)
114131

115132
def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
116133
hp = self._to_hparams_dict(hp)

src/lightning/pytorch/core/saving.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,18 @@ def _load_state(
151151
_cls_kwargs.update(cls_kwargs_loaded)
152152
_cls_kwargs.update(cls_kwargs_new)
153153

154+
instantiator = None
155+
instantiator_path = _cls_kwargs.pop("_instantiator", None)
156+
if instantiator_path is not None:
157+
# import custom instantiator
158+
module_path, name = instantiator_path.rsplit(".", 1)
159+
instantiator = getattr(__import__(module_path, fromlist=[name]), name)
160+
154161
if not cls_spec.varkw:
155162
# filter kwargs according to class init unless it allows any argument via kwargs
156163
_cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
157164

158-
obj = cls(**_cls_kwargs)
165+
obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs)
159166

160167
if isinstance(obj, pl.LightningDataModule):
161168
if obj.__class__.__qualname__ in checkpoint:

src/lightning/pytorch/utilities/parsing.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,11 @@ def collect_init_args(
140140

141141

142142
def save_hyperparameters(
143-
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
143+
obj: Any,
144+
*args: Any,
145+
ignore: Optional[Union[Sequence[str], str]] = None,
146+
frame: Optional[types.FrameType] = None,
147+
given_hparams: Optional[Dict[str, Any]] = None,
144148
) -> None:
145149
"""See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`"""
146150

@@ -156,7 +160,9 @@ def save_hyperparameters(
156160
if not isinstance(frame, types.FrameType):
157161
raise AttributeError("There is no `frame` available while being required.")
158162

159-
if is_dataclass(obj):
163+
if given_hparams is not None:
164+
init_args = given_hparams
165+
elif is_dataclass(obj):
160166
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
161167
else:
162168
init_args = {}

tests/tests_pytorch/test_cli.py

+78
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,84 @@ def configure_optimizers(self):
833833
assert init[1]["lr_scheduler"].gamma == 0.3
834834

835835

836+
class TestModelSaveHparams(BoringModel):
837+
def __init__(
838+
self,
839+
optimizer: OptimizerCallable = torch.optim.Adam,
840+
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
841+
activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05),
842+
):
843+
super().__init__()
844+
self.save_hyperparameters()
845+
self.optimizer = optimizer
846+
self.scheduler = scheduler
847+
self.activation = activation
848+
849+
def configure_optimizers(self):
850+
optimizer = self.optimizer(self.parameters())
851+
scheduler = self.scheduler(optimizer)
852+
return {"optimizer": optimizer, "lr_scheduler": scheduler}
853+
854+
855+
def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
856+
with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1"]):
857+
cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False)
858+
cli.trainer.fit(cli.model)
859+
860+
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
861+
assert hparams_path.is_file()
862+
hparams = yaml.safe_load(hparams_path.read_text())
863+
expected = {
864+
"_instantiator": "lightning.pytorch.cli.instantiate_module",
865+
"optimizer": "torch.optim.Adam",
866+
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
867+
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
868+
}
869+
assert hparams == expected
870+
871+
checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
872+
assert checkpoint_path.is_file()
873+
ckpt = torch.load(checkpoint_path)
874+
assert ckpt["hyper_parameters"] == expected
875+
876+
model = TestModelSaveHparams.load_from_checkpoint(checkpoint_path)
877+
assert isinstance(model, TestModelSaveHparams)
878+
assert isinstance(model.activation, torch.nn.LeakyReLU)
879+
assert model.activation.negative_slope == 0.05
880+
optimizer, lr_scheduler = model.configure_optimizers().values()
881+
assert isinstance(optimizer, torch.optim.Adam)
882+
assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR)
883+
884+
885+
def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(cleandir):
886+
with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1", "--model=TestModelSaveHparams"]):
887+
cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False, subclass_mode_model=True)
888+
cli.trainer.fit(cli.model)
889+
890+
expected = {
891+
"_instantiator": "lightning.pytorch.cli.instantiate_module",
892+
"class_path": f"{__name__}.TestModelSaveHparams",
893+
"init_args": {
894+
"optimizer": "torch.optim.Adam",
895+
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
896+
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
897+
},
898+
}
899+
900+
checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
901+
assert checkpoint_path.is_file()
902+
ckpt = torch.load(checkpoint_path)
903+
assert ckpt["hyper_parameters"] == expected
904+
905+
model = LightningModule.load_from_checkpoint(checkpoint_path)
906+
assert isinstance(model, TestModelSaveHparams)
907+
assert isinstance(model.activation, torch.nn.LeakyReLU)
908+
assert model.activation.negative_slope == 0.05
909+
optimizer, lr_scheduler = model.configure_optimizers().values()
910+
assert isinstance(optimizer, torch.optim.Adam)
911+
assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR)
912+
913+
836914
@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
837915
def test_lightning_cli_trainer_fn(fn):
838916
class TestCLI(LightningCLI):

0 commit comments

Comments
 (0)