Skip to content

Commit be8a5a6

Browse files
authored
MISC: GroqCloud support & Commands refactor (#104)
* add GroqCloud support * refactor commands * refactor textgen_webui commands * move removing prompt command prefix to controller * add support for command decorators * refactor OpenAI commands * move command logic from providers * unite 2 approaches of making system message * minor fixes * update logging for commands * move error handling logic to controller * update README * add groq model settings * fix config
1 parent b36cc23 commit be8a5a6

24 files changed

+386
-341
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ Hi chatGPT, you are going to pretend to be MEDIC from Team Fortress 2. You can d
387387
## That's all?
388388

389389
If you want to know more here are listed some things that were left unexplained, and some tips and tricks:
390-
[unexplained_explained.md](docs/unexplained_explained.md)
390+
[unexplained_explained.md](docs/unexplained_explained.md) or at project [Wiki](https://github.com/dborodin836/TF2-GPTChatBot/wiki).
391391

392392
## Screenshots
393393

config.ini

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
[GENERAL]
2-
TF2_LOGFILE_PATH=H:\Programs\Steam\steamapps\common\Team Fortress 2\tf\console.log
2+
TF2_LOGFILE_PATH=C:\Program Files (x86)\Steam\steamapps\common\Team Fortress 2\tf\console.log
33
OPENAI_API_KEY=
44

5+
[GROQ]
6+
GROQ_ENABLE=False
7+
GROQ_API_KEY=
8+
; Available models
9+
; https://console.groq.com/docs/models
10+
GROQ_MODEL=
11+
GROQ_COMMAND=!g
12+
GROQ_CHAT_COMMAND=!gc
13+
GROQ_PRIVATE_CHAT=!gpc
14+
; Example {"max_tokens": 2}
15+
GROQ_SETTINGS=
16+
517
[COMMANDS]
618
ENABLE_OPENAI_COMMANDS=True
719
GPT_COMMAND=!gpt3

config.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ class Config(BaseModel):
8585

8686
CUSTOM_MODEL_SETTINGS: Optional[str | dict]
8787

88+
GROQ_API_KEY: str
89+
GROQ_COMMAND: str
90+
GROQ_CHAT_COMMAND: str
91+
GROQ_PRIVATE_CHAT: str
92+
GROQ_MODEL: str
93+
GROQ_ENABLE: bool
94+
GROQ_SETTINGS: Optional[str | dict]
95+
8896
@validator("OPENAI_API_KEY")
8997
def api_key_pattern_match(cls, v):
9098
if not re.fullmatch(OPENAI_API_KEY_RE_PATTERN, v):
@@ -158,6 +166,7 @@ def init_config():
158166
for key, value in configparser_config.items(section)
159167
}
160168
global config
169+
161170
try:
162171
if config_dict.get("CUSTOM_MODEL_SETTINGS") != "":
163172
config_dict["CUSTOM_MODEL_SETTINGS"] = json.loads(
@@ -168,9 +177,19 @@ def init_config():
168177
f"CUSTOM_MODEL_SETTINGS is not dict [{e}].", "BOTH", level="ERROR"
169178
)
170179

180+
try:
181+
if config_dict.get("GROQ_SETTINGS") != "":
182+
config_dict["GROQ_SETTINGS"] = json.loads(
183+
config_dict.get("GROQ_SETTINGS")
184+
)
185+
except Exception as e:
186+
buffered_fail_message(
187+
f"GROQ_SETTINGS is not dict [{e}].", "BOTH", level="ERROR"
188+
)
189+
171190
config = Config(**config_dict)
172191

173-
if not config.ENABLE_OPENAI_COMMANDS and not config.ENABLE_CUSTOM_MODEL:
192+
if not config.ENABLE_OPENAI_COMMANDS and not config.ENABLE_CUSTOM_MODEL and not config.GROQ_ENABLE:
174193
buffered_message("You haven't enabled any AI related commands.")
175194

176195
except (pydantic.ValidationError, Exception) as e:

modules/api/base.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from abc import ABC, abstractmethod
2+
3+
from modules.typing import MessageHistory
4+
5+
6+
class LLMProvider(ABC):
7+
8+
@staticmethod
9+
@abstractmethod
10+
def get_completion_text(message_array: MessageHistory, username: str, model: str) -> str:
11+
...

modules/api/groq.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import hashlib
2+
3+
import groq
4+
5+
from config import config
6+
from modules.api.base import LLMProvider
7+
from modules.utils.text import remove_hashtags
8+
9+
10+
class GroqCloudLLMProvider(LLMProvider):
11+
12+
@staticmethod
13+
def get_completion_text(message_array, username, model):
14+
client = groq.Groq(
15+
max_retries=0,
16+
api_key=config.GROQ_API_KEY
17+
)
18+
19+
if isinstance(config.GROQ_SETTINGS, dict):
20+
completion = client.chat.completions.create(
21+
model=model,
22+
messages=message_array,
23+
user=hashlib.md5(username.encode()).hexdigest(),
24+
**config.GROQ_SETTINGS
25+
)
26+
else:
27+
completion = client.chat.completions.create(
28+
model=model,
29+
messages=message_array,
30+
user=hashlib.md5(username.encode()).hexdigest(),
31+
)
32+
33+
response_text = completion.choices[0].message.content.strip()
34+
filtered_response = remove_hashtags(response_text)
35+
return filtered_response

modules/api/openai.py

+19-127
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,32 @@
11
import hashlib
2-
import time
32

43
import openai
54

65
from config import config
7-
from modules.conversation_history import ConversationHistory
8-
from modules.logs import get_logger, log_gui_general_message, log_gui_model_message
9-
from modules.servers.tf2 import send_say_command_to_tf2
10-
from modules.typing import Message, MessageHistory
11-
from modules.utils.text import get_system_message, remove_args, remove_hashtags
6+
from modules.api.base import LLMProvider
7+
from modules.logs import get_logger
128

139
main_logger = get_logger("main")
1410
gui_logger = get_logger("gui")
1511

1612

17-
def is_violated_tos(message: str) -> bool:
13+
class OpenAILLMProvider(LLMProvider):
14+
15+
@staticmethod
16+
def get_completion_text(conversation_history, username, model):
17+
openai.api_key = config.OPENAI_API_KEY
18+
19+
completion = openai.ChatCompletion.create(
20+
model=model,
21+
messages=conversation_history,
22+
user=hashlib.md5(username.encode()).hexdigest(),
23+
)
24+
25+
response_text = completion.choices[0].message["content"].strip()
26+
return response_text
27+
28+
29+
def is_flagged(message: str) -> bool:
1830
openai.api_key = config.OPENAI_API_KEY
1931
try:
2032
response = openai.Moderation.create(
@@ -28,123 +40,3 @@ def is_violated_tos(message: str) -> bool:
2840
return True
2941

3042
return response.results[0]["flagged"]
31-
32-
33-
def send_gpt_completion_request(
34-
conversation_history: MessageHistory, username: str, model: str
35-
) -> str:
36-
openai.api_key = config.OPENAI_API_KEY
37-
38-
completion = openai.ChatCompletion.create(
39-
model=model,
40-
messages=conversation_history,
41-
user=hashlib.md5(username.encode()).hexdigest(),
42-
)
43-
44-
response_text = completion.choices[0].message["content"].strip()
45-
return response_text
46-
47-
48-
def handle_cgpt_request(
49-
username: str,
50-
user_prompt: str,
51-
conversation_history: ConversationHistory,
52-
model,
53-
is_team: bool = False,
54-
) -> ConversationHistory:
55-
"""
56-
This function is called when the user wants to send a message to the AI chatbot. It logs the
57-
user's message, and sends a request to generate a response.
58-
"""
59-
log_gui_model_message(model, username, user_prompt)
60-
61-
user_message = remove_args(user_prompt)
62-
if (
63-
not config.TOS_VIOLATION
64-
and is_violated_tos(user_message)
65-
and config.HOST_USERNAME != username
66-
):
67-
gui_logger.error(f"Request '{user_prompt}' violates OPENAI TOS. Skipping...")
68-
return conversation_history
69-
70-
conversation_history.add_user_message_from_prompt(user_prompt)
71-
72-
response = get_response(conversation_history.get_messages_array(), username, model)
73-
74-
if response:
75-
conversation_history.add_assistant_message(Message(role="assistant", content=response))
76-
log_gui_model_message(model, username, " ".join(response.split()))
77-
send_say_command_to_tf2(response, username, is_team)
78-
79-
return conversation_history
80-
81-
82-
def handle_gpt_request(
83-
username: str, user_prompt: str, model: str, is_team_chat: bool = False
84-
) -> None:
85-
"""
86-
This function is called when the user wants to send a message to the AI chatbot. It logs the
87-
user's message, and sends a request to GPT-3 to generate a response. Finally, the function
88-
sends the generated response to the TF2 game.
89-
"""
90-
log_gui_model_message(model, username, user_prompt)
91-
92-
user_message = remove_args(user_prompt)
93-
sys_message = get_system_message(user_prompt)
94-
95-
if (
96-
not config.TOS_VIOLATION
97-
and is_violated_tos(user_message)
98-
and config.HOST_USERNAME != username
99-
):
100-
gui_logger.warning(
101-
f"Request '{user_prompt}' by user {username} violates OPENAI TOS. Skipping..."
102-
)
103-
return
104-
105-
payload = [
106-
sys_message,
107-
Message(role="assistant", content=config.GREETING),
108-
Message(role="user", content=user_message),
109-
]
110-
111-
response = get_response(payload, username, model)
112-
113-
if response:
114-
main_logger.info(
115-
f"Got response for user {username}. Response: {' '.join(response.split())}"
116-
)
117-
log_gui_model_message(model, username, " ".join(response.split()))
118-
send_say_command_to_tf2(response, username, is_team_chat)
119-
120-
121-
def get_response(conversation_history: MessageHistory, username: str, model) -> str | None:
122-
attempts = 0
123-
max_attempts = 2
124-
125-
while attempts < max_attempts:
126-
try:
127-
response = send_gpt_completion_request(conversation_history, username, model=model)
128-
filtered_response = remove_hashtags(response)
129-
return filtered_response
130-
except openai.error.RateLimitError:
131-
log_gui_general_message("Rate limited! Trying again...")
132-
main_logger(f"User is rate limited.")
133-
time.sleep(2)
134-
attempts += 1
135-
except openai.error.APIError as e:
136-
log_gui_general_message(f"Wasn't able to connect to OpenAI API. Cancelling...")
137-
main_logger.error(f"APIError happened. [{e}]")
138-
return
139-
except openai.error.AuthenticationError:
140-
log_gui_general_message("Your OpenAI api key is invalid.")
141-
main_logger.error("OpenAI API key is invalid.")
142-
return
143-
except Exception as e:
144-
log_gui_general_message(f"Unhandled error happened! Cancelling ({e})")
145-
main_logger.error(f"Unhandled error happened! Cancelling ({e})")
146-
return
147-
148-
if attempts == max_attempts:
149-
log_gui_general_message("Max number of attempts reached! Try again later!")
150-
main_logger(f"Max number of attempts reached. [{max_attempts}/{max_attempts}]")

modules/api/textgen_webui.py

+12-26
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,25 @@
11
import requests
22

33
from config import config
4+
from modules.api.base import LLMProvider
45
from modules.logs import get_logger
5-
from modules.typing import Message
66

77
main_logger = get_logger("main")
88
combo_logger = get_logger("combo")
99

1010

11-
def get_custom_model_response(conversation_history: list[Message]) -> str | None:
12-
uri = f"http://{config.CUSTOM_MODEL_HOST}/v1/chat/completions"
11+
class TextGenerationWebUILLMProvider(LLMProvider):
1312

14-
headers = {"Content-Type": "application/json"}
13+
@staticmethod
14+
def get_completion_text(conversation_history, username, model):
15+
uri = f"http://{config.CUSTOM_MODEL_HOST}/v1/chat/completions"
16+
headers = {"Content-Type": "application/json"}
1517

16-
data = {"mode": "chat", "messages": conversation_history}
18+
data = {"mode": "chat", "messages": conversation_history}
19+
data.update(config.CUSTOM_MODEL_SETTINGS)
1720

18-
data.update(config.CUSTOM_MODEL_SETTINGS)
19-
20-
try:
2121
response = requests.post(uri, headers=headers, json=data, verify=False)
22-
except Exception as e:
23-
combo_logger.error(f"Failed to get response from the text-generation-webui server. [{e}]")
24-
return
25-
26-
if response.status_code == 200:
27-
try:
28-
data = response.json()["choices"][0]["message"]["content"]
29-
return data
30-
except Exception as e:
31-
combo_logger.error(f"Failed to parse data from server [{e}].")
32-
elif response.status_code == 500:
33-
combo_logger.error(f"There's error on the text-generation-webui server. [HTTP 500]")
34-
else:
35-
main_logger.error(
36-
f"Got non-200 status code from the text-generation-webui server. [HTTP {response.status_code}]"
37-
)
38-
39-
return None
22+
if response.status_code == 500:
23+
raise Exception('HTTP 500')
24+
data = response.json()["choices"][0]["message"]["content"]
25+
return data

modules/chat.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from config import config
1+
from config import config, RTDModes
22
from modules.api.github import check_for_updates
33
from modules.bans import bans_manager
44
from modules.bot_state import state_manager
55
from modules.command_controllers import CommandController, InitializerConfig
66
from modules.commands.clear_chat import handle_clear
77
from modules.commands.github import handle_gh_command
8-
from modules.commands.openai import handle_user_chat, handle_gpt3, handle_gpt4, handle_gpt4l, handle_global_chat
8+
from modules.commands.groq import GroqQuickQueryCommand, GroqGlobalChatCommand, GroqPrivateChatCommand
9+
from modules.commands.openai import OpenAIGlobalChatCommand, OpenAIPrivateChatCommand, OpenAIGPT3QuickQueryCommand, \
10+
OpenAIGPT4QuickQueryCommand, OpenAIGPT4LQuickQueryCommand
911
from modules.commands.rtd import handle_rtd
10-
from modules.commands.textgen_webui import handle_custom_user_chat, handle_custom_model, handle_custom_global_chat
12+
from modules.commands.textgen_webui import TextgenWebUIGlobalChatCommand, TextgenWebUIPrivateChatCommand, \
13+
TextgenWebUIQuickQueryCommand
1114
from modules.logs import get_logger
1215
from modules.message_queueing import messaging_queue_service
1316
from modules.servers.tf2 import check_connection, set_host_username
@@ -53,18 +56,23 @@ def parse_console_logs_and_build_conversation_history() -> None:
5356

5457
# Commands
5558
controller.register_command("!gh", handle_gh_command)
56-
controller.register_command(config.RTD_COMMAND, handle_rtd)
5759
controller.register_command(config.CLEAR_CHAT_COMMAND, handle_clear)
60+
if config.RTD_MODE != RTDModes.DISABLED:
61+
controller.register_command(config.RTD_COMMAND, handle_rtd)
5862
if config.ENABLE_OPENAI_COMMANDS:
59-
controller.register_command(config.GPT4_COMMAND, handle_gpt4)
60-
controller.register_command(config.GPT4_LEGACY_COMMAND, handle_gpt4l)
61-
controller.register_command(config.CHATGPT_COMMAND, handle_user_chat)
62-
controller.register_command(config.GLOBAL_CHAT_COMMAND, handle_global_chat)
63-
controller.register_command(config.GPT_COMMAND, handle_gpt3)
63+
controller.register_command(config.GPT4_COMMAND, OpenAIGPT4QuickQueryCommand.as_command())
64+
controller.register_command(config.GPT4_LEGACY_COMMAND, OpenAIGPT4LQuickQueryCommand.as_command())
65+
controller.register_command(config.CHATGPT_COMMAND, OpenAIPrivateChatCommand.as_command())
66+
controller.register_command(config.GLOBAL_CHAT_COMMAND, OpenAIGlobalChatCommand.as_command())
67+
controller.register_command(config.GPT_COMMAND, OpenAIGPT3QuickQueryCommand.as_command())
6468
if config.ENABLE_CUSTOM_MODEL:
65-
controller.register_command(config.CUSTOM_MODEL_COMMAND, handle_custom_model)
66-
controller.register_command(config.CUSTOM_MODEL_CHAT_COMMAND, handle_custom_user_chat)
67-
controller.register_command(config.GLOBAL_CUSTOM_CHAT_COMMAND, handle_custom_global_chat)
69+
controller.register_command(config.CUSTOM_MODEL_COMMAND, TextgenWebUIQuickQueryCommand.as_command())
70+
controller.register_command(config.CUSTOM_MODEL_CHAT_COMMAND, TextgenWebUIPrivateChatCommand.as_command())
71+
controller.register_command(config.GLOBAL_CUSTOM_CHAT_COMMAND, TextgenWebUIGlobalChatCommand.as_command())
72+
if config.GROQ_ENABLE:
73+
controller.register_command(config.GROQ_COMMAND, GroqQuickQueryCommand.as_command())
74+
controller.register_command(config.GROQ_CHAT_COMMAND, GroqGlobalChatCommand.as_command())
75+
controller.register_command(config.GROQ_PRIVATE_CHAT, GroqPrivateChatCommand.as_command())
6876

6977
# Services
7078
controller.register_service(messaging_queue_service)

0 commit comments

Comments
 (0)