Skip to content

Commit

Permalink
Merge pull request #193 from kyaukyuai/feat/llm
Browse files Browse the repository at this point in the history
Update Azure OpenAI API version and LLM creation logic
  • Loading branch information
kyaukyuai authored Mar 20, 2024
2 parents b3b29ae + 947d1ec commit be3a150
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 47 deletions.
3 changes: 2 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ OPENAI_API_KEY=<your-openai-api-key>

# USE when ENDPOINT=AZURE
AZURE_OPENAI_API_KEY=<your-azure-opnai-api-key>
AZURE_OPENAI_ENDPOINT=https://<your-azure-openai-endpoint>.openai.azure.com/
AZURE_OPENAI_API_VERSION=2023-07-01-preview
AZURE_OPENAI_DEPLOYMENT_NAME=<your-azure-openai-deployment-name>
AZURE_OPENAI_ENDPOINT=https://<your-azure-openai-endpoint>.openai.azure.com/

# LangSmith
LANGCHAIN_TRACING_V2=true
Expand Down
46 changes: 2 additions & 44 deletions gpt_all_star/core/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,21 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from functools import lru_cache

import openai
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.file_management.toolkit import (
FileManagementToolkit,
)
from langchain.agents.openai_tools.base import create_openai_tools_agent
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.prompts.prompt import PromptTemplate
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from rich.markdown import Markdown
from rich.panel import Panel
from rich.table import Table

from gpt_all_star.cli.console_terminal import ConsoleTerminal
from gpt_all_star.core.llm import LLM_TYPE, create_llm
from gpt_all_star.core.message import Message
from gpt_all_star.core.storage import Storages
from gpt_all_star.core.tools.shell_tool import ShellTool
Expand All @@ -47,7 +44,7 @@ def __init__(
language: str | None = None,
) -> None:
self.console = ConsoleTerminal()
self._llm = _create_llm(os.getenv("OPENAI_API_MODEL_NAME"), 0.1)
self._llm = create_llm(LLM_TYPE[os.getenv("ENDPOINT", default="OPENAI")])

self.role: AgentRole = role
self.name: str = name or self._get_default_profile().name
Expand Down Expand Up @@ -169,45 +166,6 @@ def _create_executor(self, tools: list) -> AgentExecutor:
)


def _create_llm(model_name: str, temperature: float) -> BaseChatModel:
endpoint = os.getenv("ENDPOINT", default="OPENAI")
if endpoint == "AZURE":
return _create_azure_chat_openai_instance(
os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")
)
else:
return _create_chat_openai_instance(model_name, temperature)


def _create_chat_openai_instance(model_name: str, temperature: float):
if model_name not in _get_supported_models():
raise ValueError(f"Model {model_name} not supported")
return ChatOpenAI(
model=model_name,
temperature=temperature,
streaming=True,
client=openai.chat.completions,
)


def _create_azure_chat_openai_instance(model_name: str):
return AzureChatOpenAI(
openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2023-07-01-preview"),
deployment_name=model_name,
streaming=True,
)


def _get_supported_models() -> list[str]:
# cache the models list since it is unlikely to change frequently.
@lru_cache(maxsize=1)
def _fetch_supported_models():
openai.api_type = "openai"
return [model.id for model in openai.models.list()]

return _fetch_supported_models()


class AgentRole(str, Enum):
COPILOT = "copilot"
PRODUCT_OWNER = "product_owner"
Expand Down
5 changes: 3 additions & 2 deletions gpt_all_star/core/agents/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder

from gpt_all_star.core.agents.agent import Agent, _create_llm
from gpt_all_star.core.agents.agent import Agent
from gpt_all_star.core.llm import LLM_TYPE, create_llm

ACTIONS = [
"Execute a command",
Expand All @@ -17,7 +18,7 @@

class Chain:
def __init__(self) -> None:
self._llm = _create_llm(os.getenv("OPENAI_API_MODEL_NAME"), 0.1)
self._llm = create_llm(LLM_TYPE[os.getenv("ENDPOINT", default="OPENAI")])

def create_supervisor_chain(self, members: list[Agent] = []):
members = [member.name for member in members]
Expand Down
99 changes: 99 additions & 0 deletions gpt_all_star/core/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
from enum import Enum

import openai
from langchain_anthropic import ChatAnthropic
from langchain_anthropic.experimental import ChatAnthropicTools
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import AzureChatOpenAI, ChatOpenAI


class LLM_TYPE(str, Enum):
OPENAI = "OPENAI"
AZURE = "AZURE"
ANTHROPIC = "ANTHROPIC"
ANTHROPIC_TOOLS = "ANTHROPIC_TOOLS"


def create_llm(llm_name: LLM_TYPE) -> BaseChatModel:
if llm_name == LLM_TYPE.OPENAI:
return _create_chat_openai(
model_name=os.getenv("OPENAI_API_MODEL", "gpt-4-turbo-preview"),
temperature=0.1,
)
elif llm_name == LLM_TYPE.AZURE:
return _create_azure_chat_openai(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
azure_endpoint=os.getenv(
"AZURE_OPENAI_ENDPOINT", "https://interpreter.openai.azure.com/"
),
openai_api_version=os.getenv(
"AZURE_OPENAI_API_VERSION", "2023-07-01-preview"
),
deployment_name=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-4-32k"),
temperature=0.1,
)
elif llm_name == LLM_TYPE.ANTHROPIC:
return _create_chat_anthropic(
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
model_name=os.getenv("ANTHROPIC_API_MODEL", "claude-3-opus-20240229"),
temperature=0.1,
)
elif llm_name == LLM_TYPE.ANTHROPIC_TOOLS:
return _create_chat_anthropic_tools(
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
model_name=os.getenv("ANTHROPIC_API_MODEL", "claude-3-opus-20240229"),
temperature=0.1,
)
else:
raise ValueError(f"Unsupported LLM type: {llm_name}")


def _create_chat_openai(model_name: str, temperature: float) -> ChatOpenAI:
openai.api_type = "openai"
return ChatOpenAI(
model_name=model_name,
temperature=temperature,
streaming=True,
client=openai.chat.completions,
)


def _create_azure_chat_openai(
api_key: str,
azure_endpoint: str,
openai_api_version: str,
deployment_name: str,
temperature: float,
) -> AzureChatOpenAI:
openai.api_type = "azure"
return AzureChatOpenAI(
api_key=api_key,
azure_endpoint=azure_endpoint,
openai_api_version=openai_api_version,
deployment_name=deployment_name,
temperature=temperature,
streaming=True,
)


def _create_chat_anthropic(
anthropic_api_key: str, model_name: str, temperature: float
) -> ChatAnthropic:
return ChatAnthropic(
anthropic_api_key=anthropic_api_key,
model=model_name,
temperature=temperature,
streaming=True,
)


def _create_chat_anthropic_tools(
anthropic_api_key: str, model_name: str, temperature: float
) -> ChatAnthropicTools:
return ChatAnthropicTools(
anthropic_api_key=anthropic_api_key,
model=model_name,
temperature=temperature,
streaming=True,
)

0 comments on commit be3a150

Please sign in to comment.