@@ -401,17 +401,17 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
401
401
prompt_speech_feat = torch .zeros (1 , 0 , 80 ), stream = False , speed = 1.0 , ** kwargs ):
402
402
# this_uuid is used to track variables related to this inference thread
403
403
this_uuid = str (uuid .uuid1 ())
404
- # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
405
- if self .use_flow_cache is True :
406
- flow_prompt_speech_token = flow_prompt_speech_token [:, - self .flow_decoder_required_cache_size :]
407
- prompt_speech_feat = prompt_speech_feat [:, - self .flow_decoder_required_cache_size * 2 :]
408
404
with self .lock :
409
405
self .tts_speech_token_dict [this_uuid ], self .llm_end_dict [this_uuid ] = [], False
410
406
self .hift_cache_dict [this_uuid ] = None
411
407
self .flow_cache_dict [this_uuid ] = self .init_flow_cache ()
412
408
p = threading .Thread (target = self .llm_job , args = (text , prompt_text , llm_prompt_speech_token , llm_embedding , this_uuid ))
413
409
p .start ()
414
410
if stream is True :
411
+ assert self .use_flow_cache is True , "set use_flow_cache=True if you want to use stream inference to avoid OOM"
412
+ # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
413
+ flow_prompt_speech_token = flow_prompt_speech_token [:, - self .flow_decoder_required_cache_size :]
414
+ prompt_speech_feat = prompt_speech_feat [:, - self .flow_decoder_required_cache_size * 2 :]
415
415
while True :
416
416
time .sleep (0.1 )
417
417
if len (self .tts_speech_token_dict [this_uuid ]) >= self .token_hop_len + self .flow .pre_lookahead_len :
@@ -442,6 +442,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
442
442
yield {'tts_speech' : this_tts_speech .cpu ()}
443
443
else :
444
444
# deal with all tokens
445
+ assert self .use_flow_cache is False , "set use_flow_cache=False for nonstream inference"
445
446
p .join ()
446
447
this_tts_speech_token = torch .tensor (self .tts_speech_token_dict [this_uuid ]).unsqueeze (dim = 0 )
447
448
this_tts_speech = self .token2wav (token = this_tts_speech_token ,
0 commit comments