Skip to content

Commit a851b9a

Browse files
authored
optimize: log & webui (#398)
- move log definition out of ChatTTS - apply colorful log level - optimize webui logic - split webui into 2 files for clear reading
1 parent e58fe48 commit a851b9a

File tree

12 files changed

+229
-152
lines changed

12 files changed

+229
-152
lines changed

ChatTTS/core.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
2-
import os, sys
1+
import os
32
import json
43
import logging
5-
from functools import partial
6-
from typing import Literal
74
import tempfile
8-
from typing import Optional
5+
from functools import partial
6+
from typing import Literal, Optional
97

108
import torch
119
from omegaconf import OmegaConf
@@ -19,16 +17,16 @@
1917
from .utils.io import get_latest_modified_file, del_all
2018
from .infer.api import refine_text, infer_code
2119
from .utils.download import check_all_assets, download_all_assets
22-
23-
logging.basicConfig(level = logging.INFO)
20+
from .utils.log import set_utils_logger
2421

2522

2623
class Chat:
27-
def __init__(self, ):
24+
def __init__(self, logger=logging.getLogger(__name__)):
2825
self.pretrain_models = {}
2926
self.normalizer = {}
3027
self.homophones_replacer = None
31-
self.logger = logging.getLogger(__name__)
28+
self.logger = logger
29+
set_utils_logger(logger)
3230

3331
def check_model(self, level = logging.INFO, use_decoder = False):
3432
not_finish = False
@@ -46,7 +44,7 @@ def check_model(self, level = logging.INFO, use_decoder = False):
4644

4745
if not not_finish:
4846
self.logger.log(level, f'All initialized.')
49-
47+
5048
return not not_finish
5149

5250
def load_models(
@@ -62,7 +60,7 @@ def load_models(
6260
with tempfile.TemporaryDirectory() as tmp:
6361
download_all_assets(tmpdir=tmp)
6462
if not check_all_assets(update=False):
65-
logging.error("counld not satisfy all assets needed.")
63+
self.logger.error("counld not satisfy all assets needed.")
6664
return False
6765
elif source == 'huggingface':
6866
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
@@ -120,14 +118,14 @@ def _load(
120118

121119
if gpt_config_path:
122120
cfg = OmegaConf.load(gpt_config_path)
123-
gpt = GPT_warpper(**cfg, device=device).eval()
121+
gpt = GPT_warpper(**cfg, device=device, logger=self.logger).eval()
124122
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
125123
gpt.load_state_dict(torch.load(gpt_ckpt_path))
126124
if compile and 'cuda' in str(device):
127125
try:
128126
gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
129127
except RuntimeError as e:
130-
logging.warning(f'Compile failed,{e}. fallback to normal mode.')
128+
self.logger.warning(f'Compile failed,{e}. fallback to normal mode.')
131129
self.pretrain_models['gpt'] = gpt
132130
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
133131
assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'

ChatTTS/model/gpt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ def __init__(
4141
num_text_tokens,
4242
num_vq=4,
4343
device="cpu",
44+
logger=logging.getLogger(__name__)
4445
):
4546
super().__init__()
4647

47-
self.logger = logging.getLogger(__name__)
48+
self.logger = logger
4849
self.device = device
4950
self.device_gpt = device if "mps" not in str(device) else "cpu"
5051
self.num_vq = num_vq

ChatTTS/utils/download.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
import hashlib
44
import requests
55
from io import BytesIO
6-
import logging
7-
8-
logger = logging.getLogger(__name__)
96

7+
from .log import logger
108

119
def sha256(f) -> str:
1210
sha256_hash = hashlib.sha256()

ChatTTS/utils/gpu_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11

22
import torch
3-
import logging
3+
4+
from .log import logger
45

56
def select_device(min_memory=2048):
6-
logger = logging.getLogger(__name__)
77
if torch.cuda.is_available():
88
available_gpus = []
99
for i in range(torch.cuda.device_count()):

ChatTTS/utils/infer_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import re
33
import torch
44
import torch.nn.functional as F
5-
import os
65
import json
76

87

ChatTTS/utils/io.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import logging
44
from typing import Union
55

6+
from .log import logger
7+
68
def get_latest_modified_file(directory):
7-
logger = logging.getLogger(__name__)
8-
9+
910
files = [os.path.join(directory, f) for f in os.listdir(directory)]
1011
if not files:
1112
logger.log(logging.WARNING, f'No files found in the directory: {directory}')

ChatTTS/utils/log.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import logging
2+
from pathlib import Path
3+
4+
logger = logging.getLogger(Path(__file__).parent.name)
5+
6+
def set_utils_logger(l: logging.Logger):
7+
global logger
8+
logger = l

examples/cmd/run.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
import ChatTTS
1414
from IPython.display import Audio
1515

16+
from tools.logger import get_logger
17+
18+
logger = get_logger("Command")
19+
1620
def save_wav_file(wav, index):
1721
wav_filename = f"output_audio_{index}.wav"
1822
# Convert numpy array to bytes and write to WAV file
@@ -22,33 +26,33 @@ def save_wav_file(wav, index):
2226
wf.setsampwidth(2) # Sample width in bytes
2327
wf.setframerate(24000) # Sample rate in Hz
2428
wf.writeframes(wav_bytes)
25-
print(f"Audio saved to {wav_filename}")
29+
logger.info(f"Audio saved to {wav_filename}")
2630

2731
def main():
2832
# Retrieve text from command line argument
2933
text_input = sys.argv[1] if len(sys.argv) > 1 else "<YOUR TEXT HERE>"
30-
print("Received text input:", text_input)
34+
logger.info("Received text input: %s", text_input)
3135

32-
chat = ChatTTS.Chat()
33-
print("Initializing ChatTTS...")
36+
chat = ChatTTS.Chat(get_logger("ChatTTS"))
37+
logger.info("Initializing ChatTTS...")
3438
if chat.load_models():
35-
print("Models loaded successfully.")
39+
logger.info("Models loaded successfully.")
3640
else:
37-
print("Models load failed.")
41+
logger.error("Models load failed.")
3842
sys.exit(1)
3943

4044
texts = [text_input]
41-
print("Text prepared for inference:", texts)
45+
logger.info("Text prepared for inference: %s", texts)
4246

4347
wavs = chat.infer(texts, use_decoder=True)
44-
print("Inference completed. Audio generation successful.")
48+
logger.info("Inference completed. Audio generation successful.")
4549
# Save each generated wav file to a local file
4650
for index, wav in enumerate(wavs):
4751
save_wav_file(wav, index)
4852

4953
return Audio(wavs[0], rate=24_000, autoplay=True)
5054

5155
if __name__ == "__main__":
52-
print("Starting the TTS application...")
56+
logger.info("Starting the TTS application...")
5357
main()
54-
print("TTS application finished.")
58+
logger.info("TTS application finished.")

examples/web/funcs.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import random
2+
3+
import torch
4+
import gradio as gr
5+
import numpy as np
6+
7+
from tools.logger import get_logger
8+
logger = get_logger(" WebUI ")
9+
10+
import ChatTTS
11+
chat = ChatTTS.Chat(get_logger("ChatTTS"))
12+
13+
# 音色选项:用于预置合适的音色
14+
voices = {
15+
"默认": {"seed": 2},
16+
"音色1": {"seed": 1111},
17+
"音色2": {"seed": 2222},
18+
"音色3": {"seed": 3333},
19+
"音色4": {"seed": 4444},
20+
"音色5": {"seed": 5555},
21+
"音色6": {"seed": 6666},
22+
"音色7": {"seed": 7777},
23+
"音色8": {"seed": 8888},
24+
"音色9": {"seed": 9999},
25+
"音色10": {"seed": 11111},
26+
}
27+
28+
def generate_seed():
29+
return gr.update(value=random.randint(1, 100000000))
30+
31+
# 返回选择音色对应的seed
32+
def on_voice_change(vocie_selection):
33+
return voices.get(vocie_selection)['seed']
34+
35+
def refine_text(text, audio_seed_input, text_seed_input, refine_text_flag):
36+
if not refine_text_flag:
37+
return text
38+
39+
global chat
40+
41+
torch.manual_seed(audio_seed_input)
42+
params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
43+
44+
torch.manual_seed(text_seed_input)
45+
46+
text = chat.infer(text,
47+
skip_refine_text=False,
48+
refine_text_only=True,
49+
params_refine_text=params_refine_text,
50+
)
51+
return text[0] if isinstance(text, list) else text
52+
53+
def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, stream):
54+
if not text: return None
55+
56+
global chat
57+
58+
torch.manual_seed(audio_seed_input)
59+
rand_spk = chat.sample_random_speaker()
60+
params_infer_code = {
61+
'spk_emb': rand_spk,
62+
'temperature': temperature,
63+
'top_P': top_P,
64+
'top_K': top_K,
65+
}
66+
torch.manual_seed(text_seed_input)
67+
68+
wav = chat.infer(
69+
text,
70+
skip_refine_text=True,
71+
params_infer_code=params_infer_code,
72+
stream=stream,
73+
)
74+
75+
if stream:
76+
for gen in wav:
77+
wavs = [np.array([[]])]
78+
wavs[0] = np.hstack([wavs[0], np.array(gen[0])])
79+
audio = wavs[0][0]
80+
81+
# normalize
82+
am = np.abs(audio).max() * 32768
83+
if am > 32768:
84+
am = 32768 * 32768 / am
85+
np.multiply(audio, am, audio)
86+
audio = audio.astype(np.int16)
87+
88+
yield 24000, audio
89+
return
90+
91+
audio_data = np.array(wav[0]).flatten()
92+
# normalize
93+
am = np.abs(audio_data).max() * 32768
94+
if am > 32768:
95+
am = 32768 * 32768 / am
96+
np.multiply(audio_data, am, audio_data)
97+
audio_data = audio_data.astype(np.int16)
98+
sample_rate = 24000
99+
100+
yield sample_rate, audio_data

0 commit comments

Comments
 (0)