Skip to content

Commit b1eceb1

Browse files
Bordalantiga
authored andcommitted
bump: Torch 2.5 (#20351)
* bump: Torch `2.5.0` * push docker * docker * 2.5.1 and mypy * update USE_DISTRIBUTED=0 test * also for pytorch lightning no distributed * set USE_LIBUV=0 on windows * try drop pickle warning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable compiling update_metrics * bump 2.2.x to bugfix * disable also log in logger connector (also calls metric) * more point release bumps * remove unloved type ignore and print some more on exit * update checkgroup * minor versions * shortened version in build-pl * pytorch 2.4 is with python 3.11 * 2.1 and 2.3 without patch release * for 2.4.1: docker with 3.11 test with 3.12 --------- Co-authored-by: Thomas Viehmann <tv.code@beamnet.de> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 61a403a)
1 parent d62b53a commit b1eceb1

26 files changed

+117
-88
lines changed

.azure/gpu-benchmarks.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
variables:
4747
DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
4848
container:
49-
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
49+
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
5050
options: "--gpus=all --shm-size=32g"
5151
strategy:
5252
matrix:

.azure/gpu-tests-fabric.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
6161
PACKAGE_NAME: "fabric"
6262
"Lightning | latest":
63-
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
63+
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
6464
PACKAGE_NAME: "lightning"
6565
workspace:
6666
clean: all

.azure/gpu-tests-pytorch.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
5454
PACKAGE_NAME: "pytorch"
5555
"Lightning | latest":
56-
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
56+
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
5757
PACKAGE_NAME: "lightning"
5858
pool: lit-rtx-3090
5959
variables:

.github/checkgroup.yml

+26-18
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,22 @@ subprojects:
2121
checks:
2222
- "pl-cpu (macOS-13, lightning, 3.9, 2.1, oldest)"
2323
- "pl-cpu (macOS-14, lightning, 3.10, 2.1)"
24-
- "pl-cpu (macOS-14, lightning, 3.11, 2.2)"
24+
- "pl-cpu (macOS-14, lightning, 3.11, 2.2.2)"
2525
- "pl-cpu (macOS-14, lightning, 3.11, 2.3)"
26-
- "pl-cpu (macOS-14, lightning, 3.12, 2.4)"
26+
- "pl-cpu (macOS-14, lightning, 3.12, 2.4.1)"
27+
- "pl-cpu (macOS-14, lightning, 3.12, 2.5.1)"
2728
- "pl-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)"
2829
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
29-
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2)"
30+
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)"
3031
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.3)"
31-
- "pl-cpu (ubuntu-20.04, lightning, 3.12, 2.4)"
32+
- "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)"
33+
- "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)"
3234
- "pl-cpu (windows-2022, lightning, 3.9, 2.1, oldest)"
3335
- "pl-cpu (windows-2022, lightning, 3.10, 2.1)"
34-
- "pl-cpu (windows-2022, lightning, 3.11, 2.2)"
36+
- "pl-cpu (windows-2022, lightning, 3.11, 2.2.2)"
3537
- "pl-cpu (windows-2022, lightning, 3.11, 2.3)"
36-
- "pl-cpu (windows-2022, lightning, 3.12, 2.4)"
38+
- "pl-cpu (windows-2022, lightning, 3.12, 2.4.1)"
39+
- "pl-cpu (windows-2022, lightning, 3.12, 2.5.1)"
3740
- "pl-cpu (macOS-14, pytorch, 3.9, 2.1)"
3841
- "pl-cpu (ubuntu-20.04, pytorch, 3.9, 2.1)"
3942
- "pl-cpu (windows-2022, pytorch, 3.9, 2.1)"
@@ -141,15 +144,17 @@ subprojects:
141144
- "!*.md"
142145
- "!**/*.md"
143146
checks:
144-
- "build-cuda (3.11, 2.1, 12.1.0)"
145-
- "build-cuda (3.11, 2.2, 12.1.0)"
146-
- "build-cuda (3.11, 2.3, 12.1.0)"
147-
- "build-cuda (3.12, 2.4, 12.1.0)"
147+
- "build-cuda (3.10, 2.1.2, 12.1.0)"
148+
- "build-cuda (3.11, 2.2.2, 12.1.0)"
149+
- "build-cuda (3.11, 2.3.1, 12.1.0)"
150+
- "build-cuda (3.11, 2.4.1, 12.1.0)"
151+
- "build-cuda (3.12, 2.5.1, 12.1.0)"
148152
#- "build-NGC"
149-
- "build-pl (3.11, 2.1, 12.1.0)"
153+
- "build-pl (3.10, 2.1, 12.1.0)"
150154
- "build-pl (3.11, 2.2, 12.1.0)"
151155
- "build-pl (3.11, 2.3, 12.1.0)"
152-
- "build-pl (3.12, 2.4, 12.1.0)"
156+
- "build-pl (3.11, 2.4, 12.1.0)"
157+
- "build-pl (3.12, 2.5, 12.1.0)"
153158

