4
4
import inspect
5
5
import random
6
6
import warnings
7
- from typing import Callable , List , Optional , Union
7
+ from typing import Callable , Optional , Union
8
8
9
9
import numpy as np
10
10
import torch
21
21
PreTrainedModel ,
22
22
StoppingCriteriaList ,
23
23
)
24
+ from transformers .generation .stopping_criteria import validate_stopping_criteria
24
25
from transformers .generation .utils import GenerateOutput , SampleOutput , logger
25
26
26
27
27
- def setup_seed (seed ) :
28
+ def setup_seed (seed : int ) -> None :
28
29
if seed == - 1 :
29
30
return
30
31
torch .manual_seed (seed )
@@ -49,9 +50,9 @@ def generate( # noqa: PLR0911
49
50
generation_config : Optional [StreamGenerationConfig ] = None ,
50
51
logits_processor : Optional [LogitsProcessorList ] = None ,
51
52
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
52
- prefix_allowed_tokens_fn : Optional [Callable [[int , torch .Tensor ], List [int ]]] = None ,
53
+ prefix_allowed_tokens_fn : Optional [Callable [[int , torch .Tensor ], list [int ]]] = None ,
53
54
synced_gpus : Optional [bool ] = False ,
54
- seed = 0 ,
55
+ seed : int = 0 ,
55
56
** kwargs ,
56
57
) -> Union [GenerateOutput , torch .LongTensor ]:
57
58
r"""
@@ -90,7 +91,7 @@ def generate( # noqa: PLR0911
90
91
Custom stopping criteria that complement the default stopping criteria built from arguments and a
91
92
generation config. If a stopping criteria is passed that is already created with the arguments or a
92
93
generation config an error is thrown. This feature is intended for advanced users.
93
- prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List [int]]`, *optional*):
94
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list [int]]`, *optional*):
94
95
If provided, this function constraints the beam search to allowed tokens only at each step. If not
95
96
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
96
97
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
@@ -568,7 +569,7 @@ def generate( # noqa: PLR0911
568
569
569
570
def typeerror ():
570
571
raise ValueError (
571
- "`force_words_ids` has to either be a `List[List[List [int]]]` or `List[List [int]]`"
572
+ "`force_words_ids` has to either be a `list[list[list [int]]]` or `list[list [int]]`"
572
573
f"of positive integers, but is { generation_config .force_words_ids } ."
573
574
)
574
575
@@ -640,7 +641,7 @@ def sample_stream(
640
641
logits_warper : Optional [LogitsProcessorList ] = None ,
641
642
max_length : Optional [int ] = None ,
642
643
pad_token_id : Optional [int ] = None ,
643
- eos_token_id : Optional [Union [int , List [int ]]] = None ,
644
+ eos_token_id : Optional [Union [int , list [int ]]] = None ,
644
645
output_attentions : Optional [bool ] = None ,
645
646
output_hidden_states : Optional [bool ] = None ,
646
647
output_scores : Optional [bool ] = None ,
0 commit comments