|
21 | 21 | from trainer.trainer_utils import get_optimizer, get_scheduler
|
22 | 22 |
|
23 | 23 | from TTS.tts.configs.shared_configs import CharactersConfig
|
24 |
| -from TTS.tts.datasets.dataset import TTSDataset, _parse_sample |
| 24 | +from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, get_attribute_balancer_weights |
25 | 25 | from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
26 | 26 | from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
27 | 27 | from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
@@ -218,30 +218,6 @@ class VitsAudioConfig(Coqpit):
|
218 | 218 | ##############################
|
219 | 219 |
|
220 | 220 |
|
221 |
| -def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): |
222 |
| - """Create inverse frequency weights for balancing the dataset. |
223 |
| - Use `multi_dict` to scale relative weights.""" |
224 |
| - attr_names_samples = np.array([item[attr_name] for item in items]) |
225 |
| - unique_attr_names = np.unique(attr_names_samples).tolist() |
226 |
| - attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] |
227 |
| - attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) |
228 |
| - weight_attr = 1.0 / attr_count |
229 |
| - dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) |
230 |
| - dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) |
231 |
| - if multi_dict is not None: |
232 |
| - # check if all keys are in the multi_dict |
233 |
| - for k in multi_dict: |
234 |
| - assert k in unique_attr_names, f"{k} not in {unique_attr_names}" |
235 |
| - # scale weights |
236 |
| - multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) |
237 |
| - dataset_samples_weight *= multiplier_samples |
238 |
| - return ( |
239 |
| - torch.from_numpy(dataset_samples_weight).float(), |
240 |
| - unique_attr_names, |
241 |
| - np.unique(dataset_samples_weight).tolist(), |
242 |
| - ) |
243 |
| - |
244 |
| - |
245 | 221 | class VitsDataset(TTSDataset):
|
246 | 222 | def __init__(self, model_args, *args, **kwargs):
|
247 | 223 | super().__init__(*args, **kwargs)
|
|
0 commit comments