Skip to content

Commit a63e9c2

Browse files
authored
feat(core): add load_models source local (#361)
rename original `local` to `custom`
1 parent ce1c962 commit a63e9c2

File tree

9 files changed

+353
-13
lines changed

9 files changed

+353
-13
lines changed

.github/workflows/checksum.yml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: Calculate and Sync SHA256
2+
on:
3+
push:
4+
branches:
5+
- main
6+
- dev
7+
jobs:
8+
checksum:
9+
runs-on: ubuntu-latest
10+
steps:
11+
- uses: actions/checkout@master
12+
13+
- name: Setup Go Environment
14+
uses: actions/setup-go@master
15+
16+
- name: Run RVC-Models-Downloader
17+
run: |
18+
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.5/rvcmd_linux_amd64.deb
19+
sudo apt -y install ./rvcmd_linux_amd64.deb
20+
rm -f ./rvcmd_linux_amd64.deb
21+
rvcmd -notrs -w 1 -notui assets/chtts
22+
23+
- name: Calculate all Checksums
24+
run: go run tools/checksum/*.go
25+
26+
- name: Commit back
27+
if: ${{ !github.head_ref }}
28+
id: commitback
29+
continue-on-error: true
30+
run: |
31+
git config --local user.name 'github-actions[bot]'
32+
git config --local user.email 'github-actions[bot]@users.noreply.github.com'
33+
git add --all
34+
git commit -m "chore(env): sync checksum on ${{github.ref_name}}"
35+
36+
- name: Create Pull Request
37+
if: steps.commitback.outcome == 'success'
38+
continue-on-error: true
39+
uses: peter-evans/create-pull-request@v5
40+
with:
41+
delete-branch: true
42+
body: "Automatically sync checksum in .env"
43+
title: "chore(env): sync checksum on ${{github.ref_name}}"
44+
commit-message: "chore(env): sync checksum on ${{github.ref_name}}"
45+
branch: checksum-${{github.ref_name}}

ChatTTS/core.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,21 @@
33
import json
44
import logging
55
from functools import partial
6-
from omegaconf import OmegaConf
6+
from typing import Literal
7+
import tempfile
78

89
import torch
10+
from omegaconf import OmegaConf
911
from vocos import Vocos
12+
from huggingface_hub import snapshot_download
13+
1014
from .model.dvae import DVAE
1115
from .model.gpt import GPT_warpper
1216
from .utils.gpu_utils import select_device
1317
from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer
1418
from .utils.io_utils import get_latest_modified_file
1519
from .infer.api import refine_text, infer_code
16-
17-
from huggingface_hub import snapshot_download
20+
from .utils.download import check_all_assets, download_all_assets
1821

1922
logging.basicConfig(level = logging.INFO)
2023

@@ -44,9 +47,23 @@ def check_model(self, level = logging.INFO, use_decoder = False):
4447
self.logger.log(level, f'All initialized.')
4548

4649
return not not_finish
47-
48-
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>', **kwargs):
49-
if source == 'huggingface':
50+
51+
def load_models(
52+
self,
53+
source: Literal['huggingface', 'local', 'custom']='local',
54+
force_redownload=False,
55+
custom_path='<LOCAL_PATH>',
56+
**kwargs,
57+
):
58+
if source == 'local':
59+
download_path = os.getcwd()
60+
if not check_all_assets(update=True):
61+
with tempfile.TemporaryDirectory() as tmp:
62+
download_all_assets(tmpdir=tmp)
63+
if not check_all_assets(update=False):
64+
logging.error("counld not satisfy all assets needed.")
65+
exit(1)
66+
elif source == 'huggingface':
5067
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
5168
try:
5269
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
@@ -57,9 +74,9 @@ def load_models(self, source='huggingface', force_redownload=False, local_path='
5774
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
5875
else:
5976
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
60-
elif source == 'local':
61-
self.logger.log(logging.INFO, f'Load from local: {local_path}')
62-
download_path = local_path
77+
elif source == 'custom':
78+
self.logger.log(logging.INFO, f'Load from local: {custom_path}')
79+
download_path = custom_path
6380

6481
self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
6582

ChatTTS/utils/download.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import os
2+
from pathlib import Path
3+
import hashlib
4+
import requests
5+
from io import BytesIO
6+
import logging
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
def sha256(f) -> str:
12+
sha256_hash = hashlib.sha256()
13+
# Read and update hash in chunks of 4M
14+
for byte_block in iter(lambda: f.read(4 * 1024 * 1024), b""):
15+
sha256_hash.update(byte_block)
16+
return sha256_hash.hexdigest()
17+
18+
19+
def check_model(
20+
dir_name: Path, model_name: str, hash: str, remove_incorrect=False
21+
) -> bool:
22+
target = dir_name / model_name
23+
relname = target.as_posix()
24+
logger.debug(f"checking {relname}...")
25+
if not os.path.exists(target):
26+
logger.info(f"{target} not exist.")
27+
return False
28+
with open(target, "rb") as f:
29+
digest = sha256(f)
30+
bakfile = f"{target}.bak"
31+
if digest != hash:
32+
logger.warn(f"{target} sha256 hash mismatch.")
33+
logger.info(f"expected: {hash}")
34+
logger.info(f"real val: {digest}")
35+
logger.warn("please add parameter --update to download the latest assets.")
36+
if remove_incorrect:
37+
if not os.path.exists(bakfile):
38+
os.rename(str(target), bakfile)
39+
else:
40+
os.remove(str(target))
41+
return False
42+
if remove_incorrect and os.path.exists(bakfile):
43+
os.remove(bakfile)
44+
return True
45+
46+
47+
def check_all_assets(update=False) -> bool:
48+
BASE_DIR = Path(__file__).resolve().parent.parent.parent
49+
50+
logger.info("checking assets...")
51+
current_dir = BASE_DIR / "asset"
52+
names = [
53+
"Decoder.pt",
54+
"DVAE.pt",
55+
"GPT.pt",
56+
"spk_stat.pt",
57+
"tokenizer.pt",
58+
"Vocos.pt",
59+
]
60+
for model in names:
61+
menv = model.replace(".", "_")
62+
if not check_model(
63+
current_dir, model, os.environ[f"sha256_asset_{menv}"], update
64+
):
65+
return False
66+
67+
logger.info("checking configs...")
68+
current_dir = BASE_DIR / "config"
69+
names = [
70+
"decoder.yaml",
71+
"dvae.yaml",
72+
"gpt.yaml",
73+
"path.yaml",
74+
"vocos.yaml",
75+
]
76+
for model in names:
77+
menv = model.replace(".", "_")
78+
if not check_model(
79+
current_dir, model, os.environ[f"sha256_config_{menv}"], update
80+
):
81+
return False
82+
83+
logger.info("all assets are already latest.")
84+
return True
85+
86+
87+
def download_and_extract_tar_gz(url: str, folder: str):
88+
import tarfile
89+
90+
logger.info(f"downloading {url}")
91+
response = requests.get(url, stream=True, timeout=(5, 10))
92+
with BytesIO() as out_file:
93+
out_file.write(response.content)
94+
out_file.seek(0)
95+
logger.info(f"downloaded.")
96+
with tarfile.open(fileobj=out_file, mode="r:gz") as tar:
97+
tar.extractall(folder)
98+
logger.info(f"extracted into {folder}")
99+
100+
101+
def download_and_extract_zip(url: str, folder: str):
102+
import zipfile
103+
104+
logger.info(f"downloading {url}")
105+
response = requests.get(url, stream=True, timeout=(5, 10))
106+
with BytesIO() as out_file:
107+
out_file.write(response.content)
108+
out_file.seek(0)
109+
logger.info(f"downloaded.")
110+
with zipfile.ZipFile(out_file) as zip_ref:
111+
zip_ref.extractall(folder)
112+
logger.info(f"extracted into {folder}")
113+
114+
115+
def download_dns_yaml(url: str, folder: str):
116+
logger.info(f"downloading {url}")
117+
response = requests.get(url, stream=True, timeout=(5, 10))
118+
with open(os.path.join(folder, "dns.yaml"), "wb") as out_file:
119+
out_file.write(response.content)
120+
logger.info(f"downloaded into {folder}")
121+
122+
123+
def download_all_assets(tmpdir: str, version="0.2.5"):
124+
import subprocess
125+
import platform
126+
127+
archs = {
128+
"aarch64": "arm64",
129+
"armv8l": "arm64",
130+
"arm64": "arm64",
131+
"x86": "386",
132+
"i386": "386",
133+
"i686": "386",
134+
"386": "386",
135+
"x86_64": "amd64",
136+
"x64": "amd64",
137+
"amd64": "amd64",
138+
}
139+
system_type = platform.system().lower()
140+
architecture = platform.machine().lower()
141+
is_win = system_type == "windows"
142+
143+
architecture = archs.get(architecture, None)
144+
if not architecture:
145+
logger.error(f"architecture {architecture} is not supported")
146+
exit(1)
147+
try:
148+
BASE_URL = "https://github.com/fumiama/RVC-Models-Downloader/releases/download/"
149+
suffix = "zip" if is_win else "tar.gz"
150+
RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}"
151+
cmdfile = os.path.join(tmpdir, "rvcmd")
152+
if is_win:
153+
download_and_extract_zip(RVCMD_URL, tmpdir)
154+
cmdfile += ".exe"
155+
else:
156+
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
157+
os.chmod(cmdfile, 0o755)
158+
subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"])
159+
except Exception:
160+
BASE_URL = "https://raw.gitcode.com/u011570312/RVC-Models-Downloader/assets/"
161+
suffix = {
162+
"darwin_amd64": "555",
163+
"darwin_arm64": "556",
164+
"linux_386": "557",
165+
"linux_amd64": "558",
166+
"linux_arm64": "559",
167+
"windows_386": "562",
168+
"windows_amd64": "563",
169+
}[f"{system_type}_{architecture}"]
170+
RVCMD_URL = BASE_URL + suffix
171+
download_dns_yaml(
172+
"https://raw.gitcode.com/u011570312/RVC-Models-Downloader/raw/main/dns.yaml",
173+
tmpdir,
174+
)
175+
if is_win:
176+
download_and_extract_zip(RVCMD_URL, tmpdir)
177+
cmdfile += ".exe"
178+
else:
179+
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
180+
os.chmod(cmdfile, 0o755)
181+
subprocess.run(
182+
[
183+
cmdfile,
184+
"-notui",
185+
"-w",
186+
"0",
187+
"-dns",
188+
os.path.join(tmpdir, "dns.yaml"),
189+
"assets/chtts",
190+
]
191+
)

examples/cmd/run.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
now_dir = os.getcwd()
77
sys.path.append(now_dir)
88

9+
from dotenv import load_dotenv
10+
load_dotenv("sha256.env")
11+
912
import wave
1013
import ChatTTS
1114
from IPython.display import Audio

examples/web/webui.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
import gradio as gr
1414
import numpy as np
1515

16+
from dotenv import load_dotenv
17+
load_dotenv("sha256.env")
18+
1619
import ChatTTS
1720

1821
# 音色选项:用于预置合适的音色
@@ -132,18 +135,18 @@ def main():
132135
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
133136
parser.add_argument('--server_port', type=int, default=8080, help='Server port')
134137
parser.add_argument('--root_path', type=str, default=None, help='Root Path')
135-
parser.add_argument('--local_path', type=str, default=None, help='the local_path if need')
138+
parser.add_argument('--custom_path', type=str, default=None, help='the custom model path')
136139
args = parser.parse_args()
137140

138141
print("loading ChatTTS model...")
139142
global chat
140143
chat = ChatTTS.Chat()
141144

142-
if args.local_path == None:
145+
if args.custom_path == None:
143146
chat.load_models()
144147
else:
145-
print('local model path:', args.local_path)
146-
chat.load_models('local', local_path=args.local_path)
148+
print('local model path:', args.custom_path)
149+
chat.load_models('custom', custom_path=args.custom_path)
147150

148151
demo.launch(server_name=args.server_name, server_port=args.server_port, root_path=args.root_path, inbrowser=True)
149152

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ transformers~=4.41.1
88
vocos
99
IPython
1010
gradio
11+
python-dotenv
1112
pynini==2.1.5
1213
WeTextProcessing
1314
nemo_text_processing

sha256.env

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
sha256_asset_Decoder_pt = 9964e36e840f0e3a748c5f716fe6de6490d2135a5f5155f4a642d51860e2ec38
2+
sha256_asset_DVAE_pt = 613cb128adf89188c93ea5880ea0b798e66b1fe6186d0c535d99bcd87bfd6976
3+
sha256_asset_GPT_pt = d7d4ee6461ea097a2be23eb40d73fb94ad3b3d39cb64fbb50cb3357fd466cadb
4+
sha256_asset_spk_stat_pt = 3228d8a4cbbf349d107a1b76d2f47820865bd3c9928c4bdfe1cefd5c7071105f
5+
sha256_asset_tokenizer_pt = e911ae7c6a7c27953433f35c44227a67838fe229a1f428503bdb6cd3d1bcc69c
6+
sha256_asset_Vocos_pt = 09a670eda1c08b740013679c7a90ebb7f1a97646ea7673069a6838e6b51d6c58
7+
8+
sha256_config_decoder_yaml = 0890ab719716b0ad8abcb9eba0a9bf52c59c2e45ddedbbbb5ed514ff87bff369
9+
sha256_config_dvae_yaml = 1b3a5aa0c6a314f766d4432ab36f84e882e29561648d837f71c04c7bea494fc6
10+
sha256_config_gpt_yaml = 0c3c7277b674094bdd00b63b18b18aa3156502101dbd03c7f802e0fcf26cff51
11+
sha256_config_path_yaml = 79829705c2d2a29b3f55e3b3f228bb81875e4e265211595fb50a73eb6434684b
12+
sha256_config_vocos_yaml = 1ca837ce790dd8b55bdd5a16c6af8f813926b9c9b48f2a4da305e7e9ff0c9b0c

0 commit comments

Comments
 (0)