Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MISC: GroqCloud support & Commands refactor #104

Merged
merged 14 commits into from
Apr 29, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ Hi chatGPT, you are going to pretend to be MEDIC from Team Fortress 2. You can d
## That's all?

If you want to know more here are listed some things that were left unexplained, and some tips and tricks:
[unexplained_explained.md](docs/unexplained_explained.md)
[unexplained_explained.md](docs/unexplained_explained.md) or at project [Wiki](https://github.com/dborodin836/TF2-GPTChatBot/wiki).

## Screenshots

Expand Down
14 changes: 13 additions & 1 deletion config.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
[GENERAL]
TF2_LOGFILE_PATH=H:\Programs\Steam\steamapps\common\Team Fortress 2\tf\console.log
TF2_LOGFILE_PATH=C:\Program Files (x86)\Steam\steamapps\common\Team Fortress 2\tf\console.log
OPENAI_API_KEY=

[GROQ]
GROQ_ENABLE=False
GROQ_API_KEY=
; Available models
; https://console.groq.com/docs/models
GROQ_MODEL=
GROQ_COMMAND=!g
GROQ_CHAT_COMMAND=!gc
GROQ_PRIVATE_CHAT=!gpc
; Example {"max_tokens": 2}
GROQ_SETTINGS=

[COMMANDS]
ENABLE_OPENAI_COMMANDS=True
GPT_COMMAND=!gpt3
Expand Down
21 changes: 20 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ class Config(BaseModel):

CUSTOM_MODEL_SETTINGS: Optional[str | dict]

GROQ_API_KEY: str
GROQ_COMMAND: str
GROQ_CHAT_COMMAND: str
GROQ_PRIVATE_CHAT: str
GROQ_MODEL: str
GROQ_ENABLE: bool
GROQ_SETTINGS: Optional[str | dict]

@validator("OPENAI_API_KEY")
def api_key_pattern_match(cls, v):
if not re.fullmatch(OPENAI_API_KEY_RE_PATTERN, v):
Expand Down Expand Up @@ -158,6 +166,7 @@ def init_config():
for key, value in configparser_config.items(section)
}
global config

try:
if config_dict.get("CUSTOM_MODEL_SETTINGS") != "":
config_dict["CUSTOM_MODEL_SETTINGS"] = json.loads(
Expand All @@ -168,9 +177,19 @@ def init_config():
f"CUSTOM_MODEL_SETTINGS is not dict [{e}].", "BOTH", level="ERROR"
)

try:
if config_dict.get("GROQ_SETTINGS") != "":
config_dict["GROQ_SETTINGS"] = json.loads(
config_dict.get("GROQ_SETTINGS")
)
except Exception as e:
buffered_fail_message(
f"GROQ_SETTINGS is not dict [{e}].", "BOTH", level="ERROR"
)

config = Config(**config_dict)

if not config.ENABLE_OPENAI_COMMANDS and not config.ENABLE_CUSTOM_MODEL:
if not config.ENABLE_OPENAI_COMMANDS and not config.ENABLE_CUSTOM_MODEL and not config.GROQ_ENABLE:
buffered_message("You haven't enabled any AI related commands.")

except (pydantic.ValidationError, Exception) as e:
Expand Down
11 changes: 11 additions & 0 deletions modules/api/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod

from modules.typing import MessageHistory


class LLMProvider(ABC):

@staticmethod
@abstractmethod
def get_completion_text(message_array: MessageHistory, username: str, model: str) -> str:
...
35 changes: 35 additions & 0 deletions modules/api/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import hashlib

import groq

from config import config
from modules.api.base import LLMProvider
from modules.utils.text import remove_hashtags


class GroqCloudLLMProvider(LLMProvider):

@staticmethod
def get_completion_text(message_array, username, model):
client = groq.Groq(
max_retries=0,
api_key=config.GROQ_API_KEY
)

if isinstance(config.GROQ_SETTINGS, dict):
completion = client.chat.completions.create(
model=model,
messages=message_array,
user=hashlib.md5(username.encode()).hexdigest(),
**config.GROQ_SETTINGS
)
else:
completion = client.chat.completions.create(
model=model,
messages=message_array,
user=hashlib.md5(username.encode()).hexdigest(),
)

response_text = completion.choices[0].message.content.strip()
filtered_response = remove_hashtags(response_text)
return filtered_response
146 changes: 19 additions & 127 deletions modules/api/openai.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
import hashlib
import time

import openai

from config import config
from modules.conversation_history import ConversationHistory
from modules.logs import get_logger, log_gui_general_message, log_gui_model_message
from modules.servers.tf2 import send_say_command_to_tf2
from modules.typing import Message, MessageHistory
from modules.utils.text import get_system_message, remove_args, remove_hashtags
from modules.api.base import LLMProvider
from modules.logs import get_logger

main_logger = get_logger("main")
gui_logger = get_logger("gui")


def is_violated_tos(message: str) -> bool:
class OpenAILLMProvider(LLMProvider):

