Skip to content

Commit f1e7dc2

Browse files
authored
Update core.py
1 parent af1c8f7 commit f1e7dc2

File tree

1 file changed

+73
-42
lines changed

1 file changed

+73
-42
lines changed

ChatTTS/core.py

Lines changed: 73 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
from .norm import Normalizer
2929

30-
3130
class Chat:
3231
def __init__(self, logger=logging.getLogger(__name__)):
3332
self.logger = logger
@@ -67,59 +66,89 @@ def download_models(
6766
source: Literal["huggingface", "local", "custom"] = "local",
6867
force_redownload=False,
6968
custom_path: Optional[torch.serialization.FILE_LIKE] = None,
69+
cache_dir: Optional[str] = None,
70+
local_dir: Optional[str] = None,
7071
) -> Optional[str]:
7172
if source == "local":
72-
download_path = custom_path if custom_path is not None else os.getcwd()
73-
if (
74-
not check_all_assets(Path(download_path), self.sha256_map, update=True)
75-
or force_redownload
76-
):
73+
download_path = local_dir if local_dir else (cache_dir if cache_dir else os.getcwd())
74+
if local_dir:
7775
with tempfile.TemporaryDirectory() as tmp:
7876
download_all_assets(tmpdir=tmp, homedir=download_path)
79-
if not check_all_assets(
80-
Path(download_path), self.sha256_map, update=False
77+
else:
78+
if (
79+
not check_all_assets(Path(download_path), self.sha256_map, update=True)
80+
or force_redownload
8181
):
82-
self.logger.error(
83-
"download to local path %s failed.", download_path
84-
)
85-
return None
82+
with tempfile.TemporaryDirectory() as tmp:
83+
download_all_assets(tmpdir=tmp, homedir=download_path)
84+
if not check_all_assets(
85+
Path(download_path), self.sha256_map, update=False
86+
):
87+
self.logger.error(
88+
"download to local path %s failed.", download_path
89+
)
90+
return None
91+
8692
elif source == "huggingface":
8793
try:
88-
download_path = (
89-
get_latest_modified_file(
90-
os.path.join(
91-
os.getenv(
92-
"HF_HOME", os.path.expanduser("~/.cache/huggingface")
93-
),
94-
"hub/models--2Noise--ChatTTS/snapshots",
95-
)
96-
)
97-
if custom_path is None
98-
else get_latest_modified_file(
99-
os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots")
94+
if local_dir:
95+
download_path = snapshot_download(
96+
repo_id="2Noise/ChatTTS",
97+
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
98+
local_dir=local_dir,
99+
force_download=force_redownload
100100
)
101-
)
102-
except:
103-
download_path = None
104-
if download_path is None or force_redownload:
105-
self.logger.log(
106-
logging.INFO,
107-
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
108-
)
109-
try:
101+
elif cache_dir:
110102
download_path = snapshot_download(
111103
repo_id="2Noise/ChatTTS",
112104
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
113-
cache_dir=custom_path,
114-
force_download=force_redownload,
105+
cache_dir=cache_dir,
106+
force_download=force_redownload
115107
)
116-
except:
117-
download_path = None
108+
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
109+
self.logger.error("Model verification failed")
110+
return None
118111
else:
119-
self.logger.log(
120-
logging.INFO,
121-
f"load latest snapshot from cache: {download_path}",
122-
)
112+
try:
113+
download_path = (
114+
get_latest_modified_file(
115+
os.path.join(
116+
os.getenv(
117+
"HF_HOME", os.path.expanduser("~/.cache/huggingface")
118+
),
119+
"hub/models--2Noise--ChatTTS/snapshots",
120+
)
121+
)
122+
if custom_path is None
123+
else get_latest_modified_file(
124+
os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots")
125+
)
126+
)
127+
except:
128+
download_path = None
129+
if download_path is None or force_redownload:
130+
self.logger.log(
131+
logging.INFO,
132+
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
133+
)
134+
try:
135+
download_path = snapshot_download(
136+
repo_id="2Noise/ChatTTS",
137+
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
138+
)
139+
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
140+
self.logger.error("Model verification failed")
141+
return None
142+
except:
143+
download_path = None
144+
else:
145+
self.logger.log(
146+
logging.INFO, f"load latest snapshot from cache: {download_path}"
147+
)
148+
except Exception as e:
149+
self.logger.error(f"Failed to download models: {str(e)}")
150+
download_path = None
151+
123152
elif source == "custom":
124153
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
125154
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
@@ -144,8 +173,10 @@ def load(
144173
use_flash_attn=False,
145174
use_vllm=False,
146175
experimental: bool = False,
176+
cache_dir: Optional[str] = None,
177+
local_dir: Optional[str] = None,
147178
) -> bool:
148-
download_path = self.download_models(source, force_redownload, custom_path)
179+
download_path = self.download_models(source, force_redownload, custom_path, cache_dir, local_dir)
149180
if download_path is None:
150181
return False
151182
return self._load(

0 commit comments

Comments
 (0)