@@ -129,7 +129,7 @@ def _load(
129
129
"cpu" if "mps" in str (device ) else device
130
130
).eval ()
131
131
assert vocos_ckpt_path , 'vocos_ckpt_path should not be None'
132
- vocos .load_state_dict (torch .load (vocos_ckpt_path ))
132
+ vocos .load_state_dict (torch .load (vocos_ckpt_path , weights_only = True , mmap = True ))
133
133
self .vocos = vocos
134
134
if "mps" in str (self .device ):
135
135
self ._vocos_decode : Callable [[torch .Tensor ], np .ndarray ] = lambda spec : self .vocos .decode (
@@ -146,15 +146,15 @@ def _load(
146
146
dvae = DVAE (** cfg , coef = coef ).to (device ).eval ()
147
147
coef = str (dvae )
148
148
assert dvae_ckpt_path , 'dvae_ckpt_path should not be None'
149
- dvae .load_state_dict (torch .load (dvae_ckpt_path , map_location = device ))
149
+ dvae .load_state_dict (torch .load (dvae_ckpt_path , weights_only = True , mmap = True ))
150
150
self .dvae = dvae
151
151
self .logger .log (logging .INFO , 'dvae loaded.' )
152
152
153
153
if gpt_config_path :
154
154
cfg = OmegaConf .load (gpt_config_path )
155
155
gpt = GPT (** cfg , device = device , logger = self .logger ).eval ()
156
156
assert gpt_ckpt_path , 'gpt_ckpt_path should not be None'
157
- gpt .load_state_dict (torch .load (gpt_ckpt_path , map_location = device ))
157
+ gpt .load_state_dict (torch .load (gpt_ckpt_path , weights_only = True , mmap = True ))
158
158
if compile and 'cuda' in str (device ):
159
159
try :
160
160
gpt .gpt .forward = torch .compile (gpt .gpt .forward , backend = 'inductor' , dynamic = True )
@@ -163,20 +163,20 @@ def _load(
163
163
self .gpt = gpt
164
164
spk_stat_path = os .path .join (os .path .dirname (gpt_ckpt_path ), 'spk_stat.pt' )
165
165
assert os .path .exists (spk_stat_path ), f'Missing spk_stat.pt: { spk_stat_path } '
166
- self .pretrain_models ['spk_stat' ] = torch .load (spk_stat_path , map_location = device ).to (device )
166
+ self .pretrain_models ['spk_stat' ] = torch .load (spk_stat_path , weights_only = True , mmap = True ).to (device )
167
167
self .logger .log (logging .INFO , 'gpt loaded.' )
168
168
169
169
if decoder_config_path :
170
170
cfg = OmegaConf .load (decoder_config_path )
171
171
decoder = DVAE (** cfg , coef = coef ).to (device ).eval ()
172
172
coef = str (decoder )
173
173
assert decoder_ckpt_path , 'decoder_ckpt_path should not be None'
174
- decoder .load_state_dict (torch .load (decoder_ckpt_path , map_location = device ))
174
+ decoder .load_state_dict (torch .load (decoder_ckpt_path , weights_only = True , mmap = True ))
175
175
self .decoder = decoder
176
176
self .logger .log (logging .INFO , 'decoder loaded.' )
177
177
178
178
if tokenizer_path :
179
- tokenizer = torch .load (tokenizer_path , map_location = device )
179
+ tokenizer = torch .load (tokenizer_path , map_location = device , mmap = True )
180
180
tokenizer .padding_side = 'left'
181
181
self .pretrain_models ['tokenizer' ] = tokenizer
182
182
self .logger .log (logging .INFO , 'tokenizer loaded.' )
@@ -187,7 +187,11 @@ def _load(
187
187
188
188
def unload (self ):
189
189
logger = self .logger
190
- del_all (self )
190
+ del_all (self .pretrain_models )
191
+ del_list = ["vocos" , "_vocos_decode" , 'gpt' , 'decoder' , 'dvae' ]
192
+ for module in del_list :
193
+ if hasattr (self , module ):
194
+ delattr (self , module )
191
195
self .__init__ (logger )
192
196
193
197
def _infer (
0 commit comments