1
+ import pytest
2
+
3
+
1
4
def test_precision_plugin_renamed_imports ():
2
5
# base class
3
6
from lightning .pytorch .plugins import PrecisionPlugin as PrecisionPlugin2
@@ -9,6 +12,10 @@ def test_precision_plugin_renamed_imports():
9
12
assert issubclass (PrecisionPlugin1 , Precision )
10
13
assert issubclass (PrecisionPlugin2 , Precision )
11
14
15
+ for plugin_cls in [PrecisionPlugin0 , PrecisionPlugin1 , PrecisionPlugin2 ]:
16
+ with pytest .warns (DeprecationWarning , match = "The `PrecisionPlugin` is deprecated" ):
17
+ plugin_cls ()
18
+
12
19
# bitsandbytes
13
20
from lightning .pytorch .plugins import BitsandbytesPrecisionPlugin as BnbPlugin2
14
21
from lightning .pytorch .plugins .precision import BitsandbytesPrecisionPlugin as BnbPlugin1
@@ -39,6 +46,10 @@ def test_precision_plugin_renamed_imports():
39
46
assert issubclass (DoublePlugin1 , DoublePrecision )
40
47
assert issubclass (DoublePlugin2 , DoublePrecision )
41
48
49
+ for plugin_cls in [DoublePlugin0 , DoublePlugin1 , DoublePlugin2 ]:
50
+ with pytest .warns (DeprecationWarning , match = "The `DoublePrecisionPlugin` is deprecated" ):
51
+ plugin_cls ()
52
+
42
53
# fsdp
43
54
from lightning .pytorch .plugins import FSDPPrecisionPlugin as FSDPPlugin2
44
55
from lightning .pytorch .plugins .precision import FSDPPrecisionPlugin as FSDPPlugin1
@@ -49,6 +60,10 @@ def test_precision_plugin_renamed_imports():
49
60
assert issubclass (FSDPPlugin1 , FSDPPrecision )
50
61
assert issubclass (FSDPPlugin2 , FSDPPrecision )
51
62
63
+ for plugin_cls in [FSDPPlugin0 , FSDPPlugin1 , FSDPPlugin2 ]:
64
+ with pytest .warns (DeprecationWarning , match = "The `FSDPPrecisionPlugin` is deprecated" ):
65
+ plugin_cls (precision = "16-mixed" )
66
+
52
67
# half
53
68
from lightning .pytorch .plugins import HalfPrecisionPlugin as HalfPlugin2
54
69
from lightning .pytorch .plugins .precision import HalfPrecisionPlugin as HalfPlugin1
@@ -59,6 +74,10 @@ def test_precision_plugin_renamed_imports():
59
74
assert issubclass (HalfPlugin1 , HalfPrecision )
60
75
assert issubclass (HalfPlugin2 , HalfPrecision )
61
76
77
+ for plugin_cls in [HalfPlugin0 , HalfPlugin1 , HalfPlugin2 ]:
78
+ with pytest .warns (DeprecationWarning , match = "The `HalfPrecisionPlugin` is deprecated" ):
79
+ plugin_cls ()
80
+
62
81
# mixed
63
82
from lightning .pytorch .plugins import MixedPrecisionPlugin as MixedPlugin2
64
83
from lightning .pytorch .plugins .precision import MixedPrecisionPlugin as MixedPlugin1
@@ -69,6 +88,10 @@ def test_precision_plugin_renamed_imports():
69
88
assert issubclass (MixedPlugin1 , MixedPrecision )
70
89
assert issubclass (MixedPlugin2 , MixedPrecision )
71
90
91
+ for plugin_cls in [MixedPlugin0 , MixedPlugin1 , MixedPlugin2 ]:
92
+ with pytest .warns (DeprecationWarning , match = "The `MixedPrecisionPlugin` is deprecated" ):
93
+ plugin_cls (precision = "bf16-mixed" , device = "cuda:0" )
94
+
72
95
# transformer_engine
73
96
from lightning .pytorch .plugins import TransformerEnginePrecisionPlugin as TEPlugin2
74
97
from lightning .pytorch .plugins .precision import TransformerEnginePrecisionPlugin as TEPlugin1
0 commit comments