Skip to content

Commit

Permalink
Fixup
Browse files Browse the repository at this point in the history
Signed-off-by: JonahSussman <sussmanjonah@gmail.com>
  • Loading branch information
JonahSussman committed Feb 19, 2025
1 parent 87c609c commit 56a844b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
1 change: 1 addition & 0 deletions kai/kai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class SupportedModelProviders(StrEnum):
CHAT_GOOGLE_GENERATIVE_AI = "ChatGoogleGenerativeAI"
AZURE_CHAT_OPENAI = "AzureChatOpenAI"
CHAT_DEEP_SEEK = "ChatDeepSeek"
VLLM_OPENAI = "VLLMOpenAI"


class KaiConfigModels(BaseModel):
Expand Down
23 changes: 12 additions & 11 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import os
from typing import Any, Optional
from typing import Any, Optional, assert_never

from langchain_aws import ChatBedrock
from langchain_community.chat_models.fake import FakeListChatModel
Expand All @@ -17,7 +17,7 @@
from pydantic.v1.utils import deep_update

from kai.cache import Cache, CachePathResolver, SimplePathResolver
from kai.kai_config import KaiConfigModels
from kai.kai_config import KaiConfigModels, SupportedModelProviders
from kai.logging.logging import get_logger

LOG = get_logger(__name__)
Expand All @@ -43,7 +43,7 @@ def __init__(

# Set the model class, model args, and model id based on the provider
match config.provider:
case "ChatOllama":
case SupportedModelProviders.CHAT_OLLAMA:
model_class = ChatOllama

defaults = {
Expand All @@ -56,7 +56,7 @@ def __init__(
model_args = deep_update(defaults, config.args)
model_id = model_args["model"]

case "ChatOpenAI":
case SupportedModelProviders.CHAT_OPENAI:
model_class = ChatOpenAI

defaults = {
Expand Down Expand Up @@ -96,7 +96,7 @@ def _get_request_payload(
if "temperature" in model_args:
del model_args["temperature"]

case "ChatBedrock":
case SupportedModelProviders.CHAT_BEDROCK:
model_class = ChatBedrock

defaults = {
Expand All @@ -106,7 +106,7 @@ def _get_request_payload(
model_args = deep_update(defaults, config.args)
model_id = model_args["model_id"]

case "FakeListChatModel":
case SupportedModelProviders.FAKE_LIST_CHAT_MODEL:
model_class = FakeListChatModel

defaults = {
Expand All @@ -132,7 +132,7 @@ def _get_request_payload(
model_args = deep_update(defaults, config.args)
model_id = "fake-list-chat-model"

case "ChatGoogleGenerativeAI":
case SupportedModelProviders.CHAT_GOOGLE_GENERATIVE_AI:
model_class = ChatGoogleGenerativeAI
api_key = os.getenv("GOOGLE_API_KEY", "dummy_value")
defaults = {
Expand All @@ -144,7 +144,7 @@ def _get_request_payload(
model_args = deep_update(defaults, config.args)
model_id = model_args["model"]

case "AzureChatOpenAI":
case SupportedModelProviders.AZURE_CHAT_OPENAI:
model_class = AzureChatOpenAI

defaults = {
Expand All @@ -159,7 +159,7 @@ def _get_request_payload(
model_args = deep_update(defaults, config.args)
model_id = model_args["azure_deployment"]

case "ChatDeepSeek":
case SupportedModelProviders.CHAT_DEEP_SEEK:
model_class = ChatDeepSeek

defaults = {
Expand All @@ -173,7 +173,7 @@ def _get_request_payload(
model_args = deep_update(defaults, config.args)
model_id = model_args["model"]

case "VLLMOpenAI":
case SupportedModelProviders.VLLM_OPENAI:
model_class = VLLMOpenAI

defaults = {}
Expand All @@ -182,6 +182,7 @@ def _get_request_payload(
model_id = model_args["model_name"]

case _:
assert_never(config.provider)
raise Exception(f"Unrecognized provider '{config.provider}'")

self.provider_id: str = config.provider
Expand Down Expand Up @@ -278,7 +279,7 @@ def invoke(

def response_to_base_message(self, llm_response: Any) -> BaseMessage:
if isinstance(llm_response, str):
return BaseMessage(content=llm_response)
return BaseMessage(content=llm_response, type="ai")
elif isinstance(llm_response, BaseMessage):
return llm_response

Expand Down

0 comments on commit 56a844b

Please sign in to comment.