27
27
28
28
from .norm import Normalizer
29
29
30
-
31
30
class Chat :
32
31
def __init__ (self , logger = logging .getLogger (__name__ )):
33
32
self .logger = logger
@@ -67,59 +66,89 @@ def download_models(
67
66
source : Literal ["huggingface" , "local" , "custom" ] = "local" ,
68
67
force_redownload = False ,
69
68
custom_path : Optional [torch .serialization .FILE_LIKE ] = None ,
69
+ cache_dir : Optional [str ] = None ,
70
+ local_dir : Optional [str ] = None ,
70
71
) -> Optional [str ]:
71
72
if source == "local" :
72
- download_path = custom_path if custom_path is not None else os .getcwd ()
73
- if (
74
- not check_all_assets (Path (download_path ), self .sha256_map , update = True )
75
- or force_redownload
76
- ):
73
+ download_path = local_dir if local_dir else (cache_dir if cache_dir else os .getcwd ())
74
+ if local_dir :
77
75
with tempfile .TemporaryDirectory () as tmp :
78
76
download_all_assets (tmpdir = tmp , homedir = download_path )
79
- if not check_all_assets (
80
- Path (download_path ), self .sha256_map , update = False
77
+ else :
78
+ if (
79
+ not check_all_assets (Path (download_path ), self .sha256_map , update = True )
80
+ or force_redownload
81
81
):
82
- self .logger .error (
83
- "download to local path %s failed." , download_path
84
- )
85
- return None
82
+ with tempfile .TemporaryDirectory () as tmp :
83
+ download_all_assets (tmpdir = tmp , homedir = download_path )
84
+ if not check_all_assets (
85
+ Path (download_path ), self .sha256_map , update = False
86
+ ):
87
+ self .logger .error (
88
+ "download to local path %s failed." , download_path
89
+ )
90
+ return None
91
+
86
92
elif source == "huggingface" :
87
93
try :
88
- download_path = (
89
- get_latest_modified_file (
90
- os .path .join (
91
- os .getenv (
92
- "HF_HOME" , os .path .expanduser ("~/.cache/huggingface" )
93
- ),
94
- "hub/models--2Noise--ChatTTS/snapshots" ,
95
- )
96
- )
97
- if custom_path is None
98
- else get_latest_modified_file (
99
- os .path .join (custom_path , "models--2Noise--ChatTTS/snapshots" )
94
+ if local_dir :
95
+ download_path = snapshot_download (
96
+ repo_id = "2Noise/ChatTTS" ,
97
+ allow_patterns = ["*.yaml" , "*.json" , "*.safetensors" ],
98
+ local_dir = local_dir ,
99
+ force_download = force_redownload
100
100
)
101
- )
102
- except :
103
- download_path = None
104
- if download_path is None or force_redownload :
105
- self .logger .log (
106
- logging .INFO ,
107
- f"download from HF: https://huggingface.co/2Noise/ChatTTS" ,
108
- )
109
- try :
101
+ elif cache_dir :
110
102
download_path = snapshot_download (
111
103
repo_id = "2Noise/ChatTTS" ,
112
104
allow_patterns = ["*.yaml" , "*.json" , "*.safetensors" ],
113
- cache_dir = custom_path ,
114
- force_download = force_redownload ,
105
+ cache_dir = cache_dir ,
106
+ force_download = force_redownload
115
107
)
116
- except :
117
- download_path = None
108
+ if not check_all_assets (Path (download_path ), self .sha256_map , update = False ):
109
+ self .logger .error ("Model verification failed" )
110
+ return None
118
111
else :
119
- self .logger .log (
120
- logging .INFO ,
121
- f"load latest snapshot from cache: { download_path } " ,
122
- )
112
+ try :
113
+ download_path = (
114
+ get_latest_modified_file (
115
+ os .path .join (
116
+ os .getenv (
117
+ "HF_HOME" , os .path .expanduser ("~/.cache/huggingface" )
118
+ ),
119
+ "hub/models--2Noise--ChatTTS/snapshots" ,
120
+ )
121
+ )
122
+ if custom_path is None
123
+ else get_latest_modified_file (
124
+ os .path .join (custom_path , "models--2Noise--ChatTTS/snapshots" )
125
+ )
126
+ )
127
+ except :
128
+ download_path = None
129
+ if download_path is None or force_redownload :
130
+ self .logger .log (
131
+ logging .INFO ,
132
+ f"download from HF: https://huggingface.co/2Noise/ChatTTS" ,
133
+ )
134
+ try :
135
+ download_path = snapshot_download (
136
+ repo_id = "2Noise/ChatTTS" ,
137
+ allow_patterns = ["*.yaml" , "*.json" , "*.safetensors" ],
138
+ )
139
+ if not check_all_assets (Path (download_path ), self .sha256_map , update = False ):
140
+ self .logger .error ("Model verification failed" )
141
+ return None
142
+ except :
143
+ download_path = None
144
+ else :
145
+ self .logger .log (
146
+ logging .INFO , f"load latest snapshot from cache: { download_path } "
147
+ )
148
+ except Exception as e :
149
+ self .logger .error (f"Failed to download models: { str (e )} " )
150
+ download_path = None
151
+
123
152
elif source == "custom" :
124
153
self .logger .log (logging .INFO , f"try to load from local: { custom_path } " )
125
154
if not check_all_assets (Path (custom_path ), self .sha256_map , update = False ):
@@ -144,8 +173,10 @@ def load(
144
173
use_flash_attn = False ,
145
174
use_vllm = False ,
146
175
experimental : bool = False ,
176
+ cache_dir : Optional [str ] = None ,
177
+ local_dir : Optional [str ] = None ,
147
178
) -> bool :
148
- download_path = self .download_models (source , force_redownload , custom_path )
179
+ download_path = self .download_models (source , force_redownload , custom_path , cache_dir , local_dir )
149
180
if download_path is None :
150
181
return False
151
182
return self ._load (
0 commit comments