Skip to content

Commit 5688090

Browse files
committed
chore(stream_generator): address lint issues
1 parent 60aaf9b commit 5688090

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

TTS/tts/layers/xtts/__init__.py

Whitespace-only changes.

TTS/tts/layers/xtts/stream_generator.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
import random
66
import warnings
7-
from typing import Callable, List, Optional, Union
7+
from typing import Callable, Optional, Union
88

99
import numpy as np
1010
import torch
@@ -21,10 +21,11 @@
2121
PreTrainedModel,
2222
StoppingCriteriaList,
2323
)
24+
from transformers.generation.stopping_criteria import validate_stopping_criteria
2425
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
2526

2627

27-
def setup_seed(seed):
28+
def setup_seed(seed: int) -> None:
2829
if seed == -1:
2930
return
3031
torch.manual_seed(seed)
@@ -49,9 +50,9 @@ def generate( # noqa: PLR0911
4950
generation_config: Optional[StreamGenerationConfig] = None,
5051
logits_processor: Optional[LogitsProcessorList] = None,
5152
stopping_criteria: Optional[StoppingCriteriaList] = None,
52-
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
53+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
5354
synced_gpus: Optional[bool] = False,
54-
seed=0,
55+
seed: int = 0,
5556
**kwargs,
5657
) -> Union[GenerateOutput, torch.LongTensor]:
5758
r"""
@@ -90,7 +91,7 @@ def generate( # noqa: PLR0911
9091
Custom stopping criteria that complement the default stopping criteria built from arguments and a
9192
generation config. If a stopping criteria is passed that is already created with the arguments or a
9293
generation config an error is thrown. This feature is intended for advanced users.
93-
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
94+
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
9495
If provided, this function constraints the beam search to allowed tokens only at each step. If not
9596
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
9697
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
@@ -568,7 +569,7 @@ def generate( # noqa: PLR0911
568569

569570
def typeerror():
570571
raise ValueError(
571-
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
572+
"`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`"
572573
f"of positive integers, but is {generation_config.force_words_ids}."
573574
)
574575

@@ -640,7 +641,7 @@ def sample_stream(
640641
logits_warper: Optional[LogitsProcessorList] = None,
641642
max_length: Optional[int] = None,
642643
pad_token_id: Optional[int] = None,
643-
eos_token_id: Optional[Union[int, List[int]]] = None,
644+
eos_token_id: Optional[Union[int, list[int]]] = None,
644645
output_attentions: Optional[bool] = None,
645646
output_hidden_states: Optional[bool] = None,
646647
output_scores: Optional[bool] = None,

0 commit comments

Comments
 (0)