Skip to content

Commit 640c47e

Browse files
Fix torch warning about deprecated function. (#8075)
Drop support for torch versions below 2.2 on the audio VAEs.
1 parent 31e9e36 commit 640c47e

File tree

2 files changed

+11
-21
lines changed

2 files changed

+11
-21
lines changed

comfy/ldm/ace/vae/music_vocoder.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@
88

99
import numpy as np
1010
import torch.nn.functional as F
11-
from torch.nn.utils import weight_norm
1211
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
13-
# from diffusers.models.modeling_utils import ModelMixin
14-
# from diffusers.loaders import FromOriginalModelMixin
15-
# from diffusers.configuration_utils import ConfigMixin, register_to_config
1612

1713
from .music_log_mel import LogMelSpectrogram
1814

@@ -259,7 +255,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
259255

260256
self.convs1 = nn.ModuleList(
261257
[
262-
weight_norm(
258+
torch.nn.utils.parametrizations.weight_norm(
263259
ops.Conv1d(
264260
channels,
265261
channels,
@@ -269,7 +265,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
269265
padding=get_padding(kernel_size, dilation[0]),
270266
)
271267
),
272-
weight_norm(
268+
torch.nn.utils.parametrizations.weight_norm(
273269
ops.Conv1d(
274270
channels,
275271
channels,
@@ -279,7 +275,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
279275
padding=get_padding(kernel_size, dilation[1]),
280276
)
281277
),
282-
weight_norm(
278+
torch.nn.utils.parametrizations.weight_norm(
283279
ops.Conv1d(
284280
channels,
285281
channels,
@@ -294,7 +290,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
294290

295291
self.convs2 = nn.ModuleList(
296292
[
297-
weight_norm(
293+
torch.nn.utils.parametrizations.weight_norm(
298294
ops.Conv1d(
299295
channels,
300296
channels,
@@ -304,7 +300,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
304300
padding=get_padding(kernel_size, 1),
305301
)
306302
),
307-
weight_norm(
303+
torch.nn.utils.parametrizations.weight_norm(
308304
ops.Conv1d(
309305
channels,
310306
channels,
@@ -314,7 +310,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
314310
padding=get_padding(kernel_size, 1),
315311
)
316312
),
317-
weight_norm(
313+
torch.nn.utils.parametrizations.weight_norm(
318314
ops.Conv1d(
319315
channels,
320316
channels,
@@ -366,7 +362,7 @@ def __init__(
366362
prod(upsample_rates) == hop_length
367363
), f"hop_length must be {prod(upsample_rates)}"
368364

369-
self.conv_pre = weight_norm(
365+
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
370366
ops.Conv1d(
371367
num_mels,
372368
upsample_initial_channel,
@@ -386,7 +382,7 @@ def __init__(
386382
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
387383
c_cur = upsample_initial_channel // (2 ** (i + 1))
388384
self.ups.append(
389-
weight_norm(
385+
torch.nn.utils.parametrizations.weight_norm(
390386
ops.ConvTranspose1d(
391387
upsample_initial_channel // (2**i),
392388
upsample_initial_channel // (2 ** (i + 1)),
@@ -421,7 +417,7 @@ def __init__(
421417
self.resblocks.append(ResBlock1(ch, k, d))
422418

423419
self.activation_post = post_activation()
424-
self.conv_post = weight_norm(
420+
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
425421
ops.Conv1d(
426422
ch,
427423
1,

comfy/ldm/audio/autoencoder.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,10 @@ def forward(self, x):
7575
return x
7676

7777
def WNConv1d(*args, **kwargs):
78-
try:
79-
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
80-
except:
81-
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
78+
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
8279

8380
def WNConvTranspose1d(*args, **kwargs):
84-
try:
85-
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
86-
except:
87-
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
81+
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
8882

8983
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
9084
if activation == "elu":

0 commit comments

Comments
 (0)