|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import inspect |
14 | 15 | import os
|
15 | 16 | import sys
|
16 | 17 | from functools import partial, update_wrapper
|
17 | 18 | from types import MethodType
|
18 |
| -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union |
| 19 | +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union |
19 | 20 |
|
20 | 21 | import torch
|
| 22 | +import yaml |
21 | 23 | from lightning_utilities.core.imports import RequirementCache
|
22 | 24 | from lightning_utilities.core.rank_zero import _warn
|
23 | 25 | from torch.optim import Optimizer
|
|
27 | 29 | from lightning.fabric.utilities.cloud_io import get_filesystem
|
28 | 30 | from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
|
29 | 31 | from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, seed_everything
|
| 32 | +from lightning.pytorch.core.mixins.hparams_mixin import _given_hyperparameters_context |
30 | 33 | from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
31 | 34 | from lightning.pytorch.utilities.model_helpers import is_overridden
|
32 | 35 | from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
33 | 36 |
|
34 |
| -_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.26.1") |
| 37 | +_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.5") |
35 | 38 |
|
36 | 39 | if _JSONARGPARSE_SIGNATURES_AVAILABLE:
|
37 | 40 | import docstring_parser
|
|
50 | 53 | locals()["ArgumentParser"] = object
|
51 | 54 | locals()["Namespace"] = object
|
52 | 55 |
|
| 56 | +ModuleType = TypeVar("ModuleType") |
| 57 | + |
53 | 58 |
|
54 | 59 | class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
55 | 60 | def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
|
@@ -381,6 +386,7 @@ def __init__(
|
381 | 386 |
|
382 | 387 | self._set_seed()
|
383 | 388 |
|
| 389 | + self._add_instantiators() |
384 | 390 | self.before_instantiate_classes()
|
385 | 391 | self.instantiate_classes()
|
386 | 392 |
|
@@ -527,6 +533,22 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
|
527 | 533 | else:
|
528 | 534 | self.config = parser.parse_args(args)
|
529 | 535 |
|
| 536 | + def _add_instantiators(self) -> None: |
| 537 | + self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False)) |
| 538 | + if "subcommand" in self.config: |
| 539 | + self.config_dump = self.config_dump[self.config.subcommand] |
| 540 | + |
| 541 | + self.parser.add_instantiator( |
| 542 | + _InstantiatorFn(cli=self, key="model"), |
| 543 | + _get_module_type(self._model_class), |
| 544 | + subclasses=self.subclass_mode_model, |
| 545 | + ) |
| 546 | + self.parser.add_instantiator( |
| 547 | + _InstantiatorFn(cli=self, key="data"), |
| 548 | + _get_module_type(self._datamodule_class), |
| 549 | + subclasses=self.subclass_mode_data, |
| 550 | + ) |
| 551 | + |
530 | 552 | def before_instantiate_classes(self) -> None:
|
531 | 553 | """Implement to run some code before instantiating the classes."""
|
532 | 554 |
|
@@ -755,3 +777,33 @@ def _get_short_description(component: object) -> Optional[str]:
|
755 | 777 | return docstring.short_description
|
756 | 778 | except (ValueError, docstring_parser.ParseError) as ex:
|
757 | 779 | rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")
|
| 780 | + |
| 781 | + |
| 782 | +def _get_module_type(value: Union[Callable, type]) -> type: |
| 783 | + if callable(value) and not isinstance(value, type): |
| 784 | + return inspect.signature(value).return_annotation |
| 785 | + return value |
| 786 | + |
| 787 | + |
| 788 | +class _InstantiatorFn: |
| 789 | + def __init__(self, cli: LightningCLI, key: str) -> None: |
| 790 | + self.cli = cli |
| 791 | + self.key = key |
| 792 | + |
| 793 | + def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: |
| 794 | + with _given_hyperparameters_context( |
| 795 | + hparams=self.cli.config_dump.get(self.key, {}), |
| 796 | + instantiator="lightning.pytorch.cli.instantiate_module", |
| 797 | + ): |
| 798 | + return class_type(*args, **kwargs) |
| 799 | + |
| 800 | + |
| 801 | +def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType: |
| 802 | + parser = ArgumentParser(exit_on_error=False) |
| 803 | + if "class_path" in config: |
| 804 | + parser.add_subclass_arguments(class_type, "module") |
| 805 | + else: |
| 806 | + parser.add_class_arguments(class_type, "module") |
| 807 | + cfg = parser.parse_object({"module": config}) |
| 808 | + init = parser.instantiate_classes(cfg) |
| 809 | + return init.module |
0 commit comments