Skip to content

Commit

Permalink
🐛 Added back the ability to get tokens for input and output (#686)
Browse files Browse the repository at this point in the history
* adding back the ability to get tokens for input and output

Signed-off-by: Shawn Hurley <shawn@hurley.page>

* Apply suggestions from code review

Co-authored-by: Jonah Sussman <42743659+JonahSussman@users.noreply.github.com>
Signed-off-by: Shawn Hurley <shawn@hurley.page>

---------

Signed-off-by: Shawn Hurley <shawn@hurley.page>
Co-authored-by: Jonah Sussman <42743659+JonahSussman@users.noreply.github.com>
  • Loading branch information
shawn-hurley and JonahSussman authored Feb 25, 2025
1 parent 4363ded commit e8cd1bd
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 8 deletions.
2 changes: 1 addition & 1 deletion build/build.spec
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ a = Analysis(
pathex=[os.path.dirname(script_path), '../'],
binaries=[],
datas=data_dirs,
hiddenimports=["_ssl", "pydantic.deprecated.decorator"],
hiddenimports=["_ssl", "pydantic.deprecated.decorator", "tiktoken_ext.openai_public", "tiktoken_ext"],
hookspath=[],
runtime_hooks=[],
excludes=[],
Expand Down
55 changes: 55 additions & 0 deletions kai/llm_interfacing/callback/token_output_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any, Optional
from uuid import UUID

import tiktoken
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.outputs.llm_result import LLMResult

from kai.logging.logging import get_logger

LOG = get_logger(__name__)


class TokenOutputCallback(BaseCallbackHandler):

def __init__(self, llm: BaseChatModel):
self.model = llm

def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any
) -> Any:
llm_token_string = "" # trunk-ignore(bandit/B105)
for generation_list in response.generations:
for generation in generation_list:
llm_token_string += generation.text
try:
tokens = self.model.get_num_tokens(llm_token_string)
LOG.info("output tokens: %s", tokens)
except Exception:
enc = tiktoken.get_encoding("cl100k_base")
tokens = len(enc.encode(llm_token_string))
LOG.info("output tokens: %s", tokens)
return

def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any
) -> None:
flat_messages = [item for sublist in messages for item in sublist]
try:
tokens = self.model.get_num_tokens_from_messages(flat_messages)
LOG.info("input tokens: %s", tokens)
except Exception:
# Here we fall back to a default encoding if no model is found.
enc = tiktoken.get_encoding("cl100k_base")
tokens = len(enc.encode(get_buffer_string(flat_messages)))
LOG.info("input tokens: %s", tokens)
21 changes: 14 additions & 7 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessage
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import ConfigurableField, RunnableConfig
from langchain_core.runnables import ConfigurableField, Runnable, RunnableConfig
from langchain_deepseek import ChatDeepSeek
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
Expand All @@ -20,6 +20,7 @@

from kai.cache import Cache, CachePathResolver, SimplePathResolver
from kai.kai_config import KaiConfigModels, SupportedModelProviders
from kai.llm_interfacing.callback.token_output_callback import TokenOutputCallback
from kai.logging.logging import get_logger

LOG = get_logger(__name__)
Expand Down Expand Up @@ -109,7 +110,7 @@ def validate_environment(self) -> None:
def configurable_llm(
self,
configurable_fields: dict[str, Any] | None = None,
) -> BaseChatModel:
) -> Runnable[LanguageModelInput, BaseMessage]:
"""
Some fields can only be configured when the model is instantiated. This
side-steps that by creating a new instance of the model with the configurable
Expand All @@ -119,11 +120,12 @@ def configurable_llm(
result = self.llm.configurable_fields(
**{k: ConfigurableField(id=k) for k in configurable_fields}
).with_config(
configurable_fields # type: ignore[arg-type]
configurable_fields, # type: ignore[arg-type]
callbacks=[TokenOutputCallback(self.llm)],
)
return cast(BaseChatModel, result) # TODO: Check if this cast is ok
return result # TODO: Check if this cast is ok
else:
return self.llm
return self.llm.with_config(callbacks=[TokenOutputCallback(self.llm)])

def invoke_llm(
self,
Expand All @@ -138,6 +140,7 @@ def invoke_llm(
Method to invoke the actual LLM. This can be overridden by subclasses to
provide additional functionality.
"""

return self.configurable_llm(configurable_fields).invoke(
input, config, stop=stop, **kwargs
)
Expand All @@ -150,8 +153,12 @@ def stream_llm(
stop: list[str] | None = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
return self.configurable_llm(configurable_fields).stream(
input, config, stop=stop, **kwargs
# This is the same cast that the base LLM does to enable streaming.
return cast(
Iterator[BaseMessageChunk],
self.configurable_llm(configurable_fields).stream(
input, config, stop=stop, **kwargs
),
)

@tracer.start_as_current_span("invoke_llm")
Expand Down

0 comments on commit e8cd1bd

Please sign in to comment.