Skip to content

Commit 2a28123

Browse files
committed
1 parent 4b6da4e commit 2a28123

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

TTS/tts/layers/xtts/stream_generator.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -201,16 +201,15 @@ def generate( # noqa: PLR0911
201201

202202
# 5. Prepare `input_ids` which will be used for auto-regressive generation
203203
if self.config.is_encoder_decoder:
204-
input_ids = self._prepare_decoder_input_ids_for_generation(
205-
batch_size,
206-
decoder_start_token_id=generation_config.decoder_start_token_id,
207-
bos_token_id=generation_config.bos_token_id,
204+
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
205+
batch_size=batch_size,
206+
model_input_name=model_input_name,
208207
model_kwargs=model_kwargs,
208+
decoder_start_token_id=generation_config.decoder_start_token_id,
209209
device=inputs_tensor.device,
210210
)
211211
else:
212-
# if decoder-only then inputs_tensor has to be `input_ids`
213-
input_ids = inputs_tensor
212+
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
214213

215214
# 6. Prepare `max_length` depending on other stopping criteria.
216215
input_ids_seq_length = input_ids.shape[-1]

0 commit comments

Comments
 (0)