Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
LobaDK committed May 18, 2024
2 parents a543070 + fdd8d4e commit 6d05982
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 23 deletions.
28 changes: 15 additions & 13 deletions QuantumKat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ async def init_models():
await conn.run_sync(models.Base.metadata.create_all)


# run(init_models())

log_helper = LogHelper()
misc_helper = MiscHelper()
discord_helper = DiscordHelper()
Expand Down Expand Up @@ -92,7 +90,7 @@ async def setup(bot: commands.Bot):


async def is_authenticated(ctx: commands.Context) -> bool:
if ctx.author.id == int(OWNER_ID):
if ctx.author.id in bot.owner_ids:
return True
if not ctx.command.name.casefold() == "request_auth":
authenticated_server_ids = await crud.get_authenticated_servers(
Expand Down Expand Up @@ -120,7 +118,7 @@ async def is_authenticated(ctx: commands.Context) -> bool:


async def is_reboot_scheduled(ctx: commands.Context) -> bool:
if ctx.author.id == int(OWNER_ID):
if ctx.author.id in bot.owner_ids:
return True
if bot.reboot_scheduled:
await ctx.reply(
Expand All @@ -132,26 +130,28 @@ async def is_reboot_scheduled(ctx: commands.Context) -> bool:


async def is_banned(ctx: commands.Context) -> bool:
if ctx.author.id == int(OWNER_ID):
if ctx.author.id in bot.owner_ids:
return True
user = await crud.get_user(
AsyncSessionLocal, schemas.User.Get(user_id=ctx.author.id)
)
server = await crud.get_server(
AsyncSessionLocal, schemas.Server.Get(server_id=ctx.guild.id)
)
if user and user.is_banned:
await ctx.reply(
"You have been banned from using QuantumKat. Please contact the bot owner for more information.",
silent=True,
)
return False
if server and server.is_banned:
await ctx.reply(
"This server has been banned from using QuantumKat. Please contact the bot owner for more information.",
silent=True,

if not discord_helper.is_dm(ctx):
server = await crud.get_server(
AsyncSessionLocal, schemas.Server.Get(server_id=ctx.guild.id)
)
return False
if server and server.is_banned:
await ctx.reply(
"This server has been banned from using QuantumKat. Please contact the bot owner for more information.",
silent=True,
)
return False
return True


Expand All @@ -166,6 +166,8 @@ async def on_guild_join(guild):

@bot.event
async def on_ready():
# await init_models()

# Add all servers the bot is in to the database on startup in case the bot was added while offline
for guild in bot.guilds:
try:
Expand Down
14 changes: 9 additions & 5 deletions cogs/Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from QuantumKat import log_helper, misc_helper, discord_helper

TOKEN_LIMIT = 1024*2

class Chat(commands.Cog):
def __init__(self, bot: commands.Bot):
Expand Down Expand Up @@ -179,8 +180,11 @@ async def initiateChat(
"""
if self.FOUND_API_KEY is True:
if user_message:
tokens = calculate_tokens(user_message, self.system_message)
if not tokens > 1024:
system_message = self.system_message
if ctx.message.reference:
system_message += f"The user has included this message in their response, which was written by {ctx.message.reference.resolved.author.display_name}: `{ctx.message.reference.resolved.content}`. Use it as context for the response."
tokens = calculate_tokens(user_message, system_message)
if not tokens > TOKEN_LIMIT:
command = ctx.invoked_with
user_message = ctx.message.content.split(
f"{self.bot.command_prefix}{command}", 1
Expand Down Expand Up @@ -212,7 +216,7 @@ async def initiateChat(
for attachment in ctx.message.attachments:
try:
base64_images.extend(
get_image_as_base64(attachment.url)
get_image_as_base64(await attachment.read())
)
except (
UnsupportedImageFormatError,
Expand Down Expand Up @@ -246,7 +250,7 @@ async def initiateChat(
messages = [
{
"role": "system",
"content": self.system_message.format(
"content": system_message.format(
user=ctx.author.id,
version=".".join(
str(misc_helper.get_git_commit_count())
Expand Down Expand Up @@ -313,7 +317,7 @@ async def initiateChat(
)
else:
await ctx.reply(
f"Message is too long! Your message is {tokens} tokens long, but the maximum is 1024 tokens.",
f"Message is too long! Your message is {tokens} tokens long, but the maximum is {TOKEN_LIMIT} tokens.",
silent=True,
)
else:
Expand Down
6 changes: 1 addition & 5 deletions cogs/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def header_mime_type(self) -> str:
str: The MIME type from the 'Content-Type' header, or None if not present.
"""
try:
return self.header["Content-Type"]
return self.header["Content-Type"].split(";")[0]
except KeyError:
return None

Expand Down Expand Up @@ -189,10 +189,6 @@ def strip_embed_disabler(url: str) -> str:
return url.replace("<", "").replace(">", "")


# TODO: It shouldn't necessarily be added here, but add a function to detect and get files attached to a ctx object (ctx.message.attachments)
# TODO: While we're at it with the above, add the ability (unsure if function is necessary) for the bot to include the message a user is replying to i.e. they initiated the chat command while replying to a message
# TODO: Improve file type detection and minimize the amount of data being fetched before the checks (check the Content-Type header of the response)
# TODO: Add some fallbacks where we may still download the file if the Content-Type header is not available, but only the first 1 KB or so
def get_image_as_base64(url_or_byte_stream: str | bytes) -> list[str]:
"""
Converts an image from a URL or byte stream into a base64 encoded string.
Expand Down

0 comments on commit 6d05982

Please sign in to comment.