Skip to content

Commit a6c0a31

Browse files
authored
Fix infinite recursion error in precision plugin graveyard (#19542)
1 parent 7880c11 commit a6c0a31

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5151
- Fixed support for Remote Stop and Remote Abort with NeptuneLogger ([#19130](https://github.com/Lightning-AI/pytorch-lightning/pull/19130))
5252

5353

54-
-
54+
- Fixed infinite recursion error in precision plugin graveyard ([#19542](https://github.com/Lightning-AI/pytorch-lightning/pull/19542))
55+
5556

5657

5758
## [2.2.0] - 2024-02-08

src/lightning/pytorch/_graveyard/precision.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def init(self: type, *args: Any, **kwargs: Any) -> None:
5050
f"The `{deprecated_name}` is deprecated."
5151
f" Use `lightning.pytorch.plugins.precision.{new_class.__name__}` instead."
5252
)
53-
super(type(self), self).__init__(*args, **kwargs)
53+
new_class.__init__(self, *args, **kwargs) # type: ignore[misc]
5454

5555
return type(deprecated_name, (new_class,), {"__init__": init})
5656

tests/tests_pytorch/graveyard/test_precision.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import pytest
2+
3+
14
def test_precision_plugin_renamed_imports():
25
# base class
36
from lightning.pytorch.plugins import PrecisionPlugin as PrecisionPlugin2
@@ -9,6 +12,10 @@ def test_precision_plugin_renamed_imports():
912
assert issubclass(PrecisionPlugin1, Precision)
1013
assert issubclass(PrecisionPlugin2, Precision)
1114

15+
for plugin_cls in [PrecisionPlugin0, PrecisionPlugin1, PrecisionPlugin2]:
16+
with pytest.warns(DeprecationWarning, match="The `PrecisionPlugin` is deprecated"):
17+
plugin_cls()
18+
1219
# bitsandbytes
1320
from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin as BnbPlugin2
1421
from lightning.pytorch.plugins.precision import BitsandbytesPrecisionPlugin as BnbPlugin1
@@ -39,6 +46,10 @@ def test_precision_plugin_renamed_imports():
3946
assert issubclass(DoublePlugin1, DoublePrecision)
4047
assert issubclass(DoublePlugin2, DoublePrecision)
4148

49+
for plugin_cls in [DoublePlugin0, DoublePlugin1, DoublePlugin2]:
50+
with pytest.warns(DeprecationWarning, match="The `DoublePrecisionPlugin` is deprecated"):
51+
plugin_cls()
52+
4253
# fsdp
4354
from lightning.pytorch.plugins import FSDPPrecisionPlugin as FSDPPlugin2
4455
from lightning.pytorch.plugins.precision import FSDPPrecisionPlugin as FSDPPlugin1
@@ -49,6 +60,10 @@ def test_precision_plugin_renamed_imports():
4960
assert issubclass(FSDPPlugin1, FSDPPrecision)
5061
assert issubclass(FSDPPlugin2, FSDPPrecision)
5162

63+
for plugin_cls in [FSDPPlugin0, FSDPPlugin1, FSDPPlugin2]:
64+
with pytest.warns(DeprecationWarning, match="The `FSDPPrecisionPlugin` is deprecated"):
65+
plugin_cls(precision="16-mixed")
66+
5267
# half
5368
from lightning.pytorch.plugins import HalfPrecisionPlugin as HalfPlugin2
5469
from lightning.pytorch.plugins.precision import HalfPrecisionPlugin as HalfPlugin1
@@ -59,6 +74,10 @@ def test_precision_plugin_renamed_imports():
5974
assert issubclass(HalfPlugin1, HalfPrecision)
6075
assert issubclass(HalfPlugin2, HalfPrecision)
6176

77+
for plugin_cls in [HalfPlugin0, HalfPlugin1, HalfPlugin2]:
78+
with pytest.warns(DeprecationWarning, match="The `HalfPrecisionPlugin` is deprecated"):
79+
plugin_cls()
80+
6281
# mixed
6382
from lightning.pytorch.plugins import MixedPrecisionPlugin as MixedPlugin2
6483
from lightning.pytorch.plugins.precision import MixedPrecisionPlugin as MixedPlugin1
@@ -69,6 +88,10 @@ def test_precision_plugin_renamed_imports():
6988
assert issubclass(MixedPlugin1, MixedPrecision)
7089
assert issubclass(MixedPlugin2, MixedPrecision)
7190

91+
for plugin_cls in [MixedPlugin0, MixedPlugin1, MixedPlugin2]:
92+
with pytest.warns(DeprecationWarning, match="The `MixedPrecisionPlugin` is deprecated"):
93+
plugin_cls(precision="bf16-mixed", device="cuda:0")
94+
7295
# transformer_engine
7396
from lightning.pytorch.plugins import TransformerEnginePrecisionPlugin as TEPlugin2
7497
from lightning.pytorch.plugins.precision import TransformerEnginePrecisionPlugin as TEPlugin1

0 commit comments

Comments
 (0)