@@ -103,6 +103,52 @@ def chunk_text(self, text: str, max_words: int = 30) -> List[str]:
103
103
chunks .append (" " .join (current_chunk ))
104
104
return chunks
105
105
106
+ def generate_result (
107
+ self , audio , start_time : float , token_count : int , segment_idx : int , ** kwargs
108
+ ) -> GenerationResult :
109
+ samples = audio .shape [0 ] if audio is not None else 0
110
+ assert samples > 0 , "No audio generated"
111
+
112
+ sample_rate = (
113
+ self .config .sample_rate
114
+ if kwargs .get ("sample_rate" ) is None
115
+ else kwargs .get ("sample_rate" )
116
+ )
117
+ audio_duration_seconds = samples / sample_rate
118
+
119
+ elapsed_time = time .perf_counter () - start_time
120
+ rtf = audio_duration_seconds / elapsed_time
121
+
122
+ duration_mins = int (audio_duration_seconds // 60 )
123
+ duration_secs = int (audio_duration_seconds % 60 )
124
+ duration_ms = int ((audio_duration_seconds % 1 ) * 1000 )
125
+ duration_hours = int (audio_duration_seconds // 3600 )
126
+ duration_str = f"{ duration_hours :02d} :{ duration_mins :02d} :{ duration_secs :02d} .{ duration_ms :03d} "
127
+
128
+ return GenerationResult (
129
+ audio = audio ,
130
+ samples = samples ,
131
+ sample_rate = sample_rate ,
132
+ segment_idx = segment_idx ,
133
+ token_count = token_count ,
134
+ audio_duration = duration_str ,
135
+ real_time_factor = rtf ,
136
+ prompt = {
137
+ "tokens" : token_count ,
138
+ "tokens-per-sec" : (
139
+ round (token_count / elapsed_time , 2 ) if elapsed_time > 0 else 0
140
+ ),
141
+ },
142
+ audio_samples = {
143
+ "samples" : samples ,
144
+ "samples-per-sec" : (
145
+ round (samples / elapsed_time , 2 ) if elapsed_time > 0 else 0
146
+ ),
147
+ },
148
+ processing_time_seconds = elapsed_time ,
149
+ peak_memory_usage = mx .get_peak_memory () / 1e9 ,
150
+ )
151
+
106
152
def generate (
107
153
self ,
108
154
text ,
@@ -113,6 +159,8 @@ def generate(
113
159
max_tokens : int = 1200 ,
114
160
verbose : bool = False ,
115
161
ref_audio : Optional [str ] = None ,
162
+ stream : bool = False ,
163
+ streaming_interval : float = 2.0 ,
116
164
** kwargs ,
117
165
):
118
166
@@ -135,8 +183,6 @@ def generate(
135
183
kwargs .get ("repetition_context_size" , 64 ),
136
184
)
137
185
138
- all_audio = []
139
-
140
186
for prompt in prompts :
141
187
completion_prompt = self .prompt_processor .get_completion_prompt (
142
188
prompt , speaker
@@ -146,7 +192,12 @@ def generate(
146
192
)
147
193
input_length = input_ids .shape [1 ]
148
194
149
- time_start = time .time ()
195
+ generated_token_count = 0
196
+ yielded_token_count = 0
197
+ streaming_token_interval = int (streaming_interval * 137.5 )
198
+ yielded_frame_count = 0
199
+
200
+ time_start = time .perf_counter ()
150
201
151
202
for i , response in enumerate (
152
203
tqdm (
@@ -164,63 +215,41 @@ def generate(
164
215
):
165
216
next_token = mx .array ([response .token ])
166
217
input_ids = mx .concatenate ([input_ids , next_token [None , :]], axis = 1 )
218
+ generated_token_count += 1
219
+
220
+ # send a partial result in streaming mode
221
+ if stream and generated_token_count % streaming_token_interval == 0 :
222
+ output_ids = input_ids [:, input_length :].tolist ()[0 ]
223
+ output = self .prompt_processor .extract_audio_from_tokens (output_ids )
224
+ audio = self .audio_processor .audio_codec .decode (mx .array ([output ]))[
225
+ - 1 , - 1 , :
226
+ ]
227
+
228
+ yield self .generate_result (
229
+ audio = audio [yielded_frame_count :],
230
+ start_time = time_start ,
231
+ token_count = len (output_ids ) - yielded_token_count ,
232
+ segment_idx = i ,
233
+ ** kwargs ,
234
+ )
235
+ yielded_token_count = len (output_ids )
236
+ yielded_frame_count = audio .shape [0 ]
237
+ time_start = time .perf_counter ()
167
238
168
239
output_ids = input_ids [:, input_length :].tolist ()[0 ]
169
240
output = self .prompt_processor .extract_audio_from_tokens (output_ids )
170
- audio = self .audio_processor .audio_codec .decode (mx .array ([output ])).squeeze (
171
- 0
172
- )
173
- all_audio .append (audio )
174
-
175
- time_end = time .time ()
176
-
177
- for i in range (len (all_audio )):
178
- audio = all_audio [i ][0 ]
179
241
180
- samples = audio .shape [0 ] if audio is not None else 0
181
- assert samples > 0 , "No audio generated"
182
-
183
- token_count = input_ids .shape [1 ] if input_ids is not None else 0
184
-
185
- sample_rate = (
186
- self .config .sample_rate
187
- if kwargs .get ("sample_rate" ) is None
188
- else kwargs .get ("sample_rate" )
189
- )
190
- audio_duration_seconds = samples / sample_rate
191
-
192
- elapsed_time = time_end - time_start
193
- rtf = audio_duration_seconds / elapsed_time
194
-
195
- duration_mins = int (audio_duration_seconds // 60 )
196
- duration_secs = int (audio_duration_seconds % 60 )
197
- duration_ms = int ((audio_duration_seconds % 1 ) * 1000 )
198
- duration_hours = int (audio_duration_seconds // 3600 )
199
- duration_str = f"{ duration_hours :02d} :{ duration_mins :02d} :{ duration_secs :02d} .{ duration_ms :03d} "
200
-
201
- yield GenerationResult (
202
- audio = audio ,
203
- samples = samples ,
204
- sample_rate = sample_rate ,
205
- segment_idx = i ,
206
- token_count = token_count ,
207
- audio_duration = duration_str ,
208
- real_time_factor = rtf ,
209
- prompt = {
210
- "tokens" : token_count ,
211
- "tokens-per-sec" : (
212
- round (token_count / elapsed_time , 2 ) if elapsed_time > 0 else 0
213
- ),
214
- },
215
- audio_samples = {
216
- "samples" : samples ,
217
- "samples-per-sec" : (
218
- round (samples / elapsed_time , 2 ) if elapsed_time > 0 else 0
219
- ),
220
- },
221
- processing_time_seconds = time_end - time_start ,
222
- peak_memory_usage = mx .get_peak_memory () / 1e9 ,
223
- )
242
+ audio = self .audio_processor .audio_codec .decode (mx .array ([output ]))[
243
+ - 1 , - 1 , :
244
+ ]
245
+ if audio .shape [0 ] > yielded_frame_count :
246
+ yield self .generate_result (
247
+ audio = audio [yielded_frame_count :],
248
+ start_time = time_start ,
249
+ token_count = len (output_ids ) - yielded_token_count ,
250
+ segment_idx = i ,
251
+ ** kwargs ,
252
+ )
224
253
225
254
# Clear cache after each segment to avoid memory leaks
226
255
mx .clear_cache ()
0 commit comments