Skip to content

Commit a41528c

Browse files
authored
Update tests for PyTorch 2.2.1 (#19521)
1 parent 623ec58 commit a41528c

File tree

12 files changed

+3
-97
lines changed

12 files changed

+3
-97
lines changed

tests/tests_fabric/plugins/precision/test_amp_integration.py

-8
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,10 @@
1313
# limitations under the License.
1414
"""Integration tests for Automatic Mixed Precision (AMP) training."""
1515

16-
import sys
17-
1816
import pytest
1917
import torch
2018
import torch.nn as nn
2119
from lightning.fabric import Fabric, seed_everything
22-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2320

2421
from tests_fabric.helpers.runif import RunIf
2522

@@ -41,11 +38,6 @@ def forward(self, x):
4138
return output
4239

4340

44-
@pytest.mark.xfail(
45-
# https://github.com/pytorch/pytorch/issues/116056
46-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
47-
reason="Windows + DDP issue in PyTorch 2.2",
48-
)
4941
@pytest.mark.parametrize(
5042
("accelerator", "precision", "expected_dtype"),
5143
[

tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py

-7
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import sys
1514

1615
import pytest
1716
import torch
1817
import torch.nn as nn
1918
from lightning.fabric import Fabric
20-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2119

2220
from tests_fabric.helpers.runif import RunIf
2321

@@ -31,11 +29,6 @@ def __init__(self):
3129
self.register_buffer("buffer", torch.ones(3))
3230

3331

34-
@pytest.mark.xfail(
35-
# https://github.com/pytorch/pytorch/issues/116056
36-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
37-
reason="Windows + DDP issue in PyTorch 2.2",
38-
)
3932
@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))])
4033
def test_memory_sharing_disabled(strategy):
4134
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race

tests/tests_fabric/strategies/test_ddp_integration.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
import sys
1615
from copy import deepcopy
1716
from unittest import mock
1817
from unittest.mock import Mock
1918

2019
import pytest
2120
import torch
2221
from lightning.fabric import Fabric
23-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_2
22+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
2423
from torch.nn.parallel.distributed import DistributedDataParallel
2524

2625
from tests_fabric.helpers.runif import RunIf
2726
from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
2827
from tests_fabric.test_fabric import BoringModel
2928

3029

31-
@pytest.mark.xfail(
32-
# https://github.com/pytorch/pytorch/issues/116056
33-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
34-
reason="Windows + DDP issue in PyTorch 2.2",
35-
)
3630
@pytest.mark.parametrize(
3731
"accelerator",
3832
[

tests/tests_fabric/utilities/test_distributed.py

-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import os
3-
import sys
43
from functools import partial
54
from pathlib import Path
65
from unittest import mock
@@ -19,7 +18,6 @@
1918
_sync_ddp,
2019
is_shared_filesystem,
2120
)
22-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2321

2422
from tests_fabric.helpers.runif import RunIf
2523

@@ -121,11 +119,6 @@ def test_collective_operations(devices, process):
121119
spawn_launch(process, devices)
122120

123121

124-
@pytest.mark.xfail(
125-
# https://github.com/pytorch/pytorch/issues/116056
126-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
127-
reason="Windows + DDP issue in PyTorch 2.2",
128-
)
129122
@pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO)
130123
def test_is_shared_filesystem(tmp_path, monkeypatch):
131124
# In the non-distributed case, every location is interpreted as 'shared'

tests/tests_fabric/utilities/test_spike.py

-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pytest
55
import torch
66
from lightning.fabric import Fabric
7-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
87
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException
98

109

@@ -29,11 +28,6 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
2928
)
3029

3130

