Skip to content

Commit 360433f

Browse files
committed
Fix tools were not being used
1 parent bf350e3 commit 360433f

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

Diff for: src/distilabel/llms/huggingface/inference_endpoints.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import random
1717
import sys
1818
import warnings
19-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
19+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
2020

2121
from pydantic import (
2222
Field,
@@ -171,9 +171,6 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided(
171171
" or dedicated inference endpoints, respectively."
172172
)
173173

174-
if self.model_id and self.tokenizer_id is None:
175-
self.tokenizer_id = self.model_id
176-
177174
if self.use_magpie_template and self.tokenizer_id is None:
178175
raise ValueError(
179176
"`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
@@ -392,7 +389,7 @@ async def _generate_with_chat_completion(
392389
seed: Optional[int] = None,
393390
stop_sequences: Optional[List[str]] = None,
394391
temperature: float = 1.0,
395-
tool_choice: Optional[Dict[str, str]] = None,
392+
tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
396393
tool_prompt: Optional[str] = None,
397394
tools: Optional[List[Dict[str, Any]]] = None,
398395
top_p: Optional[float] = None,
@@ -463,7 +460,7 @@ async def agenerate( # type: ignore
463460
seed: Optional[int] = None,
464461
stop_sequences: Optional[List[str]] = None,
465462
temperature: float = 1.0,
466-
tool_choice: Optional[Dict[str, str]] = None,
463+
tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
467464
tool_prompt: Optional[str] = None,
468465
tools: Optional[List[Dict[str, Any]]] = None,
469466
top_p: Optional[float] = None,
@@ -502,10 +499,9 @@ async def agenerate( # type: ignore
502499
`tokenizer.eos_token` if available.
503500
temperature: the temperature to use for the generation. Defaults to `1.0`.
504501
tool_choice: the name of the tool the model should call. It can be a dictionary
505-
like `{"function_name": "my_tool"}`. If not provided, then the model will
506-
automatically choose which tool to use. This argument is exclusive to the
507-
`chat_completion` method and will be used only if `tokenizer_id` is `None`.
508-
Defaults to `None`.
502+
like `{"function_name": "my_tool"}` or "auto". If not provided, then the
503+
model won't use any tool. This argument is exclusive to the `chat_completion`
504+
method and will be used only if `tokenizer_id` is `None`. Defaults to `None`.
509505
tool_prompt: A prompt to be appended before the tools. This argument is exclusive
510506
to the `chat_completion` method and will be used only if `tokenizer_id`
511507
is `None`. Defauls to `None`.

0 commit comments

Comments
 (0)