diff --git a/README.md b/README.md index e3683a5..d3e6575 100644 --- a/README.md +++ b/README.md @@ -387,7 +387,7 @@ Hi chatGPT, you are going to pretend to be MEDIC from Team Fortress 2. You can d ## That's all? If you want to know more here are listed some things that were left unexplained, and some tips and tricks: -[unexplained_explained.md](docs/unexplained_explained.md) +[unexplained_explained.md](docs/unexplained_explained.md) or at project [Wiki](https://github.com/dborodin836/TF2-GPTChatBot/wiki). ## Screenshots diff --git a/config.ini b/config.ini index c924f6c..2ad3502 100644 --- a/config.ini +++ b/config.ini @@ -1,7 +1,19 @@ [GENERAL] -TF2_LOGFILE_PATH=H:\Programs\Steam\steamapps\common\Team Fortress 2\tf\console.log +TF2_LOGFILE_PATH=C:\Program Files (x86)\Steam\steamapps\common\Team Fortress 2\tf\console.log OPENAI_API_KEY= +[GROQ] +GROQ_ENABLE=False +GROQ_API_KEY= +; Available models +; https://console.groq.com/docs/models +GROQ_MODEL= +GROQ_COMMAND=!g +GROQ_CHAT_COMMAND=!gc +GROQ_PRIVATE_CHAT=!gpc +; Example {"max_tokens": 2} +GROQ_SETTINGS= + [COMMANDS] ENABLE_OPENAI_COMMANDS=True GPT_COMMAND=!gpt3 diff --git a/config.py b/config.py index 573be6c..c9ce427 100644 --- a/config.py +++ b/config.py @@ -85,6 +85,14 @@ class Config(BaseModel): CUSTOM_MODEL_SETTINGS: Optional[str | dict] + GROQ_API_KEY: str + GROQ_COMMAND: str + GROQ_CHAT_COMMAND: str + GROQ_PRIVATE_CHAT: str + GROQ_MODEL: str + GROQ_ENABLE: bool + GROQ_SETTINGS: Optional[str | dict] + @validator("OPENAI_API_KEY") def api_key_pattern_match(cls, v): if not re.fullmatch(OPENAI_API_KEY_RE_PATTERN, v): @@ -158,6 +166,7 @@ def init_config(): for key, value in configparser_config.items(section) } global config + try: if config_dict.get("CUSTOM_MODEL_SETTINGS") != "": config_dict["CUSTOM_MODEL_SETTINGS"] = json.loads( @@ -168,9 +177,19 @@ def init_config(): f"CUSTOM_MODEL_SETTINGS is not dict [{e}].", "BOTH", level="ERROR" ) + try: + if config_dict.get("GROQ_SETTINGS") != "": + config_dict["GROQ_SETTINGS"] = json.loads( + config_dict.get("GROQ_SETTINGS") + ) + except Exception as e: + buffered_fail_message( + f"GROQ_SETTINGS is not dict [{e}].", "BOTH", level="ERROR" + ) + config = Config(**config_dict) - if not config.ENABLE_OPENAI_COMMANDS and not config.ENABLE_CUSTOM_MODEL: + if not config.ENABLE_OPENAI_COMMANDS and not config.ENABLE_CUSTOM_MODEL and not config.GROQ_ENABLE: buffered_message("You haven't enabled any AI related commands.") except (pydantic.ValidationError, Exception) as e: diff --git a/modules/api/base.py b/modules/api/base.py new file mode 100644 index 0000000..9452361 --- /dev/null +++ b/modules/api/base.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + +from modules.typing import MessageHistory + + +class LLMProvider(ABC): + + @staticmethod + @abstractmethod + def get_completion_text(message_array: MessageHistory, username: str, model: str) -> str: + ... diff --git a/modules/api/groq.py b/modules/api/groq.py new file mode 100644 index 0000000..a76e2d9 --- /dev/null +++ b/modules/api/groq.py @@ -0,0 +1,35 @@ +import hashlib + +import groq + +from config import config +from modules.api.base import LLMProvider +from modules.utils.text import remove_hashtags + + +class GroqCloudLLMProvider(LLMProvider): + + @staticmethod + def get_completion_text(message_array, username, model): + client = groq.Groq( + max_retries=0, + api_key=config.GROQ_API_KEY + ) + + if isinstance(config.GROQ_SETTINGS, dict): + completion = client.chat.completions.create( + model=model, + messages=message_array, + user=hashlib.md5(username.encode()).hexdigest(), + **config.GROQ_SETTINGS + ) + else: + completion = client.chat.completions.create( + model=model, + messages=message_array, + user=hashlib.md5(username.encode()).hexdigest(), + ) + + response_text = completion.choices[0].message.content.strip() + filtered_response = remove_hashtags(response_text) + return filtered_response diff --git a/modules/api/openai.py b/modules/api/openai.py index 04049dc..683fed9 100644 --- a/modules/api/openai.py +++ b/modules/api/openai.py @@ -1,20 +1,32 @@ import hashlib -import time import openai from config import config -from modules.conversation_history import ConversationHistory -from modules.logs import get_logger, log_gui_general_message, log_gui_model_message -from modules.servers.tf2 import send_say_command_to_tf2 -from modules.typing import Message, MessageHistory -from modules.utils.text import get_system_message, remove_args, remove_hashtags +from modules.api.base import LLMProvider +from modules.logs import get_logger main_logger = get_logger("main") gui_logger = get_logger("gui") -def is_violated_tos(message: str) -> bool: +class OpenAILLMProvider(LLMProvider): + + @staticmethod + def get_completion_text(conversation_history, username, model): + openai.api_key = config.OPENAI_API_KEY + + completion = openai.ChatCompletion.create( + model=model, + messages=conversation_history, + user=hashlib.md5(username.encode()).hexdigest(), + ) + + response_text = completion.choices[0].message["content"].strip() + return response_text + + +def is_flagged(message: str) -> bool: openai.api_key = config.OPENAI_API_KEY try: response = openai.Moderation.create( @@ -28,123 +40,3 @@ def is_violated_tos(message: str) -> bool: return True return response.results[0]["flagged"] - - -def send_gpt_completion_request( - conversation_history: MessageHistory, username: str, model: str -) -> str: - openai.api_key = config.OPENAI_API_KEY - - completion = openai.ChatCompletion.create( - model=model, - messages=conversation_history, - user=hashlib.md5(username.encode()).hexdigest(), - ) - - response_text = completion.choices[0].message["content"].strip() - return response_text - - -def handle_cgpt_request( - username: str, - user_prompt: str, - conversation_history: ConversationHistory, - model, - is_team: bool = False, -) -> ConversationHistory: - """ - This function is called when the user wants to send a message to the AI chatbot. It logs the - user's message, and sends a request to generate a response. - """ - log_gui_model_message(model, username, user_prompt) - - user_message = remove_args(user_prompt) - if ( - not config.TOS_VIOLATION - and is_violated_tos(user_message) - and config.HOST_USERNAME != username - ): - gui_logger.error(f"Request '{user_prompt}' violates OPENAI TOS. Skipping...") - return conversation_history - - conversation_history.add_user_message_from_prompt(user_prompt) - - response = get_response(conversation_history.get_messages_array(), username, model) - - if response: - conversation_history.add_assistant_message(Message(role="assistant", content=response)) - log_gui_model_message(model, username, " ".join(response.split())) - send_say_command_to_tf2(response, username, is_team) - - return conversation_history - - -def handle_gpt_request( - username: str, user_prompt: str, model: str, is_team_chat: bool = False -) -> None: - """ - This function is called when the user wants to send a message to the AI chatbot. It logs the - user's message, and sends a request to GPT-3 to generate a response. Finally, the function - sends the generated response to the TF2 game. - """ - log_gui_model_message(model, username, user_prompt) - - user_message = remove_args(user_prompt) - sys_message = get_system_message(user_prompt) - - if ( - not config.TOS_VIOLATION - and is_violated_tos(user_message) - and config.HOST_USERNAME != username - ): - gui_logger.warning( - f"Request '{user_prompt}' by user {username} violates OPENAI TOS. Skipping..." - ) - return - - payload = [ - sys_message, - Message(role="assistant", content=config.GREETING), - Message(role="user", content=user_message), - ] - - response = get_response(payload, username, model) - - if response: - main_logger.info( - f"Got response for user {username}. Response: {' '.join(response.split())}" - ) - log_gui_model_message(model, username, " ".join(response.split())) - send_say_command_to_tf2(response, username, is_team_chat) - - -def get_response(conversation_history: MessageHistory, username: str, model) -> str | None: - attempts = 0 - max_attempts = 2 - - while attempts < max_attempts: - try: - response = send_gpt_completion_request(conversation_history, username, model=model) - filtered_response = remove_hashtags(response) - return filtered_response - except openai.error.RateLimitError: - log_gui_general_message("Rate limited! Trying again...") - main_logger(f"User is rate limited.") - time.sleep(2) - attempts += 1 - except openai.error.APIError as e: - log_gui_general_message(f"Wasn't able to connect to OpenAI API. Cancelling...") - main_logger.error(f"APIError happened. [{e}]") - return - except openai.error.AuthenticationError: - log_gui_general_message("Your OpenAI api key is invalid.") - main_logger.error("OpenAI API key is invalid.") - return - except Exception as e: - log_gui_general_message(f"Unhandled error happened! Cancelling ({e})") - main_logger.error(f"Unhandled error happened! Cancelling ({e})") - return - - if attempts == max_attempts: - log_gui_general_message("Max number of attempts reached! Try again later!") - main_logger(f"Max number of attempts reached. [{max_attempts}/{max_attempts}]") diff --git a/modules/api/textgen_webui.py b/modules/api/textgen_webui.py index 0422062..200f905 100644 --- a/modules/api/textgen_webui.py +++ b/modules/api/textgen_webui.py @@ -1,39 +1,25 @@ import requests from config import config +from modules.api.base import LLMProvider from modules.logs import get_logger -from modules.typing import Message main_logger = get_logger("main") combo_logger = get_logger("combo") -def get_custom_model_response(conversation_history: list[Message]) -> str | None: - uri = f"http://{config.CUSTOM_MODEL_HOST}/v1/chat/completions" +class TextGenerationWebUILLMProvider(LLMProvider): - headers = {"Content-Type": "application/json"} + @staticmethod + def get_completion_text(conversation_history, username, model): + uri = f"http://{config.CUSTOM_MODEL_HOST}/v1/chat/completions" + headers = {"Content-Type": "application/json"} - data = {"mode": "chat", "messages": conversation_history} + data = {"mode": "chat", "messages": conversation_history} + data.update(config.CUSTOM_MODEL_SETTINGS) - data.update(config.CUSTOM_MODEL_SETTINGS) - - try: response = requests.post(uri, headers=headers, json=data, verify=False) - except Exception as e: - combo_logger.error(f"Failed to get response from the text-generation-webui server. [{e}]") - return - - if response.status_code == 200: - try: - data = response.json()["choices"][0]["message"]["content"] - return data - except Exception as e: - combo_logger.error(f"Failed to parse data from server [{e}].") - elif response.status_code == 500: - combo_logger.error(f"There's error on the text-generation-webui server. [HTTP 500]") - else: - main_logger.error( - f"Got non-200 status code from the text-generation-webui server. [HTTP {response.status_code}]" - ) - - return None + if response.status_code == 500: + raise Exception('HTTP 500') + data = response.json()["choices"][0]["message"]["content"] + return data diff --git a/modules/chat.py b/modules/chat.py index 7cf0cfb..43a7ad7 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -1,13 +1,16 @@ -from config import config +from config import config, RTDModes from modules.api.github import check_for_updates from modules.bans import bans_manager from modules.bot_state import state_manager from modules.command_controllers import CommandController, InitializerConfig from modules.commands.clear_chat import handle_clear from modules.commands.github import handle_gh_command -from modules.commands.openai import handle_user_chat, handle_gpt3, handle_gpt4, handle_gpt4l, handle_global_chat +from modules.commands.groq import GroqQuickQueryCommand, GroqGlobalChatCommand, GroqPrivateChatCommand +from modules.commands.openai import OpenAIGlobalChatCommand, OpenAIPrivateChatCommand, OpenAIGPT3QuickQueryCommand, \ + OpenAIGPT4QuickQueryCommand, OpenAIGPT4LQuickQueryCommand from modules.commands.rtd import handle_rtd -from modules.commands.textgen_webui import handle_custom_user_chat, handle_custom_model, handle_custom_global_chat +from modules.commands.textgen_webui import TextgenWebUIGlobalChatCommand, TextgenWebUIPrivateChatCommand, \ + TextgenWebUIQuickQueryCommand from modules.logs import get_logger from modules.message_queueing import messaging_queue_service from modules.servers.tf2 import check_connection, set_host_username @@ -53,18 +56,23 @@ def parse_console_logs_and_build_conversation_history() -> None: # Commands controller.register_command("!gh", handle_gh_command) - controller.register_command(config.RTD_COMMAND, handle_rtd) controller.register_command(config.CLEAR_CHAT_COMMAND, handle_clear) + if config.RTD_MODE != RTDModes.DISABLED: + controller.register_command(config.RTD_COMMAND, handle_rtd) if config.ENABLE_OPENAI_COMMANDS: - controller.register_command(config.GPT4_COMMAND, handle_gpt4) - controller.register_command(config.GPT4_LEGACY_COMMAND, handle_gpt4l) - controller.register_command(config.CHATGPT_COMMAND, handle_user_chat) - controller.register_command(config.GLOBAL_CHAT_COMMAND, handle_global_chat) - controller.register_command(config.GPT_COMMAND, handle_gpt3) + controller.register_command(config.GPT4_COMMAND, OpenAIGPT4QuickQueryCommand.as_command()) + controller.register_command(config.GPT4_LEGACY_COMMAND, OpenAIGPT4LQuickQueryCommand.as_command()) + controller.register_command(config.CHATGPT_COMMAND, OpenAIPrivateChatCommand.as_command()) + controller.register_command(config.GLOBAL_CHAT_COMMAND, OpenAIGlobalChatCommand.as_command()) + controller.register_command(config.GPT_COMMAND, OpenAIGPT3QuickQueryCommand.as_command()) if config.ENABLE_CUSTOM_MODEL: - controller.register_command(config.CUSTOM_MODEL_COMMAND, handle_custom_model) - controller.register_command(config.CUSTOM_MODEL_CHAT_COMMAND, handle_custom_user_chat) - controller.register_command(config.GLOBAL_CUSTOM_CHAT_COMMAND, handle_custom_global_chat) + controller.register_command(config.CUSTOM_MODEL_COMMAND, TextgenWebUIQuickQueryCommand.as_command()) + controller.register_command(config.CUSTOM_MODEL_CHAT_COMMAND, TextgenWebUIPrivateChatCommand.as_command()) + controller.register_command(config.GLOBAL_CUSTOM_CHAT_COMMAND, TextgenWebUIGlobalChatCommand.as_command()) + if config.GROQ_ENABLE: + controller.register_command(config.GROQ_COMMAND, GroqQuickQueryCommand.as_command()) + controller.register_command(config.GROQ_CHAT_COMMAND, GroqGlobalChatCommand.as_command()) + controller.register_command(config.GROQ_PRIVATE_CHAT, GroqPrivateChatCommand.as_command()) # Services controller.register_service(messaging_queue_service) diff --git a/modules/command_controllers.py b/modules/command_controllers.py index 6268da1..66b795f 100644 --- a/modules/command_controllers.py +++ b/modules/command_controllers.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, BaseConfig from modules.conversation_history import ConversationHistory -from modules.logs import get_logger +from modules.logs import get_logger, log_gui_model_message from modules.set_once_dict import SetOnceDictionary from modules.typing import Command, LogLine, Player @@ -83,7 +83,7 @@ def help(self, command: str, shared_dict: dict): class CommandController: def __init__(self, initializer_config: InitializerConfig = None) -> None: self.__services = OrderedSet() - self.__named_commands_registry: SetOnceDictionary[str, Callable] = SetOnceDictionary() + self.__named_commands_registry: SetOnceDictionary[str, Callable[[LogLine, InitializerConfig], Optional[str]]] = SetOnceDictionary() self.__shared = InitializerConfig() if initializer_config is not None: @@ -105,4 +105,14 @@ def process_line(self, logline: LogLine): if handler is None: return - handler(logline, self.__shared) + cleaned_prompt = logline.prompt.removeprefix(command_name).strip() + + logline = LogLine(cleaned_prompt, logline.username, logline.is_team_message, logline.player) + + log_gui_model_message(command_name.upper(), logline.username, logline.prompt) + try: + result = handler(logline, self.__shared) + if result: + log_gui_model_message(command_name.upper(), logline.username, result) + except Exception as e: + log_gui_model_message(command_name.upper(), logline.username, f"Error occurred: [{e}]") diff --git a/modules/commands/base.py b/modules/commands/base.py new file mode 100644 index 0000000..bee09ea --- /dev/null +++ b/modules/commands/base.py @@ -0,0 +1,89 @@ +from abc import ABC, abstractmethod +from typing import Callable, List, Optional + +from modules.api.base import LLMProvider +from modules.command_controllers import InitializerConfig +from modules.conversation_history import ConversationHistory +from modules.servers.tf2 import send_say_command_to_tf2 +from modules.typing import LogLine, Message +from modules.utils.text import remove_args + + +class BaseLLMCommand(ABC): + provider: LLMProvider = None + model: str = None + wrappers: List[Callable] = [] + + @classmethod + @abstractmethod + def get_handler(cls): + ... + + @classmethod + def as_command(cls) -> Callable[[LogLine, InitializerConfig], None]: + func = cls.get_handler() + + for decorator in cls.wrappers: + func = decorator(func) + + return func + + +class QuickQueryLLMCommand(BaseLLMCommand): + + @classmethod + def get_handler(cls) -> Callable[[LogLine, InitializerConfig], None]: + def func(logline: LogLine, shared_dict: InitializerConfig) -> Optional[str]: + tmp_chat_history = ConversationHistory() + + user_message = remove_args(logline.prompt) + tmp_chat_history.add_user_message_from_prompt(user_message) + + response = cls.provider.get_completion_text(tmp_chat_history.get_messages_array(), logline.username, + cls.model) + if response: + tmp_chat_history.add_assistant_message(Message(role="assistant", content=response)) + send_say_command_to_tf2(response, logline.username, logline.is_team_message) + return " ".join(response.split()) + + return func + + +class GlobalChatLLMCommand(BaseLLMCommand): + + @classmethod + def get_handler(cls) -> Callable[[LogLine, InitializerConfig], None]: + def func(logline: LogLine, shared_dict: InitializerConfig) -> Optional[str]: + chat_history = shared_dict.CHAT_CONVERSATION_HISTORY.GLOBAL + + user_message = remove_args(logline.prompt) + chat_history.add_user_message_from_prompt(user_message) + + response = cls.provider.get_completion_text(chat_history.get_messages_array(), logline.username, + cls.model) + if response: + chat_history.add_assistant_message(Message(role="assistant", content=response)) + send_say_command_to_tf2(response, logline.username, logline.is_team_message) + return " ".join(response.split()) + + return func + + +class PrivateChatLLMCommand(BaseLLMCommand): + + @classmethod + def get_handler(cls) -> Callable[[LogLine, InitializerConfig], None]: + def func(logline: LogLine, shared_dict: InitializerConfig) -> Optional[str]: + chat_history = shared_dict.CHAT_CONVERSATION_HISTORY.get_conversation_history(logline.player) + + user_message = remove_args(logline.prompt) + chat_history.add_user_message_from_prompt(user_message) + + response = cls.provider.get_completion_text(chat_history.get_messages_array(), logline.username, + cls.model) + if response: + chat_history.add_assistant_message(Message(role="assistant", content=response)) + send_say_command_to_tf2(response, logline.username, logline.is_team_message) + return " ".join(response.split()) + + return func diff --git a/modules/commands/clear_chat.py b/modules/commands/clear_chat.py index 1cddb94..9b59d2e 100644 --- a/modules/commands/clear_chat.py +++ b/modules/commands/clear_chat.py @@ -9,10 +9,12 @@ main_logger = get_logger("main") combo_logger = get_logger("combo") +CLEAR_WRONG_SYNTAX_MSG = r'Wrong syntax! e.g. !clear \global \user="username"' + def handle_clear(logline: LogLine, shared_dict: InitializerConfig): if is_admin(logline.player): - args = get_args(logline.prompt.removeprefix(config.CLEAR_CHAT_COMMAND).strip()) + args = get_args(logline.prompt) if len(args) == 0: combo_logger.info(f"Clearing chat history for user '{logline.username}'.") @@ -30,7 +32,7 @@ def handle_clear(logline: LogLine, shared_dict: InitializerConfig): try: parts = arg.split("=") if len(parts) != 2: - combo_logger.error(r'Wrong syntax! e.g. !clear \global \user="username"') + combo_logger.error(CLEAR_WRONG_SYNTAX_MSG) continue name: str @@ -38,11 +40,11 @@ def handle_clear(logline: LogLine, shared_dict: InitializerConfig): arg, name = parts if arg != r"\user": - combo_logger.error(r'Wrong syntax! e.g. !clear \global \user="username"') + combo_logger.error(CLEAR_WRONG_SYNTAX_MSG) continue if not (name.startswith("'") and name.endswith("'")): - combo_logger.error(r'Wrong syntax! e.g. !clear \global \user="username"') + combo_logger.error(CLEAR_WRONG_SYNTAX_MSG) continue name = name.removeprefix("'") diff --git a/modules/commands/decorators.py b/modules/commands/decorators.py new file mode 100644 index 0000000..f0e392e --- /dev/null +++ b/modules/commands/decorators.py @@ -0,0 +1,62 @@ +from typing import Callable, List + +from config import config +from modules.api.openai import is_flagged +from modules.command_controllers import InitializerConfig +from modules.permissions import is_admin +from modules.typing import LogLine, Player +from modules.logs import get_logger + +gui_logger = get_logger('gui') + + +def empty_prompt_wrapper_handler_factory(handler: Callable): + def decorator(func): + def wrapper(logline: LogLine, shared_dict: InitializerConfig): + if logline.prompt == '': + handler(logline, shared_dict) + return None + return func(logline, shared_dict) + + return wrapper + + return decorator + + +def gpt4_admin_only(func): + def wrapper(logline: LogLine, shared_dict: InitializerConfig): + if ( + config.GPT4_ADMIN_ONLY + and is_admin(logline.player) + or not config.GPT4_ADMIN_ONLY + ): + return func(logline, shared_dict) + raise Exception('User is not admin.') + + return wrapper + + +def openai_moderated_message(func): + def wrapper(logline: LogLine, shared_dict: InitializerConfig): + if ( + not config.TOS_VIOLATION + and is_flagged(logline.prompt) + and not is_admin(logline.player) + ): + raise Exception("Request was flagged during moderation. Skipping...") + + return func(logline, shared_dict) + + return wrapper + + +def permission_decorator_factory(permissions_funcs: List[Callable[[Player], bool]]): + def permissions_decorator(func): + def wrapper(logline: LogLine, shared_dict: InitializerConfig): + if all(map(lambda x: x(logline.player), permissions_funcs)): + return func(logline, shared_dict) + return None + + return wrapper + + return permissions_decorator diff --git a/modules/commands/github.py b/modules/commands/github.py index 74b73df..bcb7fdc 100644 --- a/modules/commands/github.py +++ b/modules/commands/github.py @@ -9,7 +9,7 @@ GITHUB_LINK = "bit.ly/tf2-gpt3" -def handle_gh_command(logline: LogLine, shared_dict: InitializerConfig) -> None: +def handle_gh_command(logline: LogLine, shared_dict: InitializerConfig): time.sleep(1) if config.ENABLE_SHORTENED_USERNAMES_RESPONSE: diff --git a/modules/commands/groq.py b/modules/commands/groq.py new file mode 100644 index 0000000..6413897 --- /dev/null +++ b/modules/commands/groq.py @@ -0,0 +1,18 @@ +from config import config +from modules.api.groq import GroqCloudLLMProvider +from modules.commands.base import QuickQueryLLMCommand, GlobalChatLLMCommand, PrivateChatLLMCommand + + +class GroqQuickQueryCommand(QuickQueryLLMCommand): + provider = GroqCloudLLMProvider + model = config.GROQ_MODEL + + +class GroqGlobalChatCommand(GlobalChatLLMCommand): + provider = GroqCloudLLMProvider + model = config.GROQ_MODEL + + +class GroqPrivateChatCommand(PrivateChatLLMCommand): + provider = GroqCloudLLMProvider + model = config.GROQ_MODEL diff --git a/modules/commands/gui/openai.py b/modules/commands/gui/openai.py index ce76ec4..c1ae0a4 100644 --- a/modules/commands/gui/openai.py +++ b/modules/commands/gui/openai.py @@ -3,7 +3,7 @@ import openai -from modules.api.openai import send_gpt_completion_request +from modules.api.openai import OpenAILLMProvider from modules.logs import get_logger from modules.typing import Message @@ -23,7 +23,7 @@ def gpt3_cmd_handler() -> None: if GPT3_PROMPTS_QUEUE.qsize() != 0: prompt = GPT3_PROMPTS_QUEUE.get() try: - response = send_gpt_completion_request( + response = OpenAILLMProvider.get_completion_text( [Message(role="user", content=prompt)], "admin", model="gpt-3.5-turbo", diff --git a/modules/commands/openai.py b/modules/commands/openai.py index c6e2bda..425cd7d 100644 --- a/modules/commands/openai.py +++ b/modules/commands/openai.py @@ -1,87 +1,66 @@ import time from config import config -from modules.api.openai import handle_cgpt_request, handle_gpt_request +from modules.api.openai import OpenAILLMProvider from modules.command_controllers import InitializerConfig -from modules.logs import get_logger, log_gui_model_message +from modules.logs import get_logger from modules.servers.tf2 import send_say_command_to_tf2 from modules.typing import LogLine +from modules.commands.base import GlobalChatLLMCommand, PrivateChatLLMCommand, QuickQueryLLMCommand +from modules.commands.decorators import empty_prompt_wrapper_handler_factory, gpt4_admin_only, openai_moderated_message main_logger = get_logger("main") -def handle_gpt3(logline: LogLine, shared_dict: InitializerConfig) -> None: - if logline.prompt.removeprefix(config.GPT_COMMAND).strip() == "": - time.sleep(1) - send_say_command_to_tf2( - "Hello there! I am ChatGPT, a ChatGPT plugin integrated into" - " Team Fortress 2. Ask me anything!", - username=None, - is_team_chat=logline.is_team_message, - ) - log_gui_model_message(config.GPT3_MODEL, logline.username, logline.prompt.strip()) - main_logger.info(f"Empty '{config.GPT_COMMAND}' command from user '{logline.username}'.") - return - - main_logger.info( - f"'{config.GPT_COMMAND}' command from user '{logline.username}'. " - f"Message: '{logline.prompt.removeprefix(config.GPT_COMMAND).strip()}'" - ) - handle_gpt_request( - logline.username, - logline.prompt.removeprefix(config.GPT_COMMAND).strip(), - model=config.GPT3_MODEL, +def handle_empty(logline: LogLine, shared_dict: InitializerConfig): + time.sleep(1) + send_say_command_to_tf2( + "Hello there! I am ChatGPT, a ChatGPT plugin integrated into" + " Team Fortress 2. Ask me anything!", + username=None, is_team_chat=logline.is_team_message, ) + main_logger.info(f"Empty '{config.GPT_COMMAND}' command from user '{logline.username}'.") -def handle_user_chat(logline: LogLine, shared_dict: InitializerConfig): - user_chat = shared_dict.CHAT_CONVERSATION_HISTORY.get_conversation_history(logline.player) +class OpenAIGPT3QuickQueryCommand(QuickQueryLLMCommand): + provider = OpenAILLMProvider + model = config.GPT3_MODEL + wrappers = [ + empty_prompt_wrapper_handler_factory(handle_empty), + openai_moderated_message + ] - conv_his = handle_cgpt_request( - logline.username, - logline.prompt.removeprefix(config.CHATGPT_COMMAND).strip(), - user_chat, - is_team=logline.is_team_message, - model=config.GPT3_CHAT_MODEL, - ) - shared_dict.CHAT_CONVERSATION_HISTORY.set_conversation_history(logline.player, conv_his) +class OpenAIPrivateChatCommand(PrivateChatLLMCommand): + provider = OpenAILLMProvider + model = config.GPT3_CHAT_MODEL + wrappers = [ + openai_moderated_message + ] -def handle_global_chat(logline: LogLine, shared_dict: InitializerConfig): - conv_his = handle_cgpt_request( - logline.username, - logline.prompt.removeprefix(config.GLOBAL_CHAT_COMMAND).strip(), - shared_dict.CHAT_CONVERSATION_HISTORY.GLOBAL, - is_team=logline.is_team_message, - model=config.GPT3_CHAT_MODEL, - ) - shared_dict.CHAT_CONVERSATION_HISTORY.GLOBAL = conv_his + +class OpenAIGlobalChatCommand(GlobalChatLLMCommand): + provider = OpenAILLMProvider + model = config.GPT3_CHAT_MODEL + wrappers = [ + openai_moderated_message + ] -def handle_gpt4(logline: LogLine, shared_dict: InitializerConfig): - if ( - config.GPT4_ADMIN_ONLY - and config.HOST_USERNAME == logline.username - or not config.GPT4_ADMIN_ONLY - ): - handle_gpt_request( - logline.username, - logline.prompt.removeprefix(config.GPT4_COMMAND).strip(), - model=config.GPT4_MODEL, - is_team_chat=logline.is_team_message, - ) +class OpenAIGPT4QuickQueryCommand(QuickQueryLLMCommand): + provider = OpenAILLMProvider + model = config.GPT4_MODEL + wrappers = [ + gpt4_admin_only, + openai_moderated_message + ] -def handle_gpt4l(logline: LogLine, shared_dict: InitializerConfig): - if ( - config.GPT4_ADMIN_ONLY - and config.HOST_USERNAME == logline.username - or not config.GPT4_ADMIN_ONLY - ): - handle_gpt_request( - logline.username, - logline.prompt.removeprefix(config.GPT4_LEGACY_COMMAND).strip(), - model=config.GPT4L_MODEL, - is_team_chat=logline.is_team_message, - ) +class OpenAIGPT4LQuickQueryCommand(QuickQueryLLMCommand): + provider = OpenAILLMProvider + model = config.GPT4L_MODEL + wrappers = [ + gpt4_admin_only, + openai_moderated_message + ] diff --git a/modules/commands/textgen_webui.py b/modules/commands/textgen_webui.py index 13545a7..bc751a2 100644 --- a/modules/commands/textgen_webui.py +++ b/modules/commands/textgen_webui.py @@ -1,80 +1,14 @@ -from config import config -from modules.api.textgen_webui import get_custom_model_response -from modules.command_controllers import InitializerConfig -from modules.conversation_history import ConversationHistory -from modules.logs import get_logger, log_gui_model_message -from modules.servers.tf2 import send_say_command_to_tf2 -from modules.typing import LogLine, Message -from modules.utils.text import get_system_message +from modules.api.textgen_webui import TextGenerationWebUILLMProvider +from modules.commands.base import QuickQueryLLMCommand, GlobalChatLLMCommand, PrivateChatLLMCommand -main_logger = get_logger("main") +class TextgenWebUIQuickQueryCommand(QuickQueryLLMCommand): + provider = TextGenerationWebUILLMProvider -def handle_custom_model(logline: LogLine, shared_dict: InitializerConfig): - main_logger.info( - f"'{config.CUSTOM_MODEL_COMMAND}' command from user '{logline.username}'. " - f"Message: '{logline.prompt.removeprefix(config.CUSTOM_MODEL_COMMAND).strip()}'" - ) - log_gui_model_message( - "CUSTOM", - logline.username, - logline.prompt.removeprefix(config.CUSTOM_MODEL_COMMAND).strip(), - ) - user_message = logline.prompt.removeprefix(config.CUSTOM_MODEL_COMMAND).strip() - sys_message = get_system_message( - logline.prompt, enable_soft_limit=config.ENABLE_SOFT_LIMIT_FOR_CUSTOM_MODEL - ) +class TextgenWebUIGlobalChatCommand(GlobalChatLLMCommand): + provider = TextGenerationWebUILLMProvider - response = get_custom_model_response( - [ - sys_message, - Message(role="assistant", content=config.GREETING), - Message(role="user", content=user_message), - ] - ) - if response: - log_gui_model_message("CUSTOM", logline.username, response.strip()) - send_say_command_to_tf2(response, logline.username, logline.is_team_message) - - -def handle_custom_user_chat(logline: LogLine, shared_dict: InitializerConfig): - conversation_history: ConversationHistory = shared_dict.CHAT_CONVERSATION_HISTORY.get_conversation_history( - logline.player) - - log_gui_model_message( - "CUSTOM CHAT", - logline.username, - logline.prompt.removeprefix(config.CUSTOM_MODEL_CHAT_COMMAND).strip(), - ) - - user_message = logline.prompt.removeprefix(config.CUSTOM_MODEL_CHAT_COMMAND).strip() - conversation_history.add_user_message_from_prompt(user_message) - response = get_custom_model_response(conversation_history.get_messages_array()) - - if response: - conversation_history.add_assistant_message(Message(role="assistant", content=response)) - log_gui_model_message("CUSTOM CHAT", logline.username, response.strip()) - send_say_command_to_tf2(response, logline.username, logline.is_team_message) - shared_dict.CHAT_CONVERSATION_HISTORY.set_conversation_history(logline.player, conversation_history) - - -def handle_custom_global_chat(logline: LogLine, shared_dict: InitializerConfig): - conversation_history: ConversationHistory = shared_dict.CHAT_CONVERSATION_HISTORY.GLOBAL - - log_gui_model_message( - "GLOBAL CUSTOM CHAT", - logline.username, - logline.prompt.removeprefix(config.CUSTOM_MODEL_CHAT_COMMAND).strip(), - ) - - user_message = logline.prompt.removeprefix(config.GLOBAL_CUSTOM_CHAT_COMMAND).strip() - conversation_history.add_user_message_from_prompt(user_message) - response = get_custom_model_response(conversation_history.get_messages_array()) - - if response: - conversation_history.add_assistant_message(Message(role="assistant", content=response)) - log_gui_model_message("GLOBAL CUSTOM CHAT", logline.username, response.strip()) - send_say_command_to_tf2(response, logline.username, logline.is_team_message) - shared_dict.CHAT_CONVERSATION_HISTORY.GLOBAL = conversation_history +class TextgenWebUIPrivateChatCommand(PrivateChatLLMCommand): + provider = TextGenerationWebUILLMProvider diff --git a/modules/conversation_history.py b/modules/conversation_history.py index 52f57c0..c5257d1 100644 --- a/modules/conversation_history.py +++ b/modules/conversation_history.py @@ -48,7 +48,7 @@ def add_assistant_message(self, message: Message) -> None: self.message_history.append(message) def add_user_message_from_prompt( - self, user_prompt: str, enable_soft_limit: bool = True + self, user_prompt: str, enable_soft_limit: bool = True ) -> None: user_message = remove_args(user_prompt) args = get_args(user_prompt) diff --git a/modules/lobby_manager.py b/modules/lobby_manager.py index 7ebd34e..08a11f8 100644 --- a/modules/lobby_manager.py +++ b/modules/lobby_manager.py @@ -64,7 +64,8 @@ "batteaxe", "scotland_shard", "voodoo_pin", - "eternal_reward" "apocofists", + "eternal_reward", + "apocofists", "bread_bite", "eviction_notice", "gloves_running_urgently", diff --git a/modules/logs.py b/modules/logs.py index 14d5203..e10adc5 100644 --- a/modules/logs.py +++ b/modules/logs.py @@ -80,11 +80,14 @@ def get_time_stamp() -> str: return f"{dt.now().strftime('%H:%M:%S')}" -def log_gui_model_message(message_type: str, username: str, user_prompt: str) -> None: +def log_gui_model_message(type_: str, username: str, message: str) -> None: """ Logs a message with the current timestamp, message type, username, user_id, and prompt text. """ - log_msg = f"[{get_time_stamp()}] ({message_type}) User: '{username}' --- '{user_prompt}'" + if message: + log_msg = f"[{get_time_stamp()}] ({type_}) User: '{username}' --- '{message}'" + else: + log_msg = f"[{get_time_stamp()}] ({type_}) User: '{username}'" __gui_logger.info(log_msg) diff --git a/modules/permissions.py b/modules/permissions.py index 469a390..24eda75 100644 --- a/modules/permissions.py +++ b/modules/permissions.py @@ -2,16 +2,6 @@ from modules.typing import Player -# def check_permission(func): -# def wrapper(*args, **kwargs): -# if user_has_permission(): -# return func(*args, **kwargs) -# else: -# return "Permission denied" -# -# return wrapper - - def is_admin(user: Player) -> bool: if config.HOST_STEAMID3 == user.steamid3: return True diff --git a/modules/utils/text.py b/modules/utils/text.py index 6875862..096cc54 100644 --- a/modules/utils/text.py +++ b/modules/utils/text.py @@ -120,7 +120,7 @@ def follow_tail(file_path: str) -> typing.Generator: for line in latest_lines[:-1]: yield line + "\n" except FileNotFoundError: - gui_logger.warning(f"Logfile doesn't exist. Checking again in 4 seconds.") + gui_logger.warning("Logfile doesn't exist. Checking again in 4 seconds.") time.sleep(4) yield "" except Exception as e: diff --git a/requirements.txt b/requirements.txt index 2ba116c..586af11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ pytest-mock==3.12.0 requests-mock==1.11.0 pytest-loguru==0.3.0 pytest-random-order==1.1.1 -pytest-repeat==0.9.3 \ No newline at end of file +pytest-repeat==0.9.3 +groq==0.5.0 diff --git a/tests/test_text.py b/tests/test_text.py index 71872e4..95551ef 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -4,7 +4,6 @@ from modules.lobby_manager import LobbyManager from modules.typing import Message, LogLine, Player from modules.utils.text import ( - get_system_message, get_chunk_size, has_cyrillic, split_into_chunks, @@ -68,12 +67,6 @@ def test_get_chunk_size_with_non_cyrillic_text(): assert get_chunk_size(text) == MAX_LENGTH_OTHER -def test_get_system_message(): - expected_output = Message(role="system", content="") - result = get_system_message(r"\l Please enter your name") - assert result == expected_output - - def test_get_args(): assert get_args(r'\user="123" \global') == [r'\user="123"', r'\global'] assert get_args(r"\user='123' \global") == [r"\user='123'", r'\global']