Skip to content

Commit e97735c

Browse files
shaviteginhard
authored andcommitted
fix: load weights only in torch.load
1 parent 3e1e2b8 commit e97735c

15 files changed

+27
-22
lines changed

TTS/tts/layers/bark/load_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def load_model(ckpt_path, device, config, model_type="text"):
118118
logger.info(f"{model_type} model not found, downloading...")
119119
_download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR)
120120

121-
checkpoint = torch.load(ckpt_path, map_location=device)
121+
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
122122
# this is a hack
123123
model_args = checkpoint["model_args"]
124124
if "input_vocab_size" not in model_args:

TTS/tts/layers/tortoise/arch_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def __init__(
332332
self.mel_norm_file = mel_norm_file
333333
if self.mel_norm_file is not None:
334334
with fsspec.open(self.mel_norm_file) as f:
335-
self.mel_norms = torch.load(f)
335+
self.mel_norms = torch.load(f, weights_only=True)
336336
else:
337337
self.mel_norms = None
338338

TTS/tts/layers/tortoise/audio_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []):
124124
voices = get_voices(extra_voice_dirs)
125125
paths = voices[voice]
126126
if len(paths) == 1 and paths[0].endswith(".pth"):
127-
return None, torch.load(paths[0])
127+
return None, torch.load(paths[0], weights_only=True)
128128
else:
129129
conds = []
130130
for cond_path in paths:

TTS/tts/layers/xtts/dvae.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def dvae_wav_to_mel(
4646
mel = mel_stft(wav)
4747
mel = torch.log(torch.clamp(mel, min=1e-5))
4848
if mel_norms is None:
49-
mel_norms = torch.load(mel_norms_file, map_location=device)
49+
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
5050
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
5151
return mel
5252

TTS/tts/layers/xtts/hifigan_decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def remove_weight_norm(self):
328328
def load_checkpoint(
329329
self, config, checkpoint_path, eval=False, cache=False
330330
): # pylint: disable=unused-argument, redefined-builtin
331-
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
331+
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
332332
self.load_state_dict(state["model"])
333333
if eval:
334334
self.eval()

TTS/tts/layers/xtts/trainer/gpt_trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(self, config: Coqpit):
9191

9292
# load GPT if available
9393
if self.args.gpt_checkpoint:
94-
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
94+
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=True)
9595
# deal with coqui Trainer exported model
9696
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
9797
logger.info("Coqui Trainer checkpoint detected! Converting it!")
@@ -184,7 +184,7 @@ def __init__(self, config: Coqpit):
184184

185185
self.dvae.eval()
186186
if self.args.dvae_checkpoint:
187-
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
187+
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=True)
188188
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
189189
logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
190190
else:

TTS/tts/layers/xtts/xtts_manager.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
class SpeakerManager:
55
def __init__(self, speaker_file_path=None):
6-
self.speakers = torch.load(speaker_file_path)
6+
self.speakers = torch.load(speaker_file_path, weights_only=True)
77

88
@property
99
def name_to_id(self):

TTS/tts/models/neuralhmm_tts.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def update_mean_std(self, statistics_dict: Dict):
107107

108108
def preprocess_batch(self, text, text_len, mels, mel_len):
109109
if self.mean.item() == 0 or self.std.item() == 1:
110-
statistics_dict = torch.load(self.mel_statistics_parameter_path)
110+
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
111111
self.update_mean_std(statistics_dict)
112112

