@@ -170,7 +170,9 @@ def classify_audio_clip(clip, model_dir):
170
170
kernel_size = 5 ,
171
171
distribute_zero_label = False ,
172
172
)
173
- classifier .load_state_dict (torch .load (os .path .join (model_dir , "classifier.pth" ), map_location = torch .device ("cpu" )))
173
+ classifier .load_state_dict (
174
+ torch .load (os .path .join (model_dir , "classifier.pth" ), map_location = torch .device ("cpu" ), weights_only = True )
175
+ )
174
176
clip = clip .cpu ().unsqueeze (0 )
175
177
results = F .softmax (classifier (clip ), dim = - 1 )
176
178
return results [0 ][0 ]
@@ -488,13 +490,15 @@ def get_random_conditioning_latents(self):
488
490
torch .load (
489
491
os .path .join (self .models_dir , "rlg_auto.pth" ),
490
492
map_location = torch .device ("cpu" ),
493
+ weights_only = True ,
491
494
)
492
495
)
493
496
self .rlg_diffusion = RandomLatentConverter (2048 ).eval ()
494
497
self .rlg_diffusion .load_state_dict (
495
498
torch .load (
496
499
os .path .join (self .models_dir , "rlg_diffuser.pth" ),
497
500
map_location = torch .device ("cpu" ),
501
+ weights_only = True ,
498
502
)
499
503
)
500
504
with torch .no_grad ():
@@ -881,24 +885,25 @@ def load_checkpoint(
881
885
882
886
if os .path .exists (ar_path ):
883
887
# remove keys from the checkpoint that are not in the model
884
- checkpoint = torch .load (ar_path , map_location = torch .device ("cpu" ))
888
+ checkpoint = torch .load (ar_path , map_location = torch .device ("cpu" ), weights_only = True )
885
889
886
890
# strict set False
887
891
# due to removed `bias` and `masked_bias` changes in Transformers
888
892
self .autoregressive .load_state_dict (checkpoint , strict = False )
889
893
890
894
if os .path .exists (diff_path ):
891
- self .diffusion .load_state_dict (torch .load (diff_path ), strict = strict )
895
+ self .diffusion .load_state_dict (torch .load (diff_path , weights_only = True ), strict = strict )
892
896
893
897
if os .path .exists (clvp_path ):
894
- self .clvp .load_state_dict (torch .load (clvp_path ), strict = strict )
898
+ self .clvp .load_state_dict (torch .load (clvp_path , weights_only = True ), strict = strict )
895
899
896
900
if os .path .exists (vocoder_checkpoint_path ):
897
901
self .vocoder .load_state_dict (
898
902
config .model_args .vocoder .value .optionally_index (
899
903
torch .load (
900
904
vocoder_checkpoint_path ,
901
905
map_location = torch .device ("cpu" ),
906
+ weights_only = True ,
902
907
)
903
908
)
904
909
)
0 commit comments