diff --git a/kai/kai_config.py b/kai/kai_config.py index 04019392..62d81c86 100644 --- a/kai/kai_config.py +++ b/kai/kai_config.py @@ -157,6 +157,7 @@ class SupportedModelProviders(StrEnum): CHAT_GOOGLE_GENERATIVE_AI = "ChatGoogleGenerativeAI" AZURE_CHAT_OPENAI = "AzureChatOpenAI" CHAT_DEEP_SEEK = "ChatDeepSeek" + VLLM_OPENAI = "VLLMOpenAI" class KaiConfigModels(BaseModel): diff --git a/kai/llm_interfacing/model_provider.py b/kai/llm_interfacing/model_provider.py index d4abe499..376ca23f 100644 --- a/kai/llm_interfacing/model_provider.py +++ b/kai/llm_interfacing/model_provider.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import Any, Optional +from typing import Any, Optional, assert_never from langchain_aws import ChatBedrock from langchain_community.chat_models.fake import FakeListChatModel @@ -17,7 +17,7 @@ from pydantic.v1.utils import deep_update from kai.cache import Cache, CachePathResolver, SimplePathResolver -from kai.kai_config import KaiConfigModels +from kai.kai_config import KaiConfigModels, SupportedModelProviders from kai.logging.logging import get_logger LOG = get_logger(__name__) @@ -43,7 +43,7 @@ def __init__( # Set the model class, model args, and model id based on the provider match config.provider: - case "ChatOllama": + case SupportedModelProviders.CHAT_OLLAMA: model_class = ChatOllama defaults = { @@ -56,7 +56,7 @@ def __init__( model_args = deep_update(defaults, config.args) model_id = model_args["model"] - case "ChatOpenAI": + case SupportedModelProviders.CHAT_OPENAI: model_class = ChatOpenAI defaults = { @@ -96,7 +96,7 @@ def _get_request_payload( if "temperature" in model_args: del model_args["temperature"] - case "ChatBedrock": + case SupportedModelProviders.CHAT_BEDROCK: model_class = ChatBedrock defaults = { @@ -106,7 +106,7 @@ def _get_request_payload( model_args = deep_update(defaults, config.args) model_id = model_args["model_id"] - case "FakeListChatModel": + case SupportedModelProviders.FAKE_LIST_CHAT_MODEL: model_class = FakeListChatModel defaults = { @@ -132,7 +132,7 @@ def _get_request_payload( model_args = deep_update(defaults, config.args) model_id = "fake-list-chat-model" - case "ChatGoogleGenerativeAI": + case SupportedModelProviders.CHAT_GOOGLE_GENERATIVE_AI: model_class = ChatGoogleGenerativeAI api_key = os.getenv("GOOGLE_API_KEY", "dummy_value") defaults = { @@ -144,7 +144,7 @@ def _get_request_payload( model_args = deep_update(defaults, config.args) model_id = model_args["model"] - case "AzureChatOpenAI": + case SupportedModelProviders.AZURE_CHAT_OPENAI: model_class = AzureChatOpenAI defaults = { @@ -159,7 +159,7 @@ def _get_request_payload( model_args = deep_update(defaults, config.args) model_id = model_args["azure_deployment"] - case "ChatDeepSeek": + case SupportedModelProviders.CHAT_DEEP_SEEK: model_class = ChatDeepSeek defaults = { @@ -173,7 +173,7 @@ def _get_request_payload( model_args = deep_update(defaults, config.args) model_id = model_args["model"] - case "VLLMOpenAI": + case SupportedModelProviders.VLLM_OPENAI: model_class = VLLMOpenAI defaults = {} @@ -182,6 +182,7 @@ def _get_request_payload( model_id = model_args["model_name"] case _: + assert_never(config.provider) raise Exception(f"Unrecognized provider '{config.provider}'") self.provider_id: str = config.provider @@ -278,7 +279,7 @@ def invoke( def response_to_base_message(self, llm_response: Any) -> BaseMessage: if isinstance(llm_response, str): - return BaseMessage(content=llm_response) + return BaseMessage(content=llm_response, type="ai") elif isinstance(llm_response, BaseMessage): return llm_response