Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load weights only in torch.load #77

Merged
merged 4 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ jobs:
- name: Upload coverage data
uses: actions/upload-artifact@v4
with:
include-hidden-files: true
name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }}
path: .coverage.*
if-no-files-found: ignore
Expand Down
1 change: 0 additions & 1 deletion TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
"https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt",
"https://coqui.gateway.scarf.sh/hf/bark/text_2.pt",
"https://coqui.gateway.scarf.sh/hf/bark/config.json",
"https://coqui.gateway.scarf.sh/hf/bark/hubert.pt",
"https://coqui.gateway.scarf.sh/hf/bark/tokenizer.pth"
],
"default_vocoder": null,
Expand Down
26 changes: 26 additions & 0 deletions TTS/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
import _codecs
import importlib.metadata
from collections import defaultdict

import numpy as np
import torch

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam

__version__ = importlib.metadata.version("coqui-tts")


torch.serialization.add_safe_globals([dict, defaultdict, RAdam])

# Bark
torch.serialization.add_safe_globals(
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)

# XTTS
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])
1 change: 0 additions & 1 deletion TTS/tts/configs/bark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def __post_init__(self):
"coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"),
"fine": os.path.join(self.CACHE_DIR, "fine_2.pt"),
"hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"),
"hubert": os.path.join(self.CACHE_DIR, "hubert.pt"),
}
self.SMALL_REMOTE_MODEL_PATHS = {
"text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")},
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/bark/hubert/kmeans_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class CustomHubert(nn.Module):
or you can train your own
"""

def __init__(self, checkpoint_path, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None):
def __init__(self, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None):
super().__init__()
self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of
Expand Down
3 changes: 1 addition & 2 deletions TTS/tts/layers/bark/inference_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,9 @@ def generate_voice(
# generate semantic tokens
# Load the HuBERT model
hubert_manager = HubertManager()
# hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"])
hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])

hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)
hubert_model = CustomHubert().to(model.device)

# Load the CustomTokenizer model
tokenizer = HubertTokenizer.load_from_checkpoint(
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/bark/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def load_model(ckpt_path, device, config, model_type="text"):
logger.info(f"{model_type} model not found, downloading...")
_download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR)

checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
# this is a hack
model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args:
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/tortoise/arch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def __init__(
self.mel_norm_file = mel_norm_file
if self.mel_norm_file is not None:
with fsspec.open(self.mel_norm_file) as f:
self.mel_norms = torch.load(f)
self.mel_norms = torch.load(f, weights_only=True)
else:
self.mel_norms = None

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/tortoise/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []):
voices = get_voices(extra_voice_dirs)
paths = voices[voice]
if len(paths) == 1 and paths[0].endswith(".pth"):
return None, torch.load(paths[0])
return None, torch.load(paths[0], weights_only=True)
else:
conds = []
for cond_path in paths:
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/xtts/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def dvae_wav_to_mel(
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device)
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/xtts/hifigan_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def remove_weight_norm(self):
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
self.load_state_dict(state["model"])
if eval:
self.eval()
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self, config: Coqpit):

# load GPT if available
if self.args.gpt_checkpoint:
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=True)
# deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
logger.info("Coqui Trainer checkpoint detected! Converting it!")
Expand Down Expand Up @@ -184,7 +184,7 @@ def __init__(self, config: Coqpit):

self.dvae.eval()
if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=True)
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
else:
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/xtts/xtts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class SpeakerManager:
def __init__(self, speaker_file_path=None):
self.speakers = torch.load(speaker_file_path)
self.speakers = torch.load(speaker_file_path, weights_only=True)

@property
def name_to_id(self):
Expand Down
3 changes: 0 additions & 3 deletions TTS/tts/models/bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def load_checkpoint(
text_model_path=None,
coarse_model_path=None,
fine_model_path=None,
hubert_model_path=None,
hubert_tokenizer_path=None,
eval=False,
strict=True,
Expand All @@ -266,13 +265,11 @@ def load_checkpoint(
text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt")
coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt")
fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt")
hubert_model_path = hubert_model_path or os.path.join(checkpoint_dir, "hubert.pt")
hubert_tokenizer_path = hubert_tokenizer_path or os.path.join(checkpoint_dir, "tokenizer.pth")

self.config.LOCAL_MODEL_PATHS["text"] = text_model_path
self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path
self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path
self.config.LOCAL_MODEL_PATHS["hubert"] = hubert_model_path
self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"] = hubert_tokenizer_path

self.load_bark_models()
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/models/neuralhmm_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def update_mean_std(self, statistics_dict: Dict):

def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path)
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
self.update_mean_std(statistics_dict)

mels = self.normalize(mels)
Expand Down Expand Up @@ -292,7 +292,7 @@ def on_init_start(self, trainer):
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
data_mean, data_std, init_transition_prob = (
statistics["mean"],
statistics["std"],
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/models/overflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def update_mean_std(self, statistics_dict: Dict):

def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path)
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
self.update_mean_std(statistics_dict)

mels = self.normalize(mels)
Expand Down Expand Up @@ -308,7 +308,7 @@ def on_init_start(self, trainer):
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
data_mean, data_std, init_transition_prob = (
statistics["mean"],
statistics["std"],
Expand Down
13 changes: 9 additions & 4 deletions TTS/tts/models/tortoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def classify_audio_clip(clip, model_dir):
kernel_size=5,
distribute_zero_label=False,
)
classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu")))
classifier.load_state_dict(
torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"), weights_only=True)
)
clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1)
return results[0][0]
Expand Down Expand Up @@ -488,13 +490,15 @@ def get_random_conditioning_latents(self):
torch.load(
os.path.join(self.models_dir, "rlg_auto.pth"),
map_location=torch.device("cpu"),
weights_only=True,
)
)
self.rlg_diffusion = RandomLatentConverter(2048).eval()
self.rlg_diffusion.load_state_dict(
torch.load(
os.path.join(self.models_dir, "rlg_diffuser.pth"),
map_location=torch.device("cpu"),
weights_only=True,
)
)
with torch.no_grad():
Expand Down Expand Up @@ -881,24 +885,25 @@ def load_checkpoint(

if os.path.exists(ar_path):
# remove keys from the checkpoint that are not in the model
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"))
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=True)

# strict set False
# due to removed `bias` and `masked_bias` changes in Transformers
self.autoregressive.load_state_dict(checkpoint, strict=False)

if os.path.exists(diff_path):
self.diffusion.load_state_dict(torch.load(diff_path), strict=strict)
self.diffusion.load_state_dict(torch.load(diff_path, weights_only=True), strict=strict)

if os.path.exists(clvp_path):
self.clvp.load_state_dict(torch.load(clvp_path), strict=strict)
self.clvp.load_state_dict(torch.load(clvp_path, weights_only=True), strict=strict)

if os.path.exists(vocoder_checkpoint_path):
self.vocoder.load_state_dict(
config.model_args.vocoder.value.optionally_index(
torch.load(
vocoder_checkpoint_path,
map_location=torch.device("cpu"),
weights_only=True,
)
)
)
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def wav_to_mel_cloning(
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device)
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/utils/fairseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def rehash_fairseq_vits_checkpoint(checkpoint_file):
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"]
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)["model"]
new_chk = {}
for k, v in chk.items():
if "enc_p." in k:
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/utils/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def load_file(path: str):
return json.load(f)
elif path.endswith(".pth"):
with fsspec.open(path, "rb") as f:
return torch.load(f, map_location="cpu")
return torch.load(f, map_location="cpu", weights_only=True)
else:
raise ValueError("Unsupported file type")

Expand Down
3 changes: 0 additions & 3 deletions TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.models.vits import Vits

# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import save_wav
Expand Down
2 changes: 1 addition & 1 deletion TTS/vc/modules/freevc/wavlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_wavlm(device="cpu"):
logger.info("Downloading WavLM model to %s ...", output_path)
urllib.request.urlretrieve(model_uri, output_path)

checkpoint = torch.load(output_path, map_location=torch.device(device))
checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=True)
cfg = WavLMConfig(checkpoint["cfg"])
wavlm = WavLM(cfg).to(device)
wavlm.load_state_dict(checkpoint["model"])
Expand Down
4 changes: 2 additions & 2 deletions notebooks/TestAttention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@
"\n",
"# load model state\n",
"if use_cuda:\n",
" cp = torch.load(MODEL_PATH)\n",
" cp = torch.load(MODEL_PATH, weights_only=True)\n",
"else:\n",
" cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)\n",
" cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage, weights_only=True)\n",
"\n",
"# load the model\n",
"model.load_state_dict(cp['model'])\n",
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ classifiers = [
]
dependencies = [
# Core
"numpy>=1.24.3,<2.0.0", # TODO: remove upper bound after spacy/thinc release
"numpy>=1.25.2,<2.0.0", # TODO: remove upper bound after spacy/thinc release
"cython>=0.29.30",
"scipy>=1.11.2",
"torch>=2.1",
"torch>=2.4",
"torchaudio",
"soundfile>=0.12.0",
"librosa>=0.10.1",
Expand Down
Loading