diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index d8cba0c53e1fa..66f191920ece5 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -551,6 +551,13 @@ def clip_gradients_value( def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: strategy_registry.register("deepspeed", cls, description="Default DeepSpeed Strategy") strategy_registry.register("deepspeed_stage_1", cls, description="DeepSpeed with ZeRO Stage 1 enabled", stage=1) + strategy_registry.register( + "deepspeed_stage_1_offload", + cls, + description="DeepSpeed with ZeRO Stage 1 and optimizer CPU Offload", + stage=1, + offload_optimizer=True, + ) strategy_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2) strategy_registry.register( "deepspeed_stage_2_offload", diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8f17607600c59..0f50c3e842173 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a utility function and CLI to consolidate FSDP sharded checkpoints into a single file ([#19213](https://github.com/Lightning-AI/lightning/pull/19213)) - The TQDM progress bar now respects the env variable `TQDM_MINITERS` for setting the refresh rate ([#19381](https://github.com/Lightning-AI/lightning/pull/19381)) - Added support for saving and loading stateful training DataLoaders ([#19361](https://github.com/Lightning-AI/lightning/pull/19361)) +- Added shortcut name `strategy='deepspeed_stage_1_offload'` to the strategy registry ([#19075](https://github.com/Lightning-AI/lightning/pull/19075)) - Added support for non-strict state-dict loading in Trainer via the new `LightningModule.strict_loading = True | False` attribute ([#19404](https://github.com/Lightning-AI/lightning/pull/19404)) ### Changed diff --git a/tests/tests_fabric/strategies/test_registry.py b/tests/tests_fabric/strategies/test_registry.py index a61225bd2c785..1865328cf59bf 100644 --- a/tests/tests_fabric/strategies/test_registry.py +++ b/tests/tests_fabric/strategies/test_registry.py @@ -44,6 +44,7 @@ def test_available_strategies_in_registry(): "ddp", "deepspeed", "deepspeed_stage_1", + "deepspeed_stage_1_offload", "deepspeed_stage_2", "deepspeed_stage_2_offload", "deepspeed_stage_3",