|
15 | 15 | from typing_extensions import Required
|
16 | 16 |
|
17 | 17 | from TTS.config import load_config, read_json_with_comments
|
| 18 | +from TTS.vc.configs.knnvc_config import KNNVCConfig |
18 | 19 |
|
19 | 20 | logger = logging.getLogger(__name__)
|
20 | 21 |
|
@@ -267,9 +268,9 @@ def set_model_url(model_item: ModelItem) -> ModelItem:
|
267 | 268 | model_item["model_url"] = model_item["github_rls_url"]
|
268 | 269 | elif "hf_url" in model_item:
|
269 | 270 | model_item["model_url"] = model_item["hf_url"]
|
270 |
| - elif "fairseq" in model_item["model_name"]: |
| 271 | + elif "fairseq" in model_item.get("model_name", ""): |
271 | 272 | model_item["model_url"] = "https://dl.fbaipublicfiles.com/mms/tts/"
|
272 |
| - elif "xtts" in model_item["model_name"]: |
| 273 | + elif "xtts" in model_item.get("model_name", ""): |
273 | 274 | model_item["model_url"] = "https://huggingface.co/coqui/"
|
274 | 275 | return model_item
|
275 | 276 |
|
@@ -367,6 +368,9 @@ def create_dir_and_download_model(self, model_name: str, model_item: ModelItem,
|
367 | 368 | logger.exception("Failed to download the model file to %s", output_path)
|
368 | 369 | rmtree(output_path)
|
369 | 370 | raise e
|
| 371 | + checkpoints = list(Path(output_path).glob("*.pt*")) |
| 372 | + if len(checkpoints) == 1: |
| 373 | + checkpoints[0].rename(checkpoints[0].parent / "model.pth") |
370 | 374 | self.print_model_license(model_item=model_item)
|
371 | 375 |
|
372 | 376 | def check_if_configs_are_equal(self, model_name: str, model_item: ModelItem, output_path: Path) -> None:
|
@@ -431,11 +435,14 @@ def download_model(self, model_name: str) -> tuple[Path, Optional[Path], ModelIt
|
431 | 435 | output_model_path = output_path
|
432 | 436 | output_config_path = None
|
433 | 437 | if (
|
434 |
| - model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name |
| 438 | + model not in ["tortoise-v2", "bark", "knnvc"] and "fairseq" not in model_name and "xtts" not in model_name |
435 | 439 | ): # TODO:This is stupid but don't care for now.
|
436 | 440 | output_model_path, output_config_path = self._find_files(output_path)
|
437 | 441 | else:
|
438 | 442 | output_config_path = output_model_path / "config.json"
|
| 443 | + if model == "knnvc" and not output_config_path.exists(): |
| 444 | + knnvc_config = KNNVCConfig() |
| 445 | + knnvc_config.save_json(output_config_path) |
439 | 446 | # update paths in the config.json
|
440 | 447 | self._update_paths(output_path, output_config_path)
|
441 | 448 | return output_model_path, output_config_path, model_item
|
|
0 commit comments