@@ -53,6 +53,7 @@ def __init__(self,
53
53
self .tts_speech_token_dict = {}
54
54
self .llm_end_dict = {}
55
55
self .mel_overlap_dict = {}
56
+ self .flow_cache_dict = {}
56
57
self .hift_cache_dict = {}
57
58
58
59
def load (self , llm_model , flow_model , hift_model ):
@@ -100,15 +101,18 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui
100
101
self .llm_end_dict [uuid ] = True
101
102
102
103
def token2wav (self , token , prompt_token , prompt_feat , embedding , uuid , finalize = False , speed = 1.0 ):
103
- tts_mel = self .flow .inference (token = token .to (self .device ),
104
- token_len = torch .tensor ([token .shape [1 ]], dtype = torch .int32 ).to (self .device ),
105
- prompt_token = prompt_token .to (self .device ),
106
- prompt_token_len = torch .tensor ([prompt_token .shape [1 ]], dtype = torch .int32 ).to (self .device ),
107
- prompt_feat = prompt_feat .to (self .device ),
108
- prompt_feat_len = torch .tensor ([prompt_feat .shape [1 ]], dtype = torch .int32 ).to (self .device ),
109
- embedding = embedding .to (self .device ))
104
+ tts_mel , flow_cache = self .flow .inference (token = token .to (self .device ),
105
+ token_len = torch .tensor ([token .shape [1 ]], dtype = torch .int32 ).to (self .device ),
106
+ prompt_token = prompt_token .to (self .device ),
107
+ prompt_token_len = torch .tensor ([prompt_token .shape [1 ]], dtype = torch .int32 ).to (self .device ),
108
+ prompt_feat = prompt_feat .to (self .device ),
109
+ prompt_feat_len = torch .tensor ([prompt_feat .shape [1 ]], dtype = torch .int32 ).to (self .device ),
110
+ embedding = embedding .to (self .device ),
111
+ flow_cache = self .flow_cache_dict [uuid ])
112
+ self .flow_cache_dict [uuid ] = flow_cache
113
+
110
114
# mel overlap fade in out
111
- if self .mel_overlap_dict [uuid ] is not None :
115
+ if self .mel_overlap_dict [uuid ]. shape [ 2 ] != 0 :
112
116
tts_mel = fade_in_out (tts_mel , self .mel_overlap_dict [uuid ], self .mel_window )
113
117
# append hift cache
114
118
if self .hift_cache_dict [uuid ] is not None :
@@ -145,7 +149,9 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
145
149
this_uuid = str (uuid .uuid1 ())
146
150
with self .lock :
147
151
self .tts_speech_token_dict [this_uuid ], self .llm_end_dict [this_uuid ] = [], False
148
- self .mel_overlap_dict [this_uuid ], self .hift_cache_dict [this_uuid ] = None , None
152
+ self .hift_cache_dict [this_uuid ] = None
153
+ self .mel_overlap_dict [this_uuid ] = torch .zeros (1 , 80 , 0 )
154
+ self .flow_cache_dict [this_uuid ] = torch .zeros (1 , 80 , 0 , 2 )
149
155
p = threading .Thread (target = self .llm_job , args = (text , prompt_text , llm_prompt_speech_token , llm_embedding , this_uuid ))
150
156
p .start ()
151
157
if stream is True :
@@ -201,7 +207,9 @@ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat,
201
207
this_uuid = str (uuid .uuid1 ())
202
208
with self .lock :
203
209
self .tts_speech_token_dict [this_uuid ], self .llm_end_dict [this_uuid ] = source_speech_token .flatten ().tolist (), True
204
- self .mel_overlap_dict [this_uuid ], self .hift_cache_dict [this_uuid ] = None , None
210
+ self .hift_cache_dict [this_uuid ] = None
211
+ self .mel_overlap_dict [this_uuid ] = torch .zeros (1 , 80 , 0 )
212
+ self .flow_cache_dict [this_uuid ] = torch .zeros (1 , 80 , 0 , 2 )
205
213
if stream is True :
206
214
token_hop_len = self .token_min_hop_len
207
215
while True :
0 commit comments