Skip to content

Commit ffa080e

Browse files
authored
chore: fix mypy issues due to transformers 4.51.1 (#9198)
1 parent 7789876 commit ffa080e

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

haystack/components/embedders/hugging_face_api_document_embedder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[
249249
logger.warning(msg)
250250
normalize = None
251251

252-
all_embeddings = []
252+
all_embeddings: List = []
253253
for i in tqdm(
254254
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
255255
):

haystack/components/generators/chat/hugging_face_local.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
import sys
99
from concurrent.futures import ThreadPoolExecutor
10-
from typing import Any, Callable, Dict, List, Literal, Optional, Union
10+
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
1111

1212
from haystack import component, default_from_dict, default_to_dict, logging
1313
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
@@ -31,7 +31,9 @@
3131

3232
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_and_transformers_import:
3333
from huggingface_hub import model_info
34-
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteriaList, pipeline
34+
from transformers import StoppingCriteriaList, pipeline
35+
from transformers.tokenization_utils import PreTrainedTokenizer
36+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
3537

3638
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
3739
HFTokenStreamingHandler,
@@ -555,6 +557,9 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
555557
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
556558
)
557559

560+
# prepared_prompt is a string, but transformers has some type issues
561+
prepared_prompt = cast(str, prepared_prompt)
562+
558563
# Avoid some unnecessary warnings in the generation pipeline call
559564
generation_kwargs["pad_token_id"] = (
560565
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
@@ -607,6 +612,9 @@ async def _run_non_streaming_async( # pylint: disable=too-many-positional-argum
607612
tools=[tc.tool_spec for tc in tools] if tools else None,
608613
)
609614

615+
# prepared_prompt is a string, but transformers has some type issues
616+
prepared_prompt = cast(str, prepared_prompt)
617+
610618
# Avoid some unnecessary warnings in the generation pipeline call
611619
generation_kwargs["pad_token_id"] = (
612620
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id

haystack/utils/hf.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,9 @@ def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]:
283283

284284

285285
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
286-
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer
286+
from transformers import StoppingCriteria, TextStreamer
287+
from transformers.tokenization_utils import PreTrainedTokenizer
288+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
287289

288290
torch_import.check()
289291
transformers_import.check()

0 commit comments

Comments
 (0)