Skip to content

Commit 8a02bc9

Browse files
Merge pull request #357 from BCurbs/type-checking
Oops forgot about this
2 parents 17079ad + 10a3409 commit 8a02bc9

31 files changed

+736
-468
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,7 @@ venv/
1919
*~
2020

2121
#VSCode
22-
settings.json
22+
settings.json
23+
24+
#for mypy
25+
/.mypy_cache/

dozer/__main__.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
"""Initializes the bot and deals with the configuration file"""
22

3+
import asyncio
34
import json
45
import os
56
import sys
6-
import asyncio
77

88
import discord
99
import sentry_sdk
1010

1111
from .db import db_init, db_migrate
1212

13-
from . import db
14-
1513
config = {
1614
'prefix': '&', 'developers': [],
1715
'cache_size': 20000,
@@ -61,8 +59,8 @@
6159
json.dump(config, f, indent='\t')
6260

6361
if config['sentry_url'] != "":
64-
sentry_sdk.init(
65-
config['sentry_url'],
62+
sentry_sdk.init( # pylint: disable=abstract-class-instantiated # noqa: E0110
63+
str(config['sentry_url']),
6664
traces_sample_rate=1.0,
6765
)
6866

@@ -78,11 +76,10 @@
7876

7977
intents = discord.Intents.default()
8078
intents.members = True
81-
intents.presences = config['presences_intents']
79+
intents.presences = bool(config['presences_intents'])
8280

8381
bot = Dozer(config, intents=intents, max_messages=config['cache_size'])
8482

85-
8683
for ext in os.listdir('dozer/cogs'):
8784
if not ext.startswith(('_', '.')):
8885
bot.load_extension('dozer.cogs.' + ext[:-3]) # Remove '.py'

dozer/bot.py

+24-23
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import re
55
import sys
66
import traceback
7+
from typing import Pattern
78

89
import discord
910
from discord.ext import commands
@@ -12,6 +13,7 @@
1213

1314
from . import utils
1415
from .cogs import _utils
16+
from .context import DozerContext
1517

1618
DOZER_LOGGER = logging.getLogger('dozer')
1719
DOZER_LOGGER.level = logging.INFO
@@ -36,19 +38,11 @@ class InvalidContext(commands.CheckFailure):
3638
"""
3739

3840

39-
class DozerContext(commands.Context):
40-
"""Cleans all messages before sending"""
41-
async def send(self, content=None, **kwargs): # pylint: disable=arguments-differ
42-
if content is not None:
43-
content = utils.clean(self, content, mass=True, member=False, role=False, channel=False)
44-
return await super().send(content, **kwargs)
45-
46-
4741
class Dozer(commands.Bot):
4842
"""Botty things that are critical to Dozer working"""
4943
_global_cooldown = commands.Cooldown(1, 1, commands.BucketType.user) # One command per second per user
5044

51-
def __init__(self, config, *args, **kwargs):
45+
def __init__(self, config: dict, *args, **kwargs):
5246
self.dynamic_prefix = _utils.PrefixHandler(config['prefix'])
5347
super().__init__(command_prefix=self.dynamic_prefix.handler, *args, **kwargs)
5448
self.slash = SlashCommand(self, sync_commands=True, override_type=True)
@@ -78,45 +72,52 @@ async def on_ready(self):
7872
DOZER_LOGGER.warning("You are running an older version of the discord.py rewrite (with breaking changes)! "
7973
"To upgrade, run `pip install -r requirements.txt --upgrade`")
8074

81-
async def get_context(self, message, *, cls=DozerContext):
75+
async def get_context(self, message: discord.Message, *, cls=DozerContext):
8276
ctx = await super().get_context(message, cls=cls)
8377
return ctx
8478

85-
async def on_command_error(self, context, exception):
79+
async def on_command_error(self, context: DozerContext, exception):
8680
if isinstance(exception, commands.NoPrivateMessage):
8781
await context.send('{}, This command cannot be used in DMs.'.format(context.author.mention))
8882
elif isinstance(exception, commands.UserInputError):
8983
await context.send('{}, {}'.format(context.author.mention, self.format_error(context, exception)))
9084
elif isinstance(exception, commands.NotOwner):
9185
await context.send('{}, {}'.format(context.author.mention, exception.args[0]))
9286
elif isinstance(exception, commands.MissingPermissions):
93-
permission_names = [name.replace('guild', 'server').replace('_', ' ').title() for name in exception.missing_perms]
87+
permission_names = [name.replace('guild', 'server').replace('_', ' ').title() for name in
88+
exception.missing_perms]
9489
await context.send('{}, you need {} permissions to run this command!'.format(
9590
context.author.mention, utils.pretty_concat(permission_names)))
9691
elif isinstance(exception, commands.BotMissingPermissions):
97-
permission_names = [name.replace('guild', 'server').replace('_', ' ').title() for name in exception.missing_perms]
92+
permission_names = [name.replace('guild', 'server').replace('_', ' ').title() for name in
93+
exception.missing_perms]
9894
await context.send('{}, I need {} permissions to run this command!'.format(
9995
context.author.mention, utils.pretty_concat(permission_names)))
10096
elif isinstance(exception, commands.CommandOnCooldown):
10197
await context.send(
102-
'{}, That command is on cooldown! Try again in {:.2f}s!'.format(context.author.mention, exception.retry_after))
98+
'{}, That command is on cooldown! Try again in {:.2f}s!'.format(context.author.mention,
99+
exception.retry_after))
103100
elif isinstance(exception, commands.MaxConcurrencyReached):
104-
types = {discord.ext.commands.BucketType.default: "`Global`", discord.ext.commands.BucketType.guild: "`Guild`",
105-
discord.ext.commands.BucketType.channel: "`Channel`", discord.ext.commands.BucketType.category: "`Category`",
101+
types = {discord.ext.commands.BucketType.default: "`Global`",
102+
discord.ext.commands.BucketType.guild: "`Guild`",
103+
discord.ext.commands.BucketType.channel: "`Channel`",
104+
discord.ext.commands.BucketType.category: "`Category`",
106105
discord.ext.commands.BucketType.member: "`Member`", discord.ext.commands.BucketType.user: "`User`"}
107106
await context.send(
108107
'{}, That command has exceeded the max {} concurrency limit of `{}` instance! Please try again later.'.format(
109108
context.author.mention, types[exception.per], exception.number))
110109
elif isinstance(exception, (commands.CommandNotFound, InvalidContext)):
111110
pass # Silent ignore
112111
else:
113-
await context.send('```\n%s\n```' % ''.join(traceback.format_exception_only(type(exception), exception)).strip())
112+
await context.send(
113+
'```\n%s\n```' % ''.join(traceback.format_exception_only(type(exception), exception)).strip())
114114
if isinstance(context.channel, discord.TextChannel):
115115
DOZER_LOGGER.error('Error in command <%d> (%d.name!r(%d.id) %d(%d.id) %d(%d.id) %d)',
116116
context.command, context.guild, context.guild, context.channel, context.channel,
117117
context.author, context.author, context.message.content)
118118
else:
119-
DOZER_LOGGER.error('Error in command <%d> (DM %d(%d.id) %d)', context.command, context.channel.recipient,
119+
DOZER_LOGGER.error('Error in command <%d> (DM %d(%d.id) %d)', context.command,
120+
context.channel.recipient,
120121
context.channel.recipient, context.message.content)
121122
DOZER_LOGGER.error(''.join(traceback.format_exception(type(exception), exception, exception.__traceback__)))
122123

@@ -126,12 +127,12 @@ async def on_error(self, event_method, *args, **kwargs):
126127
traceback.print_exc()
127128
capture_exception()
128129

129-
async def on_slash_command_error(self, ctx, ex):
130+
async def on_slash_command_error(self, ctx: DozerContext, ex: Exception):
130131
"""Passes slash command errors to primary command handler"""
131132
await self.on_command_error(ctx, ex)
132133

133134
@staticmethod
134-
def format_error(ctx, err, *, word_re=re.compile('[A-Z][a-z]+')):
135+
def format_error(ctx: DozerContext, err: Exception, *, word_re: Pattern = re.compile('[A-Z][a-z]+')):
135136
"""Turns an exception into a user-friendly (or -friendlier, at least) error message."""
136137
type_words = word_re.findall(type(err).__name__)
137138
type_msg = ' '.join(map(str.lower, type_words))
@@ -141,12 +142,12 @@ def format_error(ctx, err, *, word_re=re.compile('[A-Z][a-z]+')):
141142
else:
142143
return type_msg
143144

144-
def global_checks(self, ctx):
145+
def global_checks(self, ctx: DozerContext):
145146
"""Checks that should be executed before passed to the command"""
146147
if ctx.author.bot:
147148
raise InvalidContext('Bots cannot run commands!')
148149
retry_after = self._global_cooldown.update_rate_limit()
149-
if retry_after and not hasattr(ctx, "is_pseudo"): # bypass ratelimit for su'ed commands
150+
if retry_after and not hasattr(ctx, "is_pseudo"): # bypass ratelimit for su'ed commands
150151
raise InvalidContext('Global rate-limit exceeded!')
151152
return True
152153

@@ -155,7 +156,7 @@ def run(self, *args, **kwargs):
155156
del self.config['discord_token'] # Prevent token dumping
156157
super().run(token)
157158

158-
async def shutdown(self, restart=False):
159+
async def shutdown(self, restart: bool = False):
159160
"""Shuts down the bot"""
160161
self._restarting = restart
161162
await self.logout()

dozer/cogs/_utils.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
import logging
55
import typing
66
from collections.abc import Mapping
7+
from typing import Dict, Union
78

89
import discord
910
from discord.ext import commands
1011

1112
from dozer import db
13+
from dozer.context import DozerContext
1214

13-
__all__ = ['bot_has_permissions', 'command', 'group', 'Cog', 'Reactor', 'Paginator', 'paginate', 'chunk', 'dev_check', 'DynamicPrefixEntry']
15+
__all__ = ['bot_has_permissions', 'command', 'group', 'Cog', 'Reactor', 'Paginator', 'paginate', 'chunk', 'dev_check',
16+
'DynamicPrefixEntry']
1417

1518
DOZER_LOGGER = logging.getLogger("dozer")
1619

@@ -81,15 +84,15 @@ def group(**kwargs):
8184
class Cog(commands.Cog):
8285
"""Initiates cogs."""
8386

84-
def __init__(self, bot):
87+
def __init__(self, bot: commands.Bot):
8588
super().__init__()
8689
self.bot = bot
8790

8891

8992
def dev_check():
9093
"""Function decorator to check that the calling user is a developer"""
9194

92-
async def predicate(ctx):
95+
async def predicate(ctx: DozerContext):
9396
if ctx.author.id not in ctx.bot.config['developers']:
9497
raise commands.NotOwner('you are not a developer!')
9598
return True
@@ -118,7 +121,7 @@ class Reactor:
118121
"""
119122
_stop_reaction = object()
120123

121-
def __init__(self, ctx, initial_reactions, *, auto_remove=True, timeout=60):
124+
def __init__(self, ctx: DozerContext, initial_reactions, *, auto_remove: bool = True, timeout: int = 60):
122125
"""
123126
ctx: command context
124127
initial_reactions: iterable of emoji to react with on start
@@ -142,7 +145,8 @@ async def __aiter__(self):
142145
await self.message.add_reaction(emoji)
143146
while True:
144147
try:
145-
reaction, reacting_member = await self.bot.wait_for('reaction_add', check=self._check_reaction, timeout=self.timeout)
148+
reaction, reacting_member = await self.bot.wait_for('reaction_add', check=self._check_reaction,
149+
timeout=self.timeout)
146150
except asyncio.TimeoutError:
147151
break
148152

@@ -170,8 +174,9 @@ def stop(self):
170174
"""Listener for stop reactions."""
171175
self._action = self._stop_reaction
172176

173-
def _check_reaction(self, reaction, member):
174-
return reaction.message.id == self.message.id and member.id == self.caller.id
177+
def _check_reaction(self, reaction: discord.Reaction, member: discord.Member):
178+
if self.message is not None:
179+
return reaction.message.id == self.message.id and member.id == self.caller.id
175180

176181

177182
class Paginator(Reactor):
@@ -198,7 +203,8 @@ class Paginator(Reactor):
198203
'\N{BLACK SQUARE FOR STOP}' # :stop_button:
199204
)
200205

201-
def __init__(self, ctx, initial_reactions, pages, *, start=0, auto_remove=True, timeout=60):
206+
def __init__(self, ctx: DozerContext, initial_reactions, pages, *, start: int = 0, auto_remove: bool = True,
207+
timeout: int = 60):
202208
all_reactions = list(initial_reactions)
203209
ind = all_reactions.index(Ellipsis)
204210
all_reactions[ind:ind + 1] = self.pagination_reactions
@@ -232,40 +238,41 @@ async def __aiter__(self):
232238
else: # Only valid option left is 4
233239
self.stop()
234240

235-
def go_to_page(self, page):
241+
def go_to_page(self, page: Union[int, str]):
236242
"""Goes to a specific help page"""
237243
if isinstance(page, int):
238244
page = page % self.len_pages
239245
if page < 0:
240246
page += self.len_pages
241247
self.page = page
242-
self.do(self.message.edit(embed=self.pages[self.page]))
248+
if self.message is not None:
249+
self.do(self.message.edit(embed=self.pages[self.page]))
243250

244-
def next(self, amt=1):
251+
def next(self, amt: int = 1):
245252
"""Goes to the next help page"""
246253
if isinstance(self.page, int):
247254
self.go_to_page(self.page + amt)
248255
else:
249256
self.go_to_page(amt - 1)
250257

251-
def prev(self, amt=1):
258+
def prev(self, amt: int = 1):
252259
"""Goes to the previous help page"""
253260
if isinstance(self.page, int):
254261
self.go_to_page(self.page - amt)
255262
else:
256263
self.go_to_page(-amt)
257264

258265

259-
async def paginate(ctx, pages, *, start=0, auto_remove=True, timeout=60):
266+
async def paginate(ctx: DozerContext, pages, *, start: int = 0, auto_remove: bool = True, timeout: int = 60):
260267
"""
261268
Simple pagination based on Paginator. Pagination is handled normally and other reactions are ignored.
262269
"""
263-
paginator = Paginator(ctx, (...,), pages, start=start, auto_remove=auto_remove, timeout=timeout)
270+
paginator = Paginator(ctx, ..., pages, start=start, auto_remove=auto_remove, timeout=timeout)
264271
async for reaction in paginator:
265272
pass # The normal pagination reactions are handled - just drop anything else
266273

267274

268-
def chunk(iterable, size):
275+
def chunk(iterable, size: int):
269276
"""
270277
Break an iterable into chunks of a fixed size. Returns an iterable of iterables.
271278
Almost-inverse of itertools.chain.from_iterable - passing the output of this into that function will reconstruct the original iterable.
@@ -279,7 +286,7 @@ def chunk(iterable, size):
279286
def bot_has_permissions(**required):
280287
"""Decorator to check if bot has certain permissions when added to a command"""
281288

282-
def predicate(ctx):
289+
def predicate(ctx: DozerContext):
283290
"""Function to tell the bot if it has the right permissions"""
284291
given = ctx.channel.permissions_for((ctx.guild or ctx.channel).me)
285292
missing = [name for name, value in required.items() if getattr(given, name) != value]
@@ -309,11 +316,11 @@ def decorator(func):
309316
class PrefixHandler:
310317
"""Handles dynamic prefixes"""
311318

312-
def __init__(self, default_prefix):
319+
def __init__(self, default_prefix: str):
313320
self.default_prefix = default_prefix
314-
self.prefix_cache = {}
321+
self.prefix_cache: Dict[int, DynamicPrefixEntry] = {}
315322

316-
def handler(self, bot, message):
323+
def handler(self, bot, message: discord.Message):
317324
"""Process the dynamic prefix for each message"""
318325
dynamic = self.prefix_cache.get(message.guild.id) if message.guild else self.default_prefix
319326
# <@!> is a nickname mention which discord.py doesn't make by default
@@ -330,7 +337,7 @@ async def refresh(self):
330337
class DynamicPrefixEntry(db.DatabaseTable):
331338
"""Holds the custom prefixes for guilds"""
332339
__tablename__ = 'dynamic_prefixes'
333-
__uniques__ = 'guild_id'
340+
__uniques__ = ['guild_id']
334341

335342
@classmethod
336343
async def initial_create(cls):
@@ -343,7 +350,7 @@ async def initial_create(cls):
343350
PRIMARY KEY (guild_id)
344351
)""")
345352

346-
def __init__(self, guild_id, prefix):
353+
def __init__(self, guild_id: int, prefix: str):
347354
super().__init__()
348355
self.guild_id = guild_id
349356
self.prefix = prefix

0 commit comments

Comments
 (0)