From bcf82c34dc10d1a6ff280d728d77761b3ab63804 Mon Sep 17 00:00:00 2001 From: nova <110734810+novanai@users.noreply.github.com> Date: Wed, 15 Jan 2025 00:00:51 +0000 Subject: [PATCH] Add workflow for formatting & format codebase (#63) --- .github/workflows/format.yml | 34 ++++++++++++++++ noxfile.py | 29 ++++++++++++++ pyproject.toml | 32 +++++++++++++++ requirements.txt | 2 +- requirements_dev.txt | 5 ++- src/bot.py | 7 ++-- src/examples/commands.py | 3 +- src/examples/components.py | 12 ++++-- src/examples/modals.py | 3 +- src/examples/options.py | 3 +- src/extensions/action_items.py | 71 +++++++++++++++++++--------------- src/extensions/agenda.py | 46 ++++++++++++++-------- src/extensions/boosts.py | 5 ++- src/extensions/figlet.py | 1 - src/extensions/fortune.py | 2 +- src/extensions/gerry.py | 3 +- src/extensions/help.py | 13 ++++--- src/extensions/uptime.py | 12 +++--- src/extensions/user_roles.py | 16 +++++--- src/hooks.py | 13 ++++--- src/utils.py | 14 +++++-- 21 files changed, 237 insertions(+), 89 deletions(-) create mode 100644 .github/workflows/format.yml create mode 100644 noxfile.py diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 0000000..0ef6534 --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,34 @@ +name: Format & Type Check + +on: [push, pull_request] + +jobs: + formatting: + runs-on: ubuntu-latest + name: "Check code style" + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Run ruff via nox + run: | + python -m pip install nox + python -m nox -s format_check + + pyright: + runs-on: ubuntu-latest + name: "Type checking" + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Run pyright via nox + run: | + python -m pip install nox + python -m nox -s pyright diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 0000000..9e478fb --- /dev/null +++ b/noxfile.py @@ -0,0 +1,29 @@ +import os + +import nox +from nox import options + +PROJECT_PATH = os.path.join(".", "src") +SCRIPT_PATHS = [PROJECT_PATH, "noxfile.py"] + +options.sessions = ["format_fix", "pyright"] + + +@nox.session() +def format_fix(session: nox.Session) -> None: + session.install("-r", "requirements_dev.txt") + session.run("python", "-m", "ruff", "format", *SCRIPT_PATHS) + session.run("python", "-m", "ruff", "check", *SCRIPT_PATHS, "--fix") + + +@nox.session() +def format_check(session: nox.Session) -> None: + session.install("-r", "requirements_dev.txt") + session.run("python", "-m", "ruff", "format", *SCRIPT_PATHS, "--check") + session.run("python", "-m", "ruff", "check", *SCRIPT_PATHS) + + +@nox.session() +def pyright(session: nox.Session) -> None: + session.install("-r", "requirements_dev.txt", "-r", "requirements.txt") + session.run("pyright", *SCRIPT_PATHS) diff --git a/pyproject.toml b/pyproject.toml index 37df71c..9a8f80e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,5 +5,37 @@ target-version = "py312" docstring-code-format = true line-ending = "lf" +[tool.ruff.lint] +select = [ + "F", # Pyflakes + "E", # Error (pycodestyle) + "W", # Warning (pycodestyle) + "I", # isort + "N", # pep8-naming + "ANN", # flake8-annotations + "ASYNC", # flake8-async + "A", # flake8-builtins + "COM", # flake8-commas + "C4", # flake8-comprehensions + "DTZ", # flake8-datetimez + "ICN", # flake8-import-conventions + "Q", # flake8-quotes + "RET", # flake8-return + "SIM", # flake8-simplify + "TID", # flake8-tidy-imports + "TC", # flake8-type-checking + "ARG", # flake8-unused-arguments + "ERA", # eradicate + "PL", # Pylint + "PERF", # Perflint + "RUF", # Ruff-specific rules +] +ignore = [ + "E501", # line-too-long + "PLR2004", # magic-value-comparison + "PLR0913", # too-many-arguments + "COM812", # missing-trailing-comma +] + [tool.ruff.lint.pydocstyle] convention = "numpy" diff --git a/requirements.txt b/requirements.txt index b9ad254..ee0d9ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aiohttp==3.10.10 +aiohttp==3.11.11 fortune-python==1.1.1 hikari==2.1.0 hikari-arc==1.4.0 diff --git a/requirements_dev.txt b/requirements_dev.txt index 28ccc66..1ef6e48 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,2 +1,5 @@ -ruff==0.7.3 +nox==2024.10.9 +ruff==0.9.1 pre-commit==4.0.1 +pyright==1.1.391 + diff --git a/src/bot.py b/src/bot.py index e24b875..18f01dd 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,8 +1,8 @@ import logging +import aiohttp import arc import hikari -import aiohttp import miru from src.config import DEBUG, TOKEN @@ -28,7 +28,7 @@ @client.listen(hikari.StartingEvent) -async def on_start(event: hikari.StartingEvent) -> None: +async def on_start(_: hikari.StartingEvent) -> None: # Create an aiohttp ClientSession to use for web requests aiohttp_client = aiohttp.ClientSession() client.set_type_dependency(aiohttp.ClientSession, aiohttp_client) @@ -39,7 +39,8 @@ async def on_start(event: hikari.StartingEvent) -> None: # so dependency injection must be enabled manually for this event listener @client.inject_dependencies async def on_stop( - event: hikari.StoppedEvent, aiohttp_client: aiohttp.ClientSession = arc.inject() + _: hikari.StoppedEvent, + aiohttp_client: aiohttp.ClientSession = arc.inject(), ) -> None: await aiohttp_client.close() diff --git a/src/examples/commands.py b/src/examples/commands.py index 4103d81..4f957b5 100644 --- a/src/examples/commands.py +++ b/src/examples/commands.py @@ -11,7 +11,8 @@ async def hello(ctx: arc.GatewayContext) -> None: group = plugin.include_slash_group( - "base_group", "A base command group, with sub groups and sub commands." + "base_group", + "A base command group, with sub groups and sub commands.", ) diff --git a/src/examples/components.py b/src/examples/components.py index 21994e6..5979837 100644 --- a/src/examples/components.py +++ b/src/examples/components.py @@ -12,7 +12,7 @@ def __init__(self, user_id: int) -> None: super().__init__(timeout=60) @miru.button("Click me!", custom_id="click_me") - async def click_button(self, ctx: miru.ViewContext, button: miru.Button) -> None: + async def click_button(self, ctx: miru.ViewContext, _: miru.Button) -> None: await ctx.respond(f"{ctx.user.mention}, you clicked me!") # Defining select menus: https://miru.hypergonial.com/guides/selects/ @@ -35,7 +35,9 @@ async def click_button(self, ctx: miru.ViewContext, button: miru.Button) -> None ], ) async def colour_select( - self, ctx: miru.ViewContext, select: miru.TextSelect + self, + ctx: miru.ViewContext, + select: miru.TextSelect, ) -> None: await ctx.respond(f"Your favourite colours are: {', '.join(select.values)}!") @@ -46,7 +48,8 @@ async def view_check(self, ctx: miru.ViewContext) -> bool: # For every other user they will receive an error message. if ctx.user.id != self.user_id: await ctx.respond( - "You can't press this!", flags=hikari.MessageFlag.EPHEMERAL + "You can't press this!", + flags=hikari.MessageFlag.EPHEMERAL, ) return False @@ -69,7 +72,8 @@ async def on_timeout(self) -> None: @plugin.include @arc.slash_command("components", "A command with components.") async def components_cmd( - ctx: arc.GatewayContext, miru_client: miru.Client = arc.inject() + ctx: arc.GatewayContext, + miru_client: miru.Client = arc.inject(), ) -> None: view = View(ctx.user.id) response = await ctx.respond("Here are some components...", components=view) diff --git a/src/examples/modals.py b/src/examples/modals.py index aff22eb..ccce656 100644 --- a/src/examples/modals.py +++ b/src/examples/modals.py @@ -31,7 +31,8 @@ async def callback(self, ctx: miru.ModalContext) -> None: @plugin.include @arc.slash_command("modal", "A command with a modal response.") async def modal_command( - ctx: arc.GatewayContext, miru_client: miru.Client = arc.inject() + ctx: arc.GatewayContext, + miru_client: miru.Client = arc.inject(), ) -> None: modal = MyModal() builder = modal.build_response(miru_client) diff --git a/src/examples/options.py b/src/examples/options.py index af91dac..a9d08ca 100644 --- a/src/examples/options.py +++ b/src/examples/options.py @@ -11,7 +11,8 @@ async def options( ctx: arc.GatewayContext, str_option: arc.Option[str, arc.StrParams("A string option.", name="string")], int_option: arc.Option[ - int, arc.IntParams("An integer option.", name="integer", min=5, max=150) + int, + arc.IntParams("An integer option.", name="integer", min=5, max=150), ], attachment_option: arc.Option[ hikari.Attachment, diff --git a/src/extensions/action_items.py b/src/extensions/action_items.py index a972ca5..eecb1ed 100644 --- a/src/extensions/action_items.py +++ b/src/extensions/action_items.py @@ -1,13 +1,13 @@ -import arc -import hikari import re -import aiohttp from urllib.parse import urlparse -from src.utils import role_mention, hedgedoc_login -from src.hooks import restrict_to_channels, restrict_to_roles -from src.config import CHANNEL_IDS, ROLE_IDS, UID_MAPS +import aiohttp +import arc +import hikari +from src.config import CHANNEL_IDS, ROLE_IDS, UID_MAPS +from src.hooks import restrict_to_channels, restrict_to_roles +from src.utils import hedgedoc_login, role_mention action_items = arc.GatewayPlugin(name="Action Items") @@ -54,7 +54,9 @@ async def get_action_items( # extract the action items section from the minutes action_items_section = re.search( - r"# Action Items:?\n(.*?)(\n# |\n---|$)", content, re.DOTALL + r"# Action Items:?\n(.*?)(\n# |\n---|$)", + content, + re.DOTALL, ) if not action_items_section: @@ -78,13 +80,13 @@ async def get_action_items( # Replace user names with user mentions for i, item in enumerate(formatted_bullet_points): for name, uid in UID_MAPS.items(): - item = item.replace(f"`{name}`", f"<@{uid}>") + item = item.replace(f"`{name}`", f"<@{uid}>") # noqa: PLW2901 formatted_bullet_points[i] = item # Replace role names with role mentions for i, item in enumerate(formatted_bullet_points): for role, role_id in ROLE_IDS.items(): - item = item.replace(f"`{role}`", role_mention(role_id)) + item = item.replace(f"`{role}`", role_mention(role_id)) # noqa: PLW2901 formatted_bullet_points[i] = item # Send title to the action-items channel @@ -95,7 +97,7 @@ async def get_action_items( # send each bullet point separately for item in formatted_bullet_points: - item = await action_items.client.rest.create_message( + message = await action_items.client.rest.create_message( CHANNEL_IDS["action-items"], mentions_everyone=False, user_mentions=True, @@ -104,8 +106,8 @@ async def get_action_items( ) await action_items.client.rest.add_reaction( - channel=item.channel_id, - message=item.id, + channel=message.channel_id, + message=message.id, emoji="✅", ) @@ -136,15 +138,14 @@ async def check_valid_reaction( assert message.author # it will always be available - # ignore messages not sent by the bot and messages with no content - if message.author.id != bot_user.id or not message.content: - return False - - return True + # verify it's a message sent by the bot and has content + return message.author.id == bot_user.id and message.content is not None async def validate_user_reaction( - user_id: int, message_content: str, guild_id: int + user_id: int, + message_content: str, + guild_id: int, ) -> bool: # extract user and role mentions from the message content mention_regex = r"<@[!&]?(\d+)>" @@ -153,26 +154,27 @@ async def validate_user_reaction( # make a list of all mentions mentioned_ids = [int(id_) for id_ in mentions] + # user is mentioned if user_id in mentioned_ids: return True member = action_items.client.cache.get_member( - guild_id, user_id + guild_id, + user_id, ) or await action_items.client.rest.fetch_member(guild_id, user_id) - if any(role_id in mentioned_ids for role_id in member.role_ids): - return True - - return False + # user's role is mentioned + return any(role_id in mentioned_ids for role_id in member.role_ids) @action_items.listen() async def reaction_add(event: hikari.GuildReactionAddEvent) -> None: # retrieve the message that was reacted to message = action_items.client.cache.get_message( - event.message_id + event.message_id, ) or await action_items.client.rest.fetch_message( - event.channel_id, event.message_id + event.channel_id, + event.message_id, ) is_valid_reaction = await check_valid_reaction(event, message) @@ -182,7 +184,9 @@ async def reaction_add(event: hikari.GuildReactionAddEvent) -> None: assert message.content # check_valid_reaction verifies the message content exists is_valid_reaction = await validate_user_reaction( - event.user_id, message.content, event.guild_id + event.user_id, + message.content, + event.guild_id, ) if not is_valid_reaction: return @@ -192,7 +196,9 @@ async def reaction_add(event: hikari.GuildReactionAddEvent) -> None: # add strikethrough and checkmark updated_content = f"- ✅ ~~{message.content[2:]}~~" await action_items.client.rest.edit_message( - event.channel_id, event.message_id, content=updated_content + event.channel_id, + event.message_id, + content=updated_content, ) @@ -201,7 +207,8 @@ async def reaction_remove(event: hikari.GuildReactionDeleteEvent) -> None: # retrieve the message that was un-reacted to # NOTE: cannot use cached message as the reaction count will be outdated message = await action_items.client.rest.fetch_message( - event.channel_id, event.message_id + event.channel_id, + event.message_id, ) is_valid_reaction = await check_valid_reaction(event, message) @@ -225,8 +232,8 @@ async def reaction_remove(event: hikari.GuildReactionDeleteEvent) -> None: filter( lambda r: r is True, reactions, - ) - ) + ), + ), ) assert message.content # check_valid_reaction verifies the message content exists @@ -236,7 +243,9 @@ async def reaction_remove(event: hikari.GuildReactionDeleteEvent) -> None: # add strikethrough and checkmark updated_content = f"- {message.content[6:-2]}" await action_items.client.rest.edit_message( - event.channel_id, event.message_id, content=updated_content + event.channel_id, + event.message_id, + content=updated_content, ) diff --git a/src/extensions/agenda.py b/src/extensions/agenda.py index 3b6f066..5ae3c58 100644 --- a/src/extensions/agenda.py +++ b/src/extensions/agenda.py @@ -1,13 +1,13 @@ +import datetime +from urllib.parse import urlparse + +import aiohttp import arc import hikari -import aiohttp -from urllib.parse import urlparse -import datetime -from src.utils import role_mention, hedgedoc_login +from src.config import AGENDA_TEMPLATE_URL, CHANNEL_IDS, ROLE_IDS, UID_MAPS from src.hooks import restrict_to_channels, restrict_to_roles -from src.config import CHANNEL_IDS, ROLE_IDS, UID_MAPS, AGENDA_TEMPLATE_URL - +from src.utils import hedgedoc_login, role_mention, utcnow plugin = arc.GatewayPlugin(name="Agenda") @@ -15,10 +15,10 @@ async def generate_date_choices( - data: arc.AutocompleteData[arc.GatewayClient, str], + _: arc.AutocompleteData[arc.GatewayClient, str], ) -> list[str]: """Generate date options for the next 7 days.""" - today = datetime.date.today() + today = utcnow().today() return [ (today + datetime.timedelta(days=i)).strftime("%A %d/%m/%Y") for i in range(7) ] @@ -31,7 +31,7 @@ def generate_time_choices() -> list[str]: for hour in range(24): current_time = ( - datetime.datetime.combine(datetime.date.today(), base_time) + datetime.datetime.combine(utcnow().today(), base_time) + datetime.timedelta(hours=hour) ).time() times.append(current_time.strftime("%H:%M")) @@ -46,8 +46,8 @@ def generate_time_choices() -> list[str]: CHANNEL_IDS["bots-cmt"], CHANNEL_IDS["committee-announcements"], CHANNEL_IDS["cowboys-and-cowgirls-committee"], - ] - ) + ], + ), ) @arc.with_hook(restrict_to_roles(role_ids=[ROLE_IDS["committee"]])) @arc.slash_subcommand( @@ -64,7 +64,8 @@ async def gen_agenda( time: arc.Option[ str, arc.StrParams( - "Enter the time in HH:MM format.", choices=generate_time_choices() + "Enter the time in HH:MM format.", + choices=generate_time_choices(), ), ], room: arc.Option[ @@ -75,14 +76,23 @@ async def gen_agenda( str | None, arc.StrParams("Optional note to be included in the announcement.") ] = None, url: arc.Option[ - str, arc.StrParams("URL of the agenda template from the MD") + str, + arc.StrParams("URL of the agenda template from the MD"), ] = AGENDA_TEMPLATE_URL, aiohttp_client: aiohttp.ClientSession = arc.inject(), ) -> None: """Generate a new agenda for committee meetings.""" - parsed_date = datetime.datetime.strptime(date, "%A %d/%m/%Y").date() - parsed_time = datetime.datetime.strptime(time, "%H:%M").time() + parsed_date = ( + datetime.datetime.strptime(date, "%A %d/%m/%Y") + .replace(tzinfo=datetime.timezone.utc) + .date() + ) + parsed_time = ( + datetime.datetime.strptime(time, "%H:%M") + .replace(tzinfo=datetime.timezone.utc) + .time() + ) parsed_datetime = datetime.datetime.combine(parsed_date, parsed_time) @@ -115,7 +125,9 @@ async def gen_agenda( content = await response.text() modified_content = content.format( - DATE=formatted_date, TIME=formatted_time, ROOM=room + DATE=formatted_date, + TIME=formatted_time, + ROOM=room, ) post_url = f"{parsed_url.scheme}://{parsed_url.hostname}/new" @@ -190,7 +202,7 @@ async def view_template( colour=0x5865F2, ) embed = embed.set_image( - "https://cdn.redbrick.dcu.ie/hedgedoc-uploads/sonic-the-hedgedoc.png" + "https://cdn.redbrick.dcu.ie/hedgedoc-uploads/sonic-the-hedgedoc.png", ) await ctx.respond( diff --git a/src/extensions/boosts.py b/src/extensions/boosts.py index 6222c48..b7ad3bb 100644 --- a/src/extensions/boosts.py +++ b/src/extensions/boosts.py @@ -12,8 +12,9 @@ hikari.MessageType.USER_PREMIUM_GUILD_SUBSCRIPTION_TIER_3, ] -BOOST_MESSAGE_TYPES: list[hikari.MessageType] = BOOST_TIERS + [ - hikari.MessageType.USER_PREMIUM_GUILD_SUBSCRIPTION +BOOST_MESSAGE_TYPES: list[hikari.MessageType] = [ + *BOOST_TIERS, + hikari.MessageType.USER_PREMIUM_GUILD_SUBSCRIPTION, ] diff --git a/src/extensions/figlet.py b/src/extensions/figlet.py index c0c6f3c..b143996 100644 --- a/src/extensions/figlet.py +++ b/src/extensions/figlet.py @@ -1,6 +1,5 @@ import arc import hikari - from pyfiglet import Figlet plugin = arc.GatewayPlugin(name="figlet") diff --git a/src/extensions/fortune.py b/src/extensions/fortune.py index 8f581c6..da8eef7 100644 --- a/src/extensions/fortune.py +++ b/src/extensions/fortune.py @@ -9,7 +9,7 @@ @arc.slash_command("fortune", "Send a user a random Fortune!") async def fortune_command( ctx: arc.GatewayContext, - user: arc.Option[hikari.User, arc.UserParams("A user")] = None, + user: arc.Option[hikari.User | None, arc.UserParams("A user")] = None, ) -> None: """Send a random Fortune!""" diff --git a/src/extensions/gerry.py b/src/extensions/gerry.py index c619020..406abc8 100644 --- a/src/extensions/gerry.py +++ b/src/extensions/gerry.py @@ -11,7 +11,8 @@ async def gerry_command( ctx: arc.GatewayContext, user: arc.Option[ - hikari.User, arc.UserParams("The user to send a gerry to.") + hikari.User | None, + arc.UserParams("The user to send a gerry to."), ] = None, ) -> None: """Send a gerry!""" diff --git a/src/extensions/help.py b/src/extensions/help.py index 6f4302f..2fbf1c4 100644 --- a/src/extensions/help.py +++ b/src/extensions/help.py @@ -1,7 +1,8 @@ -import hikari -import arc -import itertools import collections +import itertools + +import arc +import hikari plugin = arc.GatewayPlugin(name="Help Command Plugin") @@ -18,7 +19,7 @@ def gather_commands() -> dict[str | None, list[str]]: continue plugin_commands[plugin_.name if plugin_ else None].append( - f"{cmd.make_mention()} - {cmd.description}" + f"{cmd.make_mention()} - {cmd.description}", ) return plugin_commands @@ -34,7 +35,9 @@ async def help_command(ctx: arc.GatewayContext) -> None: for plugin_, commands in plugin_commands.items(): embed.add_field( - name=plugin_ or "No plugin", value="\n".join(commands), inline=False + name=plugin_ or "No plugin", + value="\n".join(commands), + inline=False, ) await ctx.respond(embed=embed) diff --git a/src/extensions/uptime.py b/src/extensions/uptime.py index 5afa95a..f75df48 100644 --- a/src/extensions/uptime.py +++ b/src/extensions/uptime.py @@ -1,8 +1,8 @@ -from datetime import datetime - import arc -start_time = datetime.now() +from src.utils import utcnow + +start_time = utcnow() plugin = arc.GatewayPlugin("Blockbot Uptime") @@ -10,16 +10,16 @@ @plugin.include @arc.slash_command("uptime", "Show formatted uptime of Blockbot") async def uptime(ctx: arc.GatewayContext) -> None: - up_time = datetime.now() - start_time + up_time = utcnow() - start_time d = up_time.days h, ms = divmod(up_time.seconds, 3600) m, s = divmod(ms, 60) - def format(val: int, s: str): + def format_time(val: int, s: str) -> str: return f"{val} {s}{'s' if val != 1 else ''}" message_parts = [(d, "day"), (h, "hour"), (m, "minute"), (s, "second")] - formatted_parts = [format(val, str) for val, str in message_parts if val] + formatted_parts = [format_time(val, text) for val, text in message_parts if val] await ctx.respond(f"Uptime: **{', '.join(formatted_parts)}**") diff --git a/src/extensions/user_roles.py b/src/extensions/user_roles.py index b3130b7..e86ef13 100644 --- a/src/extensions/user_roles.py +++ b/src/extensions/user_roles.py @@ -1,9 +1,8 @@ import arc import hikari -from src.utils import role_mention - from src.config import ASSIGNABLE_ROLES +from src.utils import role_mention plugin = arc.GatewayPlugin(name="User Roles") @@ -33,7 +32,10 @@ async def add_role( return await ctx.client.rest.add_role_to_member( - ctx.guild_id, ctx.author, int(role), reason="Self-service role." + ctx.guild_id, + ctx.author, + int(role), + reason="Self-service role.", ) await ctx.respond( f"Done! Added {role_mention(role)} to your roles.", @@ -59,7 +61,10 @@ async def remove_role( return await ctx.client.rest.remove_role_from_member( - ctx.guild_id, ctx.author, int(role), reason=f"{ctx.author} removed role." + ctx.guild_id, + ctx.author, + int(role), + reason=f"{ctx.author} removed role.", ) await ctx.respond( f"Done! Removed {role_mention(role)} from your roles.", @@ -81,7 +86,8 @@ async def role_error_handler(ctx: arc.GatewayContext, exc: Exception) -> None: if isinstance(exc, hikari.NotFoundError): await ctx.respond( - "❌ Blockbot can't find that role.", flags=hikari.MessageFlag.EPHEMERAL + "❌ Blockbot can't find that role.", + flags=hikari.MessageFlag.EPHEMERAL, ) return diff --git a/src/hooks.py b/src/hooks.py index 5cbc949..cada893 100644 --- a/src/hooks.py +++ b/src/hooks.py @@ -1,10 +1,12 @@ +import typing + import arc import hikari -import typing async def _restrict_to_roles( - ctx: arc.GatewayContext, role_ids: typing.Sequence[hikari.Snowflake] + ctx: arc.GatewayContext, + role_ids: typing.Sequence[int], ) -> arc.HookResult: assert ctx.member @@ -20,7 +22,7 @@ async def _restrict_to_roles( # TODO: make response type a TypeVar for reuse (WrappedHookResult) def restrict_to_roles( - role_ids: typing.Sequence[hikari.Snowflake], + role_ids: typing.Sequence[int], ) -> typing.Callable[[arc.GatewayContext], typing.Awaitable[arc.HookResult]]: """Any command which uses this hook requires that the command be disabled in DMs as a guild role is required for this hook to function.""" @@ -31,7 +33,8 @@ async def func(ctx: arc.GatewayContext) -> arc.HookResult: async def _restrict_to_channels( - ctx: arc.GatewayContext, channel_ids: typing.Sequence[hikari.Snowflake] + ctx: arc.GatewayContext, + channel_ids: typing.Sequence[int], ) -> arc.HookResult: if ctx.channel_id not in channel_ids: await ctx.respond( @@ -44,7 +47,7 @@ async def _restrict_to_channels( def restrict_to_channels( - channel_ids: typing.Sequence[hikari.Snowflake], + channel_ids: typing.Sequence[int], ) -> typing.Callable[[arc.GatewayContext], typing.Awaitable[arc.HookResult]]: async def func(ctx: arc.GatewayContext) -> arc.HookResult: return await _restrict_to_channels(ctx, channel_ids) diff --git a/src/utils.py b/src/utils.py index b014819..7cb87d9 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,11 +1,15 @@ +import datetime + +import aiohttp import hikari from arc import GatewayClient -import aiohttp -from src.config import LDAP_USERNAME, LDAP_PASSWORD + +from src.config import LDAP_PASSWORD, LDAP_USERNAME async def get_guild( - client: GatewayClient, event: hikari.GuildMessageCreateEvent + client: GatewayClient, + event: hikari.GuildMessageCreateEvent, ) -> hikari.GatewayGuild | hikari.RESTGuild: return event.get_guild() or await client.rest.fetch_guild(event.guild_id) @@ -21,3 +25,7 @@ async def hedgedoc_login(aiohttp_client: aiohttp.ClientSession) -> None: } await aiohttp_client.post("https://md.redbrick.dcu.ie/auth/ldap", data=data) + + +def utcnow() -> datetime.datetime: + return datetime.datetime.now(datetime.timezone.utc)