113113
mels = self.normalize(mels)
@@ -292,7 +292,7 @@ def on_init_start(self, trainer):
292292
"Data parameters found for: %s. Loading mel normalization parameters...",
293293
trainer.config.mel_statistics_parameter_path,
294294
)
295-
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
295+
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
296296
data_mean, data_std, init_transition_prob = (
297297
statistics["mean"],
298298
statistics["std"],

TTS/tts/models/overflow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def update_mean_std(self, statistics_dict: Dict):
120120

121121
def preprocess_batch(self, text, text_len, mels, mel_len):
122122
if self.mean.item() == 0 or self.std.item() == 1:
123-
statistics_dict = torch.load(self.mel_statistics_parameter_path)
123+
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
124124
self.update_mean_std(statistics_dict)
125125

126126
mels = self.normalize(mels)
@@ -308,7 +308,7 @@ def on_init_start(self, trainer):
308308
"Data parameters found for: %s. Loading mel normalization parameters...",
309309
trainer.config.mel_statistics_parameter_path,
310310
)
311-
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
311+
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
312312
data_mean, data_std, init_transition_prob = (
313313
statistics["mean"],
314314
statistics["std"],

TTS/tts/models/tortoise.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def classify_audio_clip(clip, model_dir):
170170
kernel_size=5,
171171
distribute_zero_label=False,
172172
)
173-
classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu")))
173+
classifier.load_state_dict(
174+
torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"), weights_only=True)
175+
)
174176
clip = clip.cpu().unsqueeze(0)
175177
results = F.softmax(classifier(clip), dim=-1)
176178
return results[0][0]
@@ -488,13 +490,15 @@ def get_random_conditioning_latents(self):
488490
torch.load(
489491
os.path.join(self.models_dir, "rlg_auto.pth"),
490492
map_location=torch.device("cpu"),
493+
weights_only=True,
491494
)
492495
)
493496
self.rlg_diffusion = RandomLatentConverter(2048).eval()
494497
self.rlg_diffusion.load_state_dict(
495498
torch.load(
496499
os.path.join(self.models_dir, "rlg_diffuser.pth"),
497500
map_location=torch.device("cpu"),
501+
weights_only=True,
498502
)
499503
)
500504
with torch.no_grad():
@@ -881,24 +885,25 @@ def load_checkpoint(
881885

882886
if os.path.exists(ar_path):
883887
# remove keys from the checkpoint that are not in the model
884-
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"))
888+
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=True)
885889

886890
# strict set False
887891
# due to removed `bias` and `masked_bias` changes in Transformers
888892
self.autoregressive.load_state_dict(checkpoint, strict=False)
889893

890894
if os.path.exists(diff_path):
891-
self.diffusion.load_state_dict(torch.load(diff_path), strict=strict)
895+
self.diffusion.load_state_dict(torch.load(diff_path, weights_only=True), strict=strict)
892896

893897
if os.path.exists(clvp_path):
894-
self.clvp.load_state_dict(torch.load(clvp_path), strict=strict)
898+
self.clvp.load_state_dict(torch.load(clvp_path, weights_only=True), strict=strict)
895899

896900
if os.path.exists(vocoder_checkpoint_path):
897901
self.vocoder.load_state_dict(
898902
config.model_args.vocoder.value.optionally_index(
899903
torch.load(
900904
vocoder_checkpoint_path,
901905
map_location=torch.device("cpu"),
906+
weights_only=True,
902907
)
903908
)
904909
)

TTS/tts/models/xtts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def wav_to_mel_cloning(
6565
mel = mel_stft(wav)
6666
mel = torch.log(torch.clamp(mel, min=1e-5))
6767
if mel_norms is None:
68-
mel_norms = torch.load(mel_norms_file, map_location=device)
68+
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
6969
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
7070
return mel
7171

TTS/tts/utils/fairseq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def rehash_fairseq_vits_checkpoint(checkpoint_file):
5-
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"]
5+
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)["model"]
66
new_chk = {}
77
for k, v in chk.items():
88
if "enc_p." in k:

TTS/tts/utils/managers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def load_file(path: str):
1717
return json.load(f)
1818
elif path.endswith(".pth"):
1919
with fsspec.open(path, "rb") as f:
20-
return torch.load(f, map_location="cpu")
20+
return torch.load(f, map_location="cpu", weights_only=True)
2121
else:
2222
raise ValueError("Unsupported file type")
2323

TTS/vc/modules/freevc/wavlm/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def get_wavlm(device="cpu"):
2626
logger.info("Downloading WavLM model to %s ...", output_path)
2727
urllib.request.urlretrieve(model_uri, output_path)
2828

29-
checkpoint = torch.load(output_path, map_location=torch.device(device))
29+
checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=True)
3030
cfg = WavLMConfig(checkpoint["cfg"])
3131
wavlm = WavLM(cfg).to(device)
3232
wavlm.load_state_dict(checkpoint["model"])

notebooks/TestAttention.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@
119119
"\n",
120120
"# load model state\n",
121121
"if use_cuda:\n",
122-
" cp = torch.load(MODEL_PATH)\n",
122+
" cp = torch.load(MODEL_PATH, weights_only=True)\n",
123123
"else:\n",
124-
" cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)\n",
124+
" cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage, weights_only=True)\n",
125125
"\n",
126126
"# load the model\n",
127127
"model.load_state_dict(cp['model'])\n",

0 commit comments

Comments
 (0)