1
1
import hashlib
2
- import time
3
2
4
3
import openai
5
4
6
5
from config import config
7
- from modules .conversation_history import ConversationHistory
8
- from modules .logs import get_logger , log_gui_general_message , log_gui_model_message
9
- from modules .servers .tf2 import send_say_command_to_tf2
10
- from modules .typing import Message , MessageHistory
11
- from modules .utils .text import get_system_message , remove_args , remove_hashtags
6
+ from modules .api .base import LLMProvider
7
+ from modules .logs import get_logger
12
8
13
9
main_logger = get_logger ("main" )
14
10
gui_logger = get_logger ("gui" )
15
11
16
12
17
- def is_violated_tos (message : str ) -> bool :
13
+ class OpenAILLMProvider (LLMProvider ):
14
+
15
+ @staticmethod
16
+ def get_completion_text (conversation_history , username , model ):
17
+ openai .api_key = config .OPENAI_API_KEY
18
+
19
+ completion = openai .ChatCompletion .create (
20
+ model = model ,
21
+ messages = conversation_history ,
22
+ user = hashlib .md5 (username .encode ()).hexdigest (),
23
+ )
24
+
25
+ response_text = completion .choices [0 ].message ["content" ].strip ()
26
+ return response_text
27
+
28
+
29
+ def is_flagged (message : str ) -> bool :
18
30
openai .api_key = config .OPENAI_API_KEY
19
31
try :
20
32
response = openai .Moderation .create (
@@ -28,123 +40,3 @@ def is_violated_tos(message: str) -> bool:
28
40
return True
29
41
30
42
return response .results [0 ]["flagged" ]
31
-
32
-
33
- def send_gpt_completion_request (
34
- conversation_history : MessageHistory , username : str , model : str
35
- ) -> str :
36
- openai .api_key = config .OPENAI_API_KEY
37
-
38
- completion = openai .ChatCompletion .create (
39
- model = model ,
40
- messages = conversation_history ,
41
- user = hashlib .md5 (username .encode ()).hexdigest (),
42
- )
43
-
44
- response_text = completion .choices [0 ].message ["content" ].strip ()
45
- return response_text
46
-
47
-
48
- def handle_cgpt_request (
49
- username : str ,
50
- user_prompt : str ,
51
- conversation_history : ConversationHistory ,
52
- model ,
53
- is_team : bool = False ,
54
- ) -> ConversationHistory :
55
- """
56
- This function is called when the user wants to send a message to the AI chatbot. It logs the
57
- user's message, and sends a request to generate a response.
58
- """
59
- log_gui_model_message (model , username , user_prompt )
60
-
61
- user_message = remove_args (user_prompt )
62
- if (
63
- not config .TOS_VIOLATION
64
- and is_violated_tos (user_message )
65
- and config .HOST_USERNAME != username
66
- ):
67
- gui_logger .error (f"Request '{ user_prompt } ' violates OPENAI TOS. Skipping..." )
68
- return conversation_history
69
-
70
- conversation_history .add_user_message_from_prompt (user_prompt )
71
-
72
- response = get_response (conversation_history .get_messages_array (), username , model )
73
-
74
- if response :
75
- conversation_history .add_assistant_message (Message (role = "assistant" , content = response ))
76
- log_gui_model_message (model , username , " " .join (response .split ()))
77
- send_say_command_to_tf2 (response , username , is_team )
78
-
79
- return conversation_history
80
-
81
-
82
- def handle_gpt_request (
83
- username : str , user_prompt : str , model : str , is_team_chat : bool = False
84
- ) -> None :
85
- """
86
- This function is called when the user wants to send a message to the AI chatbot. It logs the
87
- user's message, and sends a request to GPT-3 to generate a response. Finally, the function
88
- sends the generated response to the TF2 game.
89
- """
90
- log_gui_model_message (model , username , user_prompt )
91
-
92
- user_message = remove_args (user_prompt )
93
- sys_message = get_system_message (user_prompt )
94
-
95
- if (
96
- not config .TOS_VIOLATION
97
- and is_violated_tos (user_message )
98
- and config .HOST_USERNAME != username
99
- ):
100
- gui_logger .warning (
101
- f"Request '{ user_prompt } ' by user { username } violates OPENAI TOS. Skipping..."
102
- )
103
- return
104
-
105
- payload = [
106
- sys_message ,
107
- Message (role = "assistant" , content = config .GREETING ),
108
- Message (role = "user" , content = user_message ),
109
- ]
110
-
111
- response = get_response (payload , username , model )
112
-
113
- if response :
114
- main_logger .info (
115
- f"Got response for user { username } . Response: { ' ' .join (response .split ())} "
116
- )
117
- log_gui_model_message (model , username , " " .join (response .split ()))
118
- send_say_command_to_tf2 (response , username , is_team_chat )
119
-
120
-
121
- def get_response (conversation_history : MessageHistory , username : str , model ) -> str | None :
122
- attempts = 0
123
- max_attempts = 2
124
-
125
- while attempts < max_attempts :
126
- try :
127
- response = send_gpt_completion_request (conversation_history , username , model = model )
128
- filtered_response = remove_hashtags (response )
129
- return filtered_response
130
- except openai .error .RateLimitError :
131
- log_gui_general_message ("Rate limited! Trying again..." )
132
- main_logger (f"User is rate limited." )
133
- time .sleep (2 )
134
- attempts += 1
135
- except openai .error .APIError as e :
136
- log_gui_general_message (f"Wasn't able to connect to OpenAI API. Cancelling..." )
137
- main_logger .error (f"APIError happened. [{ e } ]" )
138
- return
139
- except openai .error .AuthenticationError :
140
- log_gui_general_message ("Your OpenAI api key is invalid." )
141
- main_logger .error ("OpenAI API key is invalid." )
142
- return
143
- except Exception as e :
144
- log_gui_general_message (f"Unhandled error happened! Cancelling ({ e } )" )
145
- main_logger .error (f"Unhandled error happened! Cancelling ({ e } )" )
146
- return
147
-
148
- if attempts == max_attempts :
149
- log_gui_general_message ("Max number of attempts reached! Try again later!" )
150
- main_logger (f"Max number of attempts reached. [{ max_attempts } /{ max_attempts } ]" )
0 commit comments