Skip to content

feat: Add limited support for MPS devices #1129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
if torch.mps.is_available() is True:
load_jit = True
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
Expand Down
12 changes: 11 additions & 1 deletion cosyvoice/cli/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation


def get_torch_device():
if torch.backends.mps.is_available():
return torch.device('mps')
elif torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')

class CosyVoiceFrontEnd:

def __init__(self,
Expand All @@ -47,10 +55,12 @@ def __init__(self,
allowed_special: str = 'all'):
self.tokenizer = get_tokenizer()
self.feat_extractor = feat_extractor
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = get_torch_device()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
if self.device.type == "mps":
logging.warning("ONNXRuntime does not support MPS. ONNX models will run on CPU.")
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
Expand Down
27 changes: 20 additions & 7 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import uuid
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt
from cosyvoice.cli.frontend import get_torch_device


class CosyVoiceModel:
Expand All @@ -31,7 +32,7 @@ def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = get_torch_device()
self.llm = llm
self.flow = flow
self.hift = hift
Expand All @@ -57,7 +58,10 @@ def __init__(self,
# rtf and decoding related
self.stream_scale_factor = 1
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
if torch.cuda.is_available():
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device))
else:
self.llm_context = nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
Expand Down Expand Up @@ -222,7 +226,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()
torch.mps.empty_cache()

def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
Expand Down Expand Up @@ -276,7 +280,13 @@ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat,
self.llm_end_dict.pop(this_uuid)
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()
self.empty_cache()

def empty_cache(self):
if torch.mps.is_available():
torch.mps.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()


class CosyVoice2Model(CosyVoiceModel):
Expand All @@ -286,7 +296,7 @@ def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = torch.device('mps' if torch.mps.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
Expand All @@ -307,7 +317,10 @@ def __init__(self,
self.speech_window = np.hamming(2 * self.source_cache_len)
# rtf and decoding related
self.stream_scale_factor = 1
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
if torch.cuda.is_available():
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device))
else:
self.llm_context = nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
Expand Down Expand Up @@ -408,4 +421,4 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
torch.cuda.empty_cache()
self.empty_cache()