@@ -62,105 +62,104 @@ def has_loaded(self, use_decoder=False):
62
62
return not not_finish
63
63
64
64
def download_models (
65
- self ,
66
- source : Literal ["huggingface" , "local" , "custom" ] = "local" ,
67
- force_redownload = False ,
68
- custom_path : Optional [torch .serialization .FILE_LIKE ] = None ,
69
- cache_dir : Optional [str ] = None ,
70
- local_dir : Optional [str ] = None ,
65
+ self ,
66
+ source : Literal ["huggingface" , "local" , "custom" ] = "local" ,
67
+ force_redownload = False ,
68
+ custom_path : Optional [torch .serialization .FILE_LIKE ] = None ,
69
+ cache_dir : Optional [str ] = None ,
70
+ local_dir : Optional [str ] = None ,
71
71
) -> Optional [str ]:
72
- if source == "local" :
73
- download_path = local_dir if local_dir else (cache_dir if cache_dir else os .getcwd ())
74
- if local_dir :
75
- with tempfile .TemporaryDirectory () as tmp :
76
- download_all_assets (tmpdir = tmp , homedir = download_path )
77
- else :
78
- if (
79
- not check_all_assets (Path (download_path ), self .sha256_map , update = True )
80
- or force_redownload
81
- ):
82
- with tempfile .TemporaryDirectory () as tmp :
83
- download_all_assets (tmpdir = tmp , homedir = download_path )
84
- if not check_all_assets (
85
- Path (download_path ), self .sha256_map , update = False
86
- ):
87
- self .logger .error (
88
- "download to local path %s failed." , download_path
89
- )
90
- return None
91
-
92
- elif source == "huggingface" :
93
- try :
94
- if local_dir :
95
- download_path = snapshot_download (
96
- repo_id = "2Noise/ChatTTS" ,
97
- allow_patterns = ["*.yaml" , "*.json" , "*.safetensors" ],
98
- local_dir = local_dir ,
99
- force_download = force_redownload
100
- )
101
- elif cache_dir :
102
- download_path = snapshot_download (
103
- repo_id = "2Noise/ChatTTS" ,
104
- allow_patterns = ["*.yaml" , "*.json" , "*.safetensors" ],
105
- cache_dir = cache_dir ,
106
- force_download = force_redownload
107
- )
108
- if not check_all_assets (Path (download_path ), self .sha256_map , update = False ):
109
- self .logger .error ("Model verification failed" )
110
- return None
111
- else :
112
- try :
113
- download_path = (
114
- get_latest_modified_file (
115
- os .path .join (
116
- os .getenv (
117
- "HF_HOME" , os .path .expanduser ("~/.cache/huggingface" )
118
- ),
119
- "hub/models--2Noise--ChatTTS/snapshots" ,
120
- )
121
- )
122
- if custom_path is None
123
- else get_latest_modified_file (
124
- os .path .join (custom_path , "models--2Noise--ChatTTS/snapshots" )
125
- )
126
- )
127
- except :
128
- download_path = None
129
- if download_path is None or force_redownload :
130
- self .logger .log (
131
- logging .INFO ,
132
- f"download from HF: https://huggingface.co/2Noise/ChatTTS" ,
133
- )
134
- try :
135
- download_path = snapshot_download (
136
- repo_id = "2Noise/ChatTTS" ,
137
- allow_patterns = ["*.yaml" , "*.json" , "*.safetensors" ],
138
- )
139
- if not check_all_assets (Path (download_path ), self .sha256_map , update = False ):
140
- self .logger .error ("Model verification failed" )
141
- return None
142
- except :
143
- download_path = None
144
- else :
145
- self .logger .log (
146
- logging .INFO , f"load latest snapshot from cache: { download_path } "
147
- )
148
- except Exception as e :
149
- self .logger .error (f"Failed to download models: { str (e )} " )
150
- download_path = None
151
-
152
- elif source == "custom" :
153
- self .logger .log (logging .INFO , f"try to load from local: { custom_path } " )
154
- if not check_all_assets (Path (custom_path ), self .sha256_map , update = False ):
155
- self .logger .error ("check models in custom path %s failed." , custom_path )
156
- return None
157
- download_path = custom_path
158
-
159
- if download_path is None :
160
- self .logger .error ("Model download failed" )
161
- return None
162
-
163
- return download_path
72
+ if source == "local" :
73
+ download_path = local_dir if local_dir else (cache_dir if cache_dir else os .getcwd ())
74
+ if (
75
+ not check_all_assets (Path (download_path ), self .sha256_map , update = True )
76
+ or force_redownload
77
+ ):
78
+ with tempfile .TemporaryDirectory () as tmp :
79
+ download_all_assets (tmpdir = tmp , homedir = download_path )
80
+ if not check_all_assets (
81
+ Path (download_path ), self .sha256_map , update = False
82
+ ):
83
+ self .logger .error (
84
+ "download to local path %s failed." , download_path
85
+ )
86
+ return None
87
+
88
+ elif source == "huggingface" :
89
+ try :
90
+ if local_dir :
91
+ download_path = snapshot_download (
92
+ repo_id = "2Noise/ChatTTS" ,
93
+ allow_patterns = ["*.yaml" , "*.json" , "*.safetensors" , "spk_stat.pt" , "tokenizer.pt" ],
94
+ local_dir = local_dir ,
95
+ force_download = force_redownload
96
+ )
97
+ if not check_all_assets (Path (download_path ), self .sha256_map , update = False ):
98
+ self .logger .error ("Model verification failed" )
99
+ return None
100
+ elif cache_dir :
101
+ download_path = snapshot_download (
102
+ repo_id = "2Noise/ChatTTS" ,
103
+ allow_patterns = ["*.yaml" , "*.json" , "*.safetensors" , "spk_stat.pt" , "tokenizer.pt" ],
104
+ cache_dir = cache_dir ,
105
+ force_download = force_redownload
106
+ )
107
+ if not check_all_assets (Path (download_path ), self .sha256_map , update = False ):
108
+ self .logger .error ("Model verification failed" )
109
+ return None
110
+ else :
111
+ try :
112
+ download_path = (
113
+ get_latest_modified_file (
114
+ os .path .join (
115
+ os .getenv (
116
+ "HF_HOME" , os .path .expanduser ("~/.cache/huggingface" )
117
+ ),
118
+ "hub/models--2Noise--ChatTTS/snapshots" ,
119
+ )
120
+ )
121
+ if custom_path is None
122
+ else get_latest_modified_file (
123
+ os .path .join (custom_path , "models--2Noise--ChatTTS/snapshots" )
124
+ )
125
+ )
126
+ except :
127
+ download_path = None
128
+ if download_path is None or force_redownload :
129
+ self .logger .log (
130
+ logging .INFO ,
131
+ f"download from HF: https://huggingface.co/2Noise/ChatTTS" ,
132
+ )
133
+ try :
134
+ download_path = snapshot_download (
135
+ repo_id = "2Noise/ChatTTS" ,
136
+ allow_patterns = ["*.yaml" , "*.json" , "*.safetensors" , "spk_stat.pt" , "tokenizer.pt" ],
137
+ )
138
+ if not check_all_assets (Path (download_path ), self .sha256_map , update = False ):
139
+ self .logger .error ("Model verification failed" )
140
+ return None
141
+ except :
142
+ download_path = None
143
+ else :
144
+ self .logger .log (
145
+ logging .INFO , f"load latest snapshot from cache: { download_path } "
146
+ )
147
+ except Exception as e :
148
+ self .logger .error (f"Failed to download models: { str (e )} " )
149
+ download_path = None
150
+
151
+ elif source == "custom" :
152
+ self .logger .log (logging .INFO , f"try to load from local: { custom_path } " )
153
+ if not check_all_assets (Path (custom_path ), self .sha256_map , update = False ):
154
+ self .logger .error ("check models in custom path %s failed." , custom_path )
155
+ return None
156
+ download_path = custom_path
157
+
158
+ if download_path is None :
159
+ self .logger .error ("Model download failed" )
160
+ return None
161
+
162
+ return download_path
164
163
165
164
def load (
166
165
self ,
0 commit comments