Skip to content

Optimize ways to import torch components #1099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cosyvoice/bin/average_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import glob

import yaml
import torch
from torch import load,device,true_divide,save


def get_args():
Expand Down Expand Up @@ -73,7 +73,7 @@ def main():
assert num == len(path_list)
for path in path_list:
print('Processing {}'.format(path))
states = torch.load(path, map_location=torch.device('cpu'))
states = load(path, map_location=device('cpu'))
for k in states.keys():
if k not in avg.keys():
avg[k] = states[k].clone()
Expand All @@ -83,9 +83,9 @@ def main():
for k in avg.keys():
if avg[k] is not None:
# pytorch 1.6 use true_divide instead of /=
avg[k] = torch.true_divide(avg[k], num)
avg[k] = true_divide(avg[k], num)
print('Saving to {}'.format(args.dst_model))
torch.save(avg, args.dst_model)
save(avg, args.dst_model)


if __name__ == '__main__':
Expand Down
16 changes: 8 additions & 8 deletions cosyvoice/bin/export_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
import torch
from torch import jit,_C
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
Expand All @@ -38,12 +38,12 @@ def get_args():


def get_optimized_script(model, preserved_attrs=[]):
script = torch.jit.script(model)
script = jit.script(model)
if preserved_attrs != []:
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
script = jit.freeze(script, preserved_attrs=preserved_attrs)
else:
script = torch.jit.freeze(script)
script = torch.jit.optimize_for_inference(script)
script = jit.freeze(script)
script = jit.optimize_for_inference(script)
return script


Expand All @@ -52,9 +52,9 @@ def main():
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')

torch._C._jit_set_fusion_strategy([('STATIC', 1)])
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
_C._jit_set_fusion_strategy([('STATIC', 1)])
_C._jit_set_profiling_mode(False)
_C._jit_set_profiling_executor(False)

try:
model = CosyVoice(args.model_dir)
Expand Down
20 changes: 10 additions & 10 deletions cosyvoice/bin/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sys
import onnxruntime
import random
import torch
from torch import rand,ones,onnx,float32,cuda,testing,from_numpy
from tqdm import tqdm
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
Expand All @@ -31,12 +31,12 @@


def get_dummy_input(batch_size, seq_len, out_channels, device):
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
t = torch.rand((batch_size), dtype=torch.float32, device=device)
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
x = rand((batch_size, out_channels, seq_len), dtype=float32, device=device)
mask = ones((batch_size, 1, seq_len), dtype=float32, device=device)
mu = rand((batch_size, out_channels, seq_len), dtype=float32, device=device)
t = rand((batch_size), dtype=float32, device=device)
spks = rand((batch_size, out_channels), dtype=float32, device=device)
cond = rand((batch_size, out_channels, seq_len), dtype=float32, device=device)
return x, mask, mu, t, spks, cond


Expand Down Expand Up @@ -71,7 +71,7 @@ def main():
batch_size, seq_len = 2, 256
out_channels = model.model.flow.decoder.estimator.out_channels
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
torch.onnx.export(
onnx.export(
estimator,
(x, mask, mu, t, spks, cond),
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
Expand All @@ -93,7 +93,7 @@ def main():
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
providers = ['CUDAExecutionProvider' if cuda.is_available() else 'CPUExecutionProvider']
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
sess_options=option, providers=providers)

Expand All @@ -109,7 +109,7 @@ def main():
'cond': cond.cpu().numpy()
}
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
testing.assert_allclose(output_pytorch, from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)


if __name__ == "__main__":
Expand Down
11 changes: 6 additions & 5 deletions cosyvoice/bin/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch import no_grad,concat
from torch.cuda import is_available as cuda_is_available
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
Expand Down Expand Up @@ -57,8 +58,8 @@ def main():
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

# Init cosyvoice models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
use_cuda = args.gpu >= 0 and cuda_is_available()
device = device('cuda' if use_cuda else 'cpu')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)

