From 378f88dd1bab67e46293329addaa88d406ee2fdf Mon Sep 17 00:00:00 2001 From: "Nicklas H." Date: Wed, 15 May 2024 13:54:28 +0200 Subject: [PATCH] Changed cogs to use the utils module functions instead --- cogs/Chat.py | 161 ++++++++++++++----------------------------- cogs/Entanglement.py | 38 +--------- 2 files changed, 54 insertions(+), 145 deletions(-) diff --git a/cogs/Chat.py b/cogs/Chat.py index 8142440..36adb17 100644 --- a/cogs/Chat.py +++ b/cogs/Chat.py @@ -1,21 +1,26 @@ from openai import AsyncOpenAI as OpenAI, OpenAIError -import tiktoken import os import requests -import datetime -import calendar -import sys from discord.ext import commands from sql import database from sql import crud, schemas from decorators import requires_tos_acceptance +from cogs.utils.utils import ( + get_urls_in_message, + calculate_tokens, + get_usage, + split_message_by_sentence, + get_server_id_and_name, + get_image_as_base64, + UnsupportedImageFormatError, + FileSizeLimitError, + SUPPORTED_IMAGE_FORMATS, +) from QuantumKat import log_helper, misc_helper, discord_helper -sys.path.append(".") - class Chat(commands.Cog): def __init__(self, bot: commands.Bot): @@ -33,8 +38,6 @@ def __init__(self, bot: commands.Bot): ) ) - # Set the model encoding for tiktoken - self.encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") # Set the system message for ChatGPT. This controls the personality of the bot and how it responds. self.system_message = """You are QuantumKat, a cybernetic cat with void-like black fur. You still retain the bodily parts and functions of a normal cat. @@ -44,7 +47,7 @@ def __init__(self, bot: commands.Bot): however, you are still young and don't quite know how to do it correctly, as well as being clumsy, often resulting in you appearing in an incorrect location or doing the wrong thing. You are quite sarcastic. You are allowed to have opinions on things. Many people interact with you, and it's over the chat program called Discord, so you should never exceed 1950 characters in a response. - Currently, {user} is talking to you. + You are currently talking to {user}. You are currently running on version {version}. Avoid repeating yourself.""" # Attempt to get the OpenAI API key from the environment variables @@ -65,22 +68,6 @@ def __init__(self, bot: commands.Bot): "OpenAI Session key not found. Chatstatus command will not work." ) - async def calculate_tokens(self, user_message: str) -> int: - """ - Calculates the number of tokens in a given user message. - - Parameters: - - user_message (str): The user message to calculate tokens for. - - Returns: - - int: The number of tokens in the user message. - """ - messages = [user_message, self.system_message] - tokens = 0 - for message in messages: - tokens += len(self.encoding.encode(message)) - return tokens - async def database_add( self, ctx: commands.Context, @@ -100,7 +87,7 @@ async def database_add( Returns: None """ - server_id, server_name = await self.get_server_id_and_name(ctx) + server_id, server_name = get_server_id_and_name(ctx) if server_name == "DM": if not await crud.check_server_exists( database.AsyncSessionLocal, schemas.Server.Get(server_id=server_id) @@ -133,7 +120,7 @@ async def database_read(self, ctx: commands.Context, shared_chat: bool) -> list: list: A list of dictionaries containing the user and assistant messages. """ - server_id, _ = await self.get_server_id_and_name(ctx) + server_id, _ = get_server_id_and_name(ctx) if shared_chat: result = await crud.get_shared_chats_for_server( database.AsyncSessionLocal, @@ -164,7 +151,7 @@ async def database_remove(self, ctx: commands.Context, shared_chat: bool): Returns: None """ - server_id, server_name = await self.get_server_id_and_name(ctx) + server_id, server_name = get_server_id_and_name(ctx) if shared_chat: await crud.delete_shared_chat( database.AsyncSessionLocal, schemas.Chat.Delete(server_id=server_id) @@ -175,71 +162,6 @@ async def database_remove(self, ctx: commands.Context, shared_chat: bool): schemas.Chat.Delete(server_id=server_id, user_id=ctx.author.id), ) - async def get_usage(self) -> dict: - """ - Retrieves the usage statistics for the OpenAI API key. - - Returns: - dict: A dictionary containing the usage statistics for the OpenAI API key. - """ - month = datetime.datetime.now().month - month = f"{month:02}" - year = datetime.datetime.now().year - last_day = calendar.monthrange(year, int(month))[1] - response = requests.get( - f"https://api.openai.com/dashboard/billing/usage?end_date={year}-{month}-{last_day}&start_date={year}-{month}-01", - headers={"Authorization": f"Bearer {self.session_key}"}, - ) - response.raise_for_status() - return response.json() - - async def get_server_id_and_name(self, ctx: commands.Context) -> tuple: - """ - Retrieves the server ID and name from the context object. - - Args: - ctx (commands.Context): The context object representing the invocation context of the command. - - Returns: - tuple: A tuple containing the server ID and name. - """ - if not discord_helper.is_dm(ctx): - server_id = ctx.guild.id - server_name = ctx.guild.name - else: - server_id = ctx.channel.id - server_name = "DM" - return server_id, server_name - - async def split_message_by_sentence(self, message: str) -> list: - """ - Splits a given message by sentence, into multiple messages with a maximum length of 2000 characters. - - Args: - message (str): The message to be split into sentences. - - Returns: - list: A list of sentences, each with a maximum length of 2000 characters. - """ - sentences = message.split(". ") - current_length = 0 - messages = [] - current_message = "" - - for sentence in sentences: - if current_length + len(sentence) + 1 > 2000: # +1 for the period - messages.append(current_message) - current_length = 0 - current_message = "" - - current_message += sentence + ". " - current_length += len(sentence) + 1 - - if current_message: # Any leftover sentence - messages.append(current_message) - - return messages - async def initiateChat( self, ctx: commands.Context, user_message: str, shared_chat: bool ): @@ -256,7 +178,7 @@ async def initiateChat( """ if self.FOUND_API_KEY is True: if user_message: - tokens = await self.calculate_tokens(user_message) + tokens = calculate_tokens(user_message, self.system_message) if not tokens > 1024: command = ctx.invoked_with user_message = ctx.message.content.split( @@ -266,6 +188,31 @@ async def initiateChat( user_message = user_message.replace( member.mention, member.display_name ) + urls = get_urls_in_message(user_message) + if urls: + base64_images = [] + for url in urls: + try: + base64_images.extend(get_image_as_base64(url)) + except ( + UnsupportedImageFormatError, + FileSizeLimitError, + ) as e: + await ctx.reply( + str(e), + silent=True, + ) + return + + user_role = { + "role": "user", + "content": [ + user_message, + *map(lambda x: {"image": x}, base64_images), + ], + } + else: + user_role = {"role": "user", "content": user_message} conversation_history = await self.database_read(ctx, shared_chat) async with ctx.typing(): try: @@ -276,18 +223,18 @@ async def initiateChat( { "role": "system", "content": self.system_message.format( - user=ctx.author.name, + user=ctx.author.id, version=".".join( str(misc_helper.get_git_commit_count()) ), ), }, *conversation_history, - {"role": "user", "content": user_message}, + user_role, ] response = await self.openai.chat.completions.create( - model="gpt-3.5-turbo", + model="gpt-4o", messages=messages, temperature=1, max_tokens=512, @@ -318,9 +265,7 @@ async def initiateChat( username = ctx.author.name user_id = ctx.author.id - server_id, server_name = await self.get_server_id_and_name( - ctx - ) + server_id, server_name = get_server_id_and_name(ctx) self.historylogger.info( f"[User]: {username} ({user_id}) [Server]: {server_name} ({server_id}) [Message]: {user_message} [History]: {messages}." @@ -329,9 +274,7 @@ async def initiateChat( f"[User]: {username} ({user_id}) [Server]: {server_name} ({server_id}) [Message]: {user_message}. [Response]: {chat_response}. [Tokens]: {response.usage.total_tokens} tokens used in total." ) if len(chat_response) > 2000: - chat_response = await self.split_message_by_sentence( - chat_response - ) + chat_response = split_message_by_sentence(chat_response) for message in chat_response: await ctx.reply(message, silent=True) else: @@ -393,7 +336,7 @@ async def initiatechatview(self, ctx: commands.Context, shared_chat: bool): messages.append(f"{message['role'].title()}: {message['content']}") message = "\n".join(messages) if len(message) > 2000: - message = await self.split_message_by_sentence(message) + message = split_message_by_sentence(message) for msg in message: await ctx.reply(msg, silent=True) else: @@ -404,10 +347,10 @@ async def initiatechatview(self, ctx: commands.Context, shared_chat: bool): @commands.command( aliases=["sharedchat", "sharedtalk", "schat", "sc"], brief="Talk to QuantumKat in a shared chat.", - description="Talk to QuantumKat in a chat shared with all users, using the OpenAI API/ChatGPT. Is not shared between servers.", + description=f"Talk to QuantumKat in a chat shared with all users, using the OpenAI API/ChatGPT. Is not shared between servers. URLs of images and gifs are supported and will be analyzed by the AI. File size limit is 20MB and only {', '.join(SUPPORTED_IMAGE_FORMATS)} are supported", ) @requires_tos_acceptance - async def SharedChat(self, ctx: commands.Context, *, user_message=""): + async def SharedChat(self, ctx: commands.Context, *, user_message: str): """ Initiates a shared chat session with the bot. @@ -423,10 +366,10 @@ async def SharedChat(self, ctx: commands.Context, *, user_message=""): @commands.command( aliases=["chat", "talk", "c"], brief="Talk to QuantumKat.", - description="Talk to QuantumKat using the OpenAI API/ChatGPT. Each user has their own chat history. Is not shared between servers.", + description=f"Talk to QuantumKat using the OpenAI API/ChatGPT. Each user has their own chat history. Is not shared between servers. URLs of images and gifs are supported and will be analyzed by the AI. File size limit is 20MB and only {', '.join(SUPPORTED_IMAGE_FORMATS)} are supported.", ) @requires_tos_acceptance - async def Chat(self, ctx: commands.Context, *, user_message=""): + async def Chat(self, ctx: commands.Context, *, user_message: str): """ Initiates a user-separated chat session with the bot. @@ -538,7 +481,7 @@ async def ChatStatus(self, ctx: commands.Context): if self.session_key: try: - usage = await self.get_usage() + usage = get_usage(self.session_key) if usage: messages.append( "OpenAI API key usage: {:.2f}$ of tokens used this month.".format( diff --git a/cogs/Entanglement.py b/cogs/Entanglement.py index 4e1eead..ed702db 100644 --- a/cogs/Entanglement.py +++ b/cogs/Entanglement.py @@ -19,8 +19,6 @@ from alembic.util.exc import CommandError from sqlalchemy.exc import OperationalError -import mimetypes -import magic import discord import ast import astunparse @@ -30,6 +28,7 @@ from sql.database import AsyncSessionLocal from sql import crud, schemas +from cogs.utils.utils import get_file_type from QuantumKat import log_helper, misc_helper @@ -95,39 +94,6 @@ def extract_function_code(self, filename, function_name): if isinstance(node, ast.FunctionDef) and node.name == function_name: return astunparse.unparse(node.body) - async def get_mime_type(self, mime_type: str) -> str: - """ - Returns the file extension corresponding to the given MIME type. - - Parameters: - - mime_type (str): The MIME type for which to determine the file extension. - - Returns: - - str: The file extension corresponding to the given MIME type. - """ - return mimetypes.guess_extension(mime_type) - - async def get_file_type(self, ctx: commands.Context, filename: str) -> str: - """ - Retrieves the file type of a given filename. - - Parameters: - - ctx (commands.Context): The context of the command. - - filename (str): The name of the file. - - Returns: - - str: The file extension of the given file. - """ - mime = magic.Magic(mime=True) - try: - mime_type = mime.from_file(filename) - except OSError as e: - await ctx.reply(f"Error getting file type: {e}", silent=True) - self.logger.error("Error getting file type", exc_info=True) - return None - file_extension = await self.get_mime_type(mime_type) - return file_extension - async def parameter_kind_to_string(self, parameter: Parameter) -> str: """ Converts the parameter kind to a string representation. @@ -499,7 +465,7 @@ async def quantize( quantizer.write(block) if not Path(filename).suffix: - file_extension = await self.get_file_type( + file_extension = get_file_type( ctx, str(Path(data_dir, filename)) ) if not file_extension: