diff --git a/TTS/tts/layers/xtts/__init__.py b/TTS/tts/layers/xtts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index b7e07589c5..cb09895824 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -4,7 +4,7 @@ import inspect import random import warnings -from typing import Callable, List, Optional, Union +from typing import Callable, Optional, Union import numpy as np import torch @@ -21,10 +21,11 @@ PreTrainedModel, StoppingCriteriaList, ) +from transformers.generation.stopping_criteria import validate_stopping_criteria from transformers.generation.utils import GenerateOutput, SampleOutput, logger -def setup_seed(seed): +def setup_seed(seed: int) -> None: if seed == -1: return torch.manual_seed(seed) @@ -49,9 +50,9 @@ def generate( # noqa: PLR0911 generation_config: Optional[StreamGenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, synced_gpus: Optional[bool] = False, - seed=0, + seed: int = 0, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -90,7 +91,7 @@ def generate( # noqa: PLR0911 Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. - prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned @@ -151,18 +152,7 @@ def generate( # noqa: PLR0911 # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs # inputs_tensor has to be defined @@ -174,6 +164,9 @@ def generate( # noqa: PLR0911 ) batch_size = inputs_tensor.shape[0] + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + # 4. Define other model kwargs model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states @@ -182,7 +175,7 @@ def generate( # noqa: PLR0911 accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, generation_config.pad_token_id, @@ -209,16 +202,15 @@ def generate( # noqa: PLR0911 # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: - input_ids = self._prepare_decoder_input_ids_for_generation( - batch_size, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, model_kwargs=model_kwargs, + decoder_start_token_id=generation_config.decoder_start_token_id, device=inputs_tensor.device, ) else: - # if decoder-only then inputs_tensor has to be `input_ids` - input_ids = inputs_tensor + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] @@ -577,7 +569,7 @@ def generate( # noqa: PLR0911 def typeerror(): raise ValueError( - "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`" f"of positive integers, but is {generation_config.force_words_ids}." ) @@ -649,7 +641,7 @@ def sample_stream( logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, + eos_token_id: Optional[Union[int, list[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, diff --git a/pyproject.toml b/pyproject.toml index dd4ebaed6f..dad0d5ed0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ dependencies = [ "gruut[de,es,fr]==2.2.3", # Tortoise "einops>=0.6.0", - "transformers>=4.33.0,<4.41.0", + "transformers>=4.41.1", # Bark "encodec>=0.1.1", # XTTS