1
1
import math
2
- from vector_quantize_pytorch import GroupedResidualFSQ
2
+ from typing import List
3
3
4
4
import torch
5
5
import torch .nn as nn
6
6
import torch .nn .functional as F
7
+ from vector_quantize_pytorch import GroupedResidualFSQ
7
8
8
9
class ConvNeXtBlock (nn .Module ):
9
10
def __init__ (
10
11
self ,
11
12
dim : int ,
12
13
intermediate_dim : int ,
13
- kernel , dilation ,
14
+ kernel : int , dilation : int ,
14
15
layer_scale_init_value : float = 1e-6 ,
15
16
):
16
17
# ConvNeXt Block copied from Vocos.
@@ -32,25 +33,31 @@ def __init__(
32
33
33
34
def forward (self , x : torch .Tensor , cond = None ) -> torch .Tensor :
34
35
residual = x
35
- x = self .dwconv (x )
36
- x = x .transpose (1 , 2 ) # (B, C, T) -> (B, T, C)
37
- x = self .norm (x )
38
- x = self .pwconv1 (x )
39
- x = self .act (x )
40
- x = self .pwconv2 (x )
36
+
37
+ y = self .dwconv (x )
38
+ y .transpose_ (1 , 2 ) # (B, C, T) -> (B, T, C)
39
+ x = self .norm (y )
40
+ del y
41
+ y = self .pwconv1 (x )
42
+ del x
43
+ x = self .act (y )
44
+ del y
45
+ y = self .pwconv2 (x )
46
+ del x
41
47
if self .gamma is not None :
42
- x = self .gamma * x
43
- x = x .transpose (1 , 2 ) # (B, T, C) -> (B, C, T)
48
+ y *= self .gamma
49
+ y .transpose_ (1 , 2 ) # (B, T, C) -> (B, C, T)
50
+
51
+ x = y + residual
52
+ del y
44
53
45
- x = residual + x
46
54
return x
47
-
48
55
49
56
50
57
class GFSQ (nn .Module ):
51
58
52
59
def __init__ (self ,
53
- dim , levels , G , R , eps = 1e-5 , transpose = True
60
+ dim : int , levels : List [ int ] , G : int , R : int , eps = 1e-5 , transpose = True
54
61
):
55
62
super (GFSQ , self ).__init__ ()
56
63
self .quantizer = GroupedResidualFSQ (
@@ -67,19 +74,19 @@ def __init__(self,
67
74
68
75
def _embed (self , x : torch .Tensor ):
69
76
if self .transpose :
70
- x = x . transpose (1 ,2 )
77
+ x . transpose_ (1 , 2 )
71
78
"""
72
79
x = rearrange(
73
80
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
74
81
)
75
82
"""
76
- x .view (- 1 , self .G , self .R ).permute (2 , 0 , 1 , 3 )
83
+ x = x .view (x . size ( 0 ), x . size ( 1 ) , self .G , self .R ).permute (2 , 0 , 1 , 3 )
77
84
feat = self .quantizer .get_output_from_indices (x )
78
- return feat .transpose (1 ,2 ) if self .transpose else feat
85
+ return feat .transpose_ (1 ,2 ) if self .transpose else feat
79
86
80
87
def forward (self , x ,):
81
88
if self .transpose :
82
- x = x . transpose (1 ,2 )
89
+ x . transpose_ (1 ,2 )
83
90
feat , ind = self .quantizer (x )
84
91
"""
85
92
ind = rearrange(
@@ -92,19 +99,20 @@ def forward(self, x,):
92
99
embed_onehot = embed_onehot_tmp .to (x .dtype )
93
100
del embed_onehot_tmp
94
101
e_mean = torch .mean (embed_onehot , dim = [0 ,1 ])
95
- e_mean = e_mean / (e_mean .sum (dim = 1 ) + self .eps ).unsqueeze (1 )
102
+ # e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
103
+ torch .div (e_mean , (e_mean .sum (dim = 1 ) + self .eps ).unsqueeze (1 ), out = e_mean )
96
104
perplexity = torch .exp (- torch .sum (e_mean * torch .log (e_mean + self .eps ), dim = 1 ))
97
105
98
106
return (
99
107
torch .zeros (perplexity .shape , dtype = x .dtype , device = x .device ),
100
- feat .transpose (1 ,2 ) if self .transpose else feat ,
108
+ feat .transpose_ (1 ,2 ) if self .transpose else feat ,
101
109
perplexity ,
102
110
None ,
103
- ind .transpose (1 ,2 ) if self .transpose else ind ,
111
+ ind .transpose_ (1 ,2 ) if self .transpose else ind ,
104
112
)
105
-
113
+
106
114
class DVAEDecoder (nn .Module ):
107
- def __init__ (self , idim , odim ,
115
+ def __init__ (self , idim : int , odim : int ,
108
116
n_layer = 12 , bn_dim = 64 , hidden = 256 ,
109
117
kernel = 7 , dilation = 2 , up = False
110
118
):
@@ -121,14 +129,16 @@ def __init__(self, idim, odim,
121
129
122
130
def forward (self , input , conditioning = None ):
123
131
# B, T, C
124
- x = input .transpose (1 , 2 )
125
- x = self .conv_in (x )
132
+ x = input .transpose_ (1 , 2 )
133
+ y = self .conv_in (x )
134
+ del x
126
135
for f in self .decoder_block :
127
- x = f (x , conditioning )
128
-
129
- x = self .conv_out (x )
130
- return x .transpose (1 , 2 )
131
-
136
+ y = f (y , conditioning )
137
+
138
+ x = self .conv_out (y )
139
+ del y
140
+ return x .transpose_ (1 , 2 )
141
+
132
142
133
143
class DVAE (nn .Module ):
134
144
def __init__ (
@@ -144,20 +154,21 @@ def __init__(
144
154
else :
145
155
self .vq_layer = None
146
156
147
- def forward (self , inp ) :
157
+ def forward (self , inp : torch . Tensor ) -> torch . Tensor :
148
158
149
159
if self .vq_layer is not None :
150
160
vq_feats = self .vq_layer ._embed (inp )
151
161
else :
152
162
vq_feats = inp .detach ().clone ()
153
-
163
+
154
164
vq_feats = vq_feats .view (
155
165
(vq_feats .size (0 ), 2 , vq_feats .size (1 )// 2 , vq_feats .size (2 )),
156
166
).permute (0 , 2 , 3 , 1 ).flatten (2 )
157
167
158
- vq_feats = vq_feats .transpose (1 , 2 )
159
- dec_out = self .decoder (input = vq_feats )
160
- dec_out = self .out_conv (dec_out .transpose (1 , 2 ))
161
- mel = dec_out * self .coef
168
+ dec_out = self .out_conv (
169
+ self .decoder (
170
+ input = vq_feats .transpose_ (1 , 2 ),
171
+ ).transpose_ (1 , 2 ),
172
+ )
162
173
163
- return mel
174
+ return torch . mul ( dec_out , self . coef , out = dec_out )
0 commit comments