diff --git a/kai/kai_config.py b/kai/kai_config.py index 04019392..ca725d9e 100644 --- a/kai/kai_config.py +++ b/kai/kai_config.py @@ -150,11 +150,13 @@ class KaiConfigIncidentStore(BaseModel): class SupportedModelProviders(StrEnum): + CHAT_ANTHROPIC = "ChatAnthropic" CHAT_OLLAMA = "ChatOllama" CHAT_OPENAI = "ChatOpenAI" CHAT_BEDROCK = "ChatBedrock" FAKE_LIST_CHAT_MODEL = "FakeListChatModel" CHAT_GOOGLE_GENERATIVE_AI = "ChatGoogleGenerativeAI" + CHAT_GROQ = "ChatGroq" AZURE_CHAT_OPENAI = "AzureChatOpenAI" CHAT_DEEP_SEEK = "ChatDeepSeek" diff --git a/kai/llm_interfacing/model_provider.py b/kai/llm_interfacing/model_provider.py index 83d4bafa..3d1da4d3 100644 --- a/kai/llm_interfacing/model_provider.py +++ b/kai/llm_interfacing/model_provider.py @@ -3,6 +3,7 @@ import os from typing import Any, Optional +from langchain_anthropic import ChatAnthropic from langchain_aws import ChatBedrock from langchain_community.chat_models.fake import FakeListChatModel from langchain_core.language_models.base import LanguageModelInput @@ -11,6 +12,7 @@ from langchain_core.runnables import ConfigurableField, RunnableConfig from langchain_deepseek import ChatDeepSeek from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_groq import ChatGroq from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI from opentelemetry import trace @@ -173,6 +175,30 @@ def _get_request_payload( model_args = deep_update(defaults, config.args) model_id = model_args["model"] + case "ChatAnthropic": + model_class = ChatAnthropic + + defaults = { + "model": "claude-3-5-sonnet-20241022", + "temperature": 0, + "timeout": None, + "max_retries": 2, + } + + case "ChatGroq": + model_class = ChatGroq + + defaults = { + "model": "mixtral-8x7b-32768", + "temperature": 0, + "timeout": None, + "max_retries": 2, + "max_tokens": 2049, + } + + model_args = deep_update(defaults, config.args) + model_id = model_args["model"] + case _: raise Exception(f"Unrecognized provider '{config.provider}'") @@ -212,6 +238,10 @@ def challenge(k: str) -> BaseMessage: challenge("max_tokens") elif isinstance(self.llm, ChatDeepSeek): challenge("max_tokens") + elif isinstance(self.llm, ChatAnthropic): + challenge("max_tokens") + elif isinstance(self.llm, ChatGroq): + challenge("max_tokens") @tracer.start_as_current_span("invoke_llm") def invoke( diff --git a/kai/rpc_server/server.py b/kai/rpc_server/server.py index 9625143c..17dd856d 100644 --- a/kai/rpc_server/server.py +++ b/kai/rpc_server/server.py @@ -464,7 +464,7 @@ class GetCodeplanAgentSolutionParams(BaseModel): max_depth: Optional[int] = None max_priority: Optional[int] = None - chat_token: str + chat_token: Optional[str] = None class GetCodeplanAgentSolutionResult(BaseModel): diff --git a/logs/.gitignore b/logs/.gitignore deleted file mode 100644 index 6a94ef94..00000000 --- a/logs/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -#Allow us to commit the directory but ignore everything in it. -* -!.gitignore diff --git a/pyproject.toml b/pyproject.toml index e8783ece..c6b1db21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,10 +27,12 @@ dependencies = [ "python-dateutil==2.8.2", "Jinja2==3.1.4", "langchain==0.3.17", + "langchain-anthropic==0.3.7", "langchain-community==0.3.1", "langchain-openai==0.3.3", "langchain-ollama==0.2.3", "langchain-google-genai==2.0.9", + "langchain-groq==0.2.4", "langchain-aws==0.2.11", "langchain-experimental==0.3.2", "langchain-deepseek-official==0.1.0",