Skip to content

Commit ace4d0c

Browse files
authored
optimize(dvae): remove einops (#383)
1 parent e6412b1 commit ace4d0c

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

ChatTTS/model/dvae.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
from einops import rearrange
32
from vector_quantize_pytorch import GroupedResidualFSQ
43

54
import torch
@@ -66,23 +65,32 @@ def __init__(self,
6665
self.G = G
6766
self.R = R
6867

69-
def _embed(self, x):
68+
def _embed(self, x: torch.Tensor):
7069
if self.transpose:
7170
x = x.transpose(1,2)
71+
"""
7272
x = rearrange(
7373
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
74-
)
74+
)
75+
"""
76+
x.view(-1, self.G, self.R).permute(2, 0, 1, 3)
7577
feat = self.quantizer.get_output_from_indices(x)
7678
return feat.transpose(1,2) if self.transpose else feat
77-
79+
7880
def forward(self, x,):
7981
if self.transpose:
8082
x = x.transpose(1,2)
8183
feat, ind = self.quantizer(x)
84+
"""
8285
ind = rearrange(
8386
ind, "g b t r ->b t (g r)",
84-
)
85-
embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
87+
)
88+
"""
89+
ind = ind.permute(1, 2, 0, 3).contiguous()
90+
ind = ind.view(ind.size(0), ind.size(1), -1)
91+
embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind)
92+
embed_onehot = embed_onehot_tmp.to(x.dtype)
93+
del embed_onehot_tmp
8694
e_mean = torch.mean(embed_onehot, dim=[0,1])
8795
e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
8896
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))

0 commit comments

Comments
 (0)