@staticmethod
def get_completion_text(conversation_history, username, model):
openai.api_key = config.OPENAI_API_KEY

completion = openai.ChatCompletion.create(
model=model,
messages=conversation_history,
user=hashlib.md5(username.encode()).hexdigest(),
)

response_text = completion.choices[0].message["content"].strip()
return response_text


def is_flagged(message: str) -> bool:
openai.api_key = config.OPENAI_API_KEY
try:
response = openai.Moderation.create(
Expand All @@ -28,123 +40,3 @@ def is_violated_tos(message: str) -> bool:
return True

return response.results[0]["flagged"]


def send_gpt_completion_request(
conversation_history: MessageHistory, username: str, model: str
) -> str:
openai.api_key = config.OPENAI_API_KEY

completion = openai.ChatCompletion.create(
model=model,
messages=conversation_history,
user=hashlib.md5(username.encode()).hexdigest(),
)

response_text = completion.choices[0].message["content"].strip()
return response_text


def handle_cgpt_request(
username: str,
user_prompt: str,
conversation_history: ConversationHistory,
model,
is_team: bool = False,
) -> ConversationHistory:
"""
This function is called when the user wants to send a message to the AI chatbot. It logs the
user's message, and sends a request to generate a response.
"""
log_gui_model_message(model, username, user_prompt)

user_message = remove_args(user_prompt)
if (
not config.TOS_VIOLATION
and is_violated_tos(user_message)
and config.HOST_USERNAME != username
):
gui_logger.error(f"Request '{user_prompt}' violates OPENAI TOS. Skipping...")
return conversation_history

conversation_history.add_user_message_from_prompt(user_prompt)

response = get_response(conversation_history.get_messages_array(), username, model)

if response:
conversation_history.add_assistant_message(Message(role="assistant", content=response))
log_gui_model_message(model, username, " ".join(response.split()))
send_say_command_to_tf2(response, username, is_team)

return conversation_history


def handle_gpt_request(
username: str, user_prompt: str, model: str, is_team_chat: bool = False
) -> None:
"""
This function is called when the user wants to send a message to the AI chatbot. It logs the
user's message, and sends a request to GPT-3 to generate a response. Finally, the function
sends the generated response to the TF2 game.
"""
log_gui_model_message(model, username, user_prompt)

user_message = remove_args(user_prompt)
sys_message = get_system_message(user_prompt)

if (
not config.TOS_VIOLATION
and is_violated_tos(user_message)
and config.HOST_USERNAME != username
):
gui_logger.warning(
f"Request '{user_prompt}' by user {username} violates OPENAI TOS. Skipping..."
)
return

payload = [
sys_message,
Message(role="assistant", content=config.GREETING),
Message(role="user", content=user_message),
]

response = get_response(payload, username, model)

if response:
main_logger.info(
f"Got response for user {username}. Response: {' '.join(response.split())}"
)
log_gui_model_message(model, username, " ".join(response.split()))
send_say_command_to_tf2(response, username, is_team_chat)


def get_response(conversation_history: MessageHistory, username: str, model) -> str | None:
attempts = 0
max_attempts = 2

while attempts < max_attempts:
try:
response = send_gpt_completion_request(conversation_history, username, model=model)
filtered_response = remove_hashtags(response)
return filtered_response
except openai.error.RateLimitError:
log_gui_general_message("Rate limited! Trying again...")
main_logger(f"User is rate limited.")
time.sleep(2)
attempts += 1
except openai.error.APIError as e:
log_gui_general_message(f"Wasn't able to connect to OpenAI API. Cancelling...")
main_logger.error(f"APIError happened. [{e}]")
return
except openai.error.AuthenticationError:
log_gui_general_message("Your OpenAI api key is invalid.")
main_logger.error("OpenAI API key is invalid.")
return
except Exception as e:
log_gui_general_message(f"Unhandled error happened! Cancelling ({e})")
main_logger.error(f"Unhandled error happened! Cancelling ({e})")
return

if attempts == max_attempts:
log_gui_general_message("Max number of attempts reached! Try again later!")
main_logger(f"Max number of attempts reached. [{max_attempts}/{max_attempts}]")
38 changes: 12 additions & 26 deletions modules/api/textgen_webui.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,25 @@
import requests

from config import config
from modules.api.base import LLMProvider
from modules.logs import get_logger
from modules.typing import Message

main_logger = get_logger("main")
combo_logger = get_logger("combo")


def get_custom_model_response(conversation_history: list[Message]) -> str | None:
uri = f"http://{config.CUSTOM_MODEL_HOST}/v1/chat/completions"
class TextGenerationWebUILLMProvider(LLMProvider):

headers = {"Content-Type": "application/json"}
@staticmethod
def get_completion_text(conversation_history, username, model):
uri = f"http://{config.CUSTOM_MODEL_HOST}/v1/chat/completions"
headers = {"Content-Type": "application/json"}

data = {"mode": "chat", "messages": conversation_history}
data = {"mode": "chat", "messages": conversation_history}
data.update(config.CUSTOM_MODEL_SETTINGS)

data.update(config.CUSTOM_MODEL_SETTINGS)

try:
response = requests.post(uri, headers=headers, json=data, verify=False)
except Exception as e:
combo_logger.error(f"Failed to get response from the text-generation-webui server. [{e}]")
return

if response.status_code == 200:
try:
data = response.json()["choices"][0]["message"]["content"]
return data
except Exception as e:
combo_logger.error(f"Failed to parse data from server [{e}].")
elif response.status_code == 500:
combo_logger.error(f"There's error on the text-generation-webui server. [HTTP 500]")
else:
main_logger.error(
f"Got non-200 status code from the text-generation-webui server. [HTTP {response.status_code}]"
)

return None
if response.status_code == 500:
raise Exception('HTTP 500')
data = response.json()["choices"][0]["message"]["content"]
return data
32 changes: 20 additions & 12 deletions modules/chat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from config import config
from config import config, RTDModes
from modules.api.github import check_for_updates
from modules.bans import bans_manager
from modules.bot_state import state_manager
from modules.command_controllers import CommandController, InitializerConfig
from modules.commands.clear_chat import handle_clear
from modules.commands.github import handle_gh_command
from modules.commands.openai import handle_user_chat, handle_gpt3, handle_gpt4, handle_gpt4l, handle_global_chat
from modules.commands.groq import GroqQuickQueryCommand, GroqGlobalChatCommand, GroqPrivateChatCommand
from modules.commands.openai import OpenAIGlobalChatCommand, OpenAIPrivateChatCommand, OpenAIGPT3QuickQueryCommand, \
OpenAIGPT4QuickQueryCommand, OpenAIGPT4LQuickQueryCommand
from modules.commands.rtd import handle_rtd
from modules.commands.textgen_webui import handle_custom_user_chat, handle_custom_model, handle_custom_global_chat
from modules.commands.textgen_webui import TextgenWebUIGlobalChatCommand, TextgenWebUIPrivateChatCommand, \
TextgenWebUIQuickQueryCommand
from modules.logs import get_logger
from modules.message_queueing import messaging_queue_service
from modules.servers.tf2 import check_connection, set_host_username
Expand Down Expand Up @@ -53,18 +56,23 @@ def parse_console_logs_and_build_conversation_history() -> None:

# Commands
controller.register_command("!gh", handle_gh_command)
controller.register_command(config.RTD_COMMAND, handle_rtd)
controller.register_command(config.CLEAR_CHAT_COMMAND, handle_clear)
if config.RTD_MODE != RTDModes.DISABLED:
controller.register_command(config.RTD_COMMAND, handle_rtd)
if config.ENABLE_OPENAI_COMMANDS:
controller.register_command(config.GPT4_COMMAND, handle_gpt4)
controller.register_command(config.GPT4_LEGACY_COMMAND, handle_gpt4l)
controller.register_command(config.CHATGPT_COMMAND, handle_user_chat)
controller.register_command(config.GLOBAL_CHAT_COMMAND, handle_global_chat)
controller.register_command(config.GPT_COMMAND, handle_gpt3)
controller.register_command(config.GPT4_COMMAND, OpenAIGPT4QuickQueryCommand.as_command())
controller.register_command(config.GPT4_LEGACY_COMMAND, OpenAIGPT4LQuickQueryCommand.as_command())
controller.register_command(config.CHATGPT_COMMAND, OpenAIPrivateChatCommand.as_command())
controller.register_command(config.GLOBAL_CHAT_COMMAND, OpenAIGlobalChatCommand.as_command())
controller.register_command(config.GPT_COMMAND, OpenAIGPT3QuickQueryCommand.as_command())
if config.ENABLE_CUSTOM_MODEL:
controller.register_command(config.CUSTOM_MODEL_COMMAND, handle_custom_model)
controller.register_command(config.CUSTOM_MODEL_CHAT_COMMAND, handle_custom_user_chat)
controller.register_command(config.GLOBAL_CUSTOM_CHAT_COMMAND, handle_custom_global_chat)
controller.register_command(config.CUSTOM_MODEL_COMMAND, TextgenWebUIQuickQueryCommand.as_command())
controller.register_command(config.CUSTOM_MODEL_CHAT_COMMAND, TextgenWebUIPrivateChatCommand.as_command())
controller.register_command(config.GLOBAL_CUSTOM_CHAT_COMMAND, TextgenWebUIGlobalChatCommand.as_command())
if config.GROQ_ENABLE:
controller.register_command(config.GROQ_COMMAND, GroqQuickQueryCommand.as_command())
controller.register_command(config.GROQ_CHAT_COMMAND, GroqGlobalChatCommand.as_command())
controller.register_command(config.GROQ_PRIVATE_CHAT, GroqPrivateChatCommand.as_command())

# Services
controller.register_service(messaging_queue_service)
Expand Down
Loading
Loading