Skip to content

Commit 8848d50

Browse files
committed
feat(vc): add knnvc model
1 parent b091605 commit 8848d50

File tree

12 files changed

+314
-18
lines changed

12 files changed

+314
-18
lines changed

TTS/.models.json

+26
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,22 @@
787787
"license": "apache 2.0"
788788
}
789789
},
790+
"librispeech100": {
791+
"wavlm-hifigan": {
792+
"description": "HiFiGAN vocoder for WavLM features from kNN-VC",
793+
"github_rls_url": "https://github.com/idiap/coqui-ai-TTS/releases/download/v0.25.2_models/vocoder_models--en--librispeech100--wavlm-hifigan.zip",
794+
"commit": "cfba7e0",
795+
"author": "Benjamin van Niekerk @bshall, Matthew Baas @RF5",
796+
"license": "MIT"
797+
},
798+
"wavlm-hifigan_prematched": {
799+
"description": "Prematched HiFiGAN vocoder for WavLM features from kNN-VC",
800+
"github_rls_url": "https://github.com/idiap/coqui-ai-TTS/releases/download/v0.25.2_models/vocoder_models--en--librispeech100--wavlm-hifigan_prematched.zip",
801+
"commit": "cfba7e0",
802+
"author": "Benjamin van Niekerk @bshall, Matthew Baas @RF5",
803+
"license": "MIT"
804+
}
805+
},
790806
"ljspeech": {
791807
"multiband-melgan": {
792808
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.6.1_models/vocoder_models--en--ljspeech--multiband-melgan.zip",
@@ -927,18 +943,27 @@
927943
"freevc24": {
928944
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/voice_conversion_models--multilingual--vctk--freevc24.zip",
929945
"description": "FreeVC model trained on VCTK dataset from https://github.com/OlaWod/FreeVC",
946+
"default_vocoder": null,
930947
"author": "Jing-Yi Li @OlaWod",
931948
"license": "MIT",
932949
"commit": null
933950
}
934951
},
935952
"multi-dataset": {
953+
"knnvc": {
954+
"description": "kNN-VC model from https://github.com/bshall/knn-vc",
955+
"default_vocoder": "vocoder_models/en/librispeech100/wavlm-hifigan_prematched",
956+
"author": "Benjamin van Niekerk @bshall, Matthew Baas @RF5",
957+
"license": "MIT",
958+
"commit": null
959+
},
936960
"openvoice_v1": {
937961
"hf_url": [
938962
"https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter/config.json",
939963
"https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter/checkpoint.pth"
940964
],
941965
"description": "OpenVoice VC model from https://huggingface.co/myshell-ai/OpenVoiceV2",
966+
"default_vocoder": null,
942967
"author": "MyShell.ai",
943968
"license": "MIT",
944969
"commit": null
@@ -949,6 +974,7 @@
949974
"https://huggingface.co/myshell-ai/OpenVoiceV2/resolve/main/converter/checkpoint.pth"
950975
],
951976
"description": "OpenVoice VC model from https://huggingface.co/myshell-ai/OpenVoiceV2",
977+
"default_vocoder": null,
952978
"author": "MyShell.ai",
953979
"license": "MIT",
954980
"commit": null

TTS/utils/generic_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def to_camel(text):
3131
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
3232
text = text.replace("Tts", "TTS")
3333
text = text.replace("vc", "VC")
34+
text = text.replace("Knn", "KNN")
3435
return text
3536

3637

TTS/utils/manage.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing_extensions import Required
1616

1717
from TTS.config import load_config, read_json_with_comments
18+
from TTS.vc.configs.knnvc_config import KNNVCConfig
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -267,9 +268,9 @@ def set_model_url(model_item: ModelItem) -> ModelItem:
267268
model_item["model_url"] = model_item["github_rls_url"]
268269
elif "hf_url" in model_item:
269270
model_item["model_url"] = model_item["hf_url"]
270-
elif "fairseq" in model_item["model_name"]:
271+
elif "fairseq" in model_item.get("model_name", ""):
271272
model_item["model_url"] = "https://dl.fbaipublicfiles.com/mms/tts/"
272-
elif "xtts" in model_item["model_name"]:
273+
elif "xtts" in model_item.get("model_name", ""):
273274
model_item["model_url"] = "https://huggingface.co/coqui/"
274275
return model_item
275276

@@ -367,6 +368,9 @@ def create_dir_and_download_model(self, model_name: str, model_item: ModelItem,
367368
logger.exception("Failed to download the model file to %s", output_path)
368369
rmtree(output_path)
369370
raise e
371+
checkpoints = list(Path(output_path).glob("*.pt*"))
372+
if len(checkpoints) == 1:
373+
checkpoints[0].rename(checkpoints[0].parent / "model.pth")
370374
self.print_model_license(model_item=model_item)
371375

372376
def check_if_configs_are_equal(self, model_name: str, model_item: ModelItem, output_path: Path) -> None:
@@ -431,11 +435,14 @@ def download_model(self, model_name: str) -> tuple[Path, Optional[Path], ModelIt
431435
output_model_path = output_path
432436
output_config_path = None
433437
if (
434-
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
438+
model not in ["tortoise-v2", "bark", "knnvc"] and "fairseq" not in model_name and "xtts" not in model_name
435439
): # TODO:This is stupid but don't care for now.
436440
output_model_path, output_config_path = self._find_files(output_path)
437441
else:
438442
output_config_path = output_model_path / "config.json"
443+
if model == "knnvc" and not output_config_path.exists():
444+
knnvc_config = KNNVCConfig()
445+
knnvc_config.save_json(output_config_path)
439446
# update paths in the config.json
440447
self._update_paths(output_path, output_config_path)
441448
return output_model_path, output_config_path, model_item

TTS/utils/synthesizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> N
139139
"""
140140
# pylint: disable=global-statement
141141
self.vc_config = load_config(vc_config_path)
142-
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
142+
self.output_sample_rate = self.vc_config.audio.get("output_sample_rate", self.vc_config.audio["sample_rate"])
143143
self.vc_model = setup_vc_model(config=self.vc_config)
144144
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
145145
if use_cuda:

TTS/vc/configs/knnvc_config.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from dataclasses import dataclass, field
2+
3+
from coqpit import Coqpit
4+
5+
from TTS.config.shared_configs import BaseAudioConfig
6+
from TTS.vc.configs.shared_configs import BaseVCConfig
7+
8+
9+
@dataclass
10+
class KNNVCAudioConfig(BaseAudioConfig):
11+
"""Audio configuration.
12+
13+
Args:
14+
sample_rate (int):
15+
The sampling rate of the input waveform.
16+
"""
17+
18+
sample_rate: int = field(default=16000)
19+
20+
21+
@dataclass
22+
class KNNVCArgs(Coqpit):
23+
"""Model arguments.
24+
25+
Args:
26+
ssl_dim (int):
27+
The dimension of the self-supervised learning embedding.
28+
"""
29+
30+
ssl_dim: int = field(default=1024)
31+
32+
33+
@dataclass
34+
class KNNVCConfig(BaseVCConfig):
35+
"""Parameters.
36+
37+
Args:
38+
model (str):
39+
Model name. Do not change unless you know what you are doing.
40+
41+
model_args (KNNVCArgs):
42+
Model architecture arguments. Defaults to `KNNVCArgs()`.
43+
44+
audio (KNNVCAudioConfig):
45+
Audio processing configuration. Defaults to `KNNVCAudioConfig()`.
46+
47+
wavlm_layer (int):
48+
WavLM layer to use for feature extraction.
49+
50+
topk (int):
51+
k in the kNN -- the number of nearest neighbors to average over
52+
"""
53+
54+
model: str = "knnvc"
55+
model_args: KNNVCArgs = field(default_factory=KNNVCArgs)
56+
audio: KNNVCAudioConfig = field(default_factory=KNNVCAudioConfig)
57+
58+
wavlm_layer: int = 6
59+
topk: int = 4

TTS/vc/layers/freevc/wavlm/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt"
1414

1515

16-
def get_wavlm(device="cpu"):
16+
def get_wavlm(device="cpu") -> WavLM:
1717
"""Download the model and return the model object."""
1818

1919
output_path = get_user_data_dir("tts")

TTS/vc/models/__init__.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
import importlib
22
import logging
33
import re
4-
from typing import Dict, List, Union
4+
from typing import Dict, List, Optional, Union
5+
6+
from TTS.vc.configs.shared_configs import BaseVCConfig
7+
from TTS.vc.models.base_vc import BaseVC
58

69
logger = logging.getLogger(__name__)
710

811

912
def setup_model(config: BaseVCConfig) -> BaseVC:
1013
logger.info("Using model: %s", config.model)
1114
# fetch the right model implementation.
12-
if "model" in config and config["model"].lower() == "freevc":
15+
if config["model"].lower() == "freevc":
1316
MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC
14-
model = MyModel.init_from_config(config)
15-
return model
17+
elif config["model"].lower() == "knnvc":
18+
MyModel = importlib.import_module("TTS.vc.models.knnvc").KNNVC
19+
else:
20+
msg = f"Model {config.model} does not exist!"
21+
raise ValueError(msg)
22+
return MyModel.init_from_config(config)

0 commit comments

Comments
 (0)