|
1 | 1 | import math
|
2 |
| -from einops import rearrange |
3 | 2 | from vector_quantize_pytorch import GroupedResidualFSQ
|
4 | 3 |
|
5 | 4 | import torch
|
@@ -66,23 +65,32 @@ def __init__(self,
|
66 | 65 | self.G = G
|
67 | 66 | self.R = R
|
68 | 67 |
|
69 |
| - def _embed(self, x): |
| 68 | + def _embed(self, x: torch.Tensor): |
70 | 69 | if self.transpose:
|
71 | 70 | x = x.transpose(1,2)
|
| 71 | + """ |
72 | 72 | x = rearrange(
|
73 | 73 | x, "b t (g r) -> g b t r", g = self.G, r = self.R,
|
74 |
| - ) |
| 74 | + ) |
| 75 | + """ |
| 76 | + x.view(-1, self.G, self.R).permute(2, 0, 1, 3) |
75 | 77 | feat = self.quantizer.get_output_from_indices(x)
|
76 | 78 | return feat.transpose(1,2) if self.transpose else feat
|
77 |
| - |
| 79 | + |
78 | 80 | def forward(self, x,):
|
79 | 81 | if self.transpose:
|
80 | 82 | x = x.transpose(1,2)
|
81 | 83 | feat, ind = self.quantizer(x)
|
| 84 | + """ |
82 | 85 | ind = rearrange(
|
83 | 86 | ind, "g b t r ->b t (g r)",
|
84 |
| - ) |
85 |
| - embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype) |
| 87 | + ) |
| 88 | + """ |
| 89 | + ind = ind.permute(1, 2, 0, 3).contiguous() |
| 90 | + ind = ind.view(ind.size(0), ind.size(1), -1) |
| 91 | + embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind) |
| 92 | + embed_onehot = embed_onehot_tmp.to(x.dtype) |
| 93 | + del embed_onehot_tmp |
86 | 94 | e_mean = torch.mean(embed_onehot, dim=[0,1])
|
87 | 95 | e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
|
88 | 96 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
|
|
0 commit comments