Skip to content

Commit 9dc559f

Browse files
committed
force set use_flow_cache
1 parent b56dfa2 commit 9dc559f

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

Diff for: cosyvoice/cli/model.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -401,17 +401,17 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
401401
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
402402
# this_uuid is used to track variables related to this inference thread
403403
this_uuid = str(uuid.uuid1())
404-
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
405-
if self.use_flow_cache is True:
406-
flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:]
407-
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:]
408404
with self.lock:
409405
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
410406
self.hift_cache_dict[this_uuid] = None
411407
self.flow_cache_dict[this_uuid] = self.init_flow_cache()
412408
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
413409
p.start()
414410
if stream is True:
411+
assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
412+
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
413+
flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:]
414+
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:]
415415
while True:
416416
time.sleep(0.1)
417417
if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
@@ -442,6 +442,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
442442
yield {'tts_speech': this_tts_speech.cpu()}
443443
else:
444444
# deal with all tokens
445+
assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
445446
p.join()
446447
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
447448
this_tts_speech = self.token2wav(token=this_tts_speech_token,

0 commit comments

Comments
 (0)