8
8
9
9
import numpy as np
10
10
import torch .nn .functional as F
11
- from torch .nn .utils import weight_norm
12
11
from torch .nn .utils .parametrize import remove_parametrizations as remove_weight_norm
13
- # from diffusers.models.modeling_utils import ModelMixin
14
- # from diffusers.loaders import FromOriginalModelMixin
15
- # from diffusers.configuration_utils import ConfigMixin, register_to_config
16
12
17
13
from .music_log_mel import LogMelSpectrogram
18
14
@@ -259,7 +255,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
259
255
260
256
self .convs1 = nn .ModuleList (
261
257
[
262
- weight_norm (
258
+ torch . nn . utils . parametrizations . weight_norm (
263
259
ops .Conv1d (
264
260
channels ,
265
261
channels ,
@@ -269,7 +265,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
269
265
padding = get_padding (kernel_size , dilation [0 ]),
270
266
)
271
267
),
272
- weight_norm (
268
+ torch . nn . utils . parametrizations . weight_norm (
273
269
ops .Conv1d (
274
270
channels ,
275
271
channels ,
@@ -279,7 +275,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
279
275
padding = get_padding (kernel_size , dilation [1 ]),
280
276
)
281
277
),
282
- weight_norm (
278
+ torch . nn . utils . parametrizations . weight_norm (
283
279
ops .Conv1d (
284
280
channels ,
285
281
channels ,
@@ -294,7 +290,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
294
290
295
291
self .convs2 = nn .ModuleList (
296
292
[
297
- weight_norm (
293
+ torch . nn . utils . parametrizations . weight_norm (
298
294
ops .Conv1d (
299
295
channels ,
300
296
channels ,
@@ -304,7 +300,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
304
300
padding = get_padding (kernel_size , 1 ),
305
301
)
306
302
),
307
- weight_norm (
303
+ torch . nn . utils . parametrizations . weight_norm (
308
304
ops .Conv1d (
309
305
channels ,
310
306
channels ,
@@ -314,7 +310,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
314
310
padding = get_padding (kernel_size , 1 ),
315
311
)
316
312
),
317
- weight_norm (
313
+ torch . nn . utils . parametrizations . weight_norm (
318
314
ops .Conv1d (
319
315
channels ,
320
316
channels ,
@@ -366,7 +362,7 @@ def __init__(
366
362
prod (upsample_rates ) == hop_length
367
363
), f"hop_length must be { prod (upsample_rates )} "
368
364
369
- self .conv_pre = weight_norm (
365
+ self .conv_pre = torch . nn . utils . parametrizations . weight_norm (
370
366
ops .Conv1d (
371
367
num_mels ,
372
368
upsample_initial_channel ,
@@ -386,7 +382,7 @@ def __init__(
386
382
for i , (u , k ) in enumerate (zip (upsample_rates , upsample_kernel_sizes )):
387
383
c_cur = upsample_initial_channel // (2 ** (i + 1 ))
388
384
self .ups .append (
389
- weight_norm (
385
+ torch . nn . utils . parametrizations . weight_norm (
390
386
ops .ConvTranspose1d (
391
387
upsample_initial_channel // (2 ** i ),
392
388
upsample_initial_channel // (2 ** (i + 1 )),
@@ -421,7 +417,7 @@ def __init__(
421
417
self .resblocks .append (ResBlock1 (ch , k , d ))
422
418
423
419
self .activation_post = post_activation ()
424
- self .conv_post = weight_norm (
420
+ self .conv_post = torch . nn . utils . parametrizations . weight_norm (
425
421
ops .Conv1d (
426
422
ch ,
427
423
1 ,
0 commit comments