From 6de98ff480f8f921dc8cb346477d4647702cc381 Mon Sep 17 00:00:00 2001 From: akulkarni Date: Thu, 13 Jun 2024 16:27:50 +0200 Subject: [PATCH 01/15] feat(openvoice): initial integration --- TTS/vc/modules/openvoice/__init__.py | 0 TTS/vc/modules/openvoice/attentions.py | 423 +++++++++++++++ TTS/vc/modules/openvoice/commons.py | 151 ++++++ TTS/vc/modules/openvoice/config.json | 57 ++ TTS/vc/modules/openvoice/models.py | 480 +++++++++++++++++ TTS/vc/modules/openvoice/modules.py | 588 +++++++++++++++++++++ TTS/vc/modules/openvoice/standalone_api.py | 342 ++++++++++++ TTS/vc/modules/openvoice/transforms.py | 203 +++++++ 8 files changed, 2244 insertions(+) create mode 100644 TTS/vc/modules/openvoice/__init__.py create mode 100644 TTS/vc/modules/openvoice/attentions.py create mode 100644 TTS/vc/modules/openvoice/commons.py create mode 100644 TTS/vc/modules/openvoice/config.json create mode 100644 TTS/vc/modules/openvoice/models.py create mode 100644 TTS/vc/modules/openvoice/modules.py create mode 100644 TTS/vc/modules/openvoice/standalone_api.py create mode 100644 TTS/vc/modules/openvoice/transforms.py diff --git a/TTS/vc/modules/openvoice/__init__.py b/TTS/vc/modules/openvoice/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/TTS/vc/modules/openvoice/attentions.py b/TTS/vc/modules/openvoice/attentions.py new file mode 100644 index 0000000000..73c5554c98 --- /dev/null +++ b/TTS/vc/modules/openvoice/attentions.py @@ -0,0 +1,423 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.vc.modules.openvoice import commons + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + isflow=True, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + # if isflow: + # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1) + # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) + # self.cond_layer = weight_norm(cond_layer, name='weight') + # self.gin_channels = 256 + self.cond_layer_idx = self.n_layers + if "gin_channels" in kwargs: + self.gin_channels = kwargs["gin_channels"] + if self.gin_channels != 0: + self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) + # vits2 says 3rd block, so idx is 2 by default + self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 + # logging.debug(self.gin_channels, self.cond_layer_idx) + assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers" + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, g=None): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + if i == self.cond_layer_idx and g is not None: + g = self.spk_emb_linear(g.transpose(1, 2)) + g = g.transpose(1, 2) + x = x + g + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert t_s == t_t, "Local attention is only available for self-attention." + block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # pad along column + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/TTS/vc/modules/openvoice/commons.py b/TTS/vc/modules/openvoice/commons.py new file mode 100644 index 0000000000..123ee7e156 --- /dev/null +++ b/TTS/vc/modules/openvoice/commons.py @@ -0,0 +1,151 @@ +import math + +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + layer = pad_shape[::-1] + pad_shape = [item for sublist in layer for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/TTS/vc/modules/openvoice/config.json b/TTS/vc/modules/openvoice/config.json new file mode 100644 index 0000000000..3e33566b0d --- /dev/null +++ b/TTS/vc/modules/openvoice/config.json @@ -0,0 +1,57 @@ +{ + "_version_": "v2", + "data": { + "sampling_rate": 22050, + "filter_length": 1024, + "hop_length": 256, + "win_length": 1024, + "n_speakers": 0 + }, + "model": { + "zero_g": true, + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 8, + 8, + 2, + 2 + ], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [ + 16, + 16, + 4, + 4 + ], + "gin_channels": 256 + } +} \ No newline at end of file diff --git a/TTS/vc/modules/openvoice/models.py b/TTS/vc/modules/openvoice/models.py new file mode 100644 index 0000000000..c1ae7574ce --- /dev/null +++ b/TTS/vc/modules/openvoice/models.py @@ -0,0 +1,480 @@ +import math + +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm + +from TTS.vc.modules.openvoice import attentions, commons, modules +from TTS.vc.modules.openvoice.commons import init_weights + + +class TextEncoder(nn.Module): + def __init__( + self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class StochasticDurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None, tau=1.0): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for layer in self.ups: + remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + + +class ReferenceEncoder(nn.Module): + """ + inputs --- [N, Ty/r, n_mels*r] mels + outputs --- [N, ref_enc_gru_size] + """ + + def __init__(self, spec_channels, gin_channels=0, layernorm=True): + super().__init__() + self.spec_channels = spec_channels + ref_enc_filters = [32, 32, 64, 64, 128, 128] + K = len(ref_enc_filters) + filters = [1] + ref_enc_filters + convs = [ + weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] + self.convs = nn.ModuleList(convs) + + out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) + self.proj = nn.Linear(128, gin_channels) + if layernorm: + self.layernorm = nn.LayerNorm(self.spec_channels) + else: + self.layernorm = None + + def forward(self, inputs, mask=None): + N = inputs.size(0) + + out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] + if self.layernorm is not None: + out = self.layernorm(out) + + for conv in self.convs: + out = conv(out) + # out = wn(out) + out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + memory, out = self.gru(out) # out --- [1, N, 128] + + return self.proj(out.squeeze(0)) + + def calculate_channels(self, L, kernel_size, stride, pad, n_convs): + for i in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + n_vocab, + spec_channels, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=256, + gin_channels=256, + zero_g=False, + **kwargs, + ): + super().__init__() + + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + self.n_speakers = n_speakers + if n_speakers == 0: + self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) + else: + self.enc_p = TextEncoder( + n_vocab, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) + self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) + self.emb_g = nn.Embedding(n_speakers, gin_channels) + self.zero_g = zero_g + + def infer( + self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1.0, sdp_ratio=0.2, max_len=None + ): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio + self.dp( + x, x_mask, g=g + ) * (1 - sdp_ratio) + + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:, :, :max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): + g_src = sid_src + g_tgt = sid_tgt + z, m_q, logs_q, y_mask = self.enc_q( + y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau + ) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt)) + return o_hat, y_mask, (z, z_p, z_hat) diff --git a/TTS/vc/modules/openvoice/modules.py b/TTS/vc/modules/openvoice/modules.py new file mode 100644 index 0000000000..b3a60d5b12 --- /dev/null +++ b/TTS/vc/modules/openvoice/modules.py @@ -0,0 +1,588 @@ +import math + +import torch +from torch import nn +from torch.nn import Conv1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm + +from TTS.vc.modules.openvoice import commons +from TTS.vc.modules.openvoice.attentions import Encoder +from TTS.vc.modules.openvoice.commons import get_padding, init_weights +from TTS.vc.modules.openvoice.transforms import piecewise_rational_quadratic_transform + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dilated and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + p_dropout=0, + ): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=p_dropout, + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ConvFlow(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + n_layers, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) + self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x + + +class TransformerCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + n_layers, + n_heads, + p_dropout=0, + filter_channels=0, + mean_only=False, + wn_sharing_parameter=None, + gin_channels=0, + ): + assert n_layers == 3, n_layers + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = ( + Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + isflow=True, + gin_channels=gin_channels, + ) + if wn_sharing_parameter is None + else wn_sharing_parameter + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x diff --git a/TTS/vc/modules/openvoice/standalone_api.py b/TTS/vc/modules/openvoice/standalone_api.py new file mode 100644 index 0000000000..831fd4dc43 --- /dev/null +++ b/TTS/vc/modules/openvoice/standalone_api.py @@ -0,0 +1,342 @@ +import json +import os + +import librosa +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +from TTS.vc.modules.openvoice.models import SynthesizerTrn + +# vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu) + +# vc_config.audio.output_sample_rate + + +class custom_sr_config: + """Class defined to make combatible sampling rate defination with TTS api.py. + + Args: + sampling rate. + """ + + def __init__(self, value): + self.audio = self.Audio(value) + + class Audio: + def __init__(self, value): + self.output_sample_rate = value + + +class OpenVoiceSynthesizer(object): + def __init__(self, vc_checkpoint, vc_config, use_cuda="cpu"): + + if use_cuda: + self.device = "cuda" + else: + self.device = "cpu" + + hps = get_hparams_from_file(vc_config) + self.vc_config = custom_sr_config(hps.data.sampling_rate) + + # vc_config.audio.output_sample_rate + self.model = SynthesizerTrn( + len(getattr(hps, "symbols", [])), + hps.data.filter_length // 2 + 1, + n_speakers=hps.data.n_speakers, + **hps.model, + ).to(torch.device(self.device)) + + self.hps = hps + self.load_ckpt(vc_checkpoint) + self.model.eval() + + def load_ckpt(self, ckpt_path): + checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device)) + a, b = self.model.load_state_dict(checkpoint_dict["model"], strict=False) + # print("Loaded checkpoint '{}'".format(ckpt_path)) + # print('missing/unexpected keys:', a, b) + + def extract_se(self, fpath): + audio_ref, sr = librosa.load(fpath, sr=self.hps.data.sampling_rate) + y = torch.FloatTensor(audio_ref) + y = y.to(self.device) + y = y.unsqueeze(0) + y = spectrogram_torch( + y, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False, + ).to(self.device) + with torch.no_grad(): + g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1) + + return g + + # source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav" + def voice_conversion(self, source_wav, target_wav, tau=0.3, message="default"): + + if not os.path.exists(source_wav): + print("source wavpath dont exists") + exit(0) + + if not os.path.exists(target_wav): + print("target wavpath dont exists") + exit(0) + + src_se = self.extract_se(source_wav) + tgt_se = self.extract_se(target_wav) + + # load audio + audio, sample_rate = librosa.load(source_wav, sr=self.hps.data.sampling_rate) + audio = torch.tensor(audio).float() + + with torch.no_grad(): + y = torch.FloatTensor(audio).to(self.device) + y = y.unsqueeze(0) + spec = spectrogram_torch( + y, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False, + ).to(self.device) + spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device) + audio = ( + self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][0, 0] + .data.cpu() + .float() + .numpy() + ) + + return audio + + +def get_hparams_from_file(config_path): + with open(config_path, "r", encoding="utf-8") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + return hparams + + +class HParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, dict): + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.1: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.1: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False): + # if torch.min(y) < -1.: + # print('min value is ', torch.min(y)) + # if torch.max(y) > 1.: + # print('max value is ', torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + + # ******************** original ************************# + # y = y.squeeze(1) + # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + + # ******************** ConvSTFT ************************# + freq_cutoff = n_fft // 2 + 1 + fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft))) + forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1]) + forward_basis = ( + forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float() + ) + + import torch.nn.functional as F + + # if center: + # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1) + assert center is False + + forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride=hop_size) + spec2 = torch.stack( + [forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim=-1 + ) + + # ******************** Verification ************************# + spec1 = torch.stft( + y.squeeze(1), + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + assert torch.allclose(spec1, spec2, atol=1e-4) + + spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/TTS/vc/modules/openvoice/transforms.py b/TTS/vc/modules/openvoice/transforms.py new file mode 100644 index 0000000000..4270ebae3f --- /dev/null +++ b/TTS/vc/modules/openvoice/transforms.py @@ -0,0 +1,203 @@ +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs, + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet From 4124b9d663b4eea5e7034e96351fe5d4180cfb89 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Tue, 25 Jun 2024 22:28:41 +0200 Subject: [PATCH 02/15] feat(vits): add tau parameter to posterior encoder --- TTS/tts/layers/vits/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index 50ed1024de..ab2ca5667a 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -256,7 +256,7 @@ def __init__( ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, g=None): + def forward(self, x, x_lengths, g=None, tau=1.0): """ Shapes: - x: :math:`[B, C, T]` @@ -268,5 +268,5 @@ def forward(self, x, x_lengths, g=None): x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask mean, log_scale = torch.split(stats, self.out_channels, dim=1) - z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask + z = (mean + torch.randn_like(mean) * tau * torch.exp(log_scale)) * x_mask return z, mean, log_scale, x_mask From b97d5378a534acd7fa57f49cde955bec2e1a8085 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 20 Jun 2024 12:42:25 +0200 Subject: [PATCH 03/15] refactor(openvoice): remove duplicate and unused code --- TTS/vc/modules/openvoice/attentions.py | 423 --------------- TTS/vc/modules/openvoice/commons.py | 151 ------ TTS/vc/modules/openvoice/config.json | 57 -- TTS/vc/modules/openvoice/models.py | 368 +------------ TTS/vc/modules/openvoice/modules.py | 588 --------------------- TTS/vc/modules/openvoice/standalone_api.py | 342 ------------ TTS/vc/modules/openvoice/transforms.py | 203 ------- 7 files changed, 11 insertions(+), 2121 deletions(-) delete mode 100644 TTS/vc/modules/openvoice/attentions.py delete mode 100644 TTS/vc/modules/openvoice/commons.py delete mode 100644 TTS/vc/modules/openvoice/config.json delete mode 100644 TTS/vc/modules/openvoice/modules.py delete mode 100644 TTS/vc/modules/openvoice/standalone_api.py delete mode 100644 TTS/vc/modules/openvoice/transforms.py diff --git a/TTS/vc/modules/openvoice/attentions.py b/TTS/vc/modules/openvoice/attentions.py deleted file mode 100644 index 73c5554c98..0000000000 --- a/TTS/vc/modules/openvoice/attentions.py +++ /dev/null @@ -1,423 +0,0 @@ -import math - -import torch -from torch import nn -from torch.nn import functional as F - -from TTS.vc.modules.openvoice import commons - - -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - - -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts - - -class Encoder(nn.Module): - def __init__( - self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - window_size=4, - isflow=True, - **kwargs, - ): - super().__init__() - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.window_size = window_size - # if isflow: - # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1) - # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) - # self.cond_layer = weight_norm(cond_layer, name='weight') - # self.gin_channels = 256 - self.cond_layer_idx = self.n_layers - if "gin_channels" in kwargs: - self.gin_channels = kwargs["gin_channels"] - if self.gin_channels != 0: - self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) - # vits2 says 3rd block, so idx is 2 by default - self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 - # logging.debug(self.gin_channels, self.cond_layer_idx) - assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers" - self.drop = nn.Dropout(p_dropout) - self.attn_layers = nn.ModuleList() - self.norm_layers_1 = nn.ModuleList() - self.ffn_layers = nn.ModuleList() - self.norm_layers_2 = nn.ModuleList() - - for i in range(self.n_layers): - self.attn_layers.append( - MultiHeadAttention( - hidden_channels, - hidden_channels, - n_heads, - p_dropout=p_dropout, - window_size=window_size, - ) - ) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append( - FFN( - hidden_channels, - hidden_channels, - filter_channels, - kernel_size, - p_dropout=p_dropout, - ) - ) - self.norm_layers_2.append(LayerNorm(hidden_channels)) - - def forward(self, x, x_mask, g=None): - attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) - x = x * x_mask - for i in range(self.n_layers): - if i == self.cond_layer_idx and g is not None: - g = self.spk_emb_linear(g.transpose(1, 2)) - g = g.transpose(1, 2) - x = x + g - x = x * x_mask - y = self.attn_layers[i](x, x, attn_mask) - y = self.drop(y) - x = self.norm_layers_1[i](x + y) - - y = self.ffn_layers[i](x, x_mask) - y = self.drop(y) - x = self.norm_layers_2[i](x + y) - x = x * x_mask - return x - - -class Decoder(nn.Module): - def __init__( - self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - proximal_bias=False, - proximal_init=True, - **kwargs, - ): - super().__init__() - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.proximal_bias = proximal_bias - self.proximal_init = proximal_init - - self.drop = nn.Dropout(p_dropout) - self.self_attn_layers = nn.ModuleList() - self.norm_layers_0 = nn.ModuleList() - self.encdec_attn_layers = nn.ModuleList() - self.norm_layers_1 = nn.ModuleList() - self.ffn_layers = nn.ModuleList() - self.norm_layers_2 = nn.ModuleList() - for i in range(self.n_layers): - self.self_attn_layers.append( - MultiHeadAttention( - hidden_channels, - hidden_channels, - n_heads, - p_dropout=p_dropout, - proximal_bias=proximal_bias, - proximal_init=proximal_init, - ) - ) - self.norm_layers_0.append(LayerNorm(hidden_channels)) - self.encdec_attn_layers.append( - MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout) - ) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append( - FFN( - hidden_channels, - hidden_channels, - filter_channels, - kernel_size, - p_dropout=p_dropout, - causal=True, - ) - ) - self.norm_layers_2.append(LayerNorm(hidden_channels)) - - def forward(self, x, x_mask, h, h_mask): - """ - x: decoder input - h: encoder output - """ - self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) - encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) - x = x * x_mask - for i in range(self.n_layers): - y = self.self_attn_layers[i](x, x, self_attn_mask) - y = self.drop(y) - x = self.norm_layers_0[i](x + y) - - y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) - y = self.drop(y) - x = self.norm_layers_1[i](x + y) - - y = self.ffn_layers[i](x, x_mask) - y = self.drop(y) - x = self.norm_layers_2[i](x + y) - x = x * x_mask - return x - - -class MultiHeadAttention(nn.Module): - def __init__( - self, - channels, - out_channels, - n_heads, - p_dropout=0.0, - window_size=None, - heads_share=True, - block_length=None, - proximal_bias=False, - proximal_init=False, - ): - super().__init__() - assert channels % n_heads == 0 - - self.channels = channels - self.out_channels = out_channels - self.n_heads = n_heads - self.p_dropout = p_dropout - self.window_size = window_size - self.heads_share = heads_share - self.block_length = block_length - self.proximal_bias = proximal_bias - self.proximal_init = proximal_init - self.attn = None - - self.k_channels = channels // n_heads - self.conv_q = nn.Conv1d(channels, channels, 1) - self.conv_k = nn.Conv1d(channels, channels, 1) - self.conv_v = nn.Conv1d(channels, channels, 1) - self.conv_o = nn.Conv1d(channels, out_channels, 1) - self.drop = nn.Dropout(p_dropout) - - if window_size is not None: - n_heads_rel = 1 if heads_share else n_heads - rel_stddev = self.k_channels**-0.5 - self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) - self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) - - nn.init.xavier_uniform_(self.conv_q.weight) - nn.init.xavier_uniform_(self.conv_k.weight) - nn.init.xavier_uniform_(self.conv_v.weight) - if proximal_init: - with torch.no_grad(): - self.conv_k.weight.copy_(self.conv_q.weight) - self.conv_k.bias.copy_(self.conv_q.bias) - - def forward(self, x, c, attn_mask=None): - q = self.conv_q(x) - k = self.conv_k(c) - v = self.conv_v(c) - - x, self.attn = self.attention(q, k, v, mask=attn_mask) - - x = self.conv_o(x) - return x - - def attention(self, query, key, value, mask=None): - # reshape [b, d, t] -> [b, n_h, t, d_k] - b, d, t_s, t_t = (*key.size(), query.size(2)) - query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) - key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - - scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) - if self.window_size is not None: - assert t_s == t_t, "Relative attention is only available for self-attention." - key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) - rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) - scores_local = self._relative_position_to_absolute_position(rel_logits) - scores = scores + scores_local - if self.proximal_bias: - assert t_s == t_t, "Proximal bias is only available for self-attention." - scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) - if mask is not None: - scores = scores.masked_fill(mask == 0, -1e4) - if self.block_length is not None: - assert t_s == t_t, "Local attention is only available for self-attention." - block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) - scores = scores.masked_fill(block_mask == 0, -1e4) - p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] - p_attn = self.drop(p_attn) - output = torch.matmul(p_attn, value) - if self.window_size is not None: - relative_weights = self._absolute_position_to_relative_position(p_attn) - value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) - output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) - output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] - return output, p_attn - - def _matmul_with_relative_values(self, x, y): - """ - x: [b, h, l, m] - y: [h or 1, m, d] - ret: [b, h, l, d] - """ - ret = torch.matmul(x, y.unsqueeze(0)) - return ret - - def _matmul_with_relative_keys(self, x, y): - """ - x: [b, h, l, d] - y: [h or 1, m, d] - ret: [b, h, l, m] - """ - ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) - return ret - - def _get_relative_embeddings(self, relative_embeddings, length): - 2 * self.window_size + 1 - # Pad first before slice to avoid using cond ops. - pad_length = max(length - (self.window_size + 1), 0) - slice_start_position = max((self.window_size + 1) - length, 0) - slice_end_position = slice_start_position + 2 * length - 1 - if pad_length > 0: - padded_relative_embeddings = F.pad( - relative_embeddings, - commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), - ) - else: - padded_relative_embeddings = relative_embeddings - used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] - return used_relative_embeddings - - def _relative_position_to_absolute_position(self, x): - """ - x: [b, h, l, 2*l-1] - ret: [b, h, l, l] - """ - batch, heads, length, _ = x.size() - # Concat columns of pad to shift from relative to absolute indexing. - x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) - - # Concat extra elements so to add up to shape (len+1, 2*len-1). - x_flat = x.view([batch, heads, length * 2 * length]) - x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) - - # Reshape and slice out the padded elements. - x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] - return x_final - - def _absolute_position_to_relative_position(self, x): - """ - x: [b, h, l, l] - ret: [b, h, l, 2*l-1] - """ - batch, heads, length, _ = x.size() - # pad along column - x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) - x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) - # add 0's in the beginning that will skew the elements after reshape - x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) - x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] - return x_final - - def _attention_bias_proximal(self, length): - """Bias for self-attention to encourage attention to close positions. - Args: - length: an integer scalar. - Returns: - a Tensor with shape [1, 1, length, length] - """ - r = torch.arange(length, dtype=torch.float32) - diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) - return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) - - -class FFN(nn.Module): - def __init__( - self, - in_channels, - out_channels, - filter_channels, - kernel_size, - p_dropout=0.0, - activation=None, - causal=False, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.activation = activation - self.causal = causal - - if causal: - self.padding = self._causal_padding - else: - self.padding = self._same_padding - - self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) - self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) - self.drop = nn.Dropout(p_dropout) - - def forward(self, x, x_mask): - x = self.conv_1(self.padding(x * x_mask)) - if self.activation == "gelu": - x = x * torch.sigmoid(1.702 * x) - else: - x = torch.relu(x) - x = self.drop(x) - x = self.conv_2(self.padding(x * x_mask)) - return x * x_mask - - def _causal_padding(self, x): - if self.kernel_size == 1: - return x - pad_l = self.kernel_size - 1 - pad_r = 0 - padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, commons.convert_pad_shape(padding)) - return x - - def _same_padding(self, x): - if self.kernel_size == 1: - return x - pad_l = (self.kernel_size - 1) // 2 - pad_r = self.kernel_size // 2 - padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, commons.convert_pad_shape(padding)) - return x diff --git a/TTS/vc/modules/openvoice/commons.py b/TTS/vc/modules/openvoice/commons.py deleted file mode 100644 index 123ee7e156..0000000000 --- a/TTS/vc/modules/openvoice/commons.py +++ /dev/null @@ -1,151 +0,0 @@ -import math - -import torch -from torch.nn import functional as F - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -def intersperse(lst, item): - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - - -def kl_divergence(m_p, logs_p, m_q, logs_q): - """KL(P||Q)""" - kl = (logs_q - logs_p) - 0.5 - kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) - return kl - - -def rand_gumbel(shape): - """Sample from the Gumbel distribution, protect from overflows.""" - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) - - -def rand_gumbel_like(x): - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g - - -def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret - - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment - ) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal - - -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) - - -def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) - - -def subsequent_mask(length): - mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) - return mask - - -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts - - -def convert_pad_shape(pad_shape): - layer = pad_shape[::-1] - pad_shape = [item for sublist in layer for item in sublist] - return pad_shape - - -def shift_1d(x): - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2, 3) * mask - return path - - -def clip_grad_value_(parameters, clip_value, norm_type=2): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - norm_type = float(norm_type) - if clip_value is not None: - clip_value = float(clip_value) - - total_norm = 0 - for p in parameters: - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item() ** norm_type - if clip_value is not None: - p.grad.data.clamp_(min=-clip_value, max=clip_value) - total_norm = total_norm ** (1.0 / norm_type) - return total_norm diff --git a/TTS/vc/modules/openvoice/config.json b/TTS/vc/modules/openvoice/config.json deleted file mode 100644 index 3e33566b0d..0000000000 --- a/TTS/vc/modules/openvoice/config.json +++ /dev/null @@ -1,57 +0,0 @@ -{ - "_version_": "v2", - "data": { - "sampling_rate": 22050, - "filter_length": 1024, - "hop_length": 256, - "win_length": 1024, - "n_speakers": 0 - }, - "model": { - "zero_g": true, - "inter_channels": 192, - "hidden_channels": 192, - "filter_channels": 768, - "n_heads": 2, - "n_layers": 6, - "kernel_size": 3, - "p_dropout": 0.1, - "resblock": "1", - "resblock_kernel_sizes": [ - 3, - 7, - 11 - ], - "resblock_dilation_sizes": [ - [ - 1, - 3, - 5 - ], - [ - 1, - 3, - 5 - ], - [ - 1, - 3, - 5 - ] - ], - "upsample_rates": [ - 8, - 8, - 2, - 2 - ], - "upsample_initial_channel": 512, - "upsample_kernel_sizes": [ - 16, - 16, - 4, - 4 - ], - "gin_channels": 256 - } -} \ No newline at end of file diff --git a/TTS/vc/modules/openvoice/models.py b/TTS/vc/modules/openvoice/models.py index c1ae7574ce..89a1c3a40c 100644 --- a/TTS/vc/modules/openvoice/models.py +++ b/TTS/vc/modules/openvoice/models.py @@ -1,276 +1,9 @@ -import math - import torch from torch import nn -from torch.nn import Conv1d, ConvTranspose1d from torch.nn import functional as F -from torch.nn.utils import remove_weight_norm, weight_norm - -from TTS.vc.modules.openvoice import attentions, commons, modules -from TTS.vc.modules.openvoice.commons import init_weights - - -class TextEncoder(nn.Module): - def __init__( - self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout - ): - super().__init__() - self.n_vocab = n_vocab - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - - self.emb = nn.Embedding(n_vocab, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) - - self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths): - x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] - x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - - x = self.encoder(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - - m, logs = torch.split(stats, self.out_channels, dim=1) - return x, m, logs, x_mask - - -class DurationPredictor(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): - super().__init__() - - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.gin_channels = gin_channels - - self.drop = nn.Dropout(p_dropout) - self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) - self.norm_1 = modules.LayerNorm(filter_channels) - self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) - self.norm_2 = modules.LayerNorm(filter_channels) - self.proj = nn.Conv1d(filter_channels, 1, 1) - - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, in_channels, 1) - - def forward(self, x, x_mask, g=None): - x = torch.detach(x) - if g is not None: - g = torch.detach(g) - x = x + self.cond(g) - x = self.conv_1(x * x_mask) - x = torch.relu(x) - x = self.norm_1(x) - x = self.drop(x) - x = self.conv_2(x * x_mask) - x = torch.relu(x) - x = self.norm_2(x) - x = self.drop(x) - x = self.proj(x * x_mask) - return x * x_mask - - -class StochasticDurationPredictor(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): - super().__init__() - filter_channels = in_channels # it needs to be removed from future version. - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.log_flow = modules.Log() - self.flows = nn.ModuleList() - self.flows.append(modules.ElementwiseAffine(2)) - for i in range(n_flows): - self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) - self.flows.append(modules.Flip()) - - self.post_pre = nn.Conv1d(1, filter_channels, 1) - self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) - self.post_flows = nn.ModuleList() - self.post_flows.append(modules.ElementwiseAffine(2)) - for i in range(4): - self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) - self.post_flows.append(modules.Flip()) - - self.pre = nn.Conv1d(in_channels, filter_channels, 1) - self.proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, filter_channels, 1) - - def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): - x = torch.detach(x) - x = self.pre(x) - if g is not None: - g = torch.detach(g) - x = x + self.cond(g) - x = self.convs(x, x_mask) - x = self.proj(x) * x_mask - - if not reverse: - flows = self.flows - assert w is not None - - logdet_tot_q = 0 - h_w = self.post_pre(w) - h_w = self.post_convs(h_w, x_mask) - h_w = self.post_proj(h_w) * x_mask - e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask - z_q = e_q - for flow in self.post_flows: - z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) - logdet_tot_q += logdet_q - z_u, z1 = torch.split(z_q, [1, 1], 1) - u = torch.sigmoid(z_u) * x_mask - z0 = (w - u) * x_mask - logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) - logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q - - logdet_tot = 0 - z0, logdet = self.log_flow(z0, x_mask) - logdet_tot += logdet - z = torch.cat([z0, z1], 1) - for flow in flows: - z, logdet = flow(z, x_mask, g=x, reverse=reverse) - logdet_tot = logdet_tot + logdet - nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot - return nll + logq # [b] - else: - flows = list(reversed(self.flows)) - flows = flows[:-2] + [flows[-1]] # remove a useless vflow - z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale - for flow in flows: - z = flow(z, x_mask, g=x, reverse=reverse) - z0, z1 = torch.split(z, [1, 1], 1) - logw = z0 - return logw - - -class PosteriorEncoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - - self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN( - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=gin_channels, - ) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths, g=None, tau=1.0): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - x = self.pre(x) * x_mask - x = self.enc(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask - return z, m, logs, x_mask - - -class Generator(torch.nn.Module): - def __init__( - self, - initial_channel, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=0, - ): - super(Generator, self).__init__() - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) - resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - upsample_initial_channel // (2**i), - upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock(ch, k, d)) - - self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) - - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - - def forward(self, x, g=None): - x = self.conv_pre(x) - if g is not None: - x = x + self.cond(g) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, modules.LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - print("Removing weight norm...") - for layer in self.ups: - remove_weight_norm(layer) - for layer in self.resblocks: - layer.remove_weight_norm() +from TTS.tts.layers.vits.networks import PosteriorEncoder +from TTS.vc.models.freevc import Generator, ResidualCouplingBlock class ReferenceEncoder(nn.Module): @@ -286,7 +19,7 @@ def __init__(self, spec_channels, gin_channels=0, layernorm=True): K = len(ref_enc_filters) filters = [1] + ref_enc_filters convs = [ - weight_norm( + torch.nn.utils.parametrizations.weight_norm( nn.Conv2d( in_channels=filters[i], out_channels=filters[i + 1], @@ -311,7 +44,7 @@ def __init__(self, spec_channels, gin_channels=0, layernorm=True): else: self.layernorm = None - def forward(self, inputs, mask=None): + def forward(self, inputs): N = inputs.size(0) out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] @@ -320,7 +53,6 @@ def forward(self, inputs, mask=None): for conv in self.convs: out = conv(out) - # out = wn(out) out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] @@ -329,52 +61,16 @@ def forward(self, inputs, mask=None): out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] self.gru.flatten_parameters() - memory, out = self.gru(out) # out --- [1, N, 128] + _memory, out = self.gru(out) # out --- [1, N, 128] return self.proj(out.squeeze(0)) def calculate_channels(self, L, kernel_size, stride, pad, n_convs): - for i in range(n_convs): + for _ in range(n_convs): L = (L - kernel_size + 2 * pad) // stride + 1 return L -class ResidualCouplingBlock(nn.Module): - def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.flows = nn.ModuleList() - for i in range(n_flows): - self.flows.append( - modules.ResidualCouplingLayer( - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=gin_channels, - mean_only=True, - ) - ) - self.flows.append(modules.Flip()) - - def forward(self, x, x_mask, g=None, reverse=False): - if not reverse: - for flow in self.flows: - x, _ = flow(x, x_mask, g=g, reverse=reverse) - else: - for flow in reversed(self.flows): - x = flow(x, x_mask, g=g, reverse=reverse) - return x - - class SynthesizerTrn(nn.Module): """ Synthesizer for Training @@ -382,22 +78,16 @@ class SynthesizerTrn(nn.Module): def __init__( self, - n_vocab, spec_channels, inter_channels, hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, - n_speakers=256, + n_speakers=0, gin_channels=256, zero_g=False, **kwargs, @@ -421,53 +111,17 @@ def __init__( 5, 1, 16, - gin_channels=gin_channels, + cond_channels=gin_channels, ) self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) self.n_speakers = n_speakers - if n_speakers == 0: - self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) - else: - self.enc_p = TextEncoder( - n_vocab, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout - ) - self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) - self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) - self.emb_g = nn.Embedding(n_speakers, gin_channels) + if n_speakers != 0: + raise ValueError("OpenVoice inference only supports n_speaker==0") + self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) self.zero_g = zero_g - def infer( - self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1.0, sdp_ratio=0.2, max_len=None - ): - x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) - if self.n_speakers > 0: - g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - else: - g = None - - logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio + self.dp( - x, x_mask, g=g - ) * (1 - sdp_ratio) - - w = torch.exp(logw) * x_mask * length_scale - w_ceil = torch.ceil(w) - y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = commons.generate_path(w_ceil, attn_mask) - - m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] - logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose( - 1, 2 - ) # [b, t', t], [b, t, d] -> [b, d, t'] - - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale - z = self.flow(z_p, y_mask, g=g, reverse=True) - o = self.dec((z * y_mask)[:, :, :max_len], g=g) - return o, attn, y_mask, (z, z_p, m_p, logs_p) - def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): g_src = sid_src g_tgt = sid_tgt diff --git a/TTS/vc/modules/openvoice/modules.py b/TTS/vc/modules/openvoice/modules.py deleted file mode 100644 index b3a60d5b12..0000000000 --- a/TTS/vc/modules/openvoice/modules.py +++ /dev/null @@ -1,588 +0,0 @@ -import math - -import torch -from torch import nn -from torch.nn import Conv1d -from torch.nn import functional as F -from torch.nn.utils import remove_weight_norm, weight_norm - -from TTS.vc.modules.openvoice import commons -from TTS.vc.modules.openvoice.attentions import Encoder -from TTS.vc.modules.openvoice.commons import get_padding, init_weights -from TTS.vc.modules.openvoice.transforms import piecewise_rational_quadratic_transform - -LRELU_SLOPE = 0.1 - - -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - - -class ConvReluNorm(nn.Module): - def __init__( - self, - in_channels, - hidden_channels, - out_channels, - kernel_size, - n_layers, - p_dropout, - ): - super().__init__() - self.in_channels = in_channels - self.hidden_channels = hidden_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - assert n_layers > 1, "Number of layers should be larger than 0." - - self.conv_layers = nn.ModuleList() - self.norm_layers = nn.ModuleList() - self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) - for _ in range(n_layers - 1): - self.conv_layers.append( - nn.Conv1d( - hidden_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2, - ) - ) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.proj = nn.Conv1d(hidden_channels, out_channels, 1) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward(self, x, x_mask): - x_org = x - for i in range(self.n_layers): - x = self.conv_layers[i](x * x_mask) - x = self.norm_layers[i](x) - x = self.relu_drop(x) - x = x_org + self.proj(x) - return x * x_mask - - -class DDSConv(nn.Module): - """ - Dilated and Depth-Separable Convolution - """ - - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size**i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append( - nn.Conv1d( - channels, - channels, - kernel_size, - groups=channels, - dilation=dilation, - padding=padding, - ) - ) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) - - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask - - -class WN(torch.nn.Module): - def __init__( - self, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0, - p_dropout=0, - ): - super(WN, self).__init__() - assert kernel_size % 2 == 1 - self.hidden_channels = hidden_channels - self.kernel_size = (kernel_size,) - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - self.p_dropout = p_dropout - - self.in_layers = torch.nn.ModuleList() - self.res_skip_layers = torch.nn.ModuleList() - self.drop = nn.Dropout(p_dropout) - - if gin_channels != 0: - cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") - - for i in range(n_layers): - dilation = dilation_rate**i - padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d( - hidden_channels, - 2 * hidden_channels, - kernel_size, - dilation=dilation, - padding=padding, - ) - in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") - self.in_layers.append(in_layer) - - # last one is not necessary - if i < n_layers - 1: - res_skip_channels = 2 * hidden_channels - else: - res_skip_channels = hidden_channels - - res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") - self.res_skip_layers.append(res_skip_layer) - - def forward(self, x, x_mask, g=None, **kwargs): - output = torch.zeros_like(x) - n_channels_tensor = torch.IntTensor([self.hidden_channels]) - - if g is not None: - g = self.cond_layer(g) - - for i in range(self.n_layers): - x_in = self.in_layers[i](x) - if g is not None: - cond_offset = i * 2 * self.hidden_channels - g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] - else: - g_l = torch.zeros_like(x_in) - - acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) - acts = self.drop(acts) - - res_skip_acts = self.res_skip_layers[i](acts) - if i < self.n_layers - 1: - res_acts = res_skip_acts[:, : self.hidden_channels, :] - x = (x + res_acts) * x_mask - output = output + res_skip_acts[:, self.hidden_channels :, :] - else: - output = output + res_skip_acts - return output * x_mask - - def remove_weight_norm(self): - if self.gin_channels != 0: - torch.nn.utils.remove_weight_norm(self.cond_layer) - for l in self.in_layers: - torch.nn.utils.remove_weight_norm(l) - for l in self.res_skip_layers: - torch.nn.utils.remove_weight_norm(l) - - -class ResBlock1(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): - super(ResBlock1, self).__init__() - self.convs1 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - ] - ) - self.convs2.apply(init_weights) - - def forward(self, x, x_mask=None): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - if x_mask is not None: - xt = xt * x_mask - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - if x_mask is not None: - xt = xt * x_mask - xt = c2(xt) - x = xt + x - if x_mask is not None: - x = x * x_mask - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - - -class ResBlock2(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3)): - super(ResBlock2, self).__init__() - self.convs = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - ] - ) - self.convs.apply(init_weights) - - def forward(self, x, x_mask=None): - for c in self.convs: - xt = F.leaky_relu(x, LRELU_SLOPE) - if x_mask is not None: - xt = xt * x_mask - xt = c(xt) - x = xt + x - if x_mask is not None: - x = x * x_mask - return x - - def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) - - -class Log(nn.Module): - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask - logdet = torch.sum(-y, [1, 2]) - return y, logdet - else: - x = torch.exp(x) * x_mask - return x - - -class Flip(nn.Module): - def forward(self, x, *args, reverse=False, **kwargs): - x = torch.flip(x, [1]) - if not reverse: - logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) - return x, logdet - else: - return x - - -class ElementwiseAffine(nn.Module): - def __init__(self, channels): - super().__init__() - self.channels = channels - self.m = nn.Parameter(torch.zeros(channels, 1)) - self.logs = nn.Parameter(torch.zeros(channels, 1)) - - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = self.m + torch.exp(self.logs) * x - y = y * x_mask - logdet = torch.sum(self.logs * x_mask, [1, 2]) - return y, logdet - else: - x = (x - self.m) * torch.exp(-self.logs) * x_mask - return x - - -class ResidualCouplingLayer(nn.Module): - def __init__( - self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - p_dropout=0, - gin_channels=0, - mean_only=False, - ): - assert channels % 2 == 0, "channels should be divisible by 2" - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.half_channels = channels // 2 - self.mean_only = mean_only - - self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = WN( - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - p_dropout=p_dropout, - gin_channels=gin_channels, - ) - self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) - self.post.weight.data.zero_() - self.post.bias.data.zero_() - - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels] * 2, 1) - h = self.pre(x0) * x_mask - h = self.enc(h, x_mask, g=g) - stats = self.post(h) * x_mask - if not self.mean_only: - m, logs = torch.split(stats, [self.half_channels] * 2, 1) - else: - m = stats - logs = torch.zeros_like(m) - - if not reverse: - x1 = m + x1 * torch.exp(logs) * x_mask - x = torch.cat([x0, x1], 1) - logdet = torch.sum(logs, [1, 2]) - return x, logdet - else: - x1 = (x1 - m) * torch.exp(-logs) * x_mask - x = torch.cat([x0, x1], 1) - return x - - -class ConvFlow(nn.Module): - def __init__( - self, - in_channels, - filter_channels, - kernel_size, - n_layers, - num_bins=10, - tail_bound=5.0, - ): - super().__init__() - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.num_bins = num_bins - self.tail_bound = tail_bound - self.half_channels = in_channels // 2 - - self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) - self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) - self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels] * 2, 1) - h = self.pre(x0) - h = self.convs(h, x_mask, g=g) - h = self.proj(h) * x_mask - - b, c, t = x0.shape - h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] - - unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) - unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) - unnormalized_derivatives = h[..., 2 * self.num_bins :] - - x1, logabsdet = piecewise_rational_quadratic_transform( - x1, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=reverse, - tails="linear", - tail_bound=self.tail_bound, - ) - - x = torch.cat([x0, x1], 1) * x_mask - logdet = torch.sum(logabsdet * x_mask, [1, 2]) - if not reverse: - return x, logdet - else: - return x - - -class TransformerCouplingLayer(nn.Module): - def __init__( - self, - channels, - hidden_channels, - kernel_size, - n_layers, - n_heads, - p_dropout=0, - filter_channels=0, - mean_only=False, - wn_sharing_parameter=None, - gin_channels=0, - ): - assert n_layers == 3, n_layers - assert channels % 2 == 0, "channels should be divisible by 2" - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.half_channels = channels // 2 - self.mean_only = mean_only - - self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = ( - Encoder( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - isflow=True, - gin_channels=gin_channels, - ) - if wn_sharing_parameter is None - else wn_sharing_parameter - ) - self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) - self.post.weight.data.zero_() - self.post.bias.data.zero_() - - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels] * 2, 1) - h = self.pre(x0) * x_mask - h = self.enc(h, x_mask, g=g) - stats = self.post(h) * x_mask - if not self.mean_only: - m, logs = torch.split(stats, [self.half_channels] * 2, 1) - else: - m = stats - logs = torch.zeros_like(m) - - if not reverse: - x1 = m + x1 * torch.exp(logs) * x_mask - x = torch.cat([x0, x1], 1) - logdet = torch.sum(logs, [1, 2]) - return x, logdet - else: - x1 = (x1 - m) * torch.exp(-logs) * x_mask - x = torch.cat([x0, x1], 1) - return x - - x1, logabsdet = piecewise_rational_quadratic_transform( - x1, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=reverse, - tails="linear", - tail_bound=self.tail_bound, - ) - - x = torch.cat([x0, x1], 1) * x_mask - logdet = torch.sum(logabsdet * x_mask, [1, 2]) - if not reverse: - return x, logdet - else: - return x diff --git a/TTS/vc/modules/openvoice/standalone_api.py b/TTS/vc/modules/openvoice/standalone_api.py deleted file mode 100644 index 831fd4dc43..0000000000 --- a/TTS/vc/modules/openvoice/standalone_api.py +++ /dev/null @@ -1,342 +0,0 @@ -import json -import os - -import librosa -import torch -import torch.utils.data -from librosa.filters import mel as librosa_mel_fn - -from TTS.vc.modules.openvoice.models import SynthesizerTrn - -# vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu) - -# vc_config.audio.output_sample_rate - - -class custom_sr_config: - """Class defined to make combatible sampling rate defination with TTS api.py. - - Args: - sampling rate. - """ - - def __init__(self, value): - self.audio = self.Audio(value) - - class Audio: - def __init__(self, value): - self.output_sample_rate = value - - -class OpenVoiceSynthesizer(object): - def __init__(self, vc_checkpoint, vc_config, use_cuda="cpu"): - - if use_cuda: - self.device = "cuda" - else: - self.device = "cpu" - - hps = get_hparams_from_file(vc_config) - self.vc_config = custom_sr_config(hps.data.sampling_rate) - - # vc_config.audio.output_sample_rate - self.model = SynthesizerTrn( - len(getattr(hps, "symbols", [])), - hps.data.filter_length // 2 + 1, - n_speakers=hps.data.n_speakers, - **hps.model, - ).to(torch.device(self.device)) - - self.hps = hps - self.load_ckpt(vc_checkpoint) - self.model.eval() - - def load_ckpt(self, ckpt_path): - checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device)) - a, b = self.model.load_state_dict(checkpoint_dict["model"], strict=False) - # print("Loaded checkpoint '{}'".format(ckpt_path)) - # print('missing/unexpected keys:', a, b) - - def extract_se(self, fpath): - audio_ref, sr = librosa.load(fpath, sr=self.hps.data.sampling_rate) - y = torch.FloatTensor(audio_ref) - y = y.to(self.device) - y = y.unsqueeze(0) - y = spectrogram_torch( - y, - self.hps.data.filter_length, - self.hps.data.sampling_rate, - self.hps.data.hop_length, - self.hps.data.win_length, - center=False, - ).to(self.device) - with torch.no_grad(): - g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1) - - return g - - # source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav" - def voice_conversion(self, source_wav, target_wav, tau=0.3, message="default"): - - if not os.path.exists(source_wav): - print("source wavpath dont exists") - exit(0) - - if not os.path.exists(target_wav): - print("target wavpath dont exists") - exit(0) - - src_se = self.extract_se(source_wav) - tgt_se = self.extract_se(target_wav) - - # load audio - audio, sample_rate = librosa.load(source_wav, sr=self.hps.data.sampling_rate) - audio = torch.tensor(audio).float() - - with torch.no_grad(): - y = torch.FloatTensor(audio).to(self.device) - y = y.unsqueeze(0) - spec = spectrogram_torch( - y, - self.hps.data.filter_length, - self.hps.data.sampling_rate, - self.hps.data.hop_length, - self.hps.data.win_length, - center=False, - ).to(self.device) - spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device) - audio = ( - self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][0, 0] - .data.cpu() - .float() - .numpy() - ) - - return audio - - -def get_hparams_from_file(config_path): - with open(config_path, "r", encoding="utf-8") as f: - data = f.read() - config = json.loads(data) - - hparams = HParams(**config) - return hparams - - -class HParams: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - if isinstance(v, dict): - v = HParams(**v) - self[k] = v - - def keys(self): - return self.__dict__.keys() - - def items(self): - return self.__dict__.items() - - def values(self): - return self.__dict__.values() - - def __len__(self): - return len(self.__dict__) - - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - return setattr(self, key, value) - - def __contains__(self, key): - return key in self.__dict__ - - def __repr__(self): - return self.__dict__.__repr__() - - -MAX_WAV_VALUE = 32768.0 - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - """ - PARAMS - ------ - C: compression factor - """ - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression_torch(x, C=1): - """ - PARAMS - ------ - C: compression factor used to compress - """ - return torch.exp(x) / C - - -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - - -def spectral_de_normalize_torch(magnitudes): - output = dynamic_range_decompression_torch(magnitudes) - return output - - -mel_basis = {} -hann_window = {} - - -def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): - if torch.min(y) < -1.1: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.1: - print("max value is ", torch.max(y)) - - global hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - wnsize_dtype_device = str(win_size) + "_" + dtype_device - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - return spec - - -def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False): - # if torch.min(y) < -1.: - # print('min value is ', torch.min(y)) - # if torch.max(y) > 1.: - # print('max value is ', torch.max(y)) - - global hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - wnsize_dtype_device = str(win_size) + "_" + dtype_device - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" - ) - - # ******************** original ************************# - # y = y.squeeze(1) - # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], - # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) - - # ******************** ConvSTFT ************************# - freq_cutoff = n_fft // 2 + 1 - fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft))) - forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1]) - forward_basis = ( - forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float() - ) - - import torch.nn.functional as F - - # if center: - # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1) - assert center is False - - forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride=hop_size) - spec2 = torch.stack( - [forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim=-1 - ) - - # ******************** Verification ************************# - spec1 = torch.stft( - y.squeeze(1), - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, - ) - assert torch.allclose(spec1, spec2, atol=1e-4) - - spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6) - return spec - - -def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): - global mel_basis - dtype_device = str(spec.dtype) + "_" + str(spec.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = spectral_normalize_torch(spec) - return spec - - -def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) - - global mel_basis, hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - wnsize_dtype_device = str(win_size) + "_" + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = spectral_normalize_torch(spec) - - return spec diff --git a/TTS/vc/modules/openvoice/transforms.py b/TTS/vc/modules/openvoice/transforms.py deleted file mode 100644 index 4270ebae3f..0000000000 --- a/TTS/vc/modules/openvoice/transforms.py +++ /dev/null @@ -1,203 +0,0 @@ -import numpy as np -import torch -from torch.nn import functional as F - -DEFAULT_MIN_BIN_WIDTH = 1e-3 -DEFAULT_MIN_BIN_HEIGHT = 1e-3 -DEFAULT_MIN_DERIVATIVE = 1e-3 - - -def piecewise_rational_quadratic_transform( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails=None, - tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - if tails is None: - spline_fn = rational_quadratic_spline - spline_kwargs = {} - else: - spline_fn = unconstrained_rational_quadratic_spline - spline_kwargs = {"tails": tails, "tail_bound": tail_bound} - - outputs, logabsdet = spline_fn( - inputs=inputs, - unnormalized_widths=unnormalized_widths, - unnormalized_heights=unnormalized_heights, - unnormalized_derivatives=unnormalized_derivatives, - inverse=inverse, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative, - **spline_kwargs, - ) - return outputs, logabsdet - - -def searchsorted(bin_locations, inputs, eps=1e-6): - bin_locations[..., -1] += eps - return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 - - -def unconstrained_rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails="linear", - tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) - outside_interval_mask = ~inside_interval_mask - - outputs = torch.zeros_like(inputs) - logabsdet = torch.zeros_like(inputs) - - if tails == "linear": - unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) - constant = np.log(np.exp(1 - min_derivative) - 1) - unnormalized_derivatives[..., 0] = constant - unnormalized_derivatives[..., -1] = constant - - outputs[outside_interval_mask] = inputs[outside_interval_mask] - logabsdet[outside_interval_mask] = 0 - else: - raise RuntimeError("{} tails are not implemented.".format(tails)) - - ( - outputs[inside_interval_mask], - logabsdet[inside_interval_mask], - ) = rational_quadratic_spline( - inputs=inputs[inside_interval_mask], - unnormalized_widths=unnormalized_widths[inside_interval_mask, :], - unnormalized_heights=unnormalized_heights[inside_interval_mask, :], - unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], - inverse=inverse, - left=-tail_bound, - right=tail_bound, - bottom=-tail_bound, - top=tail_bound, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative, - ) - - return outputs, logabsdet - - -def rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - left=0.0, - right=1.0, - bottom=0.0, - top=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - if torch.min(inputs) < left or torch.max(inputs) > right: - raise ValueError("Input to a transform is not within its domain") - - num_bins = unnormalized_widths.shape[-1] - - if min_bin_width * num_bins > 1.0: - raise ValueError("Minimal bin width too large for the number of bins") - if min_bin_height * num_bins > 1.0: - raise ValueError("Minimal bin height too large for the number of bins") - - widths = F.softmax(unnormalized_widths, dim=-1) - widths = min_bin_width + (1 - min_bin_width * num_bins) * widths - cumwidths = torch.cumsum(widths, dim=-1) - cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) - cumwidths = (right - left) * cumwidths + left - cumwidths[..., 0] = left - cumwidths[..., -1] = right - widths = cumwidths[..., 1:] - cumwidths[..., :-1] - - derivatives = min_derivative + F.softplus(unnormalized_derivatives) - - heights = F.softmax(unnormalized_heights, dim=-1) - heights = min_bin_height + (1 - min_bin_height * num_bins) * heights - cumheights = torch.cumsum(heights, dim=-1) - cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) - cumheights = (top - bottom) * cumheights + bottom - cumheights[..., 0] = bottom - cumheights[..., -1] = top - heights = cumheights[..., 1:] - cumheights[..., :-1] - - if inverse: - bin_idx = searchsorted(cumheights, inputs)[..., None] - else: - bin_idx = searchsorted(cumwidths, inputs)[..., None] - - input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] - input_bin_widths = widths.gather(-1, bin_idx)[..., 0] - - input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] - delta = heights / widths - input_delta = delta.gather(-1, bin_idx)[..., 0] - - input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] - input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] - - input_heights = heights.gather(-1, bin_idx)[..., 0] - - if inverse: - a = (inputs - input_cumheights) * ( - input_derivatives + input_derivatives_plus_one - 2 * input_delta - ) + input_heights * (input_delta - input_derivatives) - b = input_heights * input_derivatives - (inputs - input_cumheights) * ( - input_derivatives + input_derivatives_plus_one - 2 * input_delta - ) - c = -input_delta * (inputs - input_cumheights) - - discriminant = b.pow(2) - 4 * a * c - assert (discriminant >= 0).all() - - root = (2 * c) / (-b - torch.sqrt(discriminant)) - outputs = root * input_bin_widths + input_cumwidths - - theta_one_minus_theta = root * (1 - root) - denominator = input_delta + ( - (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta - ) - derivative_numerator = input_delta.pow(2) * ( - input_derivatives_plus_one * root.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - root).pow(2) - ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) - - return outputs, -logabsdet - else: - theta = (inputs - input_cumwidths) / input_bin_widths - theta_one_minus_theta = theta * (1 - theta) - - numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) - denominator = input_delta + ( - (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta - ) - outputs = input_cumheights + numerator / denominator - - derivative_numerator = input_delta.pow(2) * ( - input_derivatives_plus_one * theta.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - theta).pow(2) - ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) - - return outputs, logabsdet From 95998374bf7d7aeb2ff8556e9b03c4d29475c189 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 20 Jun 2024 10:16:35 +0200 Subject: [PATCH 04/15] feat(openvoice): add config classes --- TTS/vc/configs/freevc_config.py | 2 +- TTS/vc/configs/openvoice_config.py | 201 +++++++++++++++++++++++++++++ 2 files changed, 202 insertions(+), 1 deletion(-) create mode 100644 TTS/vc/configs/openvoice_config.py diff --git a/TTS/vc/configs/freevc_config.py b/TTS/vc/configs/freevc_config.py index 207181b303..d600bfb1f4 100644 --- a/TTS/vc/configs/freevc_config.py +++ b/TTS/vc/configs/freevc_config.py @@ -229,7 +229,7 @@ class FreeVCConfig(BaseVCConfig): If true, language embedding is used. Defaults to `False`. Note: - Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + Check :class:`TTS.tts.configs.shared_configs.BaseVCConfig` for the inherited parameters. Example: diff --git a/TTS/vc/configs/openvoice_config.py b/TTS/vc/configs/openvoice_config.py new file mode 100644 index 0000000000..261cdd6f47 --- /dev/null +++ b/TTS/vc/configs/openvoice_config.py @@ -0,0 +1,201 @@ +from dataclasses import dataclass, field +from typing import Optional + +from coqpit import Coqpit + +from TTS.vc.configs.shared_configs import BaseVCConfig + + +@dataclass +class OpenVoiceAudioConfig(Coqpit): + """Audio configuration + + Args: + input_sample_rate (int): + The sampling rate of the input waveform. + + output_sample_rate (int): + The sampling rate of the output waveform. + + fft_size (int): + The length of the filter. + + hop_length (int): + The hop length. + + win_length (int): + The window length. + """ + + input_sample_rate: int = field(default=22050) + output_sample_rate: int = field(default=22050) + fft_size: int = field(default=1024) + hop_length: int = field(default=256) + win_length: int = field(default=1024) + + +@dataclass +class OpenVoiceArgs(Coqpit): + """OpenVoice model arguments. + + zero_g (bool): + Whether to zero the gradients. + + inter_channels (int): + The number of channels in the intermediate layers. + + hidden_channels (int): + The number of channels in the hidden layers. + + filter_channels (int): + The number of channels in the filter layers. + + n_heads (int): + The number of attention heads. + + n_layers (int): + The number of layers. + + kernel_size (int): + The size of the kernel. + + p_dropout (float): + The dropout probability. + + resblock (str): + The type of residual block. + + resblock_kernel_sizes (List[int]): + The kernel sizes for the residual blocks. + + resblock_dilation_sizes (List[List[int]]): + The dilation sizes for the residual blocks. + + upsample_rates (List[int]): + The upsample rates. + + upsample_initial_channel (int): + The number of channels in the initial upsample layer. + + upsample_kernel_sizes (List[int]): + The kernel sizes for the upsample layers. + + n_layers_q (int): + The number of layers in the quantization network. + + use_spectral_norm (bool): + Whether to use spectral normalization. + + gin_channels (int): + The number of channels in the global conditioning vector. + + tau (float): + Tau parameter for the posterior encoder + """ + + zero_g: bool = field(default=True) + inter_channels: int = field(default=192) + hidden_channels: int = field(default=192) + filter_channels: int = field(default=768) + n_heads: int = field(default=2) + n_layers: int = field(default=6) + kernel_size: int = field(default=3) + p_dropout: float = field(default=0.1) + resblock: str = field(default="1") + resblock_kernel_sizes: list[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes: list[list[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates: list[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel: int = field(default=512) + upsample_kernel_sizes: list[int] = field(default_factory=lambda: [16, 16, 4, 4]) + n_layers_q: int = field(default=3) + use_spectral_norm: bool = field(default=False) + gin_channels: int = field(default=256) + tau: float = field(default=0.3) + + +@dataclass +class OpenVoiceConfig(BaseVCConfig): + """Defines parameters for OpenVoice VC model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (OpenVoiceArgs): + Model architecture arguments. Defaults to `OpenVoiceArgs()`. + + audio (OpenVoiceAudioConfig): + Audio processing configuration. Defaults to `OpenVoiceAudioConfig()`. + + return_wav (bool): + If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`. + + compute_linear_spec (bool): + If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + + use_weighted_sampler (bool): + If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`. + + weighted_sampler_attrs (dict): + Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities + by overweighting `root_path` by 2.0. Defaults to `{}`. + + weighted_sampler_multipliers (dict): + Weight each unique value of a key returned by the formatter for weighted sampling. + For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`. + It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`. + + r (int): + Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. + + add_blank (bool): + If true, a blank token is added in between every character. Defaults to `True`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseVCConfig` for the inherited parameters. + + Example: + + >>> from TTS.vc.configs.openvoice_config import OpenVoiceConfig + >>> config = OpenVoiceConfig() + """ + + model: str = "openvoice" + # model specific params + model_args: OpenVoiceArgs = field(default_factory=OpenVoiceArgs) + audio: OpenVoiceAudioConfig = field(default_factory=OpenVoiceAudioConfig) + + # optimizer + # TODO with training support + + # loss params + # TODO with training support + + # data loader params + return_wav: bool = True + compute_linear_spec: bool = True + + # sampler params + use_weighted_sampler: bool = False # TODO: move it to the base config + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 # DO NOT CHANGE + add_blank: bool = True + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + speakers_file: Optional[str] = None + speaker_embedding_channels: int = 256 + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: Optional[list[str]] = None + d_vector_dim: Optional[int] = None + + def __post_init__(self) -> None: + for key, val in self.model_args.items(): + if hasattr(self, key): + self[key] = val From ca02d0352bd5c9118d56fdbb06aef1c15782cf11 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 13 Nov 2024 19:47:32 +0100 Subject: [PATCH 05/15] feat(openvoice): add to .models.json --- TTS/.models.json | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/TTS/.models.json b/TTS/.models.json index 7c3a498bff..36654d0555 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -931,6 +931,28 @@ "license": "MIT", "commit": null } + }, + "multi-dataset": { + "openvoice_v1": { + "hf_url": [ + "https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter/config.json", + "https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter/checkpoint.pth" + ], + "description": "OpenVoice VC model from https://huggingface.co/myshell-ai/OpenVoiceV2", + "author": "MyShell.ai", + "license": "MIT", + "commit": null + }, + "openvoice_v2": { + "hf_url": [ + "https://huggingface.co/myshell-ai/OpenVoiceV2/resolve/main/converter/config.json", + "https://huggingface.co/myshell-ai/OpenVoiceV2/resolve/main/converter/checkpoint.pth" + ], + "description": "OpenVoice VC model from https://huggingface.co/myshell-ai/OpenVoiceV2", + "author": "MyShell.ai", + "license": "MIT", + "commit": null + } } } } From 1a21853b9022596ba1e609b687e278ec0beed0d8 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 13 Nov 2024 19:58:30 +0100 Subject: [PATCH 06/15] ci: validate .models.json file --- .pre-commit-config.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 92f6f3ab3c..62420e9958 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,8 @@ repos: - repo: "https://github.com/pre-commit/pre-commit-hooks" rev: v5.0.0 hooks: + - id: check-json + files: "TTS/.models.json" - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace From fce3137e0d0d0cad101bd0264673dc60447c3a8a Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Tue, 25 Jun 2024 23:01:47 +0200 Subject: [PATCH 07/15] feat: add openvoice vc model --- README.md | 1 + TTS/api.py | 6 +- TTS/utils/manage.py | 2 +- TTS/utils/synthesizer.py | 21 +- TTS/vc/models/openvoice.py | 320 +++++++++++++++++++++++++++ TTS/vc/modules/openvoice/__init__.py | 0 TTS/vc/modules/openvoice/models.py | 134 ----------- 7 files changed, 346 insertions(+), 138 deletions(-) create mode 100644 TTS/vc/models/openvoice.py delete mode 100644 TTS/vc/modules/openvoice/__init__.py delete mode 100644 TTS/vc/modules/openvoice/models.py diff --git a/README.md b/README.md index 5ca825b6ba..381a8e95f2 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,7 @@ repository are also still a useful source of information. ### Voice Conversion - FreeVC: [paper](https://arxiv.org/abs/2210.15418) +- OpenVoice: [technical report](https://arxiv.org/abs/2312.01479) You can also help us implement more models. diff --git a/TTS/api.py b/TTS/api.py index 250ed1a0d9..12e82af52c 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -155,8 +155,10 @@ def load_vc_model_by_name(self, model_name: str, gpu: bool = False): gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ self.model_name = model_name - model_path, config_path, _, _, _ = self.download_model_by_name(model_name) - self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu) + model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name) + self.voice_converter = Synthesizer( + vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu + ) def load_tts_model_by_name(self, model_name: str, gpu: bool = False): """Load one of 🐸TTS models by name. diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index bd445b3a2f..38fcfd60e9 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -424,7 +424,7 @@ def _find_files(output_path: str) -> Tuple[str, str]: model_file = None config_file = None for file_name in os.listdir(output_path): - if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]: + if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]: model_file = os.path.join(output_path, file_name) elif file_name == "config.json": config_file = os.path.join(output_path, file_name) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 90af4f48f9..a158df60e1 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,6 +1,7 @@ import logging import os import time +from pathlib import Path from typing import List import numpy as np @@ -15,7 +16,9 @@ from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import save_wav +from TTS.vc.configs.openvoice_config import OpenVoiceConfig from TTS.vc.models import setup_model as setup_vc_model +from TTS.vc.models.openvoice import OpenVoice from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input @@ -97,7 +100,7 @@ def __init__( self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) self.output_sample_rate = self.vocoder_config.audio["sample_rate"] - if vc_checkpoint: + if vc_checkpoint and model_dir is None: self._load_vc(vc_checkpoint, vc_config, use_cuda) self.output_sample_rate = self.vc_config.audio["output_sample_rate"] @@ -105,6 +108,9 @@ def __init__( if "fairseq" in model_dir: self._load_fairseq_from_dir(model_dir, use_cuda) self.output_sample_rate = self.tts_config.audio["sample_rate"] + elif "openvoice" in model_dir: + self._load_openvoice_from_dir(Path(model_dir), use_cuda) + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] else: self._load_tts_from_dir(model_dir, use_cuda) self.output_sample_rate = self.tts_config.audio["output_sample_rate"] @@ -153,6 +159,19 @@ def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None: if use_cuda: self.tts_model.cuda() + def _load_openvoice_from_dir(self, checkpoint: Path, use_cuda: bool) -> None: + """Load the OpenVoice model from a directory. + + We assume the model knows how to load itself from the directory and + there is a config.json file in the directory. + """ + self.vc_config = OpenVoiceConfig() + self.vc_model = OpenVoice.init_from_config(self.vc_config) + self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True) + self.vc_config = self.vc_model.config + if use_cuda: + self.vc_model.cuda() + def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: """Load the TTS model from a directory. diff --git a/TTS/vc/models/openvoice.py b/TTS/vc/models/openvoice.py new file mode 100644 index 0000000000..135b0861b9 --- /dev/null +++ b/TTS/vc/models/openvoice.py @@ -0,0 +1,320 @@ +import json +import logging +import os +from pathlib import Path +from typing import Any, Mapping, Optional, Union + +import librosa +import numpy as np +import numpy.typing as npt +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn import functional as F +from trainer.io import load_fsspec + +from TTS.tts.layers.vits.networks import PosteriorEncoder +from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.audio.torch_transforms import wav_to_spec +from TTS.vc.configs.openvoice_config import OpenVoiceConfig +from TTS.vc.models.base_vc import BaseVC +from TTS.vc.models.freevc import Generator, ResidualCouplingBlock + +logger = logging.getLogger(__name__) + + +class ReferenceEncoder(nn.Module): + """NN module creating a fixed size prosody embedding from a spectrogram. + + inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + + def __init__(self, spec_channels: int, embedding_dim: int = 0, layernorm: bool = True) -> None: + super().__init__() + self.spec_channels = spec_channels + ref_enc_filters = [32, 32, 64, 64, 128, 128] + K = len(ref_enc_filters) + filters = [1] + ref_enc_filters + convs = [ + torch.nn.utils.parametrizations.weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] + self.convs = nn.ModuleList(convs) + + out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) + self.proj = nn.Linear(128, embedding_dim) + self.layernorm = nn.LayerNorm(self.spec_channels) if layernorm else None + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + N = inputs.size(0) + + out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] + if self.layernorm is not None: + out = self.layernorm(out) + + for conv in self.convs: + out = conv(out) + out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + _memory, out = self.gru(out) # out --- [1, N, 128] + + return self.proj(out.squeeze(0)) + + def calculate_channels(self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int: + for _ in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class OpenVoice(BaseVC): + """ + OpenVoice voice conversion model (inference only). + + Source: https://github.com/myshell-ai/OpenVoice + Paper: https://arxiv.org/abs/2312.01479 + + Paper abstract: + We introduce OpenVoice, a versatile voice cloning approach that requires + only a short audio clip from the reference speaker to replicate their voice and + generate speech in multiple languages. OpenVoice represents a significant + advancement in addressing the following open challenges in the field: 1) + Flexible Voice Style Control. OpenVoice enables granular control over voice + styles, including emotion, accent, rhythm, pauses, and intonation, in addition + to replicating the tone color of the reference speaker. The voice styles are not + directly copied from and constrained by the style of the reference speaker. + Previous approaches lacked the ability to flexibly manipulate voice styles after + cloning. 2) Zero-Shot Cross-Lingual Voice Cloning. OpenVoice achieves zero-shot + cross-lingual voice cloning for languages not included in the massive-speaker + training set. Unlike previous approaches, which typically require extensive + massive-speaker multi-lingual (MSML) dataset for all languages, OpenVoice can + clone voices into a new language without any massive-speaker training data for + that language. OpenVoice is also computationally efficient, costing tens of + times less than commercially available APIs that offer even inferior + performance. To foster further research in the field, we have made the source + code and trained model publicly accessible. We also provide qualitative results + in our demo website. Prior to its public release, our internal version of + OpenVoice was used tens of millions of times by users worldwide between May and + October 2023, serving as the backend of MyShell. + """ + + def __init__(self, config: Coqpit, speaker_manager: Optional[SpeakerManager] = None) -> None: + super().__init__(config, None, speaker_manager, None) + + self.init_multispeaker(config) + + self.zero_g = self.args.zero_g + self.inter_channels = self.args.inter_channels + self.hidden_channels = self.args.hidden_channels + self.filter_channels = self.args.filter_channels + self.n_heads = self.args.n_heads + self.n_layers = self.args.n_layers + self.kernel_size = self.args.kernel_size + self.p_dropout = self.args.p_dropout + self.resblock = self.args.resblock + self.resblock_kernel_sizes = self.args.resblock_kernel_sizes + self.resblock_dilation_sizes = self.args.resblock_dilation_sizes + self.upsample_rates = self.args.upsample_rates + self.upsample_initial_channel = self.args.upsample_initial_channel + self.upsample_kernel_sizes = self.args.upsample_kernel_sizes + self.n_layers_q = self.args.n_layers_q + self.use_spectral_norm = self.args.use_spectral_norm + self.gin_channels = self.args.gin_channels + self.tau = self.args.tau + + self.spec_channels = config.audio.fft_size // 2 + 1 + + self.dec = Generator( + self.inter_channels, + self.resblock, + self.resblock_kernel_sizes, + self.resblock_dilation_sizes, + self.upsample_rates, + self.upsample_initial_channel, + self.upsample_kernel_sizes, + gin_channels=self.gin_channels, + ) + self.enc_q = PosteriorEncoder( + self.spec_channels, + self.inter_channels, + self.hidden_channels, + kernel_size=5, + dilation_rate=1, + num_layers=16, + cond_channels=self.gin_channels, + ) + + self.flow = ResidualCouplingBlock( + self.inter_channels, + self.hidden_channels, + kernel_size=5, + dilation_rate=1, + n_layers=4, + gin_channels=self.gin_channels, + ) + + self.ref_enc = ReferenceEncoder(self.spec_channels, self.gin_channels) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @staticmethod + def init_from_config(config: OpenVoiceConfig) -> "OpenVoice": + return OpenVoice(config) + + def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None: + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + + Args: + config (Coqpit): Model configuration. + data (list, optional): Dataset items to infer number of speakers. Defaults to None. + """ + self.num_spks = config.num_speakers + if self.speaker_manager: + self.num_spks = self.speaker_manager.num_speakers + + def load_checkpoint( + self, + config: OpenVoiceConfig, + checkpoint_path: Union[str, os.PathLike[Any]], + eval: bool = False, + strict: bool = True, + cache: bool = False, + ) -> None: + """Map from OpenVoice's config structure.""" + config_path = Path(checkpoint_path).parent / "config.json" + with open(config_path, encoding="utf-8") as f: + config_org = json.load(f) + self.config.audio.input_sample_rate = config_org["data"]["sampling_rate"] + self.config.audio.output_sample_rate = config_org["data"]["sampling_rate"] + self.config.audio.fft_size = config_org["data"]["filter_length"] + self.config.audio.hop_length = config_org["data"]["hop_length"] + self.config.audio.win_length = config_org["data"]["win_length"] + state = load_fsspec(str(checkpoint_path), map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"], strict=strict) + if eval: + self.eval() + + def forward(self) -> None: ... + def train_step(self) -> None: ... + def eval_step(self) -> None: ... + + @staticmethod + def _set_x_lengths(x: torch.Tensor, aux_input: Mapping[str, Optional[torch.Tensor]]) -> torch.Tensor: + if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: + return aux_input["x_lengths"] + return torch.tensor(x.shape[1:2]).to(x.device) + + @torch.no_grad() + def inference( + self, + x: torch.Tensor, + aux_input: Mapping[str, Optional[torch.Tensor]] = {"x_lengths": None, "g_src": None, "g_tgt": None}, + ) -> dict[str, torch.Tensor]: + """ + Inference pass of the model + + Args: + x (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len). + x_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,). + g_src (torch.Tensor): Source speaker embedding tensor. Shape: (batch_size, spk_emb_dim). + g_tgt (torch.Tensor): Target speaker embedding tensor. Shape: (batch_size, spk_emb_dim). + + Returns: + o_hat: Output spectrogram tensor. Shape: (batch_size, spec_seq_len, spec_dim). + x_mask: Spectrogram mask. Shape: (batch_size, spec_seq_len). + (z, z_p, z_hat): A tuple of latent variables. + """ + x_lengths = self._set_x_lengths(x, aux_input) + if "g_src" in aux_input and aux_input["g_src"] is not None: + g_src = aux_input["g_src"] + else: + raise ValueError("aux_input must define g_src") + if "g_tgt" in aux_input and aux_input["g_tgt"] is not None: + g_tgt = aux_input["g_tgt"] + else: + raise ValueError("aux_input must define g_tgt") + z, _m_q, _logs_q, y_mask = self.enc_q( + x, x_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=self.tau + ) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt)) + return { + "model_outputs": o_hat, + "y_mask": y_mask, + "z": z, + "z_p": z_p, + "z_hat": z_hat, + } + + def load_audio(self, wav: Union[str, npt.NDArray[np.float32], torch.Tensor, list[float]]) -> torch.Tensor: + """Read and format the input audio.""" + if isinstance(wav, str): + out = torch.from_numpy(librosa.load(wav, sr=self.config.audio.input_sample_rate)[0]) + elif isinstance(wav, np.ndarray): + out = torch.from_numpy(wav) + elif isinstance(wav, list): + out = torch.from_numpy(np.array(wav)) + else: + out = wav + return out.to(self.device).float() + + def extract_se(self, audio: Union[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + audio_ref = self.load_audio(audio) + y = torch.FloatTensor(audio_ref) + y = y.to(self.device) + y = y.unsqueeze(0) + spec = wav_to_spec( + y, + n_fft=self.config.audio.fft_size, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, + center=False, + ).to(self.device) + with torch.no_grad(): + g = self.ref_enc(spec.transpose(1, 2)).unsqueeze(-1) + + return g, spec + + @torch.inference_mode() + def voice_conversion(self, src: Union[str, torch.Tensor], tgt: Union[str, torch.Tensor]) -> npt.NDArray[np.float32]: + """ + Voice conversion pass of the model. + + Args: + src (str or torch.Tensor): Source utterance. + tgt (str or torch.Tensor): Target utterance. + + Returns: + Output numpy array. + """ + src_se, src_spec = self.extract_se(src) + tgt_se, _ = self.extract_se(tgt) + + aux_input = {"g_src": src_se, "g_tgt": tgt_se} + audio = self.inference(src_spec, aux_input) + return audio["model_outputs"][0, 0].data.cpu().float().numpy() diff --git a/TTS/vc/modules/openvoice/__init__.py b/TTS/vc/modules/openvoice/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/TTS/vc/modules/openvoice/models.py b/TTS/vc/modules/openvoice/models.py deleted file mode 100644 index 89a1c3a40c..0000000000 --- a/TTS/vc/modules/openvoice/models.py +++ /dev/null @@ -1,134 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F - -from TTS.tts.layers.vits.networks import PosteriorEncoder -from TTS.vc.models.freevc import Generator, ResidualCouplingBlock - - -class ReferenceEncoder(nn.Module): - """ - inputs --- [N, Ty/r, n_mels*r] mels - outputs --- [N, ref_enc_gru_size] - """ - - def __init__(self, spec_channels, gin_channels=0, layernorm=True): - super().__init__() - self.spec_channels = spec_channels - ref_enc_filters = [32, 32, 64, 64, 128, 128] - K = len(ref_enc_filters) - filters = [1] + ref_enc_filters - convs = [ - torch.nn.utils.parametrizations.weight_norm( - nn.Conv2d( - in_channels=filters[i], - out_channels=filters[i + 1], - kernel_size=(3, 3), - stride=(2, 2), - padding=(1, 1), - ) - ) - for i in range(K) - ] - self.convs = nn.ModuleList(convs) - - out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) - self.gru = nn.GRU( - input_size=ref_enc_filters[-1] * out_channels, - hidden_size=256 // 2, - batch_first=True, - ) - self.proj = nn.Linear(128, gin_channels) - if layernorm: - self.layernorm = nn.LayerNorm(self.spec_channels) - else: - self.layernorm = None - - def forward(self, inputs): - N = inputs.size(0) - - out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] - if self.layernorm is not None: - out = self.layernorm(out) - - for conv in self.convs: - out = conv(out) - out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] - - out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] - T = out.size(1) - N = out.size(0) - out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] - - self.gru.flatten_parameters() - _memory, out = self.gru(out) # out --- [1, N, 128] - - return self.proj(out.squeeze(0)) - - def calculate_channels(self, L, kernel_size, stride, pad, n_convs): - for _ in range(n_convs): - L = (L - kernel_size + 2 * pad) // stride + 1 - return L - - -class SynthesizerTrn(nn.Module): - """ - Synthesizer for Training - """ - - def __init__( - self, - spec_channels, - inter_channels, - hidden_channels, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - n_speakers=0, - gin_channels=256, - zero_g=False, - **kwargs, - ): - super().__init__() - - self.dec = Generator( - inter_channels, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=gin_channels, - ) - self.enc_q = PosteriorEncoder( - spec_channels, - inter_channels, - hidden_channels, - 5, - 1, - 16, - cond_channels=gin_channels, - ) - - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) - - self.n_speakers = n_speakers - if n_speakers != 0: - raise ValueError("OpenVoice inference only supports n_speaker==0") - self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) - self.zero_g = zero_g - - def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): - g_src = sid_src - g_tgt = sid_tgt - z, m_q, logs_q, y_mask = self.enc_q( - y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau - ) - z_p = self.flow(z, y_mask, g=g_src) - z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) - o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt)) - return o_hat, y_mask, (z, z_p, z_hat) From d488441b756570ff4b82c1fe5e27d4406bf553a7 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 13 Nov 2024 22:55:46 +0100 Subject: [PATCH 08/15] test(freevc): remove unused code --- tests/vc_tests/test_freevc.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/tests/vc_tests/test_freevc.py b/tests/vc_tests/test_freevc.py index c90551b494..914237b520 100644 --- a/tests/vc_tests/test_freevc.py +++ b/tests/vc_tests/test_freevc.py @@ -22,15 +22,12 @@ class TestFreeVC(unittest.TestCase): def _create_inputs(self, config, batch_size=2): - input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device) - input_lengths = torch.randint(100, 30 * config.audio["hop_length"], (batch_size,)).long().to(device) - input_lengths[-1] = 30 * config.audio["hop_length"] spec = torch.rand(batch_size, 30, config.audio["filter_length"] // 2 + 1).to(device) mel = torch.rand(batch_size, 30, config.audio["n_mel_channels"]).to(device) spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) spec_lengths[-1] = spec.size(2) waveform = torch.rand(batch_size, spec.size(2) * config.audio["hop_length"]).to(device) - return input_dummy, input_lengths, mel, spec, spec_lengths, waveform + return mel, spec, spec_lengths, waveform @staticmethod def _create_inputs_inference(): @@ -38,15 +35,6 @@ def _create_inputs_inference(): target_wav = torch.rand(16000) return source_wav, target_wav - @staticmethod - def _check_parameter_changes(model, model_ref): - count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref - ) - count += 1 - def test_methods(self): config = FreeVCConfig() model = FreeVC(config).to(device) @@ -69,7 +57,7 @@ def _test_forward(self, batch_size): model.train() print(" > Num parameters for FreeVC model:%s" % (count_parameters(model))) - _, _, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size) + mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size) wavlm_vec = model.extract_wavlm_features(waveform) wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long) @@ -86,7 +74,7 @@ def _test_inference(self, batch_size): model = FreeVC(config).to(device) model.eval() - _, _, mel, _, _, waveform = self._create_inputs(config, batch_size) + mel, _, _, waveform = self._create_inputs(config, batch_size) wavlm_vec = model.extract_wavlm_features(waveform) wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long) From 6927e0bb89f0c76dbbf5d14716cebd72ee13b2a5 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 29 Nov 2024 16:17:02 +0100 Subject: [PATCH 09/15] fix(api): clearer error message when model doesn't support VC --- TTS/api.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 12e82af52c..ed82825007 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -357,15 +357,17 @@ def voice_conversion( target_wav (str):` Path to the target wav file. """ - wav = self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav) - return wav + if self.voice_converter is None: + msg = "The selected model does not support voice conversion." + raise RuntimeError(msg) + return self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav) def voice_conversion_to_file( self, source_wav: str, target_wav: str, file_path: str = "output.wav", - ): + ) -> str: """Voice conversion with FreeVC. Convert source wav to target speaker. Args: From 546f43cb254793366f996deab33eb1cc88e915bd Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 29 Nov 2024 16:27:14 +0100 Subject: [PATCH 10/15] refactor: only use keyword args in Synthesizer --- TTS/bin/synthesize.py | 24 +++++++++++------------ TTS/utils/synthesizer.py | 1 + tests/inference_tests/test_synthesizer.py | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 20e429df04..454f528ab4 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -407,18 +407,18 @@ def main(): # load models synthesizer = Synthesizer( - tts_path, - tts_config_path, - speakers_file_path, - language_ids_file_path, - vocoder_path, - vocoder_config_path, - encoder_path, - encoder_config_path, - vc_path, - vc_config_path, - model_dir, - args.voice_dir, + tts_checkpoint=tts_path, + tts_config_path=tts_config_path, + tts_speakers_file=speakers_file_path, + tts_languages_file=language_ids_file_path, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + encoder_checkpoint=encoder_path, + encoder_config=encoder_config_path, + vc_checkpoint=vc_path, + vc_config=vc_config_path, + model_dir=model_dir, + voice_dir=args.voice_dir, ).to(device) # query speaker ids of a multi-speaker model. diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index a158df60e1..73f596d167 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -28,6 +28,7 @@ class Synthesizer(nn.Module): def __init__( self, + *, tts_checkpoint: str = "", tts_config_path: str = "", tts_speakers_file: str = "", diff --git a/tests/inference_tests/test_synthesizer.py b/tests/inference_tests/test_synthesizer.py index ce4fc751c2..21cc194131 100644 --- a/tests/inference_tests/test_synthesizer.py +++ b/tests/inference_tests/test_synthesizer.py @@ -23,7 +23,7 @@ def test_in_out(self): tts_root_path = get_tests_input_path() tts_checkpoint = os.path.join(tts_root_path, "checkpoint_10.pth") tts_config = os.path.join(tts_root_path, "dummy_model_config.json") - synthesizer = Synthesizer(tts_checkpoint, tts_config, None, None) + synthesizer = Synthesizer(tts_checkpoint=tts_checkpoint, tts_config_path=tts_config) synthesizer.tts("Better this test works!!") def test_split_into_sentences(self): From 9ef2c7ed624fda8ac8052ea3824132b5ab6b4481 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 2 Dec 2024 00:09:39 +0100 Subject: [PATCH 11/15] test(freevc): fix output length check --- tests/vc_tests/test_freevc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/vc_tests/test_freevc.py b/tests/vc_tests/test_freevc.py index 914237b520..fe07b2723c 100644 --- a/tests/vc_tests/test_freevc.py +++ b/tests/vc_tests/test_freevc.py @@ -31,7 +31,7 @@ def _create_inputs(self, config, batch_size=2): @staticmethod def _create_inputs_inference(): - source_wav = torch.rand(16000) + source_wav = torch.rand(15999) target_wav = torch.rand(16000) return source_wav, target_wav @@ -96,8 +96,8 @@ def test_voice_conversion(self): source_wav, target_wav = self._create_inputs_inference() output_wav = model.voice_conversion(source_wav, target_wav) assert ( - output_wav.shape[0] + config.audio.hop_length == source_wav.shape[0] - ), f"{output_wav.shape} != {source_wav.shape}" + output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length + ), f"{output_wav.shape} != {source_wav.shape}, {config.audio.hop_length}" def test_train_step(self): ... From 5f8ad4c64b26960dad6b1399deae5f9a0a4aade2 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 29 Nov 2024 17:23:30 +0100 Subject: [PATCH 12/15] test(openvoice): add sanity check --- tests/vc_tests/test_openvoice.py | 42 ++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/vc_tests/test_openvoice.py diff --git a/tests/vc_tests/test_openvoice.py b/tests/vc_tests/test_openvoice.py new file mode 100644 index 0000000000..c9f7ae3931 --- /dev/null +++ b/tests/vc_tests/test_openvoice.py @@ -0,0 +1,42 @@ +import os +import unittest + +import torch + +from tests import get_tests_input_path +from TTS.vc.models.openvoice import OpenVoice, OpenVoiceConfig + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +c = OpenVoiceConfig() + +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") + + +class TestOpenVoice(unittest.TestCase): + + @staticmethod + def _create_inputs_inference(): + source_wav = torch.rand(16100) + target_wav = torch.rand(16000) + return source_wav, target_wav + + def test_load_audio(self): + config = OpenVoiceConfig() + model = OpenVoice(config).to(device) + wav = model.load_audio(WAV_FILE) + wav2 = model.load_audio(wav) + assert all(torch.isclose(wav, wav2)) + + def test_voice_conversion(self): + config = OpenVoiceConfig() + model = OpenVoice(config).to(device) + model.eval() + + source_wav, target_wav = self._create_inputs_inference() + output_wav = model.voice_conversion(source_wav, target_wav) + assert ( + output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length + ), f"{output_wav.shape} != {source_wav.shape}" From 32c99e8e66d06055ddb44a321481447bffcf8bb1 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 13 Jun 2024 16:35:59 +0200 Subject: [PATCH 13/15] docs(readme): mention openvoice vc --- README.md | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 381a8e95f2..7dddf3a37b 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,12 @@ ## 🐸Coqui TTS News - 📣 Fork of the [original, unmaintained repository](https://github.com/coqui-ai/TTS). New PyPI package: [coqui-tts](https://pypi.org/project/coqui-tts) +- 📣 [OpenVoice](https://github.com/myshell-ai/OpenVoice) models now available for voice conversion. - 📣 Prebuilt wheels are now also published for Mac and Windows (in addition to Linux as before) for easier installation across platforms. -- 📣 ⓍTTSv2 is here with 16 languages and better performance across the board. +- 📣 ⓍTTSv2 is here with 17 languages and better performance across the board. ⓍTTS can stream with <200ms latency. - 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/idiap/coqui-ai-TTS/tree/dev/recipes/ljspeech). -- 📣 ⓍTTS can now stream with <200ms latency. -- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/latest/models/xtts.html) - 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/bark.html) -- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS. +- 📣 You can use [Fairseq models in ~1100 languages](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS. ## @@ -245,8 +244,14 @@ tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progr tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav") ``` -#### Example voice cloning together with the voice conversion model. -This way, you can clone voices by using any model in 🐸TTS. +Other available voice conversion models: +- `voice_conversion_models/multilingual/multi-dataset/openvoice_v1` +- `voice_conversion_models/multilingual/multi-dataset/openvoice_v2` + +#### Example voice cloning together with the default voice conversion model. + +This way, you can clone voices by using any model in 🐸TTS. The FreeVC model is +used for voice conversion after synthesizing speech. ```python @@ -413,4 +418,6 @@ $ tts --out_path output/path/speech.wav --model_name "// Date: Mon, 2 Dec 2024 00:16:39 +0100 Subject: [PATCH 14/15] refactor(vc): rename TTS.vc.modules to TTS.vc.layers for consistency Same as in TTS.tts and TTS.vocoder --- TTS/vc/{modules => layers}/__init__.py | 0 TTS/vc/{modules => layers}/freevc/__init__.py | 0 TTS/vc/{modules => layers}/freevc/commons.py | 0 TTS/vc/{modules => layers}/freevc/mel_processing.py | 0 TTS/vc/{modules => layers}/freevc/modules.py | 2 +- .../freevc/speaker_encoder/__init__.py | 0 .../freevc/speaker_encoder/audio.py | 2 +- .../freevc/speaker_encoder/hparams.py | 0 .../freevc/speaker_encoder/speaker_encoder.py | 4 ++-- TTS/vc/{modules => layers}/freevc/wavlm/__init__.py | 2 +- TTS/vc/{modules => layers}/freevc/wavlm/config.json | 0 TTS/vc/{modules => layers}/freevc/wavlm/modules.py | 0 TTS/vc/{modules => layers}/freevc/wavlm/wavlm.py | 2 +- TTS/vc/models/freevc.py | 13 ++++++------- 14 files changed, 12 insertions(+), 13 deletions(-) rename TTS/vc/{modules => layers}/__init__.py (100%) rename TTS/vc/{modules => layers}/freevc/__init__.py (100%) rename TTS/vc/{modules => layers}/freevc/commons.py (100%) rename TTS/vc/{modules => layers}/freevc/mel_processing.py (100%) rename TTS/vc/{modules => layers}/freevc/modules.py (99%) rename TTS/vc/{modules => layers}/freevc/speaker_encoder/__init__.py (100%) rename TTS/vc/{modules => layers}/freevc/speaker_encoder/audio.py (97%) rename TTS/vc/{modules => layers}/freevc/speaker_encoder/hparams.py (100%) rename TTS/vc/{modules => layers}/freevc/speaker_encoder/speaker_encoder.py (98%) rename TTS/vc/{modules => layers}/freevc/wavlm/__init__.py (94%) rename TTS/vc/{modules => layers}/freevc/wavlm/config.json (100%) rename TTS/vc/{modules => layers}/freevc/wavlm/modules.py (100%) rename TTS/vc/{modules => layers}/freevc/wavlm/wavlm.py (99%) diff --git a/TTS/vc/modules/__init__.py b/TTS/vc/layers/__init__.py similarity index 100% rename from TTS/vc/modules/__init__.py rename to TTS/vc/layers/__init__.py diff --git a/TTS/vc/modules/freevc/__init__.py b/TTS/vc/layers/freevc/__init__.py similarity index 100% rename from TTS/vc/modules/freevc/__init__.py rename to TTS/vc/layers/freevc/__init__.py diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/layers/freevc/commons.py similarity index 100% rename from TTS/vc/modules/freevc/commons.py rename to TTS/vc/layers/freevc/commons.py diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/layers/freevc/mel_processing.py similarity index 100% rename from TTS/vc/modules/freevc/mel_processing.py rename to TTS/vc/layers/freevc/mel_processing.py diff --git a/TTS/vc/modules/freevc/modules.py b/TTS/vc/layers/freevc/modules.py similarity index 99% rename from TTS/vc/modules/freevc/modules.py rename to TTS/vc/layers/freevc/modules.py index ea17be24d6..c34f22d701 100644 --- a/TTS/vc/modules/freevc/modules.py +++ b/TTS/vc/layers/freevc/modules.py @@ -7,7 +7,7 @@ from TTS.tts.layers.generic.normalization import LayerNorm2 from TTS.tts.layers.generic.wavenet import fused_add_tanh_sigmoid_multiply -from TTS.vc.modules.freevc.commons import init_weights +from TTS.vc.layers.freevc.commons import init_weights from TTS.vocoder.models.hifigan_generator import get_padding LRELU_SLOPE = 0.1 diff --git a/TTS/vc/modules/freevc/speaker_encoder/__init__.py b/TTS/vc/layers/freevc/speaker_encoder/__init__.py similarity index 100% rename from TTS/vc/modules/freevc/speaker_encoder/__init__.py rename to TTS/vc/layers/freevc/speaker_encoder/__init__.py diff --git a/TTS/vc/modules/freevc/speaker_encoder/audio.py b/TTS/vc/layers/freevc/speaker_encoder/audio.py similarity index 97% rename from TTS/vc/modules/freevc/speaker_encoder/audio.py rename to TTS/vc/layers/freevc/speaker_encoder/audio.py index 5b23a4dbb6..5fa317ce45 100644 --- a/TTS/vc/modules/freevc/speaker_encoder/audio.py +++ b/TTS/vc/layers/freevc/speaker_encoder/audio.py @@ -5,7 +5,7 @@ import librosa import numpy as np -from TTS.vc.modules.freevc.speaker_encoder.hparams import ( +from TTS.vc.layers.freevc.speaker_encoder.hparams import ( audio_norm_target_dBFS, mel_n_channels, mel_window_length, diff --git a/TTS/vc/modules/freevc/speaker_encoder/hparams.py b/TTS/vc/layers/freevc/speaker_encoder/hparams.py similarity index 100% rename from TTS/vc/modules/freevc/speaker_encoder/hparams.py rename to TTS/vc/layers/freevc/speaker_encoder/hparams.py diff --git a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py b/TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py similarity index 98% rename from TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py rename to TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py index 294bf322cb..a6d5bcf942 100644 --- a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py +++ b/TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py @@ -7,8 +7,8 @@ from torch import nn from trainer.io import load_fsspec -from TTS.vc.modules.freevc.speaker_encoder import audio -from TTS.vc.modules.freevc.speaker_encoder.hparams import ( +from TTS.vc.layers.freevc.speaker_encoder import audio +from TTS.vc.layers.freevc.speaker_encoder.hparams import ( mel_n_channels, mel_window_step, model_embedding_size, diff --git a/TTS/vc/modules/freevc/wavlm/__init__.py b/TTS/vc/layers/freevc/wavlm/__init__.py similarity index 94% rename from TTS/vc/modules/freevc/wavlm/__init__.py rename to TTS/vc/layers/freevc/wavlm/__init__.py index 4046e137f5..62f7e74aaf 100644 --- a/TTS/vc/modules/freevc/wavlm/__init__.py +++ b/TTS/vc/layers/freevc/wavlm/__init__.py @@ -6,7 +6,7 @@ from trainer.io import get_user_data_dir from TTS.utils.generic_utils import is_pytorch_at_least_2_4 -from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig +from TTS.vc.layers.freevc.wavlm.wavlm import WavLM, WavLMConfig logger = logging.getLogger(__name__) diff --git a/TTS/vc/modules/freevc/wavlm/config.json b/TTS/vc/layers/freevc/wavlm/config.json similarity index 100% rename from TTS/vc/modules/freevc/wavlm/config.json rename to TTS/vc/layers/freevc/wavlm/config.json diff --git a/TTS/vc/modules/freevc/wavlm/modules.py b/TTS/vc/layers/freevc/wavlm/modules.py similarity index 100% rename from TTS/vc/modules/freevc/wavlm/modules.py rename to TTS/vc/layers/freevc/wavlm/modules.py diff --git a/TTS/vc/modules/freevc/wavlm/wavlm.py b/TTS/vc/layers/freevc/wavlm/wavlm.py similarity index 99% rename from TTS/vc/modules/freevc/wavlm/wavlm.py rename to TTS/vc/layers/freevc/wavlm/wavlm.py index 10dd09ed0c..775f3e5979 100644 --- a/TTS/vc/modules/freevc/wavlm/wavlm.py +++ b/TTS/vc/layers/freevc/wavlm/wavlm.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from torch.nn import LayerNorm -from TTS.vc.modules.freevc.wavlm.modules import ( +from TTS.vc.layers.freevc.wavlm.modules import ( Fp32GroupNorm, Fp32LayerNorm, GLU_Linear, diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index 62559de534..c654219c39 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -12,17 +12,16 @@ from torch.nn.utils.parametrize import remove_parametrizations from trainer.io import load_fsspec -import TTS.vc.modules.freevc.commons as commons -import TTS.vc.modules.freevc.modules as modules +import TTS.vc.layers.freevc.modules as modules from TTS.tts.layers.vits.discriminator import DiscriminatorS from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.vc.configs.freevc_config import FreeVCConfig +from TTS.vc.layers.freevc.commons import init_weights, rand_slice_segments +from TTS.vc.layers.freevc.mel_processing import mel_spectrogram_torch +from TTS.vc.layers.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx +from TTS.vc.layers.freevc.wavlm import get_wavlm from TTS.vc.models.base_vc import BaseVC -from TTS.vc.modules.freevc.commons import init_weights -from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch -from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx -from TTS.vc.modules.freevc.wavlm import get_wavlm from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP logger = logging.getLogger(__name__) @@ -385,7 +384,7 @@ def forward( z_p = self.flow(z, spec_mask, g=g) # Randomly slice z and compute o using dec - z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size) + z_slice, ids_slice = rand_slice_segments(z, spec_lengths, self.segment_size) o = self.dec(z_slice, g=g) return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q) From 3539e65d8e9d31d44c57b2c4a84ae1f372ade611 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 2 Dec 2024 22:50:33 +0100 Subject: [PATCH 15/15] refactor(synthesizer): set sample rate in loading methods --- TTS/utils/synthesizer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 73f596d167..a9b9feffc1 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -95,26 +95,20 @@ def __init__( if tts_checkpoint: self._load_tts(tts_checkpoint, tts_config_path, use_cuda) - self.output_sample_rate = self.tts_config.audio["sample_rate"] if vocoder_checkpoint: self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) - self.output_sample_rate = self.vocoder_config.audio["sample_rate"] if vc_checkpoint and model_dir is None: self._load_vc(vc_checkpoint, vc_config, use_cuda) - self.output_sample_rate = self.vc_config.audio["output_sample_rate"] if model_dir: if "fairseq" in model_dir: self._load_fairseq_from_dir(model_dir, use_cuda) - self.output_sample_rate = self.tts_config.audio["sample_rate"] elif "openvoice" in model_dir: self._load_openvoice_from_dir(Path(model_dir), use_cuda) - self.output_sample_rate = self.vc_config.audio["output_sample_rate"] else: self._load_tts_from_dir(model_dir, use_cuda) - self.output_sample_rate = self.tts_config.audio["output_sample_rate"] @staticmethod def _get_segmenter(lang: str): @@ -143,6 +137,7 @@ def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> N """ # pylint: disable=global-statement self.vc_config = load_config(vc_config_path) + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] self.vc_model = setup_vc_model(config=self.vc_config) self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint) if use_cuda: @@ -157,6 +152,7 @@ def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None: self.tts_model = Vits.init_from_config(self.tts_config) self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True) self.tts_config = self.tts_model.config + self.output_sample_rate = self.tts_config.audio["sample_rate"] if use_cuda: self.tts_model.cuda() @@ -170,6 +166,7 @@ def _load_openvoice_from_dir(self, checkpoint: Path, use_cuda: bool) -> None: self.vc_model = OpenVoice.init_from_config(self.vc_config) self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True) self.vc_config = self.vc_model.config + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] if use_cuda: self.vc_model.cuda() @@ -180,6 +177,7 @@ def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: """ config = load_config(os.path.join(model_dir, "config.json")) self.tts_config = config + self.output_sample_rate = self.tts_config.audio["output_sample_rate"] self.tts_model = setup_tts_model(config) self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True) if use_cuda: @@ -201,6 +199,7 @@ def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) - """ # pylint: disable=global-statement self.tts_config = load_config(tts_config_path) + self.output_sample_rate = self.tts_config.audio["sample_rate"] if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None: raise ValueError("Phonemizer is not defined in the TTS config.") @@ -238,6 +237,7 @@ def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> N use_cuda (bool): enable/disable CUDA use. """ self.vocoder_config = load_config(model_config) + self.output_sample_rate = self.vocoder_config.audio["sample_rate"] self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio) self.vocoder_model = setup_vocoder_model(self.vocoder_config) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)