|
7 | 7 | import re
|
8 | 8 | import sys
|
9 | 9 | from concurrent.futures import ThreadPoolExecutor
|
10 |
| -from typing import Any, Callable, Dict, List, Literal, Optional, Union |
| 10 | +from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast |
11 | 11 |
|
12 | 12 | from haystack import component, default_from_dict, default_to_dict, logging
|
13 | 13 | from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
|
|
31 | 31 |
|
32 | 32 | with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_and_transformers_import:
|
33 | 33 | from huggingface_hub import model_info
|
34 |
| - from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteriaList, pipeline |
| 34 | + from transformers import StoppingCriteriaList, pipeline |
| 35 | + from transformers.tokenization_utils import PreTrainedTokenizer |
| 36 | + from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
35 | 37 |
|
36 | 38 | from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
|
37 | 39 | HFTokenStreamingHandler,
|
@@ -555,6 +557,9 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
|
555 | 557 | hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
|
556 | 558 | )
|
557 | 559 |
|
| 560 | + # prepared_prompt is a string, but transformers has some type issues |
| 561 | + prepared_prompt = cast(str, prepared_prompt) |
| 562 | + |
558 | 563 | # Avoid some unnecessary warnings in the generation pipeline call
|
559 | 564 | generation_kwargs["pad_token_id"] = (
|
560 | 565 | generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
@@ -607,6 +612,9 @@ async def _run_non_streaming_async( # pylint: disable=too-many-positional-argum
|
607 | 612 | tools=[tc.tool_spec for tc in tools] if tools else None,
|
608 | 613 | )
|
609 | 614 |
|
| 615 | + # prepared_prompt is a string, but transformers has some type issues |
| 616 | + prepared_prompt = cast(str, prepared_prompt) |
| 617 | + |
610 | 618 | # Avoid some unnecessary warnings in the generation pipeline call
|
611 | 619 | generation_kwargs["pad_token_id"] = (
|
612 | 620 | generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
|
|
0 commit comments