Skip to content

Commit

Permalink
Changed cogs to use the utils module functions instead
Browse files Browse the repository at this point in the history
  • Loading branch information
LobaDK committed May 15, 2024
1 parent a32fa0c commit 378f88d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 145 deletions.
161 changes: 52 additions & 109 deletions cogs/Chat.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from openai import AsyncOpenAI as OpenAI, OpenAIError
import tiktoken
import os
import requests
import datetime
import calendar
import sys

from discord.ext import commands

from sql import database
from sql import crud, schemas
from decorators import requires_tos_acceptance
from cogs.utils.utils import (
get_urls_in_message,
calculate_tokens,
get_usage,
split_message_by_sentence,
get_server_id_and_name,
get_image_as_base64,
UnsupportedImageFormatError,
FileSizeLimitError,
SUPPORTED_IMAGE_FORMATS,
)

from QuantumKat import log_helper, misc_helper, discord_helper

sys.path.append(".")


class Chat(commands.Cog):
def __init__(self, bot: commands.Bot):
Expand All @@ -33,8 +38,6 @@ def __init__(self, bot: commands.Bot):
)
)

# Set the model encoding for tiktoken
self.encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
# Set the system message for ChatGPT. This controls the personality of the bot and how it responds.
self.system_message = """You are QuantumKat, a cybernetic cat with void-like black fur.
You still retain the bodily parts and functions of a normal cat.
Expand All @@ -44,7 +47,7 @@ def __init__(self, bot: commands.Bot):
however, you are still young and don't quite know how to do it correctly, as well as being clumsy, often resulting in you appearing in an incorrect location or doing the wrong thing.
You are quite sarcastic. You are allowed to have opinions on things.
Many people interact with you, and it's over the chat program called Discord, so you should never exceed 1950 characters in a response.
Currently, {user} is talking to you.
You are currently talking to {user}.
You are currently running on version {version}.
Avoid repeating yourself."""
# Attempt to get the OpenAI API key from the environment variables
Expand All @@ -65,22 +68,6 @@ def __init__(self, bot: commands.Bot):
"OpenAI Session key not found. Chatstatus command will not work."
)

async def calculate_tokens(self, user_message: str) -> int:
"""
Calculates the number of tokens in a given user message.
Parameters:
- user_message (str): The user message to calculate tokens for.
Returns:
- int: The number of tokens in the user message.
"""
messages = [user_message, self.system_message]
tokens = 0
for message in messages:
tokens += len(self.encoding.encode(message))
return tokens

