diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py index 433e6ab25f8b7..a087b2a862d8b 100644 --- a/src/lightning/pytorch/utilities/types.py +++ b/src/lightning/pytorch/utilities/types.py @@ -121,6 +121,7 @@ class OptimizerLRSchedulerConfig(TypedDict): Sequence[Optimizer], Tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]], OptimizerLRSchedulerConfig, + Sequence[OptimizerLRSchedulerConfig] ] ] diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index d44d4b009c095..60b987c88b625 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -231,6 +231,16 @@ def test_optimizer_return_options(tmpdir): assert opt[0] == opt_a assert lr_sched[0] == ref_lr_sched + # opt list of dictionaries + model.automatic_optimization = False + model.configure_optimizers = lambda: [ + {"optimizer": opt_a, "lr_scheduler": scheduler_a}, {"optimizer": opt_b, "lr_scheduler": scheduler_a} + ] + opt, lr_sched = _init_optimizers_and_lr_schedulers(model) + assert len(opt) == len(lr_sched) == 2 + assert opt == [opt_a, opt_b] + assert lr_sched == [ref_lr_sched, ref_lr_sched] + def test_none_optimizer(tmpdir): model = BoringModel()