diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 73a840f7..f5be289f 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -2,23 +2,9 @@ - - - - - - - - - - - - - - - - - + + + diff --git a/src/so_vits_svc_fork/f0.py b/src/so_vits_svc_fork/f0.py index 990f2050..c104d4f5 100644 --- a/src/so_vits_svc_fork/f0.py +++ b/src/so_vits_svc_fork/f0.py @@ -29,8 +29,6 @@ def normalize_f0( factor = torch.ones(f0.shape[0], 1).to(f0.device) # normalize f0 based on means and factor f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1) - if torch.isnan(f0_norm).any(): - exit(0) return f0_norm * x_mask @@ -218,17 +216,17 @@ def compute_f0( def f0_to_coarse(f0: torch.Tensor | float): is_torch = isinstance(f0, torch.Tensor) f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) - f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / ( - f0_mel_max - f0_mel_min - ) + 1 + # f0_mel[f0_mel > 0] = ... + f0_mel = (f0_mel - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 - f0_mel[f0_mel <= 1] = 1 - f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 + # f0_mel[f0_mel <= 1] = 1 + # f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 + f0_mel = torch.clamp(f0_mel, 1, f0_bin - 1) f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) - assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, ( - f0_coarse.max(), - f0_coarse.min(), - ) + # assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, ( + # f0_coarse.max(), + # f0_coarse.min(), + # ) return f0_coarse diff --git a/src/so_vits_svc_fork/modules/commons.py b/src/so_vits_svc_fork/modules/commons.py index 16b3cfbe..b9ef83c7 100644 --- a/src/so_vits_svc_fork/modules/commons.py +++ b/src/so_vits_svc_fork/modules/commons.py @@ -1,27 +1,49 @@ -import math - import torch -from torch.nn import functional as F +from torch import Tensor + + +def slice_segments(x: Tensor, starts: Tensor, length: int) -> Tensor: + x_slice = torch.zeros((x.size()[:-1] + (length,)), dtype=x.dtype, device=x.device) + ends = starts + length + for i, (start, end) in enumerate(zip(starts, ends)): + x_slice[i, ...] = x[i, ..., start:end] + return x_slice -def slice_pitch_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, idx_str:idx_end] - return ret +def slice_2d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor: + batch_size, num_features, seq_len = x.shape + ends = starts + length + idxs = ( + torch.arange(seq_len) + .unsqueeze(0) + .unsqueeze(1) + .repeat(batch_size, num_features, 1) + ) + mask = (idxs >= starts.unsqueeze(-1).unsqueeze(-1)) & ( + idxs < ends.unsqueeze(-1).unsqueeze(-1) + ) + return x[mask].reshape(batch_size, num_features, length) -def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size) - return ret, ret_pitch, ids_str +def slice_1d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor: + batch_size, seq_len = x.shape + ends = starts + length + idxs = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1) + mask = (idxs >= starts.unsqueeze(-1)) & (idxs < ends.unsqueeze(-1)) + return x[mask].reshape(batch_size, length) + + +def _slice_segments_v3(x: Tensor, starts: Tensor, length: int) -> Tensor: + shape = x.shape[:-1] + (length,) + ends = starts + length + idxs = torch.arange(x.shape[-1], device=x.device).unsqueeze(0).unsqueeze(0) + unsqueeze_dims = len(shape) - len( + x.shape + ) # calculate number of dimensions to unsqueeze + starts = starts.reshape(starts.shape + (1,) * unsqueeze_dims) + ends = ends.reshape(ends.shape + (1,) * unsqueeze_dims) + mask = (idxs >= starts) & (idxs < ends) + return x[mask].reshape(shape) def init_weights(m, mean=0.0, std=0.01): @@ -40,89 +62,6 @@ def convert_pad_shape(pad_shape): return pad_shape -def intersperse(lst, item): - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - - -def kl_divergence(m_p, logs_p, m_q, logs_q): - """KL(P||Q)""" - kl = (logs_q - logs_p) - 0.5 - kl += ( - 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) - ) - return kl - - -def rand_gumbel(shape): - """Sample from the Gumbel distribution, protect from overflows.""" - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) - - -def rand_gumbel_like(x): - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g - - -def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret - - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def rand_spec_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( - num_timescales - 1 - ) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment - ) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal - - -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) - - -def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) - - def subsequent_mask(length): mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) return mask @@ -138,11 +77,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): return acts -def shift_1d(x): - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x - - def sequence_mask(length, max_length=None): if max_length is None: max_length = length.max() @@ -150,24 +84,6 @@ def sequence_mask(length, max_length=None): return x.unsqueeze(0) < length.unsqueeze(1) -def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - duration.device - - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2, 3) * mask - return path - - def clip_grad_value_(parameters, clip_value, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] diff --git a/src/so_vits_svc_fork/modules/mel_processing.py b/src/so_vits_svc_fork/modules/mel_processing.py index 298acd86..6b325a86 100644 --- a/src/so_vits_svc_fork/modules/mel_processing.py +++ b/src/so_vits_svc_fork/modules/mel_processing.py @@ -1,4 +1,4 @@ -"""from logging import getLogger +from logging import getLogger import torch import torch.utils.data @@ -42,7 +42,8 @@ def mel_spectrogram_torch(audio: torch.Tensor, hps: HParams) -> torch.Tensor: power=1.0, window_fn=torch.hann_window, normalized=False, - ).to(audio.device)(audio)""" + ).to(audio.device)(audio) + from logging import getLogger @@ -87,7 +88,7 @@ def spectral_de_normalize_torch(magnitudes): hann_window = {} -def spectrogram_torch(y, hps, center=False): +def spectrogram_torch_old(y, hps, center=False): if torch.min(y) < -1.0: LOG.info("min value is ", torch.min(y)) if torch.max(y) > 1.0: @@ -127,7 +128,7 @@ def spectrogram_torch(y, hps, center=False): return spec -def spec_to_mel_torch(spec, hps): +def spec_to_mel_torch_old(spec, hps): sampling_rate = hps.data.sampling_rate n_fft = hps.data.filter_length num_mels = hps.data.n_mel_channels @@ -148,7 +149,7 @@ def spec_to_mel_torch(spec, hps): return spec -def mel_spectrogram_torch(y, hps, center=False): +def mel_spectrogram_torch_old(y, hps, center=False): sampling_rate = hps.data.sampling_rate n_fft = hps.data.filter_length num_mels = hps.data.n_mel_channels diff --git a/src/so_vits_svc_fork/modules/synthesizers.py b/src/so_vits_svc_fork/modules/synthesizers.py index c96e021b..732fbb20 100644 --- a/src/so_vits_svc_fork/modules/synthesizers.py +++ b/src/so_vits_svc_fork/modules/synthesizers.py @@ -1,3 +1,4 @@ +import random import warnings from logging import getLogger from typing import Any, Literal, Sequence @@ -51,6 +52,7 @@ def __init__( gen_istft_n_fft: int = 16, gen_istft_hop_size: int = 4, subbands: int = 4, + ipu: bool = False, **kwargs: Any, ): super().__init__() @@ -162,8 +164,28 @@ def __init__( spk_channels=gin_channels, ) self.emb_uv = nn.Embedding(2, hidden_channels) + self.ipu = ipu - def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None): + def forward( + self, + c: torch.Tensor, + f0: torch.Tensor, + uv: torch.Tensor, + spec: torch.Tensor, + g: torch.Tensor, + c_lengths: torch.Tensor, + spec_lengths: torch.Tensor, + ) -> torch.Tensor: + """ + B: batch size + c: content, (B, ssl_dim, T) + f0: f0, (B, T) + uv: uv, (B, T) + spec: spectrogram, (B, F, T) + g: speaker id, (B,) + c_lengths: content length, (B,) + spec_lengths: spectrogram length, (B,) + """ g = self.emb_g(g).transpose(1, 2) # ssl prenet x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to( @@ -179,24 +201,34 @@ def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None): # encoder z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0)) z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g) + # z: [batch, dim, time] # flow z_p = self.flow(z, spec_mask, g=g) - z_slice, pitch_slice, ids_slice = commons.rand_slice_segments_with_pitch( - z, f0, spec_lengths, self.segment_size - ) + + # randomly slice to time = self.segment_size + if not self.ipu: + slice_starts = ( + torch.rand(z.size(0)) * (spec_lengths - self.segment_size) + ).long() + z_slice = commons.slice_2d_segments(z, slice_starts, self.segment_size) + f0_slice = commons.slice_1d_segments(f0, slice_starts, self.segment_size) + else: + slice_starts = random.randint(0, spec.size(2) - self.segment_size) + z_slice = z[..., slice_starts : slice_starts + self.segment_size] + f0_slice = f0[..., slice_starts : slice_starts + self.segment_size] # MB-iSTFT-VITS if self.mb: o, o_mb = self.dec(z_slice, g=g) # HiFi-GAN else: - o = self.dec(z_slice, g=g, f0=pitch_slice) + o = self.dec(z_slice, g=g, f0=f0_slice) o_mb = None return ( o, o_mb, - ids_slice, + slice_starts, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, @@ -204,7 +236,15 @@ def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None): lf0, ) - def infer(self, c, f0, uv, g=None, noice_scale=0.35, predict_f0=False): + def infer( + self, + c: torch.Tensor, + f0: torch.Tensor, + uv: torch.Tensor, + g: torch.Tensor, + noice_scale: float = 0.35, + predict_f0: bool = False, + ) -> torch.Tensor: c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) g = self.emb_g(g).transpose(1, 2) x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to( diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index 2209de6b..195dbff8 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -8,7 +8,7 @@ import lightning.pytorch as pl import torch -from lightning.pytorch.accelerators import TPUAccelerator +from lightning.pytorch.accelerators import IPUAccelerator, TPUAccelerator from lightning.pytorch.loggers import TensorBoardLogger from torch.cuda.amp import autocast from torch.nn import functional as F @@ -120,6 +120,92 @@ def stft( torch.stft = stft + if isinstance(self.trainer.accelerator, IPUAccelerator): + # patch mel_scale_fbanks + LOG.warning( + "Using IPU. Patching torchaudio.functional.mel_scale_fbanks not to use max()" + ) + + import torchaudio.functional + from torchaudio.functional.functional import ( + _create_triangular_filterbank, + _hz_to_mel, + _mel_to_hz, + ) + + def melscale_fbanks( + n_freqs: int, + f_min: float, + f_max: float, + n_mels: int, + sample_rate: int, + norm: str | None = None, + mel_scale: str = "htk", + ) -> torch.Tensor: + r"""Create a frequency bin conversion matrix. + + .. devices:: CPU + + .. properties:: TorchScript + + Note: + For the sake of the numerical compatibility with librosa, not all the coefficients + in the resulting filter bank has magnitude of 1. + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png + :alt: Visualization of generated filter bank + + Args: + n_freqs (int): Number of frequencies to highlight/apply + f_min (float): Minimum frequency (Hz) + f_max (float): Maximum frequency (Hz) + n_mels (int): Number of mel filterbanks + sample_rate (int): Sample rate of the audio waveform + norm (str or None, optional): If "slaney", divide the triangular mel weights by the width of the mel band + (area normalization). (Default: ``None``) + mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) + + Returns: + Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) + meaning number of frequencies to highlight/apply to x the number of filterbanks. + Each column is a filterbank so that assuming there is a matrix A of + size (..., ``n_freqs``), the applied result would be + ``A * melscale_fbanks(A.size(-1), ...)``. + + """ + + if norm is not None and norm != "slaney": + raise ValueError('norm must be one of None or "slaney"') + + # freq bins + all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) + + # calculate mel freq bins + m_min = _hz_to_mel(f_min, mel_scale=mel_scale) + m_max = _hz_to_mel(f_max, mel_scale=mel_scale) + + m_pts = torch.linspace(m_min, m_max, n_mels + 2) + f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale) + + # create filterbank + fb = _create_triangular_filterbank(all_freqs, f_pts) + + if norm is not None and norm == "slaney": + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) + fb *= enorm.unsqueeze(0) + + # if (fb.max(dim=0).values == 0.0).any(): + # warnings.warn( + # "At least one mel filterbank has all zero values. " + # f"The value for `n_mels` ({n_mels}) may be set too high. " + # f"Or, the value for `n_freqs` ({n_freqs}) may be set too low." + # ) + + return fb + + torchaudio.functional.melscale_fbanks = melscale_fbanks + def set_current_epoch(self, epoch: int): LOG.info(f"Setting current epoch to {epoch}") self.trainer.fit_loop.epoch_progress.current.completed = epoch @@ -171,6 +257,7 @@ def __init__(self, reset_optimizer: bool = False, **hparams: Any): self.net_g = SynthesizerTrn( self.hparams.data.filter_length // 2 + 1, self.hparams.train.segment_size // self.hparams.data.hop_length, + ipu=True, **self.hparams.model, ) self.net_d = MultiPeriodDiscriminator(self.hparams.model.use_spectral_norm) @@ -178,13 +265,13 @@ def __init__(self, reset_optimizer: bool = False, **hparams: Any): self.optim_g = torch.optim.AdamW( self.net_g.parameters(), self.hparams.train.learning_rate, - betas=self.hparams.train.betas, + betas=tuple(self.hparams.train.betas), eps=self.hparams.train.eps, ) self.optim_d = torch.optim.AdamW( self.net_d.parameters(), self.hparams.train.learning_rate, - betas=self.hparams.train.betas, + betas=tuple(self.hparams.train.betas), eps=self.hparams.train.eps, ) self.scheduler_g = torch.optim.lr_scheduler.ExponentialLR( @@ -195,6 +282,8 @@ def __init__(self, reset_optimizer: bool = False, **hparams: Any): ) def configure_optimizers(self): + if isinstance(self.trainer.accelerator, IPUAccelerator): + return [self.optim_g], [self.scheduler_g] return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d] def log_image_dict( @@ -225,7 +314,11 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None: self.net_d.train() # get optims - optim_g, optim_d = self.optimizers() + self.is_ipu = isinstance(self.trainer.accelerator, IPUAccelerator) + if self.is_ipu: + optim_g = self.optimizers() + else: + optim_g, optim_d = self.optimizers() # Generator # train @@ -241,17 +334,32 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None: norm_lf0, lf0, ) = self.net_g(c, f0, uv, spec, g=g, c_lengths=lengths, spec_lengths=lengths) - y_mel = commons.slice_segments( - mel, - ids_slice, - self.hparams.train.segment_size // self.hparams.data.hop_length, - ) + y_hat_mel = mel_spectrogram_torch(y_hat.squeeze(1), self.hparams) - y = commons.slice_segments( - y, - ids_slice * self.hparams.data.hop_length, - self.hparams.train.segment_size, - ) + if self.is_ipu: + y_mel = mel[ + ..., + ids_slice : self.hparams.train.segment_size + // self.hparams.data.hop_length + + ids_slice, + ] + y = y[ + ..., + ids_slice + * self.hparams.data.hop_length : self.hparams.train.segment_size + + ids_slice * self.hparams.data.hop_length, + ] + else: + y_mel = commons.slice_2d_segments( + mel, + ids_slice, + self.hparams.train.segment_size // self.hparams.data.hop_length, + ) + y = commons.slice_2d_segments( + y, + ids_slice * self.hparams.data.hop_length, + self.hparams.train.segment_size, + ) # generator loss LOG.debug("Calculating generator loss") @@ -320,6 +428,8 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None: optim_g.zero_grad() self.untoggle_optimizer(optim_g) + if self.is_ipu: + return # Discriminator # train self.toggle_optimizer(optim_d)