Skip to content

Commit 87b1913

Browse files
authored
Update core.py
1 parent f1e7dc2 commit 87b1913

File tree

1 file changed

+97
-98
lines changed

1 file changed

+97
-98
lines changed

ChatTTS/core.py

Lines changed: 97 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -62,105 +62,104 @@ def has_loaded(self, use_decoder=False):
6262
return not not_finish
6363

6464
def download_models(
65-
self,
66-
source: Literal["huggingface", "local", "custom"] = "local",
67-
force_redownload=False,
68-
custom_path: Optional[torch.serialization.FILE_LIKE] = None,
69-
cache_dir: Optional[str] = None,
70-
local_dir: Optional[str] = None,
65+
self,
66+
source: Literal["huggingface", "local", "custom"] = "local",
67+
force_redownload=False,
68+
custom_path: Optional[torch.serialization.FILE_LIKE] = None,
69+
cache_dir: Optional[str] = None,
70+
local_dir: Optional[str] = None,
7171
) -> Optional[str]:
72-
if source == "local":
73-
download_path = local_dir if local_dir else (cache_dir if cache_dir else os.getcwd())
74-
if local_dir:
75-
with tempfile.TemporaryDirectory() as tmp:
76-
download_all_assets(tmpdir=tmp, homedir=download_path)
77-
else:
78-
if (
79-
not check_all_assets(Path(download_path), self.sha256_map, update=True)
80-
or force_redownload
81-
):
82-
with tempfile.TemporaryDirectory() as tmp:
83-
download_all_assets(tmpdir=tmp, homedir=download_path)
84-
if not check_all_assets(
85-
Path(download_path), self.sha256_map, update=False
86-
):
87-
self.logger.error(
88-
"download to local path %s failed.", download_path
89-
)
90-
return None
91-
92-
elif source == "huggingface":
93-
try:
94-
if local_dir:
95-
download_path = snapshot_download(
96-
repo_id="2Noise/ChatTTS",
97-
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
98-
local_dir=local_dir,
99-
force_download=force_redownload
100-
)
101-
elif cache_dir:
102-
download_path = snapshot_download(
103-
repo_id="2Noise/ChatTTS",
104-
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
105-
cache_dir=cache_dir,
106-
force_download=force_redownload
107-
)
108-
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
109-
self.logger.error("Model verification failed")
110-
return None
111-
else:
112-
try:
113-
download_path = (
114-
get_latest_modified_file(
115-
os.path.join(
116-
os.getenv(
117-
"HF_HOME", os.path.expanduser("~/.cache/huggingface")
118-
),
119-
"hub/models--2Noise--ChatTTS/snapshots",
120-
)
121-
)
122-
if custom_path is None
123-
else get_latest_modified_file(
124-
os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots")
125-
)
126-
)
127-
except:
128-
download_path = None
129-
if download_path is None or force_redownload:
130-
self.logger.log(
131-
logging.INFO,
132-
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
133-
)
134-
try:
135-
download_path = snapshot_download(
136-
repo_id="2Noise/ChatTTS",
137-
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
138-
)
139-
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
140-
self.logger.error("Model verification failed")
141-
return None
142-
except:
143-
download_path = None
144-
else:
145-
self.logger.log(
146-
logging.INFO, f"load latest snapshot from cache: {download_path}"
147-
)
148-
except Exception as e:
149-
self.logger.error(f"Failed to download models: {str(e)}")
150-
download_path = None
151-
152-
elif source == "custom":
153-
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
154-
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
155-
self.logger.error("check models in custom path %s failed.", custom_path)
156-
return None
157-
download_path = custom_path
158-
159-
if download_path is None:
160-
self.logger.error("Model download failed")
161-
return None
162-
163-
return download_path
72+
if source == "local":
73+
download_path = local_dir if local_dir else (cache_dir if cache_dir else os.getcwd())
74+
if (
75+
not check_all_assets(Path(download_path), self.sha256_map, update=True)
76+
or force_redownload
77+
):
78+
with tempfile.TemporaryDirectory() as tmp:
79+
download_all_assets(tmpdir=tmp, homedir=download_path)
80+
if not check_all_assets(
81+
Path(download_path), self.sha256_map, update=False
82+
):
83+
self.logger.error(
84+
"download to local path %s failed.", download_path
85+
)
86+
return None
87+
88+
elif source == "huggingface":
89+
try:
90+
if local_dir:
91+
download_path = snapshot_download(
92+
repo_id="2Noise/ChatTTS",
93+
allow_patterns=["*.yaml", "*.json", "*.safetensors", "spk_stat.pt", "tokenizer.pt"],
94+
local_dir=local_dir,
95+
force_download=force_redownload
96+
)
97+
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
98+
self.logger.error("Model verification failed")
99+
return None
100+
elif cache_dir:
101+
download_path = snapshot_download(
102+
repo_id="2Noise/ChatTTS",
103+
allow_patterns=["*.yaml", "*.json", "*.safetensors", "spk_stat.pt", "tokenizer.pt"],
104+
cache_dir=cache_dir,
105+
force_download=force_redownload
106+
)
107+
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
108+
self.logger.error("Model verification failed")
109+
return None
110+
else:
111+
try:
112+
download_path = (
113+
get_latest_modified_file(
114+
os.path.join(
115+
os.getenv(
116+
"HF_HOME", os.path.expanduser("~/.cache/huggingface")
117+
),
118+
"hub/models--2Noise--ChatTTS/snapshots",
119+
)
120+
)
121+
if custom_path is None
122+
else get_latest_modified_file(
123+
os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots")
124+
)
125+
)
126+
except:
127+
download_path = None
128+
if download_path is None or force_redownload:
129+
self.logger.log(
130+
logging.INFO,
131+
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
132+
)
133+
try:
134+
download_path = snapshot_download(
135+
repo_id="2Noise/ChatTTS",
136+
allow_patterns=["*.yaml", "*.json", "*.safetensors", "spk_stat.pt", "tokenizer.pt"],
137+
)
138+
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
139+
self.logger.error("Model verification failed")
140+
return None
141+
except:
142+
download_path = None
143+
else:
144+
self.logger.log(
145+
logging.INFO, f"load latest snapshot from cache: {download_path}"
146+
)
147+
except Exception as e:
148+
self.logger.error(f"Failed to download models: {str(e)}")
149+
download_path = None
150+
151+
elif source == "custom":
152+
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
153+
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
154+
self.logger.error("check models in custom path %s failed.", custom_path)
155+
return None
156+
download_path = custom_path
157+
158+
if download_path is None:
159+
self.logger.error("Model download failed")
160+
return None
161+
162+
return download_path
164163

165164
def load(
166165
self,

0 commit comments

Comments
 (0)