Skip to content

Commit

Permalink
Added Anthropic and Groq support
Browse files Browse the repository at this point in the history
Signed-off-by: devjpt23 <devpatel232408@gmail.com>
  • Loading branch information
devjpt23 committed Feb 24, 2025
1 parent f5814b5 commit 432c8ca
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 4 deletions.
2 changes: 2 additions & 0 deletions kai/kai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,13 @@ class KaiConfigIncidentStore(BaseModel):


class SupportedModelProviders(StrEnum):
CHAT_ANTHROPIC = "ChatAnthropic"
CHAT_OLLAMA = "ChatOllama"
CHAT_OPENAI = "ChatOpenAI"
CHAT_BEDROCK = "ChatBedrock"
FAKE_LIST_CHAT_MODEL = "FakeListChatModel"
CHAT_GOOGLE_GENERATIVE_AI = "ChatGoogleGenerativeAI"
CHAT_GROQ = "ChatGroq"

Check failure on line 159 in kai/kai_config.py

View workflow job for this annotation

GitHub Actions / Trunk Check

cspell(error)

[new] Unknown word (GROQ) Suggestions: [GROG, GROK, GROS, GROT, GROW]

Check failure on line 159 in kai/kai_config.py

View workflow job for this annotation

GitHub Actions / Trunk Check

cspell(error)

[new] Unknown word (Groq) Suggestions: [Grog, Grok, gros, Gros, Grot]
AZURE_CHAT_OPENAI = "AzureChatOpenAI"
CHAT_DEEP_SEEK = "ChatDeepSeek"

Expand Down
30 changes: 30 additions & 0 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from typing import Any, Optional

from langchain_anthropic import ChatAnthropic
from langchain_aws import ChatBedrock
from langchain_community.chat_models.fake import FakeListChatModel
from langchain_core.language_models.base import LanguageModelInput
Expand All @@ -11,6 +12,7 @@
from langchain_core.runnables import ConfigurableField, RunnableConfig
from langchain_deepseek import ChatDeepSeek
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq

Check failure on line 15 in kai/llm_interfacing/model_provider.py

View workflow job for this annotation

GitHub Actions / Trunk Check

cspell(error)

[new] Unknown word (Groq) Suggestions: [Grog, Grok, gros, Gros, Grot]

Check failure on line 15 in kai/llm_interfacing/model_provider.py

View workflow job for this annotation

GitHub Actions / Trunk Check

cspell(error)

[new] Unknown word (groq) Suggestions: [grog, grok, gros, grot, grow]
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from opentelemetry import trace
Expand Down Expand Up @@ -173,6 +175,30 @@ def _get_request_payload(
model_args = deep_update(defaults, config.args)
model_id = model_args["model"]

case "ChatAnthropic":
model_class = ChatAnthropic

defaults = {
"model": "claude-3-5-sonnet-20241022",
"temperature": 0,
"timeout": None,
"max_retries": 2,
}

case "ChatGroq":

Check failure on line 188 in kai/llm_interfacing/model_provider.py

View workflow job for this annotation

GitHub Actions / Trunk Check

cspell(error)

[new] Unknown word (Groq) Suggestions: [Grog, Grok, gros, Gros, Grot]
model_class = ChatGroq

Check failure on line 189 in kai/llm_interfacing/model_provider.py

View workflow job for this annotation

GitHub Actions / Trunk Check

cspell(error)

[new] Unknown word (Groq) Suggestions: [Grog, Grok, gros, Gros, Grot]

defaults = {
"model": "mixtral-8x7b-32768",
"temperature": 0,
"timeout": None,
"max_retries": 2,
"max_tokens": 2049,
}

model_args = deep_update(defaults, config.args)
model_id = model_args["model"]

case _:
raise Exception(f"Unrecognized provider '{config.provider}'")

Expand Down Expand Up @@ -212,6 +238,10 @@ def challenge(k: str) -> BaseMessage:
challenge("max_tokens")
elif isinstance(self.llm, ChatDeepSeek):
challenge("max_tokens")
elif isinstance(self.llm, ChatAnthropic):
challenge("max_tokens")
elif isinstance(self.llm, ChatGroq):

Check failure on line 243 in kai/llm_interfacing/model_provider.py

View workflow job for this annotation

GitHub Actions / Trunk Check

cspell(error)

[new] Unknown word (Groq) Suggestions: [Grog, Grok, gros, Gros, Grot]
challenge("max_tokens")

@tracer.start_as_current_span("invoke_llm")
def invoke(
Expand Down
2 changes: 1 addition & 1 deletion kai/rpc_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ class GetCodeplanAgentSolutionParams(BaseModel):
max_depth: Optional[int] = None
max_priority: Optional[int] = None

chat_token: str
chat_token: Optional[str] = None


class GetCodeplanAgentSolutionResult(BaseModel):
Expand Down
3 changes: 0 additions & 3 deletions logs/.gitignore

This file was deleted.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ dependencies = [
"python-dateutil==2.8.2",
"Jinja2==3.1.4",
"langchain==0.3.17",
"langchain-anthropic==0.3.7",
"langchain-community==0.3.1",
"langchain-openai==0.3.3",
"langchain-ollama==0.2.3",
"langchain-google-genai==2.0.9",
"langchain-groq==0.2.4",
"langchain-aws==0.2.11",
"langchain-experimental==0.3.2",
"langchain-deepseek-official==0.1.0",
Expand Down

0 comments on commit 432c8ca

Please sign in to comment.