Skip to content

Commit

Permalink
Merge pull request #109 from idiap/transformers
Browse files Browse the repository at this point in the history
Add compatibility with transformers>=4.43
  • Loading branch information
eginhard authored Oct 21, 2024
2 parents 073f8de + ad435b5 commit b66c782
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 21 deletions.
82 changes: 63 additions & 19 deletions TTS/tts/layers/xtts/stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
PhrasalConstraint,
PreTrainedModel,
StoppingCriteriaList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation.stopping_criteria import validate_stopping_criteria
from transformers.generation.utils import GenerateOutput, SampleOutput, logger


Expand Down Expand Up @@ -152,7 +154,18 @@ def generate( # noqa: PLR0911
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id

# 3. Define model inputs
# inputs_tensor has to be defined
Expand All @@ -164,22 +177,38 @@ def generate( # noqa: PLR0911
)
batch_size = inputs_tensor.shape[0]

device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["cache_position"] = torch.Tensor([0]).to(inputs_tensor.device)

accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs

if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
setattr(
generation_config,
"_pad_token_tensor",
torch.full(
(inputs_tensor.shape[0], inputs_tensor.shape[1]),
generation_config.pad_token_id,
device=inputs_tensor.device,
),
)
setattr(
generation_config,
"_eos_token_tensor",
torch.full(
(inputs_tensor.shape[0], inputs_tensor.shape[1]),
generation_config.eos_token_id,
device=inputs_tensor.device,
),
)
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
generation_config.eos_token_id,
generation_config._pad_token_tensor,
generation_config._eos_token_tensor,
)

# decoder-only models should use left-padding for generation
Expand All @@ -202,15 +231,16 @@ def generate( # noqa: PLR0911

# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
device=inputs_tensor.device,
)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor

# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
Expand Down Expand Up @@ -376,7 +406,7 @@ def generate( # noqa: PLR0911

elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
logits_warper = _get_logits_warper(generation_config)

# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
Expand All @@ -401,7 +431,7 @@ def generate( # noqa: PLR0911
)
elif is_sample_gen_stream_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
logits_warper = _get_logits_warper(generation_config)

# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
Expand Down Expand Up @@ -463,7 +493,7 @@ def generate( # noqa: PLR0911

elif is_beam_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)
logits_warper = _get_logits_warper(generation_config)

if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
Expand Down Expand Up @@ -877,10 +907,10 @@ def init_stream_support():


if __name__ == "__main__":
from transformers import AutoModelForCausalLM, AutoTokenizer

init_stream_support()
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

PreTrainedModel.generate = NewGenerationMixin.generate
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)

tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
Expand Down Expand Up @@ -920,3 +950,17 @@ def init_stream_support():
chunk = tokenizer.decode(x, skip_special_tokens=True)
stream_result += chunk
print(stream_result)


def _get_logits_warper(generation_config: GenerationConfig) -> LogitsProcessorList:

warpers = LogitsProcessorList()

if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))

return warpers
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ classifiers = [
]
dependencies = [
# Core
"numpy>=1.25.2",
"numpy>=1.25.2,<2.0",
"cython>=3.0.0",
"scipy>=1.11.2",
"torch>=2.4",
Expand All @@ -68,7 +68,7 @@ dependencies = [
"gruut[de,es,fr]>=2.4.0",
# Tortoise
"einops>=0.6.0",
"transformers>=4.42.0,<4.43.0",
"transformers>=4.43.0",
# Bark
"encodec>=0.1.1",
# XTTS
Expand Down Expand Up @@ -147,6 +147,9 @@ Discussions = "https://github.com/idiap/coqui-ai-TTS/discussions"
tts = "TTS.bin.synthesize:main"
tts-server = "TTS.server.server:main"

[tool.uv]
constraint-dependencies = ["numba>0.58.0"]

[tool.ruff]
target-version = "py39"
line-length = 120
Expand Down

0 comments on commit b66c782

Please sign in to comment.