Commit 2a28123 1 parent 4b6da4e commit 2a28123 Copy full SHA for 2a28123
File tree 1 file changed +5
-6
lines changed
1 file changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -201,16 +201,15 @@ def generate( # noqa: PLR0911
201
201
202
202
# 5. Prepare `input_ids` which will be used for auto-regressive generation
203
203
if self .config .is_encoder_decoder :
204
- input_ids = self ._prepare_decoder_input_ids_for_generation (
205
- batch_size ,
206
- decoder_start_token_id = generation_config .decoder_start_token_id ,
207
- bos_token_id = generation_config .bos_token_id ,
204
+ input_ids , model_kwargs = self ._prepare_decoder_input_ids_for_generation (
205
+ batch_size = batch_size ,
206
+ model_input_name = model_input_name ,
208
207
model_kwargs = model_kwargs ,
208
+ decoder_start_token_id = generation_config .decoder_start_token_id ,
209
209
device = inputs_tensor .device ,
210
210
)
211
211
else :
212
- # if decoder-only then inputs_tensor has to be `input_ids`
213
- input_ids = inputs_tensor
212
+ input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs .pop ("input_ids" )
214
213
215
214
# 6. Prepare `max_length` depending on other stopping criteria.
216
215
input_ids_seq_length = input_ids .shape [- 1 ]
You can’t perform that action at this time.
0 commit comments