async def database_add(
self,
ctx: commands.Context,
Expand All @@ -100,7 +87,7 @@ async def database_add(
Returns:
None
"""
server_id, server_name = await self.get_server_id_and_name(ctx)
server_id, server_name = get_server_id_and_name(ctx)
if server_name == "DM":
if not await crud.check_server_exists(
database.AsyncSessionLocal, schemas.Server.Get(server_id=server_id)
Expand Down Expand Up @@ -133,7 +120,7 @@ async def database_read(self, ctx: commands.Context, shared_chat: bool) -> list:
list: A list of dictionaries containing the user and assistant messages.
"""
server_id, _ = await self.get_server_id_and_name(ctx)
server_id, _ = get_server_id_and_name(ctx)
if shared_chat:
result = await crud.get_shared_chats_for_server(
database.AsyncSessionLocal,
Expand Down Expand Up @@ -164,7 +151,7 @@ async def database_remove(self, ctx: commands.Context, shared_chat: bool):
Returns:
None
"""
server_id, server_name = await self.get_server_id_and_name(ctx)
server_id, server_name = get_server_id_and_name(ctx)
if shared_chat:
await crud.delete_shared_chat(
database.AsyncSessionLocal, schemas.Chat.Delete(server_id=server_id)
Expand All @@ -175,71 +162,6 @@ async def database_remove(self, ctx: commands.Context, shared_chat: bool):
schemas.Chat.Delete(server_id=server_id, user_id=ctx.author.id),
)

async def get_usage(self) -> dict:
"""
Retrieves the usage statistics for the OpenAI API key.
Returns:
dict: A dictionary containing the usage statistics for the OpenAI API key.
"""
month = datetime.datetime.now().month
month = f"{month:02}"
year = datetime.datetime.now().year
last_day = calendar.monthrange(year, int(month))[1]
response = requests.get(
f"https://api.openai.com/dashboard/billing/usage?end_date={year}-{month}-{last_day}&start_date={year}-{month}-01",
headers={"Authorization": f"Bearer {self.session_key}"},
)
response.raise_for_status()
return response.json()

async def get_server_id_and_name(self, ctx: commands.Context) -> tuple:
"""
Retrieves the server ID and name from the context object.
Args:
ctx (commands.Context): The context object representing the invocation context of the command.
Returns:
tuple: A tuple containing the server ID and name.
"""
if not discord_helper.is_dm(ctx):
server_id = ctx.guild.id
server_name = ctx.guild.name
else:
server_id = ctx.channel.id
server_name = "DM"
return server_id, server_name

async def split_message_by_sentence(self, message: str) -> list:
"""
Splits a given message by sentence, into multiple messages with a maximum length of 2000 characters.
Args:
message (str): The message to be split into sentences.
Returns:
list: A list of sentences, each with a maximum length of 2000 characters.
"""
sentences = message.split(". ")
current_length = 0
messages = []
current_message = ""

for sentence in sentences:
if current_length + len(sentence) + 1 > 2000: # +1 for the period
messages.append(current_message)
current_length = 0
current_message = ""

current_message += sentence + ". "
current_length += len(sentence) + 1

if current_message: # Any leftover sentence
messages.append(current_message)

return messages

async def initiateChat(
self, ctx: commands.Context, user_message: str, shared_chat: bool
):
Expand All @@ -256,7 +178,7 @@ async def initiateChat(
"""
if self.FOUND_API_KEY is True:
if user_message:
tokens = await self.calculate_tokens(user_message)
tokens = calculate_tokens(user_message, self.system_message)
if not tokens > 1024:
command = ctx.invoked_with
user_message = ctx.message.content.split(
Expand All @@ -266,6 +188,31 @@ async def initiateChat(
user_message = user_message.replace(
member.mention, member.display_name
)
urls = get_urls_in_message(user_message)
if urls:
base64_images = []
for url in urls:
try:
base64_images.extend(get_image_as_base64(url))
except (
UnsupportedImageFormatError,
FileSizeLimitError,
) as e:
await ctx.reply(
str(e),
silent=True,
)
return

user_role = {
"role": "user",
"content": [
user_message,
*map(lambda x: {"image": x}, base64_images),
],
}
else:
user_role = {"role": "user", "content": user_message}
conversation_history = await self.database_read(ctx, shared_chat)
async with ctx.typing():
try:
Expand All @@ -276,18 +223,18 @@ async def initiateChat(
{
"role": "system",
"content": self.system_message.format(
user=ctx.author.name,
user=ctx.author.id,
version=".".join(
str(misc_helper.get_git_commit_count())
),
),
},
*conversation_history,
{"role": "user", "content": user_message},
user_role,
]

response = await self.openai.chat.completions.create(
model="gpt-3.5-turbo",
model="gpt-4o",
messages=messages,
temperature=1,
max_tokens=512,
Expand Down Expand Up @@ -318,9 +265,7 @@ async def initiateChat(

username = ctx.author.name
user_id = ctx.author.id
server_id, server_name = await self.get_server_id_and_name(
ctx
)
server_id, server_name = get_server_id_and_name(ctx)

self.historylogger.info(
f"[User]: {username} ({user_id}) [Server]: {server_name} ({server_id}) [Message]: {user_message} [History]: {messages}."
Expand All @@ -329,9 +274,7 @@ async def initiateChat(
f"[User]: {username} ({user_id}) [Server]: {server_name} ({server_id}) [Message]: {user_message}. [Response]: {chat_response}. [Tokens]: {response.usage.total_tokens} tokens used in total."
)
if len(chat_response) > 2000:
chat_response = await self.split_message_by_sentence(
chat_response
)
chat_response = split_message_by_sentence(chat_response)
for message in chat_response:
await ctx.reply(message, silent=True)
else:
Expand Down Expand Up @@ -393,7 +336,7 @@ async def initiatechatview(self, ctx: commands.Context, shared_chat: bool):
messages.append(f"{message['role'].title()}: {message['content']}")
message = "\n".join(messages)
if len(message) > 2000:
message = await self.split_message_by_sentence(message)
message = split_message_by_sentence(message)
for msg in message:
await ctx.reply(msg, silent=True)
else:
Expand All @@ -404,10 +347,10 @@ async def initiatechatview(self, ctx: commands.Context, shared_chat: bool):
@commands.command(
aliases=["sharedchat", "sharedtalk", "schat", "sc"],
brief="Talk to QuantumKat in a shared chat.",
description="Talk to QuantumKat in a chat shared with all users, using the OpenAI API/ChatGPT. Is not shared between servers.",
description=f"Talk to QuantumKat in a chat shared with all users, using the OpenAI API/ChatGPT. Is not shared between servers. URLs of images and gifs are supported and will be analyzed by the AI. File size limit is 20MB and only {', '.join(SUPPORTED_IMAGE_FORMATS)} are supported",
)
@requires_tos_acceptance
async def SharedChat(self, ctx: commands.Context, *, user_message=""):
async def SharedChat(self, ctx: commands.Context, *, user_message: str):
"""
Initiates a shared chat session with the bot.
Expand All @@ -423,10 +366,10 @@ async def SharedChat(self, ctx: commands.Context, *, user_message=""):
@commands.command(
aliases=["chat", "talk", "c"],
brief="Talk to QuantumKat.",
description="Talk to QuantumKat using the OpenAI API/ChatGPT. Each user has their own chat history. Is not shared between servers.",
description=f"Talk to QuantumKat using the OpenAI API/ChatGPT. Each user has their own chat history. Is not shared between servers. URLs of images and gifs are supported and will be analyzed by the AI. File size limit is 20MB and only {', '.join(SUPPORTED_IMAGE_FORMATS)} are supported.",
)
@requires_tos_acceptance
async def Chat(self, ctx: commands.Context, *, user_message=""):
async def Chat(self, ctx: commands.Context, *, user_message: str):
"""
Initiates a user-separated chat session with the bot.
Expand Down Expand Up @@ -538,7 +481,7 @@ async def ChatStatus(self, ctx: commands.Context):

if self.session_key:
try:
usage = await self.get_usage()
usage = get_usage(self.session_key)
if usage:
messages.append(
"OpenAI API key usage: {:.2f}$ of tokens used this month.".format(
Expand Down
38 changes: 2 additions & 36 deletions cogs/Entanglement.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from alembic.util.exc import CommandError
from sqlalchemy.exc import OperationalError

import mimetypes
import magic
import discord
import ast
import astunparse
Expand All @@ -30,6 +28,7 @@

from sql.database import AsyncSessionLocal
from sql import crud, schemas
from cogs.utils.utils import get_file_type

from QuantumKat import log_helper, misc_helper

Expand Down Expand Up @@ -95,39 +94,6 @@ def extract_function_code(self, filename, function_name):
if isinstance(node, ast.FunctionDef) and node.name == function_name:
return astunparse.unparse(node.body)

async def get_mime_type(self, mime_type: str) -> str:
"""
Returns the file extension corresponding to the given MIME type.
Parameters:
- mime_type (str): The MIME type for which to determine the file extension.
Returns:
- str: The file extension corresponding to the given MIME type.
"""
return mimetypes.guess_extension(mime_type)

async def get_file_type(self, ctx: commands.Context, filename: str) -> str:
"""
Retrieves the file type of a given filename.
Parameters:
- ctx (commands.Context): The context of the command.
- filename (str): The name of the file.
Returns:
- str: The file extension of the given file.
"""
mime = magic.Magic(mime=True)
try:
mime_type = mime.from_file(filename)
except OSError as e:
await ctx.reply(f"Error getting file type: {e}", silent=True)
self.logger.error("Error getting file type", exc_info=True)
return None
file_extension = await self.get_mime_type(mime_type)
return file_extension

async def parameter_kind_to_string(self, parameter: Parameter) -> str:
"""
Converts the parameter kind to a string representation.
Expand Down Expand Up @@ -499,7 +465,7 @@ async def quantize(
quantizer.write(block)

if not Path(filename).suffix:
file_extension = await self.get_file_type(
file_extension = get_file_type(
ctx, str(Path(data_dir, filename))
)
if not file_extension:
Expand Down

0 comments on commit 378f88d

Please sign in to comment.