154159
# SECTION: lightning_fabric
155160

@@ -168,19 +173,22 @@ subprojects:
168173
checks:
169174
- "fabric-cpu (macOS-13, lightning, 3.9, 2.1, oldest)"
170175
- "fabric-cpu (macOS-14, lightning, 3.10, 2.1)"
171-
- "fabric-cpu (macOS-14, lightning, 3.11, 2.2)"
176+
- "fabric-cpu (macOS-14, lightning, 3.11, 2.2.2)"
172177
- "fabric-cpu (macOS-14, lightning, 3.11, 2.3)"
173-
- "fabric-cpu (macOS-14, lightning, 3.12, 2.4)"
178+
- "fabric-cpu (macOS-14, lightning, 3.12, 2.4.1)"
179+
- "fabric-cpu (macOS-14, lightning, 3.12, 2.5.1)"
174180
- "fabric-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)"
175181
- "fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
176-
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2)"
182+
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)"
177183
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3)"
178-
- "fabric-cpu (ubuntu-20.04, lightning, 3.12, 2.4)"
184+
- "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)"
185+
- "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)"
179186
- "fabric-cpu (windows-2022, lightning, 3.9, 2.1, oldest)"
180187
- "fabric-cpu (windows-2022, lightning, 3.10, 2.1)"
181-
- "fabric-cpu (windows-2022, lightning, 3.11, 2.2)"
188+
- "fabric-cpu (windows-2022, lightning, 3.11, 2.2.2)"
182189
- "fabric-cpu (windows-2022, lightning, 3.11, 2.3)"
183-
- "fabric-cpu (windows-2022, lightning, 3.12, 2.4)"
190+
- "fabric-cpu (windows-2022, lightning, 3.12, 2.4.1)"
191+
- "fabric-cpu (windows-2022, lightning, 3.12, 2.5.1)"
184192
- "fabric-cpu (macOS-14, fabric, 3.9, 2.1)"
185193
- "fabric-cpu (ubuntu-20.04, fabric, 3.9, 2.1)"
186194
- "fabric-cpu (windows-2022, fabric, 3.9, 2.1)"

.github/workflows/ci-tests-fabric.yml

+9-6
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,18 @@ jobs:
4343
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
4444
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
4545
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
46-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
47-
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
48-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
46+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
47+
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
48+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
4949
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5050
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5151
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
52-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
53-
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
54-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
52+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
53+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
54+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
55+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
56+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
57+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
5558
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
5659
- { os: "macOS-13", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
5760
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }

.github/workflows/ci-tests-pytorch.yml

+9-6
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,18 @@ jobs:
4747
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
4848
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
4949
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
50-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
51-
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
52-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
50+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
51+
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
52+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
5353
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5454
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5555
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
56-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
57-
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
58-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
56+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
57+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
58+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
59+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
60+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
61+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
5962
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
6063
- { os: "macOS-13", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
6164
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }

.github/workflows/docker-build.yml

