Skip to content

CUDA with additional MPS and XPU support #1075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
58 changes: 36 additions & 22 deletions cosyvoice/bin/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)

logging.getLogger("matplotlib").setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -53,13 +54,20 @@ def get_args():

def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s")
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

# Init cosyvoice models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

if torch.cuda.is_available():
device = torch.device("cuda:{}".format(args.gpu))
elif torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.xpu.is_available():
device = torch.device("xpu")
else:
device = torch.device("cpu")

try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
Expand All @@ -74,15 +82,14 @@ def main():

model.load(args.llm_model, args.flow_model, args.hifigan_model)

test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
test_dataset = Dataset(args.prompt_data, data_pipeline=configs["data_pipeline"], mode="inference", shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)

sample_rate = configs['sample_rate']
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
fn = os.path.join(args.result_dir, "wav.scp")
f = open(fn, "w")
with torch.no_grad():
for _, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
Expand All @@ -98,28 +105,35 @@ def main():
speech_feat_len = batch["speech_feat_len"].to(device)
utt_embedding = batch["utt_embedding"].to(device)
spk_embedding = batch["spk_embedding"].to(device)
if args.mode == 'sft':
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
if args.mode == "sft":
model_input = {"text": tts_text_token, "text_len": tts_text_token_len, "llm_embedding": spk_embedding, "flow_embedding": spk_embedding}
else:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': text_token, 'prompt_text_len': text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
model_input = {
"text": tts_text_token,
"text_len": tts_text_token_len,
"prompt_text": text_token,
"prompt_text_len": text_token_len,
"llm_prompt_speech_token": speech_token,
"llm_prompt_speech_token_len": speech_token_len,
"flow_prompt_speech_token": speech_token,
"flow_prompt_speech_token_len": speech_token_len,
"prompt_speech_feat": speech_feat,
"prompt_speech_feat_len": speech_feat_len,
"llm_embedding": utt_embedding,
"flow_embedding": utt_embedding,
}
tts_speeches = []
for model_output in model.tts(**model_input):
tts_speeches.append(model_output['tts_speech'])
tts_speeches.append(model_output["tts_speech"])
tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush()
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
logging.info("Result wav.scp saved in {}".format(fn))


if __name__ == '__main__':
if __name__ == "__main__":
main()
11 changes: 10 additions & 1 deletion cosyvoice/cli/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,16 @@ def __init__(self,
allowed_special: str = 'all'):
self.tokenizer = get_tokenizer()
self.feat_extractor = feat_extractor
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
self.device = torch.device('cuda')
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
elif torch.xpu.is_available():
self.device = torch.device('xpu')
else:
self.device = torch.device('cpu')

option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
Expand Down
48 changes: 42 additions & 6 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,16 @@ def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
self.device = torch.device('cuda')
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
elif torch.xpu.is_available():
self.device = torch.device('xpu')
else:
self.device = torch.device('cpu')

self.llm = llm
self.flow = flow
self.hift = hift
Expand Down Expand Up @@ -81,7 +90,7 @@ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
self.flow.encoder = flow_encoder

def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
assert torch.cuda.is_available() or torch.backends.mps.is_available() or torch.xpu.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model):
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
if os.path.getsize(flow_decoder_estimator_model) == 0:
Expand Down Expand Up @@ -224,7 +233,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()
self.clear_cache()

def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
Expand Down Expand Up @@ -279,7 +288,18 @@ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat,
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()
import gc
gc.collect()
self.clear_cache()

def clear_cache(self):
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()



class CosyVoice2Model(CosyVoiceModel):
Expand All @@ -290,7 +310,17 @@ def __init__(self,
hift: torch.nn.Module,
fp16: bool,
use_flow_cache: bool):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
self.device = torch.device('cuda')
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
elif torch.xpu.is_available():
self.device = torch.device('xpu')
else:
self.device = torch.device('cpu')


self.llm = llm
self.flow = flow
self.hift = hift
Expand Down Expand Up @@ -458,4 +488,10 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
self.llm_end_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()

if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
8 changes: 7 additions & 1 deletion cosyvoice/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,13 @@ def set_all_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available() is True:
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
elif torch.backends.mps.is_available() is True:
torch.mps.manual_seed(seed)
elif torch.xpu.is_available() is True:
torch.xpu.manual_seed(seed)


def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
Expand Down
8 changes: 7 additions & 1 deletion cosyvoice/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def __init__(self, gan: bool = False):
self.step = 0
self.epoch = 0
self.rank = int(os.environ.get('RANK', 0))
self.device = torch.device('cuda:{}'.format(self.rank))
self.device = torch.device("cpu")
if torch.cuda.is_available():
self.device = torch.device('cuda:{}'.format(self.rank))
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
elif torch.xpu.is_available():
self.device = torch.device('xpu')

def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
''' Train one epoch
Expand Down