32-
@pytest.mark.xfail(
33-
# https://github.com/pytorch/pytorch/issues/116056
34-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
35-
reason="Windows + DDP issue in PyTorch 2.2",
36-
)
3731
@pytest.mark.flaky(max_runs=3)
3832
@pytest.mark.parametrize(
3933
("global_rank_spike", "num_devices", "spike_value", "finite_only"),

tests/tests_pytorch/callbacks/test_spike.py

-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import pytest
55
import torch
6-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
76
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException
87
from lightning.pytorch import LightningModule, Trainer
98
from lightning.pytorch.callbacks.spike import SpikeDetection
@@ -47,11 +46,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
4746
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
4847

4948

50-
@pytest.mark.xfail(
51-
# https://github.com/pytorch/pytorch/issues/116056
52-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
53-
reason="Windows + DDP issue in PyTorch 2.2",
54-
)
5549
@pytest.mark.flaky(max_runs=3)
5650
@pytest.mark.parametrize(
5751
("global_rank_spike", "num_devices", "spike_value", "finite_only"),

tests/tests_pytorch/loops/test_prediction_loop.py

-7
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import itertools
15-
import sys
1615

1716
import pytest
18-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
1917
from lightning.pytorch import LightningModule, Trainer
2018
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
2119
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
@@ -52,11 +50,6 @@ def predict_step(self, batch, batch_idx):
5250
assert trainer.predict_loop.predictions == []
5351

5452

55-
@pytest.mark.xfail(
56-
# https://github.com/pytorch/pytorch/issues/116056
57-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
58-
reason="Windows + DDP issue in PyTorch 2.2",
59-
)
6053
@pytest.mark.parametrize("use_distributed_sampler", [False, True])
6154
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, use_distributed_sampler):
6255
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction."""

tests/tests_pytorch/models/test_amp.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
import sys
1615
from unittest import mock
1716

1817
import pytest
1918
import torch
2019
from lightning.fabric.plugins.environments import SLURMEnvironment
21-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2220
from lightning.pytorch import Trainer
2321
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
2422
from torch.utils.data import DataLoader
@@ -55,16 +53,7 @@ def _assert_autocast_enabled(self):
5553
[
5654
("single_device", "16-mixed", 1),
5755
("single_device", "bf16-mixed", 1),
58-
pytest.param(
59-
"ddp_spawn",
60-
"16-mixed",
61-
2,
62-
marks=pytest.mark.xfail(
63-
# https://github.com/pytorch/pytorch/issues/116056
64-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
65-
reason="Windows + DDP issue in PyTorch 2.2",
66-
),
67-
),
56+
("ddp_spawn", "16-mixed", 2),
6857
pytest.param("ddp_spawn", "bf16-mixed", 2, marks=RunIf(skip_windows=True)),
6958
],
7059
)

tests/tests_pytorch/serve/test_servable_module_validator.py

-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import sys
21
from typing import Dict
32

43
import pytest
54
import torch
6-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
75
from lightning.pytorch import Trainer
86
from lightning.pytorch.demos.boring_classes import BoringModel
97
from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator
@@ -38,11 +36,6 @@ def test_servable_module_validator():
3836
callback.on_train_start(Trainer(accelerator="cpu"), model)
3937

4038

41-
@pytest.mark.xfail(
42-
# https://github.com/pytorch/pytorch/issues/116056
43-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
44-
reason="Windows + DDP issue in PyTorch 2.2",
45-
)
4639
@pytest.mark.flaky(reruns=3)
4740
def test_servable_module_validator_with_trainer(tmpdir):
4841
callback = ServableModuleValidator()

tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

-12
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
import sys
1615
from multiprocessing import Process
1716
from unittest import mock
1817
from unittest.mock import ANY, Mock, call, patch
1918

2019
import pytest
2120
import torch
2221
from lightning.fabric.plugins import ClusterEnvironment
23-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2422
from lightning.pytorch import Trainer
2523
from lightning.pytorch.demos.boring_classes import BoringModel
2624
from lightning.pytorch.strategies import DDPStrategy
@@ -196,11 +194,6 @@ def on_fit_start(self) -> None:
196194
assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data)
197195

198196

199-
@pytest.mark.xfail(
200-
# https://github.com/pytorch/pytorch/issues/116056
201-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
202-
reason="Windows + DDP issue in PyTorch 2.2",
203-
)
204197
def test_memory_sharing_disabled():
205198
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
206199
conditions on model updates."""
@@ -221,11 +214,6 @@ def test_check_for_missing_main_guard():
221214
launcher.launch(function=Mock())
222215

223216

224-
@pytest.mark.xfail(
225-
# https://github.com/pytorch/pytorch/issues/116056
226-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
227-
reason="Windows + DDP issue in PyTorch 2.2",
228-
)
229217
def test_fit_twice_raises():
230218
model = BoringModel()
231219
trainer = Trainer(

tests/tests_pytorch/trainer/connectors/test_data_connector.py

-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import sys
1514
from re import escape
1615
from typing import Sized
1716
from unittest import mock
@@ -20,7 +19,6 @@
2019
import lightning.fabric
2120
import pytest
2221
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
23-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2422
from lightning.fabric.utilities.warnings import PossibleUserWarning
2523
from lightning.pytorch import Trainer
2624
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
@@ -125,11 +123,6 @@ def on_train_end(self):
125123
self.ctx.__exit__(None, None, None)
126124

127125

128-
@pytest.mark.xfail(
129-
# https://github.com/pytorch/pytorch/issues/116056
130-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
131-
reason="Windows + DDP issue in PyTorch 2.2",
132-
)
133126
@pytest.mark.parametrize("num_workers", [0, 1, 2])
134127
def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path):
135128
"""Test that when the multiprocessing start-method is 'spawn', we recommend setting `persistent_workers=True`."""

tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,13 @@
1515

1616
import collections
1717
import itertools
18-
import sys
1918
from re import escape
2019
from unittest import mock
2120
from unittest.mock import call
2221

2322
import numpy as np
2423
import pytest
2524
import torch
26-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2725
from lightning.pytorch import Trainer, callbacks
2826
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
2927
from lightning.pytorch.core.module import LightningModule
@@ -348,15 +346,7 @@ def validation_step(self, batch, batch_idx):
348346
("devices", "accelerator"),
349347
[
350348
(1, "cpu"),
351-
pytest.param(
352-
2,
353-
"cpu",
354-
marks=pytest.mark.xfail(
355-
# https://github.com/pytorch/pytorch/issues/116056
356-
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
357-
reason="Windows + DDP issue in PyTorch 2.2",
358-
),
359-
),
349+
(2, "cpu"),
360350
pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)),
361351
],
362352
)

0 commit comments

Comments
 (0)