Skip to content

Add total token counts to models #1282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ def _setup_managed_agents(self, managed_agents: list | None = None) -> None:
self.managed_agents = {agent.name: agent for agent in managed_agents}

def _setup_tools(self, tools, add_base_tools):
assert all(isinstance(tool, Tool) for tool in tools), "All elements must be instance of Tool (or a subclass)"
invalid_tools = [tool for tool in tools if not isinstance(tool, Tool)]
if invalid_tools:
raise ValueError(f"The following tools are not instances of Tool (or a subclass): {invalid_tools}")
self.tools = {tool.name: tool for tool in tools}
if add_base_tools:
self.tools.update(
Expand Down
14 changes: 11 additions & 3 deletions src/smolagents/default_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any
from typing import Any, List

from .local_python_executor import (
BASE_BUILTIN_MODULES,
Expand Down Expand Up @@ -130,7 +130,13 @@ class GoogleSearchTool(Tool):
"query": {"type": "string", "description": "The search query to perform."},
"filter_year": {
"type": "integer",
"description": "Optionally restrict results to a certain year",
"description": "Optional. Filter on one full year (e.g. 2023). Leave None to not filter. Year must be in the format YYYY.",
"nullable": True,
},
"n_results": {
"type": "integer",
"description": "The number of results to return.",
"default": 50,
"nullable": True,
},
}
Expand All @@ -150,8 +156,9 @@ def __init__(self, provider: str = "serpapi"):
self.api_key = os.getenv(api_key_env_name)
if self.api_key is None:
raise ValueError(f"Missing API key. Make sure you have '{api_key_env_name}' in your env variables.")


def forward(self, query: str, filter_year: int | None = None) -> str:
def forward(self, query: str, filter_year: int | None = None, n_results: int = 50) -> str:
import requests

if self.provider == "serpapi":
Expand All @@ -166,6 +173,7 @@ def forward(self, query: str, filter_year: int | None = None) -> str:
params = {
"q": query,
"api_key": self.api_key,
"num": n_results,
}
base_url = "https://google.serper.dev/search"
if filter_year is not None:
Expand Down
97 changes: 78 additions & 19 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ def __init__(
self.last_input_token_count: int | None = None
self.last_output_token_count: int | None = None
self.model_id: str | None = model_id
self.total_input_token_count: int = 0
self.total_output_token_count: int = 0

def _prepare_completion_kwargs(
self,
Expand Down Expand Up @@ -363,8 +365,10 @@ def get_token_counts(self) -> dict[str, int]:
if self.last_input_token_count is None or self.last_output_token_count is None:
raise ValueError("Token counts are not available")
return {
"input_token_count": self.last_input_token_count,
"output_token_count": self.last_output_token_count,
"last_input_token_count": self.last_input_token_count,
"last_output_token_count": self.last_output_token_count,
"total_input_token_count": self.total_input_token_count,
"total_output_token_count": self.total_output_token_count,
}

def generate(
Expand Down Expand Up @@ -396,6 +400,18 @@ def generate(

def __call__(self, *args, **kwargs):
return self.generate(*args, **kwargs)

def update_token_counts(self, input_tokens: int, output_tokens: int):
"""Update the total token counts for input and output tokens.

Parameters:
input_tokens (`int`): The number of input tokens.
output_tokens (`int`): The number of output tokens.
"""
self.last_input_token_count = input_tokens
self.last_output_token_count = output_tokens
self.total_input_token_count += input_tokens
self.total_output_token_count += output_tokens

def parse_tool_calls(self, message: ChatMessage) -> ChatMessage:
"""Sometimes APIs do not return the tool call as a specific object, so we need to parse it."""
Expand All @@ -418,6 +434,8 @@ def to_dict(self) -> dict:
**self.kwargs,
"last_input_token_count": self.last_input_token_count,
"last_output_token_count": self.last_output_token_count,
"total_input_token_count": self.total_input_token_count,
"total_output_token_count": self.total_output_token_count,
"model_id": self.model_id,
}
for attribute in [
Expand Down Expand Up @@ -455,6 +473,8 @@ def from_dict(cls, model_dictionary: dict[str, Any]) -> "Model":
)
model_instance.last_input_token_count = model_dictionary.pop("last_input_token_count", None)
model_instance.last_output_token_count = model_dictionary.pop("last_output_token_count", None)
model_instance.total_input_token_count = model_dictionary.pop("total_input_token_count", 0)
model_instance.total_output_token_count = model_dictionary.pop("total_output_token_count", 0)
return model_instance


Expand Down Expand Up @@ -554,8 +574,12 @@ def generate(
sampling_params=sampling_params,
)
output_text = out[0].outputs[0].text
self.last_input_token_count = len(out[0].prompt_token_ids)
self.last_output_token_count = len(out[0].outputs[0].token_ids)

super().update_token_counts(
input_tokens=len(out[0].prompt_token_ids),
output_tokens=len(out[0].outputs[0].token_ids),
)

return ChatMessage(
role=MessageRole.ASSISTANT,
content=output_text,
Expand Down Expand Up @@ -651,16 +675,20 @@ def generate(
add_generation_prompt=True,
)

self.last_input_token_count = len(prompt_ids)
self.last_output_token_count = 0
text = ""
output_tokens = 0
for response in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs):
self.last_output_token_count += 1
output_tokens += 1
text += response.text
if any((stop_index := text.rfind(stop)) != -1 for stop in stops):
text = text[:stop_index]
break

super().update_token_counts(
input_tokens=len(prompt_ids),
output_tokens=output_tokens,
)

return ChatMessage(
role=MessageRole.ASSISTANT, content=text, raw={"out": text, "completion_kwargs": completion_kwargs}
)
Expand Down Expand Up @@ -870,9 +898,11 @@ def generate(
output_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
else:
output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens
self.last_output_token_count = len(generated_tokens)

super().update_token_counts(
input_tokens=count_prompt_tokens,
output_tokens=len(generated_tokens),
)
if stop_sequences is not None:
output_text = remove_stop_sequences(output_text, stop_sequences)

Expand Down Expand Up @@ -912,7 +942,12 @@ def generate_stream(
yield ChatMessageStreamDelta(content=new_text, tool_calls=None)
self.last_output_token_count += 1

self.last_input_token_count = count_prompt_tokens
super().update_token_counts(
input_tokens=count_prompt_tokens,
output_tokens=self.last_output_token_count,
)


thread.join()


Expand Down Expand Up @@ -1030,8 +1065,11 @@ def generate(

response = self.client.completion(**completion_kwargs)

self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
super().update_token_counts(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
)

return ChatMessage.from_dict(
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}),
raw=response,
Expand Down Expand Up @@ -1070,6 +1108,11 @@ def generate_stream(
self.last_input_token_count = event.usage.prompt_tokens
self.last_output_token_count = event.usage.completion_tokens

super().update_token_counts(
input_tokens=self.last_input_token_count,
output_tokens=self.last_output_token_count,
)


class LiteLLMRouterModel(LiteLLMModel):
"""Router‑based client for interacting with the [LiteLLM Python SDK Router](https://docs.litellm.ai/docs/routing).
Expand Down Expand Up @@ -1274,8 +1317,11 @@ def generate(
)
response = self.client.chat_completion(**completion_kwargs)

self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
super().update_token_counts(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
)

return ChatMessage.from_dict(asdict(response.choices[0].message), raw=response)

def generate_stream(
Expand Down Expand Up @@ -1313,6 +1359,10 @@ def generate_stream(
self.last_input_token_count = event.usage.prompt_tokens
self.last_output_token_count = event.usage.completion_tokens

super().update_token_counts(
input_tokens=self.last_input_token_count,
output_tokens=self.last_output_token_count,
)

class HfApiModel(InferenceClientModel):
def __new__(cls, *args, **kwargs):
Expand Down Expand Up @@ -1419,6 +1469,11 @@ def generate_stream(
self.last_input_token_count = event.usage.prompt_tokens
self.last_output_token_count = event.usage.completion_tokens

super().update_token_counts(
input_tokens=self.last_input_token_count,
output_tokens=self.last_output_token_count,
)

def generate(
self,
messages: list[dict[str, str | list[dict]]],
Expand All @@ -1438,8 +1493,11 @@ def generate(
**kwargs,
)
response = self.client.chat.completions.create(**completion_kwargs)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens

super().update_token_counts(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
)

return ChatMessage.from_dict(
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}),
Expand Down Expand Up @@ -1665,9 +1723,10 @@ def generate(
# self.client is created in ApiModel class
response = self.client.converse(**completion_kwargs)

# Get usage
self.last_input_token_count = response["usage"]["inputTokens"]
self.last_output_token_count = response["usage"]["outputTokens"]
super().update_token_counts(
input_tokens=response["usage"]["inputTokens"],
output_tokens=response["usage"]["outputTokens"],
)

# Get first message
response["output"]["message"]["content"] = response["output"]["message"]["content"][0]["text"]
Expand Down