Skip to content

Commit 1eb879e

Browse files
lucasnewmanBlaizzy
andauthored
Add streaming support to OuteTTS (#169)
* Add streaming support to OuteTTS. * Formatting. --------- Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
1 parent 8914f08 commit 1eb879e

File tree

2 files changed

+92
-59
lines changed

2 files changed

+92
-59
lines changed

mlx_audio/tts/models/outetts/dac_interface.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,15 @@ def decode(self, codes: mx.array, verbose: bool = False) -> mx.array:
148148
else:
149149
range_fn = range
150150

151+
@mx.compile
152+
def decode_chunk(codes):
153+
z = model.quantizer.from_codes(codes)[0]
154+
r = model.decode(z)
155+
return r
156+
151157
for i in range_fn(0, codes.shape[-1], chunk_length):
152158
c = codes[..., i : i + chunk_length]
153-
z = model.quantizer.from_codes(c)[0]
154-
r = model.decode(z)
155-
recons.append(r)
159+
recons.append(decode_chunk(c))
156160

157161
recons = mx.concatenate(recons, axis=-1)
158162
return process_audio_array(recons.swapaxes(1, 2))

mlx_audio/tts/models/outetts/outetts.py

Lines changed: 85 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,52 @@ def chunk_text(self, text: str, max_words: int = 30) -> List[str]:
103103
chunks.append(" ".join(current_chunk))
104104
return chunks
105105

106+
def generate_result(
107+
self, audio, start_time: float, token_count: int, segment_idx: int, **kwargs
108+
) -> GenerationResult:
109+
samples = audio.shape[0] if audio is not None else 0
110+
assert samples > 0, "No audio generated"
111+
112+
sample_rate = (
113+
self.config.sample_rate
114+
if kwargs.get("sample_rate") is None
115+
else kwargs.get("sample_rate")
116+
)
117+
audio_duration_seconds = samples / sample_rate
118+
119+
elapsed_time = time.perf_counter() - start_time
120+
rtf = audio_duration_seconds / elapsed_time
121+
122+
duration_mins = int(audio_duration_seconds // 60)
123+
duration_secs = int(audio_duration_seconds % 60)
124+
duration_ms = int((audio_duration_seconds % 1) * 1000)
125+
duration_hours = int(audio_duration_seconds // 3600)
126+
duration_str = f"{duration_hours:02d}:{duration_mins:02d}:{duration_secs:02d}.{duration_ms:03d}"
127+
128+
return GenerationResult(
129+
audio=audio,
130+
samples=samples,
131+
sample_rate=sample_rate,
132+
segment_idx=segment_idx,
133+
token_count=token_count,
134+
audio_duration=duration_str,
135+
real_time_factor=rtf,
136+
prompt={
137+
"tokens": token_count,
138+
"tokens-per-sec": (
139+
round(token_count / elapsed_time, 2) if elapsed_time > 0 else 0
140+
),
141+
},
142+
audio_samples={
143+
"samples": samples,
144+
"samples-per-sec": (
145+
round(samples / elapsed_time, 2) if elapsed_time > 0 else 0
146+
),
147+
},
148+
processing_time_seconds=elapsed_time,
149+
peak_memory_usage=mx.get_peak_memory() / 1e9,
150+
)
151+
106152
def generate(
107153
self,
108154
text,
@@ -113,6 +159,8 @@ def generate(
113159
max_tokens: int = 1200,
114160
verbose: bool = False,
115161
ref_audio: Optional[str] = None,
162+
stream: bool = False,
163+
streaming_interval: float = 2.0,
116164
**kwargs,
117165
):
118166

@@ -135,8 +183,6 @@ def generate(
135183
kwargs.get("repetition_context_size", 64),
136184
)
137185

138-
all_audio = []
139-
140186
for prompt in prompts:
141187
completion_prompt = self.prompt_processor.get_completion_prompt(
142188
prompt, speaker
@@ -146,7 +192,12 @@ def generate(
146192
)
147193
input_length = input_ids.shape[1]
148194

149-
time_start = time.time()
195+
generated_token_count = 0
196+
yielded_token_count = 0
197+
streaming_token_interval = int(streaming_interval * 137.5)
198+
yielded_frame_count = 0
199+
200+
time_start = time.perf_counter()
150201

151202
for i, response in enumerate(
152203
tqdm(
@@ -164,63 +215,41 @@ def generate(
164215
):
165216
next_token = mx.array([response.token])
166217
input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1)
218+
generated_token_count += 1
219+
220+
# send a partial result in streaming mode
221+
if stream and generated_token_count % streaming_token_interval == 0:
222+
output_ids = input_ids[:, input_length:].tolist()[0]
223+
output = self.prompt_processor.extract_audio_from_tokens(output_ids)
224+
audio = self.audio_processor.audio_codec.decode(mx.array([output]))[
225+
-1, -1, :
226+
]
227+
228+
yield self.generate_result(
229+
audio=audio[yielded_frame_count:],
230+
start_time=time_start,
231+
token_count=len(output_ids) - yielded_token_count,
232+
segment_idx=i,
233+
**kwargs,
234+
)
235+
yielded_token_count = len(output_ids)
236+
yielded_frame_count = audio.shape[0]
237+
time_start = time.perf_counter()
167238

168239
output_ids = input_ids[:, input_length:].tolist()[0]
169240
output = self.prompt_processor.extract_audio_from_tokens(output_ids)
170-
audio = self.audio_processor.audio_codec.decode(mx.array([output])).squeeze(
171-
0
172-
)
173-
all_audio.append(audio)
174-
175-
time_end = time.time()
176-
177-
for i in range(len(all_audio)):
178-
audio = all_audio[i][0]
179241

180-
samples = audio.shape[0] if audio is not None else 0
181-
assert samples > 0, "No audio generated"
182-
183-
token_count = input_ids.shape[1] if input_ids is not None else 0
184-
185-
sample_rate = (
186-
self.config.sample_rate
187-
if kwargs.get("sample_rate") is None
188-
else kwargs.get("sample_rate")
189-
)
190-
audio_duration_seconds = samples / sample_rate
191-
192-
elapsed_time = time_end - time_start
193-
rtf = audio_duration_seconds / elapsed_time
194-
195-
duration_mins = int(audio_duration_seconds // 60)
196-
duration_secs = int(audio_duration_seconds % 60)
197-
duration_ms = int((audio_duration_seconds % 1) * 1000)
198-
duration_hours = int(audio_duration_seconds // 3600)
199-
duration_str = f"{duration_hours:02d}:{duration_mins:02d}:{duration_secs:02d}.{duration_ms:03d}"
200-
201-
yield GenerationResult(
202-
audio=audio,
203-
samples=samples,
204-
sample_rate=sample_rate,
205-
segment_idx=i,
206-
token_count=token_count,
207-
audio_duration=duration_str,
208-
real_time_factor=rtf,
209-
prompt={
210-
"tokens": token_count,
211-
"tokens-per-sec": (
212-
round(token_count / elapsed_time, 2) if elapsed_time > 0 else 0
213-
),
214-
},
215-
audio_samples={
216-
"samples": samples,
217-
"samples-per-sec": (
218-
round(samples / elapsed_time, 2) if elapsed_time > 0 else 0
219-
),
220-
},
221-
processing_time_seconds=time_end - time_start,
222-
peak_memory_usage=mx.get_peak_memory() / 1e9,
223-
)
242+
audio = self.audio_processor.audio_codec.decode(mx.array([output]))[
243+
-1, -1, :
244+
]
245+
if audio.shape[0] > yielded_frame_count:
246+
yield self.generate_result(
247+
audio=audio[yielded_frame_count:],
248+
start_time=time_start,
249+
token_count=len(output_ids) - yielded_token_count,
250+
segment_idx=i,
251+
**kwargs,
252+
)
224253

225254
# Clear cache after each segment to avoid memory leaks
226255
mx.clear_cache()

0 commit comments

Comments
 (0)