Expand All @@ -73,7 +74,7 @@ def main():
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
with torch.no_grad():
with no_grad():
for _, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
assert len(utts) == 1, "inference mode only support batchsize 1"
Expand Down Expand Up @@ -101,7 +102,7 @@ def main():
tts_speeches = []
for model_output in model.tts(**model_input):
tts_speeches.append(model_output['tts_speech'])
tts_speeches = torch.concat(tts_speeches, dim=1)
tts_speeches = concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
Expand Down
6 changes: 3 additions & 3 deletions cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tqdm import tqdm
from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download
import torch
from torch.cuda import is_available as cuda_is_available
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.utils.file_utils import logging
Expand All @@ -42,7 +42,7 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
if cuda_is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
if cuda_is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
Expand Down
41 changes: 22 additions & 19 deletions cosyvoice/cli/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
from typing import Generator
import json
import onnxruntime
import torch
from torch import tensor,load,device,int32
from torch.cuda import is_available
import numpy as np
import whisper
from typing import Callable
import torchaudio.compliance.kaldi as kaldi
import torchaudio
from torchaudio.compliance import kaldi
from torchaudio.transforms import Resample
import os
import re
import inflect
from inflect import engine
try:
import ttsfrd
use_ttsfrd = True
Expand All @@ -44,19 +45,21 @@ def __init__(self,
campplus_model: str,
speech_tokenizer_model: str,
spk2info: str = '',
allowed_special: str = 'all'):
allowed_special: str = 'all',
refresh_fst_cache: bool = False):
self.tokenizer = get_tokenizer()
self.feat_extractor = feat_extractor
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device('cuda' if is_available() else 'cpu')
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
providers=["CUDAExecutionProvider" if is_available() else
"CPUExecutionProvider"])
self.refresh_fst_cache = refresh_fst_cache
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device)
self.spk2info = load(spk2info, map_location=self.device,weights_only=True)
else:
self.spk2info = {}
self.allowed_special = allowed_special
Expand All @@ -68,19 +71,19 @@ def __init__(self,
'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyinvg')
else:
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=self.refresh_fst_cache)
self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine()
self.inflect_parser = engine() # from inflect

def _extract_text_token(self, text):
if isinstance(text, Generator):
logging.info('get tts_text generator, will return _extract_text_token_generator!')
# NOTE add a dummy text_token_len for compatibility
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
return self._extract_text_token_generator(text), tensor([0], dtype=int32).to(self.device)
else:
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
text_token = tensor([text_token], dtype=int32).to(self.device)
text_token_len = tensor([text_token.shape[1]], dtype=int32).to(self.device)
return text_token, text_token_len

def _extract_text_token_generator(self, text_generator):
Expand All @@ -97,8 +100,8 @@ def _extract_speech_token(self, speech):
feat.detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name:
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
speech_token = tensor([speech_token], dtype=int32).to(self.device)
speech_token_len = tensor([speech_token.shape[1]], dtype=int32).to(self.device)
return speech_token, speech_token_len

def _extract_spk_embedding(self, speech):
Expand All @@ -109,13 +112,13 @@ def _extract_spk_embedding(self, speech):
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.campplus_session.run(None,
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device)
embedding = tensor([embedding]).to(self.device)
return embedding

def _extract_speech_feat(self, speech):
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
speech_feat_len = tensor([speech_feat.shape[1]], dtype=int32).to(self.device)
return speech_feat, speech_feat_len

def text_normalize(self, text, split=True, text_frontend=True):
Expand Down Expand Up @@ -157,7 +160,7 @@ def frontend_sft(self, tts_text, spk_id):
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
prompt_speech_resample = Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
if resample_rate == 24000:
Expand Down Expand Up @@ -200,7 +203,7 @@ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resampl

def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
prompt_speech_resample = Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
embedding = self._extract_spk_embedding(prompt_speech_16k)
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
Expand Down
Loading