Skip to content

Commit 1ee323d

Browse files
committed
refactor: move shared function into dataset.py
1 parent f93b74e commit 1ee323d

File tree

3 files changed

+27
-45
lines changed

3 files changed

+27
-45
lines changed

TTS/tts/datasets/dataset.py

+25
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,31 @@ def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
6363
raise RuntimeError(msg) from e
6464

6565

66+
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: Optional[dict] = None):
67+
"""Create inverse frequency weights for balancing the dataset.
68+
69+
Use `multi_dict` to scale relative weights."""
70+
attr_names_samples = np.array([item[attr_name] for item in items])
71+
unique_attr_names = np.unique(attr_names_samples).tolist()
72+
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
73+
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
74+
weight_attr = 1.0 / attr_count
75+
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
76+
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
77+
if multi_dict is not None:
78+
# check if all keys are in the multi_dict
79+
for k in multi_dict:
80+
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
81+
# scale weights
82+
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
83+
dataset_samples_weight *= multiplier_samples
84+
return (
85+
torch.from_numpy(dataset_samples_weight).float(),
86+
unique_attr_names,
87+
np.unique(dataset_samples_weight).tolist(),
88+
)
89+
90+
6691
class TTSDataset(Dataset):
6792
def __init__(
6893
self,

TTS/tts/models/delightful_tts.py

+1-20
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
1818
from trainer.trainer_utils import get_optimizer, get_scheduler
1919

20-
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
20+
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample, get_attribute_balancer_weights
2121
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
2222
from TTS.tts.layers.losses import (
2323
ForwardSumLoss,
@@ -193,25 +193,6 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
193193
##############################
194194

195195

196-
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
197-
"""Create balancer weight for torch WeightedSampler"""
198-
attr_names_samples = np.array([item[attr_name] for item in items])
199-
unique_attr_names = np.unique(attr_names_samples).tolist()
200-
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
201-
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
202-
weight_attr = 1.0 / attr_count
203-
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
204-
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
205-
if multi_dict is not None:
206-
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
207-
dataset_samples_weight *= multiplier_samples
208-
return (
209-
torch.from_numpy(dataset_samples_weight).float(),
210-
unique_attr_names,
211-
np.unique(dataset_samples_weight).tolist(),
212-
)
213-
214-
215196
class ForwardTTSE2eF0Dataset(F0Dataset):
216197
"""Override F0Dataset to avoid slow computing of pitches"""
217198

TTS/tts/models/vits.py

+1-25
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from trainer.trainer_utils import get_optimizer, get_scheduler
2222

2323
from TTS.tts.configs.shared_configs import CharactersConfig
24-
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
24+
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, get_attribute_balancer_weights
2525
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
2626
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
2727
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
@@ -218,30 +218,6 @@ class VitsAudioConfig(Coqpit):
218218
##############################
219219

220220

221-
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
222-
"""Create inverse frequency weights for balancing the dataset.
223-
Use `multi_dict` to scale relative weights."""
224-
attr_names_samples = np.array([item[attr_name] for item in items])
225-
unique_attr_names = np.unique(attr_names_samples).tolist()
226-
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
227-
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
228-
weight_attr = 1.0 / attr_count
229-
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
230-
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
231-
if multi_dict is not None:
232-
# check if all keys are in the multi_dict
233-
for k in multi_dict:
234-
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
235-
# scale weights
236-
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
237-
dataset_samples_weight *= multiplier_samples
238-
return (
239-
torch.from_numpy(dataset_samples_weight).float(),
240-
unique_attr_names,
241-
np.unique(dataset_samples_weight).tolist(),
242-
)
243-
244-
245221
class VitsDataset(TTSDataset):
246222
def __init__(self, model_args, *args, **kwargs):
247223
super().__init__(*args, **kwargs)

0 commit comments

Comments
 (0)