Skip to content

Commit 51c2118

Browse files
committed
fix(log): utils log cannot display
1 parent d93bc19 commit 51c2118

File tree

5 files changed

+35
-29
lines changed

5 files changed

+35
-29
lines changed

ChatTTS/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from .utils.infer import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer
1818
from .utils.io import get_latest_modified_file, del_all
1919
from .infer.api import refine_text, infer_code
20-
from .utils.download import check_all_assets, download_all_assets
21-
from .utils.log import set_utils_logger
20+
from .utils.dl import check_all_assets, download_all_assets
21+
from .utils.log import logger as utils_logger
2222

2323

2424
class Chat:
@@ -27,7 +27,7 @@ def __init__(self, logger=logging.getLogger(__name__)):
2727
self.normalizer = {}
2828
self.homophones_replacer = None
2929
self.logger = logger
30-
set_utils_logger(logger)
30+
utils_logger.set_logger(logger)
3131

3232
def has_loaded(self, use_decoder = False):
3333
not_finish = False

ChatTTS/utils/download.py renamed to ChatTTS/utils/dl.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,18 @@ def check_model(
1919
) -> bool:
2020
target = dir_name / model_name
2121
relname = target.as_posix()
22-
logger.debug(f"checking {relname}...")
22+
logger.get_logger().debug(f"checking {relname}...")
2323
if not os.path.exists(target):
24-
logger.info(f"{target} not exist.")
24+
logger.get_logger().info(f"{target} not exist.")
2525
return False
2626
with open(target, "rb") as f:
2727
digest = sha256(f.fileno())
2828
bakfile = f"{target}.bak"
2929
if digest != hash:
30-
logger.warn(f"{target} sha256 hash mismatch.")
31-
logger.info(f"expected: {hash}")
32-
logger.info(f"real val: {digest}")
33-
logger.warn("please add parameter --update to download the latest assets.")
30+
logger.get_logger().warn(f"{target} sha256 hash mismatch.")
31+
logger.get_logger().info(f"expected: {hash}")
32+
logger.get_logger().info(f"real val: {digest}")
33+
logger.get_logger().warn("please add parameter --update to download the latest assets.")
3434
if remove_incorrect:
3535
if not os.path.exists(bakfile):
3636
os.rename(str(target), bakfile)
@@ -45,7 +45,7 @@ def check_model(
4545
def check_all_assets(update=False) -> bool:
4646
BASE_DIR = Path(os.getcwd())
4747

48-
logger.info("checking assets...")
48+
logger.get_logger().info("checking assets...")
4949
current_dir = BASE_DIR / "asset"
5050
names = [
5151
"Decoder.pt",
@@ -62,7 +62,7 @@ def check_all_assets(update=False) -> bool:
6262
):
6363
return False
6464

65-
logger.info("checking configs...")
65+
logger.get_logger().info("checking configs...")
6666
current_dir = BASE_DIR / "config"
6767
names = [
6868
"decoder.yaml",
@@ -78,44 +78,44 @@ def check_all_assets(update=False) -> bool:
7878
):
7979
return False
8080

81-
logger.info("all assets are already latest.")
81+
logger.get_logger().info("all assets are already latest.")
8282
return True
8383

8484

8585
def download_and_extract_tar_gz(url: str, folder: str):
8686
import tarfile
8787

88-
logger.info(f"downloading {url}")
88+
logger.get_logger().info(f"downloading {url}")
8989
response = requests.get(url, stream=True, timeout=(5, 10))
9090
with BytesIO() as out_file:
9191
out_file.write(response.content)
9292
out_file.seek(0)
93-
logger.info(f"downloaded.")
93+
logger.get_logger().info(f"downloaded.")
9494
with tarfile.open(fileobj=out_file, mode="r:gz") as tar:
9595
tar.extractall(folder)
96-
logger.info(f"extracted into {folder}")
96+
logger.get_logger().info(f"extracted into {folder}")
9797

9898

9999
def download_and_extract_zip(url: str, folder: str):
100100
import zipfile
101101

102-
logger.info(f"downloading {url}")
102+
logger.get_logger().info(f"downloading {url}")
103103
response = requests.get(url, stream=True, timeout=(5, 10))
104104
with BytesIO() as out_file:
105105
out_file.write(response.content)
106106
out_file.seek(0)
107-
logger.info(f"downloaded.")
107+
logger.get_logger().info(f"downloaded.")
108108
with zipfile.ZipFile(out_file) as zip_ref:
109109
zip_ref.extractall(folder)
110-
logger.info(f"extracted into {folder}")
110+
logger.get_logger().info(f"extracted into {folder}")
111111

112112

113113
def download_dns_yaml(url: str, folder: str):
114-
logger.info(f"downloading {url}")
114+
logger.get_logger().info(f"downloading {url}")
115115
response = requests.get(url, stream=True, timeout=(5, 10))
116116
with open(os.path.join(folder, "dns.yaml"), "wb") as out_file:
117117
out_file.write(response.content)
118-
logger.info(f"downloaded into {folder}")
118+
logger.get_logger().info(f"downloaded into {folder}")
119119

120120

121121
def download_all_assets(tmpdir: str, version="0.2.5"):
@@ -140,7 +140,7 @@ def download_all_assets(tmpdir: str, version="0.2.5"):
140140

141141
architecture = archs.get(architecture, None)
142142
if not architecture:
143-
logger.error(f"architecture {architecture} is not supported")
143+
logger.get_logger().error(f"architecture {architecture} is not supported")
144144
exit(1)
145145
try:
146146
BASE_URL = "https://github.com/fumiama/RVC-Models-Downloader/releases/download/"

ChatTTS/utils/gpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ def select_device(min_memory=2048):
1414
device = torch.device(f'cuda:{selected_gpu}')
1515
free_memory_mb = max_free_memory / (1024 * 1024)
1616
if free_memory_mb < min_memory:
17-
logger.warning(f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU.')
17+
logger.get_logger().warning(f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU.')
1818
device = torch.device('cpu')
1919
elif torch.backends.mps.is_available():
2020
# For Apple M1/M2 chips with Metal Performance Shaders
21-
logger.info('Apple GPU found, using MPS.')
21+
logger.get_logger().info('Apple GPU found, using MPS.')
2222
device = torch.device('mps')
2323
else:
24-
logger.warning('No GPU found, use CPU instead')
24+
logger.get_logger().warning('No GPU found, use CPU instead')
2525
device = torch.device('cpu')
2626

2727
return device

ChatTTS/utils/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def get_latest_modified_file(directory):
1010

1111
files = [os.path.join(directory, f) for f in os.listdir(directory)]
1212
if not files:
13-
logger.log(logging.WARNING, f'no files found in the directory: {directory}')
13+
logger.get_logger().log(logging.WARNING, f'no files found in the directory: {directory}')
1414
return None
1515
latest_file = max(files, key=os.path.getmtime)
1616

ChatTTS/utils/log.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
import logging
22
from pathlib import Path
33

4-
logger = logging.getLogger(Path(__file__).parent.name)
4+
class Logger():
5+
def __init__(self, logger=logging.getLogger(Path(__file__).parent.name)):
6+
self.logger = logger
57

6-
def set_utils_logger(l: logging.Logger):
7-
global logger
8-
logger = l
8+
def set_logger(self, logger: logging.Logger):
9+
self.logger = logger
10+
11+
def get_logger(self) -> logging.Logger:
12+
return self.logger
13+
14+
logger = Logger()

0 commit comments

Comments
 (0)