diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index d052f144a..11158ea3a 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -264,7 +264,9 @@ def _setup_managed_agents(self, managed_agents: list | None = None) -> None: self.managed_agents = {agent.name: agent for agent in managed_agents} def _setup_tools(self, tools, add_base_tools): - assert all(isinstance(tool, Tool) for tool in tools), "All elements must be instance of Tool (or a subclass)" + invalid_tools = [tool for tool in tools if not isinstance(tool, Tool)] + if invalid_tools: + raise ValueError(f"The following tools are not instances of Tool (or a subclass): {invalid_tools}") self.tools = {tool.name: tool for tool in tools} if add_base_tools: self.tools.update( diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index d12a38d5a..265154ae5 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Any +from typing import Any, List from .local_python_executor import ( BASE_BUILTIN_MODULES, @@ -130,7 +130,13 @@ class GoogleSearchTool(Tool): "query": {"type": "string", "description": "The search query to perform."}, "filter_year": { "type": "integer", - "description": "Optionally restrict results to a certain year", + "description": "Optional. Filter on one full year (e.g. 2023). Leave None to not filter. Year must be in the format YYYY.", + "nullable": True, + }, + "n_results": { + "type": "integer", + "description": "The number of results to return.", + "default": 50, "nullable": True, }, } @@ -150,8 +156,9 @@ def __init__(self, provider: str = "serpapi"): self.api_key = os.getenv(api_key_env_name) if self.api_key is None: raise ValueError(f"Missing API key. Make sure you have '{api_key_env_name}' in your env variables.") + - def forward(self, query: str, filter_year: int | None = None) -> str: + def forward(self, query: str, filter_year: int | None = None, n_results: int = 50) -> str: import requests if self.provider == "serpapi": @@ -166,6 +173,7 @@ def forward(self, query: str, filter_year: int | None = None) -> str: params = { "q": query, "api_key": self.api_key, + "num": n_results, } base_url = "https://google.serper.dev/search" if filter_year is not None: diff --git a/src/smolagents/models.py b/src/smolagents/models.py index ae930ae5d..01238005e 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -304,6 +304,8 @@ def __init__( self.last_input_token_count: int | None = None self.last_output_token_count: int | None = None self.model_id: str | None = model_id + self.total_input_token_count: int = 0 + self.total_output_token_count: int = 0 def _prepare_completion_kwargs( self, @@ -363,8 +365,10 @@ def get_token_counts(self) -> dict[str, int]: if self.last_input_token_count is None or self.last_output_token_count is None: raise ValueError("Token counts are not available") return { - "input_token_count": self.last_input_token_count, - "output_token_count": self.last_output_token_count, + "last_input_token_count": self.last_input_token_count, + "last_output_token_count": self.last_output_token_count, + "total_input_token_count": self.total_input_token_count, + "total_output_token_count": self.total_output_token_count, } def generate( @@ -396,6 +400,18 @@ def generate( def __call__(self, *args, **kwargs): return self.generate(*args, **kwargs) + + def update_token_counts(self, input_tokens: int, output_tokens: int): + """Update the total token counts for input and output tokens. + + Parameters: + input_tokens (`int`): The number of input tokens. + output_tokens (`int`): The number of output tokens. + """ + self.last_input_token_count = input_tokens + self.last_output_token_count = output_tokens + self.total_input_token_count += input_tokens + self.total_output_token_count += output_tokens def parse_tool_calls(self, message: ChatMessage) -> ChatMessage: """Sometimes APIs do not return the tool call as a specific object, so we need to parse it.""" @@ -418,6 +434,8 @@ def to_dict(self) -> dict: **self.kwargs, "last_input_token_count": self.last_input_token_count, "last_output_token_count": self.last_output_token_count, + "total_input_token_count": self.total_input_token_count, + "total_output_token_count": self.total_output_token_count, "model_id": self.model_id, } for attribute in [ @@ -455,6 +473,8 @@ def from_dict(cls, model_dictionary: dict[str, Any]) -> "Model": ) model_instance.last_input_token_count = model_dictionary.pop("last_input_token_count", None) model_instance.last_output_token_count = model_dictionary.pop("last_output_token_count", None) + model_instance.total_input_token_count = model_dictionary.pop("total_input_token_count", 0) + model_instance.total_output_token_count = model_dictionary.pop("total_output_token_count", 0) return model_instance @@ -554,8 +574,12 @@ def generate( sampling_params=sampling_params, ) output_text = out[0].outputs[0].text - self.last_input_token_count = len(out[0].prompt_token_ids) - self.last_output_token_count = len(out[0].outputs[0].token_ids) + + super().update_token_counts( + input_tokens=len(out[0].prompt_token_ids), + output_tokens=len(out[0].outputs[0].token_ids), + ) + return ChatMessage( role=MessageRole.ASSISTANT, content=output_text, @@ -651,16 +675,20 @@ def generate( add_generation_prompt=True, ) - self.last_input_token_count = len(prompt_ids) - self.last_output_token_count = 0 text = "" + output_tokens = 0 for response in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs): - self.last_output_token_count += 1 + output_tokens += 1 text += response.text if any((stop_index := text.rfind(stop)) != -1 for stop in stops): text = text[:stop_index] break + super().update_token_counts( + input_tokens=len(prompt_ids), + output_tokens=output_tokens, + ) + return ChatMessage( role=MessageRole.ASSISTANT, content=text, raw={"out": text, "completion_kwargs": completion_kwargs} ) @@ -870,9 +898,11 @@ def generate( output_text = self.processor.decode(generated_tokens, skip_special_tokens=True) else: output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) - self.last_input_token_count = count_prompt_tokens - self.last_output_token_count = len(generated_tokens) + super().update_token_counts( + input_tokens=count_prompt_tokens, + output_tokens=len(generated_tokens), + ) if stop_sequences is not None: output_text = remove_stop_sequences(output_text, stop_sequences) @@ -912,7 +942,12 @@ def generate_stream( yield ChatMessageStreamDelta(content=new_text, tool_calls=None) self.last_output_token_count += 1 - self.last_input_token_count = count_prompt_tokens + super().update_token_counts( + input_tokens=count_prompt_tokens, + output_tokens=self.last_output_token_count, + ) + + thread.join() @@ -1030,8 +1065,11 @@ def generate( response = self.client.completion(**completion_kwargs) - self.last_input_token_count = response.usage.prompt_tokens - self.last_output_token_count = response.usage.completion_tokens + super().update_token_counts( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + ) + return ChatMessage.from_dict( response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}), raw=response, @@ -1070,6 +1108,11 @@ def generate_stream( self.last_input_token_count = event.usage.prompt_tokens self.last_output_token_count = event.usage.completion_tokens + super().update_token_counts( + input_tokens=self.last_input_token_count, + output_tokens=self.last_output_token_count, + ) + class LiteLLMRouterModel(LiteLLMModel): """Router‑based client for interacting with the [LiteLLM Python SDK Router](https://docs.litellm.ai/docs/routing). @@ -1274,8 +1317,11 @@ def generate( ) response = self.client.chat_completion(**completion_kwargs) - self.last_input_token_count = response.usage.prompt_tokens - self.last_output_token_count = response.usage.completion_tokens + super().update_token_counts( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + ) + return ChatMessage.from_dict(asdict(response.choices[0].message), raw=response) def generate_stream( @@ -1313,6 +1359,10 @@ def generate_stream( self.last_input_token_count = event.usage.prompt_tokens self.last_output_token_count = event.usage.completion_tokens + super().update_token_counts( + input_tokens=self.last_input_token_count, + output_tokens=self.last_output_token_count, + ) class HfApiModel(InferenceClientModel): def __new__(cls, *args, **kwargs): @@ -1419,6 +1469,11 @@ def generate_stream( self.last_input_token_count = event.usage.prompt_tokens self.last_output_token_count = event.usage.completion_tokens + super().update_token_counts( + input_tokens=self.last_input_token_count, + output_tokens=self.last_output_token_count, + ) + def generate( self, messages: list[dict[str, str | list[dict]]], @@ -1438,8 +1493,11 @@ def generate( **kwargs, ) response = self.client.chat.completions.create(**completion_kwargs) - self.last_input_token_count = response.usage.prompt_tokens - self.last_output_token_count = response.usage.completion_tokens + + super().update_token_counts( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + ) return ChatMessage.from_dict( response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}), @@ -1665,9 +1723,10 @@ def generate( # self.client is created in ApiModel class response = self.client.converse(**completion_kwargs) - # Get usage - self.last_input_token_count = response["usage"]["inputTokens"] - self.last_output_token_count = response["usage"]["outputTokens"] + super().update_token_counts( + input_tokens=response["usage"]["inputTokens"], + output_tokens=response["usage"]["outputTokens"], + ) # Get first message response["output"]["message"]["content"] = response["output"]["message"]["content"][0]["text"]