+15-7
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ jobs:
4343
include:
4444
# We only release one docker image per PyTorch version.
4545
# Make sure the matrix here matches the one below.
46-
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
46+
- { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" }
4747
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
4848
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
49-
- { python_version: "3.12", pytorch_version: "2.4", cuda_version: "12.1.0" }
49+
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
50+
- { python_version: "3.12", pytorch_version: "2.5", cuda_version: "12.1.0" }
5051
steps:
5152
- uses: actions/checkout@v4
5253
with:
@@ -103,10 +104,11 @@ jobs:
103104
include:
104105
# These are the base images for PL release docker images.
105106
# Make sure the matrix here matches the one above.
106-
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
107-
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
108-
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
109-
- { python_version: "3.12", pytorch_version: "2.4", cuda_version: "12.1.0" }
107+
- { python_version: "3.10", pytorch_version: "2.1.2", cuda_version: "12.1.0" }
108+
- { python_version: "3.11", pytorch_version: "2.2.2", cuda_version: "12.1.0" }
109+
- { python_version: "3.11", pytorch_version: "2.3.1", cuda_version: "12.1.0" }
110+
- { python_version: "3.11", pytorch_version: "2.4.1", cuda_version: "12.1.0" }
111+
- { python_version: "3.12", pytorch_version: "2.5.1", cuda_version: "12.1.0" }
110112
steps:
111113
- uses: actions/checkout@v4
112114
- uses: docker/setup-buildx-action@v3
@@ -115,6 +117,12 @@ jobs:
115117
with:
116118
username: ${{ secrets.DOCKER_USERNAME }}
117119
password: ${{ secrets.DOCKER_PASSWORD }}
120+
121+
- name: shorten Torch version
122+
run: |
123+
# convert 1.10.2 to 1.10
124+
pt_version=$(echo ${{ matrix.pytorch_version }} | cut -d. -f1,2)
125+
echo "PT_VERSION=$pt_version" >> $GITHUB_ENV
118126
- uses: docker/build-push-action@v6
119127
with:
120128
build-args: |
@@ -123,7 +131,7 @@ jobs:
123131
CUDA_VERSION=${{ matrix.cuda_version }}
124132
file: dockers/base-cuda/Dockerfile
125133
push: ${{ env.PUSH_NIGHTLY }}
126-
tags: "pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}-cuda${{ matrix.cuda_version }}"
134+
tags: "pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ env.PT_VERSION }}-cuda${{ matrix.cuda_version }}"
127135
timeout-minutes: 95
128136
- uses: ravsamhq/notify-slack-action@v2
129137
if: failure() && env.PUSH_NIGHTLY == 'true'

requirements/fabric/base.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

4-
torch >=2.1.0, <2.5.0
4+
torch >=2.1.0, <2.6.0
55
fsspec[http] >=2022.5.0, <2024.4.0
66
packaging >=20.0, <=23.1
77
typing-extensions >=4.4.0, <4.10.0

requirements/fabric/examples.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

4-
torchvision >=0.16.0, <0.20.0
5-
torchmetrics >=0.10.0, <1.3.0
4+
torchvision >=0.16.0, <0.21.0
5+
torchmetrics >=0.10.0, <1.5.0
66
lightning-utilities >=0.8.0, <0.12.0

requirements/fabric/test.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ pytest-rerunfailures ==12.0
77
pytest-random-order ==1.1.0
88
click ==8.1.7
99
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
10-
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
10+
torchmetrics >=0.7.0, <1.5.0 # needed for using fixed compare_version

requirements/pytorch/base.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

4-
torch >=2.1.0, <2.5.0
4+
torch >=2.1.0, <2.6.0
55
tqdm >=4.57.0, <4.67.0
66
PyYAML >=5.4, <6.1.0
77
fsspec[http] >=2022.5.0, <2024.4.0
8-
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
8+
torchmetrics >=0.7.0, <1.5.0 # needed for using fixed compare_version
99
packaging >=20.0, <=23.1
1010
typing-extensions >=4.4.0, <4.10.0
1111
lightning-utilities >=0.10.0, <0.12.0

requirements/pytorch/examples.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

44
requests <2.32.0
5-
torchvision >=0.16.0, <0.20.0
5+
torchvision >=0.16.0, <0.21.0
66
ipython[all] <8.15.0
7-
torchmetrics >=0.10.0, <1.3.0
7+
torchmetrics >=0.10.0, <1.5.0
88
lightning-utilities >=0.8.0, <0.12.0

requirements/typing.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mypy==1.11.0
2-
torch==2.4.1
2+
torch==2.5.1
33

44
types-Markdown
55
types-PyYAML

src/lightning/fabric/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
import os
5+
import sys
56

67
from lightning_utilities.core.imports import package_available
78

@@ -26,6 +27,10 @@
2627
# https://github.com/pytorch/pytorch/issues/83973
2728
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"
2829

30+
# see https://github.com/pytorch/pytorch/issues/139990
31+
if sys.platform == "win32":
32+
os.environ["USE_LIBUV"] = "0"
33+
2934

3035
from lightning.fabric.fabric import Fabric # noqa: E402
3136
from lightning.fabric.utilities.seed import seed_everything # noqa: E402

src/lightning/pytorch/core/module.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def log(
531531
logger=logger,
532532
on_step=on_step,
533533
on_epoch=on_epoch,
534-
reduce_fx=reduce_fx, # type: ignore[arg-type]
534+
reduce_fx=reduce_fx,
535535
enable_graph=enable_graph,
536536
add_dataloader_idx=add_dataloader_idx,
537537
batch_size=batch_size,
@@ -1405,7 +1405,9 @@ def forward(self, x):
14051405
input_sample = self._apply_batch_transfer_handler(input_sample)
14061406

14071407
file_path = str(file_path) if isinstance(file_path, Path) else file_path
1408-
torch.onnx.export(self, input_sample, file_path, **kwargs)
1408+
# PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but
1409+
# BytesIO does work, too.
1410+
torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
14091411
self.train(mode)
14101412

14111413
@torch.no_grad()

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

+2
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def _extract_batch_size(self, value: _ResultMetric, batch_size: Optional[int], m
351351

352352
return batch_size
353353

354+
@torch.compiler.disable
354355
def log(
355356
self,
356357
fx: str,
@@ -413,6 +414,7 @@ def log(
413414
batch_size = self._extract_batch_size(self[key], batch_size, meta)
414415
self.update_metrics(key, value, batch_size)
415416

417+
@torch.compiler.disable
416418
def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
417419
result_metric = self[key]
418420
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`

tests/run_standalone_tests.sh

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ function show_batched_output {
4848
# heuristic: stop if there's mentions of errors. this can prevent false negatives when only some of the ranks fail
4949
if perl -nle 'print if /error|(?<!(?-i)on_)exception|traceback|(?<!(?-i)x)failed/i' standalone_test_output.txt | grep -qv -f testnames.txt; then
5050
echo "Potential error! Stopping."
51+
perl -nle 'print if /error|(?<!(?-i)on_)exception|traceback|(?<!(?-i)x)failed/i' standalone_test_output.txt
5152
rm standalone_test_output.txt
5253
exit 1
5354
fi

tests/tests_fabric/utilities/test_imports.py

+12
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ def test_import_fabric_with_torch_dist_unavailable():
2323
code = dedent(
2424
"""
2525
import torch
26+
try:
27+
# PyTorch 2.5 relies on torch,distributed._composable.fsdp not
28+
# existing with USE_DISTRIBUTED=0
29+
import torch._dynamo.variables.functions
30+
torch._dynamo.variables.functions._fsdp_param_group = None
31+
except ImportError:
32+
pass
2633
2734
# pretend torch.distributed not available
2835
for name in list(torch.distributed.__dict__.keys()):
@@ -31,6 +38,11 @@ def test_import_fabric_with_torch_dist_unavailable():
3138
3239
torch.distributed.is_available = lambda: False
3340
41+
# needed for Dynamo in PT 2.5+ compare the torch.distributed source
42+
class _ProcessGroupStub:
43+
pass
44+
torch.distributed.ProcessGroup = _ProcessGroupStub
45+
3446
import lightning.fabric
3547
"""
3648
)

tests/tests_pytorch/callbacks/test_early_stopping.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,13 @@
1515
import math
1616
import os
1717
import pickle
18-
from contextlib import nullcontext
1918
from typing import List, Optional
2019
from unittest import mock
2120
from unittest.mock import Mock
2221

2322
import cloudpickle
2423
import pytest
2524
import torch
26-
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
2725
from lightning.pytorch import Trainer, seed_everything
2826
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
2927
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -193,13 +191,11 @@ def test_pickling():
193191
early_stopping = EarlyStopping(monitor="foo")
194192

195193
early_stopping_pickled = pickle.dumps(early_stopping)
196-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
197-
early_stopping_loaded = pickle.loads(early_stopping_pickled)
194+
early_stopping_loaded = pickle.loads(early_stopping_pickled)
198195
assert vars(early_stopping) == vars(early_stopping_loaded)
199196

200197
early_stopping_pickled = cloudpickle.dumps(early_stopping)
201-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
202-
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
198+
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
203199
assert vars(early_stopping) == vars(early_stopping_loaded)
204200

205201

0 commit comments

Comments
 (0)