From 4102863f087acbfb42bb0e4cc180dcfcbb8e3386 Mon Sep 17 00:00:00 2001 From: Pierre Fersing Date: Mon, 1 Jan 2024 19:59:48 +0100 Subject: [PATCH] Add mypy types to client.py --- ChangeLog.txt | 10 + pyproject.toml | 9 + src/paho/mqtt/client.py | 1299 +++++++++++++--------- src/paho/mqtt/publish.py | 106 +- src/paho/mqtt/subscribeoptions.py | 8 +- tests/lib/clients/03-publish-b2c-qos1.py | 2 +- tests/test_client.py | 41 + tox.ini | 6 +- 8 files changed, 934 insertions(+), 547 deletions(-) diff --git a/ChangeLog.txt b/ChangeLog.txt index c86f40c4..4b2ebc3b 100644 --- a/ChangeLog.txt +++ b/ChangeLog.txt @@ -3,6 +3,16 @@ v2.0.0 - 2023-xx-xx - **BREAKING** Drop support for Python 2.7, Python 3.5 and Python 3.6 Minimum tested version is Python 3.7 +- **BREAKING** connnect_srv changed it signature to take an additional bind_port parameter. + This is a breaking change, but in previous version connect_srv was broken anyway. + Closes #493. +- Add types to Client class, which caused few change which should be compatible. + Known risk of breaking changes: + - Use enum for returned error code (like MQTT_ERR_SUCCESS). It use an IntEnum + which should be a drop-in replacement. Excepted if someone is doing "rc is 0" instead of "rc == 0". + - reason in on_connect callback when using MQTTv5 is now always a ReasonCode object. It used to possibly be + an integer with the value 132. + - MQTTMessage field "dup" and "retain" used to be integer with value 0 and 1. They are now boolean. - Add on_pre_connect() callback, which is called immediately before a connection attempt is made. diff --git a/pyproject.toml b/pyproject.toml index 20783582..70f457bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,15 @@ include = [ "src/paho", ] +[tool.mypy] + +[[tool.mypy.overrides]] +module = "paho.mqtt.client" +# check_untyped_defs = true +# disallow_untyped_calls = true +# disallow_incomplete_defs = true +disallow_untyped_defs = true + [tool.pytest.ini_options] addopts = ["-r", "xs"] testpaths = "tests src" diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index d0cd7a09..ed1d102a 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -19,6 +19,7 @@ import base64 import collections +import enum import errno import hashlib import logging @@ -30,6 +31,7 @@ import struct import threading import time +import typing import urllib.parse import urllib.request import uuid @@ -42,13 +44,13 @@ try: import ssl except ImportError: - ssl = None + ssl = None # type: ignore[assignment] try: - import socks + import socks # type: ignore[import-untyped] except ImportError: - socks = None + socks = None # type: ignore[assignment] try: @@ -66,7 +68,7 @@ if platform.system() == "Windows": - EAGAIN = errno.WSAEWOULDBLOCK + EAGAIN = errno.WSAEWOULDBLOCK # type: ignore[attr-defined] else: EAGAIN = errno.EAGAIN @@ -131,25 +133,49 @@ mqtt_ms_send_pubrec = 8 mqtt_ms_queued = 9 + # Error values -MQTT_ERR_AGAIN = -1 -MQTT_ERR_SUCCESS = 0 -MQTT_ERR_NOMEM = 1 -MQTT_ERR_PROTOCOL = 2 -MQTT_ERR_INVAL = 3 -MQTT_ERR_NO_CONN = 4 -MQTT_ERR_CONN_REFUSED = 5 -MQTT_ERR_NOT_FOUND = 6 -MQTT_ERR_CONN_LOST = 7 -MQTT_ERR_TLS = 8 -MQTT_ERR_PAYLOAD_SIZE = 9 -MQTT_ERR_NOT_SUPPORTED = 10 -MQTT_ERR_AUTH = 11 -MQTT_ERR_ACL_DENIED = 12 -MQTT_ERR_UNKNOWN = 13 -MQTT_ERR_ERRNO = 14 -MQTT_ERR_QUEUE_SIZE = 15 -MQTT_ERR_KEEPALIVE = 16 +class MQTTErrorCode(enum.IntEnum): + MQTT_ERR_AGAIN = -1 + MQTT_ERR_SUCCESS = 0 + MQTT_ERR_NOMEM = 1 + MQTT_ERR_PROTOCOL = 2 + MQTT_ERR_INVAL = 3 + MQTT_ERR_NO_CONN = 4 + MQTT_ERR_CONN_REFUSED = 5 + MQTT_ERR_NOT_FOUND = 6 + MQTT_ERR_CONN_LOST = 7 + MQTT_ERR_TLS = 8 + MQTT_ERR_PAYLOAD_SIZE = 9 + MQTT_ERR_NOT_SUPPORTED = 10 + MQTT_ERR_AUTH = 11 + MQTT_ERR_ACL_DENIED = 12 + MQTT_ERR_UNKNOWN = 13 + MQTT_ERR_ERRNO = 14 + MQTT_ERR_QUEUE_SIZE = 15 + MQTT_ERR_KEEPALIVE = 16 + + +# This probably do the same as @global_enum, but this decorator require Python 3.11 +MQTT_ERR_AGAIN = MQTTErrorCode.MQTT_ERR_AGAIN +MQTT_ERR_SUCCESS = MQTTErrorCode.MQTT_ERR_SUCCESS +MQTT_ERR_NOMEM = MQTTErrorCode.MQTT_ERR_NOMEM +MQTT_ERR_PROTOCOL = MQTTErrorCode.MQTT_ERR_PROTOCOL +MQTT_ERR_INVAL = MQTTErrorCode.MQTT_ERR_INVAL +MQTT_ERR_NO_CONN = MQTTErrorCode.MQTT_ERR_NO_CONN +MQTT_ERR_CONN_REFUSED = MQTTErrorCode.MQTT_ERR_CONN_REFUSED +MQTT_ERR_NOT_FOUND = MQTTErrorCode.MQTT_ERR_NOT_FOUND +MQTT_ERR_CONN_LOST = MQTTErrorCode.MQTT_ERR_CONN_LOST +MQTT_ERR_TLS = MQTTErrorCode.MQTT_ERR_TLS +MQTT_ERR_PAYLOAD_SIZE = MQTTErrorCode.MQTT_ERR_PAYLOAD_SIZE +MQTT_ERR_NOT_SUPPORTED = MQTTErrorCode.MQTT_ERR_NOT_SUPPORTED +MQTT_ERR_AUTH = MQTTErrorCode.MQTT_ERR_AUTH +MQTT_ERR_ACL_DENIED = MQTTErrorCode.MQTT_ERR_ACL_DENIED +MQTT_ERR_UNKNOWN = MQTTErrorCode.MQTT_ERR_UNKNOWN +MQTT_ERR_ERRNO = MQTTErrorCode.MQTT_ERR_ERRNO +MQTT_ERR_QUEUE_SIZE = MQTTErrorCode.MQTT_ERR_QUEUE_SIZE +MQTT_ERR_KEEPALIVE = MQTTErrorCode.MQTT_ERR_KEEPALIVE + MQTT_CLIENT = 0 MQTT_BRIDGE = 1 @@ -159,12 +185,75 @@ sockpair_data = b"0" +# Payload support all those type and will be converted to bytes: +# * str are utf8 encoded +# * int/float are converted to string and utf8 encoded (e.g. 1 is converted to b"1") +# * None is converted to a zero-length payload (i.e. b"") +PayloadType = typing.Union[str, bytes, bytearray, int, float, None] + +HTTPHeader = dict[str, str] +WebSocketHeaders = typing.Union[typing.Callable[[HTTPHeader], HTTPHeader], HTTPHeader] + +SocketLike = typing.Union[socket.socket, "ssl.SSLSocket", "WebsocketWrapper"] + + +CallbackOnConnect = typing.Union[ + typing.Callable[["Client", typing.Any, ReasonCodes, Properties], None], + typing.Callable[["Client", typing.Any, MQTTErrorCode], None], +] +CallbackOnConnectFail = typing.Callable[["Client", typing.Any], None] +CallbackOnDisconnect = typing.Union[ + typing.Callable[ + ["Client", typing.Any, dict[str, typing.Any], ReasonCodes, Properties], None + ], + typing.Callable[["Client", typing.Any, dict[str, typing.Any], MQTTErrorCode], None], +] +CallbackOnLog = typing.Callable[["Client", typing.Any, int, str], None] +CallbackOnMessage = typing.Callable[["Client", typing.Any, "MQTTMessage"], None] +CallbackOnPreConnect = typing.Callable[["Client", typing.Any], None] +CallbackOnPublish = typing.Callable[["Client", typing.Any, int], None] +CallbackOnSocket = typing.Callable[["Client", typing.Any, SocketLike], None] +CallbackOnSubscribe = typing.Union[ + typing.Callable[ + ["Client", typing.Any, Properties, list[ReasonCodes], Properties], None + ], + typing.Callable[["Client", typing.Any, int, tuple[int, ...]], None], +] +CallbackOnUnsubscribe = typing.Union[ + typing.Callable[["Client", typing.Any, Properties, ReasonCodes], None], + typing.Callable[["Client", typing.Any, int], None], +] + +# This is needed for typing because class Client redefined the name "socket" +_socket = socket + + +class _InPacket(typing.TypedDict): + command: int + have_remaining: int + remaining_count: list[int] + remaining_mult: int + remaining_length: int + packet: bytearray + to_process: int + pos: int + + +class _OutPacket(typing.TypedDict): + command: int + mid: int + qos: int + pos: int + to_process: int + packet: bytes + info: typing.Optional["MQTTMessageInfo"] + class WebsocketConnectionError(ValueError): pass -def error_string(mqtt_errno): +def error_string(mqtt_errno: MQTTErrorCode) -> str: """Return the error string associated with an mqtt error number.""" if mqtt_errno == MQTT_ERR_SUCCESS: return "No error." @@ -204,7 +293,7 @@ def error_string(mqtt_errno): return "Unknown error." -def connack_string(connack_code): +def connack_string(connack_code: int) -> str: """Return the string associated with a CONNACK result.""" if connack_code == CONNACK_ACCEPTED: return "Connection Accepted." @@ -222,7 +311,9 @@ def connack_string(connack_code): return "Connection Refused: unknown reason." -def base62(num, base=string.digits + string.ascii_letters, padding=1): +def base62( + num: int, base: str = string.digits + string.ascii_letters, padding: int = 1 +) -> str: """Convert a number to base-62 representation.""" if num < 0: raise ValueError("Number must be positive or zero") @@ -234,7 +325,7 @@ def base62(num, base=string.digits + string.ascii_letters, padding=1): return "".join(reversed(digits)) -def topic_matches_sub(sub, topic): +def topic_matches_sub(sub: str, topic: str) -> bool: """Check whether a topic matches a subscription. For example: @@ -251,7 +342,7 @@ def topic_matches_sub(sub, topic): return False -def _socketpair_compat(): +def _socketpair_compat() -> tuple[socket.socket, socket.socket]: """TCP/IP socketpair including Windows support""" listensock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) listensock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -260,13 +351,13 @@ def _socketpair_compat(): iface, port = listensock.getsockname() sock1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) - sock1.setblocking(0) + sock1.setblocking(False) try: sock1.connect(("127.0.0.1", port)) except BlockingIOError: pass sock2, address = listensock.accept() - sock2.setblocking(0) + sock2.setblocking(False) listensock.close() return (sock1, sock2) @@ -279,24 +370,24 @@ class MQTTMessageInfo: __slots__ = "mid", "_published", "_condition", "rc", "_iterpos" - def __init__(self, mid): + def __init__(self, mid: int): self.mid = mid self._published = False self._condition = threading.Condition() - self.rc = 0 + self.rc: MQTTErrorCode = MQTTErrorCode.MQTT_ERR_SUCCESS self._iterpos = 0 - def __str__(self): + def __str__(self) -> str: return str((self.rc, self.mid)) - def __iter__(self): + def __iter__(self) -> typing.Iterator[typing.Union[MQTTErrorCode, int]]: self._iterpos = 0 return self - def __next__(self): + def __next__(self) -> typing.Union[MQTTErrorCode, int]: return self.next() - def next(self): + def next(self) -> typing.Union[MQTTErrorCode, int]: if self._iterpos == 0: self._iterpos = 1 return self.rc @@ -306,7 +397,7 @@ def next(self): else: raise StopIteration - def __getitem__(self, index): + def __getitem__(self, index: int) -> typing.Union[MQTTErrorCode, int]: if index == 0: return self.rc elif index == 1: @@ -314,12 +405,12 @@ def __getitem__(self, index): else: raise IndexError("index out of range") - def _set_as_published(self): + def _set_as_published(self) -> None: with self._condition: self._published = True self._condition.notify() - def wait_for_publish(self, timeout=None): + def wait_for_publish(self, timeout: typing.Optional[float] = None) -> None: """Block until the message associated with this object is published, or until the timeout occurs. If timeout is None, this will never time out. Set timeout to a positive number of seconds, e.g. 1.2, to enable the @@ -341,19 +432,19 @@ def wait_for_publish(self, timeout=None): timeout_time = None if timeout is None else time_func() + timeout timeout_tenth = None if timeout is None else timeout / 10.0 - def timed_out(): - return False if timeout is None else time_func() > timeout_time + def timed_out() -> bool: + return False if timeout_time is None else time_func() > timeout_time with self._condition: while not self._published and not timed_out(): self._condition.wait(timeout_tenth) - def is_published(self): + def is_published(self) -> bool: """Returns True if the message associated with this object has been published, else returns False.""" - if self.rc == MQTT_ERR_QUEUE_SIZE: + if self.rc == MQTTErrorCode.MQTT_ERR_QUEUE_SIZE: raise ValueError("Message is not queued due to ERR_QUEUE_SIZE") - elif self.rc == MQTT_ERR_AGAIN: + elif self.rc == MQTTErrorCode.MQTT_ERR_AGAIN: pass elif self.rc > 0: raise RuntimeError(f"Message publish failed: {error_string(self.rc)}") @@ -389,8 +480,8 @@ class MQTTMessage: "properties", ) - def __init__(self, mid=0, topic=b""): - self.timestamp = 0 + def __init__(self, mid: int = 0, topic: bytes = b""): + self.timestamp = 0.0 self.state = mqtt_ms_invalid self.dup = False self.mid = mid @@ -399,23 +490,24 @@ def __init__(self, mid=0, topic=b""): self.qos = 0 self.retain = False self.info = MQTTMessageInfo(mid) + self.properties: typing.Optional[Properties] = None - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """Override the default Equals behavior""" if isinstance(other, self.__class__): return self.mid == other.mid return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: """Define a non-equality test""" return not self.__eq__(other) @property - def topic(self): + def topic(self) -> str: return self._topic.decode("utf-8") @topic.setter - def topic(self, value): + def topic(self, value: bytes) -> None: self._topic = value @@ -481,14 +573,14 @@ def on_connect(client, userdata, flags, rc): def __init__( self, - client_id="", - clean_session=None, - userdata=None, - protocol=MQTTv311, - transport="tcp", - reconnect_on_failure=True, - manual_ack=False, - ): + client_id: str = "", + clean_session: typing.Optional[bool] = None, + userdata: typing.Any = None, + protocol: int = MQTTv311, + transport: str = "tcp", + reconnect_on_failure: bool = True, + manual_ack: bool = False, + ) -> None: """client_id is the unique client id string used when connecting to the broker. If client_id is zero length or None, then the behaviour is defined by which protocol version is in use. If using MQTT v3.1.1, then @@ -538,11 +630,11 @@ def __init__( self._transport = transport.lower() self._protocol = protocol self._userdata = userdata - self._sock = None - self._sockpairR, self._sockpairW = ( - None, - None, - ) + self._sock: typing.Union[ + socket.socket, WebsocketWrapper, "ssl.SSLSocket", None + ] = None + self._sockpairR: typing.Optional[socket.socket] = None + self._sockpairW: typing.Optional[socket.socket] = None self._keepalive = 60 self._connect_timeout = 5.0 self._client_mode = MQTT_CLIENT @@ -562,43 +654,50 @@ def __init__( # [MQTT-3.1.3-4] Client Id must be UTF-8 encoded string. if client_id == "" or client_id is None: if protocol == MQTTv31: - self._client_id = base62(uuid.uuid4().int, padding=22) + self._client_id = base62(uuid.uuid4().int, padding=22).encode("utf8") else: self._client_id = b"" else: - self._client_id = client_id - if isinstance(self._client_id, str): - self._client_id = self._client_id.encode("utf-8") - - self._username = None - self._password = None - self._in_packet = { - "command": 0, - "have_remaining": 0, - "remaining_count": [], - "remaining_mult": 1, - "remaining_length": 0, - "packet": bytearray(b""), - "to_process": 0, - "pos": 0, - } - self._out_packet = collections.deque() + if isinstance(client_id, str): + self._client_id = client_id.encode("utf-8") + else: + self._client_id = client_id + + self._username: typing.Optional[bytes] = None + self._password: typing.Optional[bytes] = None + self._in_packet = _InPacket( + { + "command": 0, + "have_remaining": 0, + "remaining_count": [], + "remaining_mult": 1, + "remaining_length": 0, + "packet": bytearray(b""), + "to_process": 0, + "pos": 0, + } + ) + self._out_packet: collections.deque[_OutPacket] = collections.deque() self._last_msg_in = time_func() self._last_msg_out = time_func() self._reconnect_min_delay = 1 self._reconnect_max_delay = 120 - self._reconnect_delay = None + self._reconnect_delay: typing.Optional[int] = None self._reconnect_on_failure = reconnect_on_failure - self._ping_t = 0 + self._ping_t = 0.0 self._last_mid = 0 self._state = mqtt_cs_new - self._out_messages = collections.OrderedDict() - self._in_messages = collections.OrderedDict() + self._out_messages: collections.OrderedDict[ + int, MQTTMessage + ] = collections.OrderedDict() + self._in_messages: collections.OrderedDict[ + int, MQTTMessage + ] = collections.OrderedDict() self._max_inflight_messages = 20 self._inflight_messages = 0 self._max_queued_messages = 0 - self._connect_properties = None - self._will_properties = None + self._connect_properties: typing.Optional[Properties] = None + self._will_properties: typing.Optional[Properties] = None self._will = False self._will_topic = b"" self._will_payload = b"" @@ -609,7 +708,7 @@ def __init__( self._port = 1883 self._bind_address = "" self._bind_port = 0 - self._proxy = {} + self._proxy: typing.Any = {} self._in_callback_mutex = threading.Lock() self._callback_mutex = threading.RLock() self._msgtime_mutex = threading.Lock() @@ -617,38 +716,41 @@ def __init__( self._in_message_mutex = threading.Lock() self._reconnect_delay_mutex = threading.Lock() self._mid_generate_mutex = threading.Lock() - self._thread = None + self._thread: typing.Optional[threading.Thread] = None self._thread_terminate = False self._ssl = False - self._ssl_context = None + self._ssl_context: typing.Optional["ssl.SSLContext"] = None # Only used when SSL context does not have check_hostname attribute self._tls_insecure = False - self._logger = None + self._logger: typing.Optional[logging.Logger] = None self._registered_write = False # No default callbacks - self._on_log = None - self._on_pre_connect = None - self._on_connect = None - self._on_connect_fail = None - self._on_subscribe = None - self._on_message = None - self._on_publish = None - self._on_unsubscribe = None - self._on_disconnect = None - self._on_socket_open = None - self._on_socket_close = None - self._on_socket_register_write = None - self._on_socket_unregister_write = None + self._on_log: typing.Optional[CallbackOnLog] = None + self._on_pre_connect: typing.Optional[CallbackOnPreConnect] = None + self._on_connect: typing.Optional[CallbackOnConnect] = None + self._on_connect_fail: typing.Optional[CallbackOnConnectFail] = None + self._on_subscribe: typing.Optional[CallbackOnSubscribe] = None + self._on_message: typing.Optional[CallbackOnMessage] = None + self._on_publish: typing.Optional[CallbackOnPublish] = None + self._on_unsubscribe: typing.Optional[CallbackOnUnsubscribe] = None + self._on_disconnect: typing.Optional[CallbackOnDisconnect] = None + self._on_socket_open: typing.Optional[CallbackOnSocket] = None + self._on_socket_close: typing.Optional[CallbackOnSocket] = None + self._on_socket_register_write: typing.Optional[CallbackOnSocket] = None + self._on_socket_unregister_write: typing.Optional[CallbackOnSocket] = None self._websocket_path = "/mqtt" - self._websocket_extra_headers = None + self._websocket_extra_headers: typing.Optional[WebSocketHeaders] = None # for clean_start == MQTT_CLEAN_START_FIRST_ONLY self._mqttv5_first_connect = True self.suppress_exceptions = False # For callbacks - def __del__(self): + def __del__(self) -> None: self._reset_sockets() - def _sock_recv(self, bufsize): + def _sock_recv(self, bufsize: int) -> bytes: + if self._sock is None: + raise ConnectionError("self._sock is None") + try: return self._sock.recv(bufsize) except ssl.SSLWantReadError as err: @@ -660,7 +762,10 @@ def _sock_recv(self, bufsize): self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err) raise ConnectionError() from err - def _sock_send(self, buf): + def _sock_send(self, buf: bytes) -> int: + if self._sock is None: + raise ConnectionError("self._sock is None") + try: return self._sock.send(buf) except ssl.SSLWantReadError as err: @@ -672,7 +777,7 @@ def _sock_send(self, buf): self._call_socket_register_write() raise BlockingIOError() from err - def _sock_close(self): + def _sock_close(self) -> None: """Close the connection to the server.""" if not self._sock: return @@ -686,7 +791,7 @@ def _sock_close(self): # In case a callback fails, still close the socket to avoid leaking the file descriptor. sock.close() - def _reset_sockets(self, sockpair_only=False): + def _reset_sockets(self, sockpair_only: bool = False) -> None: if not sockpair_only: self._sock_close() @@ -697,12 +802,19 @@ def _reset_sockets(self, sockpair_only=False): self._sockpairW.close() self._sockpairW = None - def reinitialise(self, client_id="", clean_session=True, userdata=None): + def reinitialise( + self, + client_id: str = "", + clean_session: bool = True, + userdata: typing.Any = None, + ) -> None: self._reset_sockets() - self.__init__(client_id, clean_session, userdata) + self.__init__(client_id, clean_session, userdata) # type: ignore[misc] - def ws_set_options(self, path="/mqtt", headers=None): + def ws_set_options( + self, path: str = "/mqtt", headers: typing.Optional[WebSocketHeaders] = None + ) -> None: """Set the path and headers for a websocket connection path is a string starting with / which should be the endpoint of the @@ -723,7 +835,9 @@ def ws_set_options(self, path="/mqtt", headers=None): "'headers' option to ws_set_options has to be either a dictionary or callable" ) - def tls_set_context(self, context=None): + def tls_set_context( + self, context: typing.Optional["ssl.SSLContext"] = None + ) -> None: """Configure network encryption and authentication context. Enables SSL/TLS support. context : an ssl.SSLContext object. By default this is given by @@ -745,14 +859,14 @@ def tls_set_context(self, context=None): def tls_set( self, - ca_certs=None, - certfile=None, - keyfile=None, - cert_reqs=None, - tls_version=None, - ciphers=None, - keyfile_password=None, - ): + ca_certs: typing.Optional[str] = None, + certfile: typing.Optional[str] = None, + keyfile: typing.Optional[str] = None, + cert_reqs: typing.Optional["ssl.VerifyMode"] = None, + tls_version: typing.Optional[int] = None, + ciphers: typing.Optional[str] = None, + keyfile_password: typing.Optional[str] = None, + ) -> None: """Configure network encryption and authentication options. Enables SSL/TLS support. ca_certs : a string path to the Certificate Authority certificate files @@ -842,7 +956,7 @@ def tls_set( # But with ssl.CERT_NONE, we can not check_hostname self.tls_insecure_set(True) - def tls_insecure_set(self, value): + def tls_insecure_set(self, value: bool) -> None: """Configure verification of the server hostname in the server certificate. If value is set to true, it is impossible to guarantee that the host @@ -870,7 +984,7 @@ def tls_insecure_set(self, value): # If verify_mode is CERT_NONE then the host name will never be checked self._ssl_context.check_hostname = not value - def proxy_set(self, **proxy_args): + def proxy_set(self, **proxy_args: typing.Any) -> None: """Configure proxying of MQTT connection. Enables support for SOCKS or HTTP proxies. @@ -895,7 +1009,7 @@ def proxy_set(self, **proxy_args): else: self._proxy = proxy_args - def enable_logger(self, logger=None): + def enable_logger(self, logger: typing.Optional[logging.Logger] = None) -> None: """Enables a logger to send log messages to""" if logger is None: if self._logger is not None: @@ -904,19 +1018,21 @@ def enable_logger(self, logger=None): logger = logging.getLogger(__name__) self._logger = logger - def disable_logger(self): + def disable_logger(self) -> None: self._logger = None def connect( self, - host, - port=1883, - keepalive=60, - bind_address="", - bind_port=0, - clean_start=MQTT_CLEAN_START_FIRST_ONLY, - properties=None, - ): + host: str, + port: int = 1883, + keepalive: int = 60, + bind_address: str = "", + bind_port: int = 0, + clean_start: typing.Union[ + bool, typing.Literal[3] + ] = MQTT_CLEAN_START_FIRST_ONLY, # type: ignore + properties: typing.Optional[Properties] = None, + ) -> MQTTErrorCode: """Connect to a remote broker. This is a blocking call that establishes the underlying connection and transmits a CONNECT packet. @@ -950,12 +1066,15 @@ def connect( def connect_srv( self, - domain=None, - keepalive=60, - bind_address="", - clean_start=MQTT_CLEAN_START_FIRST_ONLY, - properties=None, - ): + domain: typing.Optional[str] = None, + keepalive: int = 60, + bind_address: str = "", + bind_port: int = 0, + clean_start: typing.Union[ + bool, typing.Literal[3] + ] = MQTT_CLEAN_START_FIRST_ONLY, # type: ignore + properties: typing.Optional[Properties] = None, + ) -> MQTTErrorCode: """Connect to a remote broker. domain is the DNS domain to search for SRV records; if None, @@ -994,7 +1113,13 @@ def connect_srv( try: return self.connect( - host, port, keepalive, bind_address, clean_start, properties + host, + port, + keepalive, + bind_address, + bind_port, + clean_start, + properties, ) except Exception: # noqa: S110 pass @@ -1003,14 +1128,16 @@ def connect_srv( def connect_async( self, - host, - port=1883, - keepalive=60, - bind_address="", - bind_port=0, - clean_start=MQTT_CLEAN_START_FIRST_ONLY, - properties=None, - ): + host: str, + port: int = 1883, + keepalive: int = 60, + bind_address: str = "", + bind_port: int = 0, + clean_start: typing.Union[ + bool, typing.Literal[3] + ] = MQTT_CLEAN_START_FIRST_ONLY, # type: ignore + properties: typing.Optional[Properties] = None, + ) -> None: """Connect to a remote broker asynchronously. This is a non-blocking connect call that can be used with loop_start() to provide very quick start. @@ -1047,7 +1174,7 @@ def connect_async( self._connect_properties = properties self._state = mqtt_cs_connect_async - def reconnect_delay_set(self, min_delay=1, max_delay=120): + def reconnect_delay_set(self, min_delay: int = 1, max_delay: int = 120) -> None: """Configure the exponential reconnect delay When connection is lost, wait initially min_delay seconds and @@ -1060,7 +1187,7 @@ def reconnect_delay_set(self, min_delay=1, max_delay=120): self._reconnect_max_delay = max_delay self._reconnect_delay = None - def reconnect(self): + def reconnect(self) -> MQTTErrorCode: """Reconnect the client after a disconnect. Can only be called after connect()/connect_async().""" if len(self._host) == 0: @@ -1068,16 +1195,18 @@ def reconnect(self): if self._port <= 0: raise ValueError("Invalid port number.") - self._in_packet = { - "command": 0, - "have_remaining": 0, - "remaining_count": [], - "remaining_mult": 1, - "remaining_length": 0, - "packet": bytearray(b""), - "to_process": 0, - "pos": 0, - } + self._in_packet = _InPacket( + { + "command": 0, + "have_remaining": 0, + "remaining_count": [], + "remaining_mult": 1, + "remaining_length": 0, + "packet": bytearray(b""), + "to_process": 0, + "pos": 0, + } + ) self._out_packet = collections.deque() @@ -1085,7 +1214,7 @@ def reconnect(self): self._last_msg_in = time_func() self._last_msg_out = time_func() - self._ping_t = 0 + self._ping_t = 0.0 self._state = mqtt_cs_new self._sock_close() @@ -1106,16 +1235,19 @@ def reconnect(self): if not self.suppress_exceptions: raise - sock = self._create_socket_connection() + tcp_sock = self._create_socket_connection() if self._ssl: - # SSL is only supported when SSLContext is available (implies Python >= 2.7.9 or >= 3.2) + if self._ssl_context is None: + raise ValueError( + "Impossible condition. _ssl_context should never be None if _ssl is True" + ) verify_host = not self._tls_insecure try: # Try with server_hostname, even it's not supported in certain scenarios - sock = self._ssl_context.wrap_socket( - sock, + ssl_sock = self._ssl_context.wrap_socket( + tcp_sock, server_hostname=self._host, do_handshake_on_connect=False, ) @@ -1124,43 +1256,52 @@ def reconnect(self): raise except ValueError: # Python version requires SNI in order to handle server_hostname, but SNI is not available - sock = self._ssl_context.wrap_socket( - sock, + ssl_sock = self._ssl_context.wrap_socket( + tcp_sock, do_handshake_on_connect=False, ) else: # If SSL context has already checked hostname, then don't need to do it again if ( hasattr(self._ssl_context, "check_hostname") - and self._ssl_context.check_hostname + and self._ssl_context.check_hostname # type: ignore ): verify_host = False - sock.settimeout(self._keepalive) - sock.do_handshake() + ssl_sock.settimeout(self._keepalive) + ssl_sock.do_handshake() if verify_host: - ssl.match_hostname(sock.getpeercert(), self._host) + # TODO: this type error is a true error: + # error: Module has no attribute "match_hostname" [attr-defined] + # Python 3.12 no longer have this method. + ssl.match_hostname(ssl_sock.getpeercert(), self._host) # type: ignore + + sock_without_ws: typing.Union[socket.socket, "ssl.SSLSocket"] = ssl_sock + else: + sock_without_ws = tcp_sock if self._transport == "websockets": - sock.settimeout(self._keepalive) - sock = WebsocketWrapper( - sock, + sock_without_ws.settimeout(self._keepalive) + ws_sock = WebsocketWrapper( + sock_without_ws, self._host, self._port, self._ssl, self._websocket_path, self._websocket_extra_headers, ) + self._sock = ws_sock + else: + self._sock = sock_without_ws - self._sock = sock - self._sock.setblocking(0) + self._sock.setblocking(False) # type: ignore[attr-defined] self._registered_write = False self._call_socket_open() return self._send_connect(self._keepalive) - def loop(self, timeout=1.0, max_packets=1): + def loop(self, timeout: float = 1.0, max_packets: int = 1) -> MQTTErrorCode: """Process network events. It is strongly recommended that you use loop_start(), or @@ -1192,7 +1333,7 @@ def loop(self, timeout=1.0, max_packets=1): return self._loop(timeout) - def _loop(self, timeout=1.0): + def _loop(self, timeout: float = 1.0) -> MQTTErrorCode: if timeout < 0.0: raise ValueError("Invalid timeout.") @@ -1206,7 +1347,7 @@ def _loop(self, timeout=1.0): # used to check if there are any bytes left in the (SSL) socket pending_bytes = 0 if hasattr(self._sock, "pending"): - pending_bytes = self._sock.pending() + pending_bytes = self._sock.pending() # type: ignore[union-attr] # if bytes are pending do not wait in select if pending_bytes > 0: @@ -1223,15 +1364,15 @@ def _loop(self, timeout=1.0): socklist = select.select(rlist, wlist, [], timeout) except TypeError: # Socket isn't correct type, in likelihood connection is lost - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST except ValueError: # Can occur if we just reconnected but rlist/wlist contain a -1 for # some reason. - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST except Exception: # Note that KeyboardInterrupt, etc. can still terminate since they # are not derived from Exception - return MQTT_ERR_UNKNOWN + return MQTTErrorCode.MQTT_ERR_UNKNOWN if self._sock in socklist[0] or pending_bytes > 0: rc = self.loop_read() @@ -1257,7 +1398,14 @@ def _loop(self, timeout=1.0): return self.loop_misc() - def publish(self, topic, payload=None, qos=0, retain=False, properties=None): + def publish( + self, + topic: str, + payload: PayloadType = None, + qos: int = 0, + retain: bool = False, + properties: typing.Optional[Properties] = None, + ) -> MQTTMessageInfo: """Publish a message on a topic. This causes a message to be sent to the broker and subsequently from @@ -1300,9 +1448,12 @@ def publish(self, topic, payload=None, qos=0, retain=False, properties=None): if topic is None or len(topic) == 0: raise ValueError("Invalid topic.") - topic = topic.encode("utf-8") + topic_bytes = topic.encode("utf-8") - if self._topic_wildcard_len_check(topic) != MQTT_ERR_SUCCESS: + if ( + self._topic_wildcard_len_check(topic_bytes) + != MQTTErrorCode.MQTT_ERR_SUCCESS + ): raise ValueError("Publish topic cannot contain wildcards.") if qos < 0 or qos > 2: @@ -1327,12 +1478,19 @@ def publish(self, topic, payload=None, qos=0, retain=False, properties=None): if qos == 0: info = MQTTMessageInfo(local_mid) rc = self._send_publish( - local_mid, topic, local_payload, qos, retain, False, info, properties + local_mid, + topic_bytes, + local_payload, + qos, + retain, + False, + info, + properties, ) info.rc = rc return info else: - message = MQTTMessage(local_mid, topic) + message = MQTTMessage(local_mid, topic_bytes) message.timestamp = time_func() message.payload = local_payload message.qos = qos @@ -1345,11 +1503,11 @@ def publish(self, topic, payload=None, qos=0, retain=False, properties=None): self._max_queued_messages > 0 and len(self._out_messages) >= self._max_queued_messages ): - message.info.rc = MQTT_ERR_QUEUE_SIZE + message.info.rc = MQTTErrorCode.MQTT_ERR_QUEUE_SIZE return message.info if local_mid in self._out_messages: - message.info.rc = MQTT_ERR_QUEUE_SIZE + message.info.rc = MQTTErrorCode.MQTT_ERR_QUEUE_SIZE return message.info self._out_messages[message.mid] = message @@ -1365,7 +1523,7 @@ def publish(self, topic, payload=None, qos=0, retain=False, properties=None): rc = self._send_publish( message.mid, - topic, + topic_bytes, message.payload, message.qos, message.retain, @@ -1375,7 +1533,7 @@ def publish(self, topic, payload=None, qos=0, retain=False, properties=None): ) # remove from inflight messages so it will be send after a connection is made - if rc is MQTT_ERR_NO_CONN: + if rc == MQTTErrorCode.MQTT_ERR_NO_CONN: self._inflight_messages -= 1 message.state = mqtt_ms_publish @@ -1383,10 +1541,12 @@ def publish(self, topic, payload=None, qos=0, retain=False, properties=None): return message.info else: message.state = mqtt_ms_queued - message.info.rc = MQTT_ERR_SUCCESS + message.info.rc = MQTTErrorCode.MQTT_ERR_SUCCESS return message.info - def username_pw_set(self, username, password=None): + def username_pw_set( + self, username: typing.Optional[str], password: typing.Optional[str] = None + ) -> None: """Set a username and optionally a password for broker authentication. Must be called before connect() to have any effect. @@ -1401,11 +1561,12 @@ def username_pw_set(self, username, password=None): # [MQTT-3.1.3-11] User name must be UTF-8 encoded string self._username = None if username is None else username.encode("utf-8") - self._password = password - if isinstance(self._password, str): - self._password = self._password.encode("utf-8") + if isinstance(password, str): + self._password = password.encode("utf-8") + else: + self._password = password - def enable_bridge_mode(self): + def enable_bridge_mode(self) -> None: """Sets the client in a bridge mode instead of client mode. Must be called before connect() to have any effect. @@ -1421,7 +1582,7 @@ def enable_bridge_mode(self): """ self._client_mode = MQTT_BRIDGE - def is_connected(self): + def is_connected(self) -> bool: """Returns the current status of the connection True if connection exists @@ -1429,7 +1590,11 @@ def is_connected(self): """ return self._state == mqtt_cs_connected - def disconnect(self, reasoncode=None, properties=None): + def disconnect( + self, + reasoncode: typing.Optional[ReasonCodes] = None, + properties: typing.Optional[Properties] = None, + ) -> MQTTErrorCode: """Disconnect a connected client from the broker. reasoncode: (MQTT v5.0 only) a ReasonCodes instance setting the MQTT v5.0 reasoncode to be sent with the disconnect. It is optional, the receiver @@ -1444,7 +1609,19 @@ def disconnect(self, reasoncode=None, properties=None): return self._send_disconnect(reasoncode, properties) - def subscribe(self, topic, qos=0, options=None, properties=None): + def subscribe( + self, + topic: typing.Union[ + str, + tuple[str, int], + tuple[str, SubscribeOptions], + list[tuple[str, int]], + list[tuple[str, SubscribeOptions]], + ], + qos: int = 0, + options: typing.Optional[SubscribeOptions] = None, + properties: typing.Optional[Properties] = None, + ) -> tuple[MQTTErrorCode, typing.Optional[int]]: """Subscribe the client to one or more topics. This function may be called in three different ways (and a further three for MQTT v5.0): @@ -1528,13 +1705,13 @@ def subscribe(self, topic, qos=0, options=None, properties=None): if isinstance(topic, tuple): if self._protocol == MQTTv5: - topic, options = topic + topic, options = topic # type: ignore if not isinstance(options, SubscribeOptions): raise ValueError( "Subscribe options must be instance of SubscribeOptions class." ) else: - topic, qos = topic + topic, qos = topic # type: ignore if isinstance(topic, (bytes, str)): if qos < 0 or qos > 2: @@ -1555,7 +1732,7 @@ def subscribe(self, topic, qos=0, options=None, properties=None): else: if topic is None or len(topic) == 0: raise ValueError("Invalid topic.") - topic_qos_list = [(topic.encode("utf-8"), qos)] + topic_qos_list = [(topic.encode("utf-8"), qos)] # type: ignore elif isinstance(topic, list): topic_qos_list = [] if self._protocol == MQTTv5: @@ -1568,11 +1745,11 @@ def subscribe(self, topic, qos=0, options=None, properties=None): topic_qos_list.append((t.encode("utf-8"), o)) else: for t, q in topic: - if q < 0 or q > 2: + if isinstance(q, SubscribeOptions) or q < 0 or q > 2: raise ValueError("Invalid QoS level.") if t is None or len(t) == 0 or not isinstance(t, (bytes, str)): raise ValueError("Invalid topic.") - topic_qos_list.append((t.encode("utf-8"), q)) + topic_qos_list.append((t.encode("utf-8"), q)) # type: ignore if topic_qos_list is None: raise ValueError("No topic specified, or incorrect topic type.") @@ -1588,7 +1765,9 @@ def subscribe(self, topic, qos=0, options=None, properties=None): return self._send_subscribe(False, topic_qos_list, properties) - def unsubscribe(self, topic, properties=None): + def unsubscribe( + self, topic: str, properties: typing.Optional[Properties] = None + ) -> tuple[MQTTErrorCode, typing.Optional[int]]: """Unsubscribe the client from one or more topics. topic: A single string, or list of strings that are the subscription @@ -1624,11 +1803,11 @@ def unsubscribe(self, topic, properties=None): raise ValueError("No topic specified, or incorrect topic type.") if self._sock is None: - return (MQTT_ERR_NO_CONN, None) + return (MQTTErrorCode.MQTT_ERR_NO_CONN, None) return self._send_unsubscribe(False, topic_list, properties) - def loop_read(self, max_packets=1): + def loop_read(self, max_packets: int = 1) -> MQTTErrorCode: """Process read network events. Use in place of calling loop() if you wish to handle your client reads as part of your own application. @@ -1637,7 +1816,7 @@ def loop_read(self, max_packets=1): Do not use if you are using the threaded interface loop_start().""" if self._sock is None: - return MQTT_ERR_NO_CONN + return MQTTErrorCode.MQTT_ERR_NO_CONN max_packets = len(self._out_messages) + len(self._in_messages) if max_packets < 1: @@ -1645,15 +1824,15 @@ def loop_read(self, max_packets=1): for _ in range(0, max_packets): if self._sock is None: - return MQTT_ERR_NO_CONN + return MQTTErrorCode.MQTT_ERR_NO_CONN rc = self._packet_read() if rc > 0: - return self._loop_rc_handle(rc) - elif rc == MQTT_ERR_AGAIN: - return MQTT_ERR_SUCCESS - return MQTT_ERR_SUCCESS + return self._loop_rc_handle(rc) # type: ignore + elif rc == MQTTErrorCode.MQTT_ERR_AGAIN: + return MQTTErrorCode.MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def loop_write(self, max_packets=1): + def loop_write(self, max_packets: int = 1) -> MQTTErrorCode: """Process write network events. Use in place of calling loop() if you wish to handle your client writes as part of your own application. @@ -1664,23 +1843,23 @@ def loop_write(self, max_packets=1): Do not use if you are using the threaded interface loop_start().""" if self._sock is None: - return MQTT_ERR_NO_CONN + return MQTTErrorCode.MQTT_ERR_NO_CONN try: rc = self._packet_write() - if rc == MQTT_ERR_AGAIN: - return MQTT_ERR_SUCCESS + if rc == MQTTErrorCode.MQTT_ERR_AGAIN: + return MQTTErrorCode.MQTT_ERR_SUCCESS elif rc > 0: - return self._loop_rc_handle(rc) + return self._loop_rc_handle(rc) # type: ignore else: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS finally: if self.want_write(): self._call_socket_register_write() else: self._call_socket_unregister_write() - def want_write(self): + def want_write(self) -> bool: """Call to determine if there is network data waiting to be written. Useful if you are calling select() yourself rather than using loop(). """ @@ -1691,13 +1870,13 @@ def want_write(self): except IndexError: return False - def loop_misc(self): + def loop_misc(self) -> MQTTErrorCode: """Process miscellaneous network events. Use in place of calling loop() if you wish to call select() or equivalent on. Do not use if you are using the threaded interface loop_start().""" if self._sock is None: - return MQTT_ERR_NO_CONN + return MQTTErrorCode.MQTT_ERR_NO_CONN now = time_func() self._check_keepalive() @@ -1708,24 +1887,24 @@ def loop_misc(self): self._sock_close() if self._state == mqtt_cs_disconnecting: - rc = MQTT_ERR_SUCCESS + rc = MQTTErrorCode.MQTT_ERR_SUCCESS else: - rc = MQTT_ERR_KEEPALIVE + rc = MQTTErrorCode.MQTT_ERR_KEEPALIVE self._do_on_disconnect(rc) - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def max_inflight_messages_set(self, inflight): + def max_inflight_messages_set(self, inflight: int) -> None: """Set the maximum number of messages with QoS>0 that can be part way through their network flow at once. Defaults to 20.""" if inflight < 0: raise ValueError("Invalid inflight.") self._max_inflight_messages = inflight - def max_queued_messages_set(self, queue_size): + def max_queued_messages_set(self, queue_size: int) -> "Client": """Set the maximum number of messages in the outgoing message queue. 0 means unlimited.""" if queue_size < 0: @@ -1735,19 +1914,26 @@ def max_queued_messages_set(self, queue_size): self._max_queued_messages = queue_size return self - def message_retry_set(self, retry): + def message_retry_set(self, retry): # type: ignore """No longer used, remove in version 2.0""" pass - def user_data_set(self, userdata): + def user_data_set(self, userdata: typing.Any) -> None: """Set the user data variable passed to callbacks. May be any data type.""" self._userdata = userdata - def user_data_get(self): + def user_data_get(self) -> typing.Any: """Get the user data variable passed to callbacks. May be any data type.""" return self._userdata - def will_set(self, topic, payload=None, qos=0, retain=False, properties=None): + def will_set( + self, + topic: str, + payload: PayloadType = None, + qos: int = 0, + retain: bool = False, + properties: typing.Optional[Properties] = None, + ) -> None: """Set a Will to be sent by the broker in case the client disconnects unexpectedly. This must be called before connect() to have any effect. @@ -1795,7 +1981,7 @@ def will_set(self, topic, payload=None, qos=0, retain=False, properties=None): self._will_retain = retain self._will_properties = properties - def will_clear(self): + def will_clear(self) -> None: """Removes a will that was previously configured with will_set(). Must be called before connect() to have any effect.""" @@ -1805,11 +1991,16 @@ def will_clear(self): self._will_qos = 0 self._will_retain = False - def socket(self): + def socket(self) -> typing.Optional[SocketLike]: """Return the socket or ssl object for this client.""" return self._sock - def loop_forever(self, timeout=1.0, max_packets=1, retry_first_connection=False): + def loop_forever( + self, + timeout: float = 1.0, + max_packets: int = 1, + retry_first_connection: bool = False, + ) -> MQTTErrorCode: """This function calls the network loop functions for you in an infinite blocking loop. It is useful for the case where you only want to run the MQTT client loop in your program. @@ -1847,8 +2038,8 @@ def loop_forever(self, timeout=1.0, max_packets=1, retry_first_connection=False) break while run: - rc = MQTT_ERR_SUCCESS - while rc == MQTT_ERR_SUCCESS: + rc = MQTTErrorCode.MQTT_ERR_SUCCESS + while rc == MQTTErrorCode.MQTT_ERR_SUCCESS: rc = self._loop(timeout) # We don't need to worry about locking here, because we've # either called loop_forever() when in single threaded mode, or @@ -1859,10 +2050,10 @@ def loop_forever(self, timeout=1.0, max_packets=1, retry_first_connection=False) and len(self._out_packet) == 0 and len(self._out_messages) == 0 ): - rc = 1 + rc = MQTTErrorCode.MQTT_ERR_NOMEM run = False - def should_exit(): + def should_exit() -> bool: # B023: uses the run variable from the outer scope on purpose return ( self._state == mqtt_cs_disconnecting @@ -1886,13 +2077,13 @@ def should_exit(): return rc - def loop_start(self): + def loop_start(self) -> MQTTErrorCode: """This is part of the threaded client interface. Call this once to start a new thread to process network traffic. This provides an alternative to repeatedly calling loop() yourself. """ if self._thread is not None: - return MQTT_ERR_INVAL + return MQTTErrorCode.MQTT_ERR_INVAL self._sockpairR, self._sockpairW = _socketpair_compat() self._thread_terminate = False @@ -1903,7 +2094,9 @@ def loop_start(self): self._thread.daemon = True self._thread.start() - def loop_stop(self, force=False): + return MQTTErrorCode.MQTT_ERR_SUCCESS + + def loop_stop(self, force: bool = False) -> MQTTErrorCode: """This is part of the threaded client interface. Call this once to stop the network thread previously created with loop_start(). This call will block until the network thread finishes. @@ -1911,21 +2104,23 @@ def loop_stop(self, force=False): The force parameter is currently ignored. """ if self._thread is None: - return MQTT_ERR_INVAL + return MQTTErrorCode.MQTT_ERR_INVAL self._thread_terminate = True if threading.current_thread() != self._thread: self._thread.join() self._thread = None + return MQTTErrorCode.MQTT_ERR_SUCCESS + @property - def on_log(self): + def on_log(self) -> typing.Optional[CallbackOnLog]: """If implemented, called when the client has log information. Defined to allow debugging.""" return self._on_log @on_log.setter - def on_log(self, func): + def on_log(self, func: typing.Optional[CallbackOnLog]) -> None: """Define the logging callback implementation. Expected signature is: @@ -1943,21 +2138,21 @@ def on_log(self, func): """ self._on_log = func - def log_callback(self): - def decorator(func): + def log_callback(self) -> typing.Callable[[CallbackOnLog], CallbackOnLog]: + def decorator(func: CallbackOnLog) -> CallbackOnLog: self.on_log = func return func return decorator @property - def on_pre_connect(self): + def on_pre_connect(self) -> typing.Optional[CallbackOnPreConnect]: """If implemented, called immediately prior to the connection is made request.""" return self._on_pre_connect @on_pre_connect.setter - def on_pre_connect(self, func): + def on_pre_connect(self, func: typing.Optional[CallbackOnPreConnect]) -> None: """Define the pre_connect callback implementation. Expected signature: @@ -1973,21 +2168,23 @@ def on_pre_connect(self, func): with self._callback_mutex: self._on_pre_connect = func - def pre_connect_callback(self): - def decorator(func): + def pre_connect_callback( + self, + ) -> typing.Callable[[CallbackOnPreConnect], CallbackOnPreConnect]: + def decorator(func: CallbackOnPreConnect) -> CallbackOnPreConnect: self.on_pre_connect = func return func return decorator @property - def on_connect(self): + def on_connect(self) -> typing.Optional[CallbackOnConnect]: """If implemented, called when the broker responds to our connection request.""" return self._on_connect @on_connect.setter - def on_connect(self, func): + def on_connect(self, func: typing.Optional[CallbackOnConnect]) -> None: """Define the connect callback implementation. Expected signature for MQTT v3.1 and v3.1.1 is: @@ -2030,21 +2227,23 @@ def on_connect(self, func): with self._callback_mutex: self._on_connect = func - def connect_callback(self): - def decorator(func): + def connect_callback( + self, + ) -> typing.Callable[[CallbackOnConnect], CallbackOnConnect]: + def decorator(func: CallbackOnConnect) -> CallbackOnConnect: self.on_connect = func return func return decorator @property - def on_connect_fail(self): + def on_connect_fail(self) -> typing.Optional[CallbackOnConnectFail]: """If implemented, called when the client failed to connect to the broker.""" return self._on_connect_fail @on_connect_fail.setter - def on_connect_fail(self, func): + def on_connect_fail(self, func: typing.Optional[CallbackOnConnectFail]) -> None: """Define the connection failure callback implementation Expected signature is: @@ -2060,21 +2259,23 @@ def on_connect_fail(self, func): with self._callback_mutex: self._on_connect_fail = func - def connect_fail_callback(self): - def decorator(func): + def connect_fail_callback( + self, + ) -> typing.Callable[[CallbackOnConnectFail], CallbackOnConnectFail]: + def decorator(func: CallbackOnConnectFail) -> CallbackOnConnectFail: self.on_connect_fail = func return func return decorator @property - def on_subscribe(self): + def on_subscribe(self) -> typing.Optional[CallbackOnSubscribe]: """If implemented, called when the broker responds to a subscribe request.""" return self._on_subscribe @on_subscribe.setter - def on_subscribe(self, func): + def on_subscribe(self, func: typing.Optional[CallbackOnSubscribe]) -> None: """Define the subscribe callback implementation. Expected signature for MQTT v3.1.1 and v3.1 is: @@ -2100,15 +2301,17 @@ def on_subscribe(self, func): with self._callback_mutex: self._on_subscribe = func - def subscribe_callback(self): - def decorator(func): + def subscribe_callback( + self, + ) -> typing.Callable[[CallbackOnSubscribe], CallbackOnSubscribe]: + def decorator(func: CallbackOnSubscribe) -> CallbackOnSubscribe: self.on_subscribe = func return func return decorator @property - def on_message(self): + def on_message(self) -> typing.Optional[CallbackOnMessage]: """If implemented, called when a message has been received on a topic that the client subscribes to. @@ -2118,7 +2321,7 @@ def on_message(self): return self._on_message @on_message.setter - def on_message(self, func): + def on_message(self, func: typing.Optional[CallbackOnMessage]) -> None: """Define the message received callback implementation. Expected signature is: @@ -2136,15 +2339,17 @@ def on_message(self, func): with self._callback_mutex: self._on_message = func - def message_callback(self): - def decorator(func): + def message_callback( + self, + ) -> typing.Callable[[CallbackOnMessage], CallbackOnMessage]: + def decorator(func: CallbackOnMessage) -> CallbackOnMessage: self.on_message = func return func return decorator @property - def on_publish(self): + def on_publish(self) -> typing.Optional[CallbackOnPublish]: """If implemented, called when a message that was to be sent using the publish() call has completed transmission to the broker. @@ -2156,7 +2361,7 @@ def on_publish(self): return self._on_publish @on_publish.setter - def on_publish(self, func): + def on_publish(self, func: typing.Optional[CallbackOnPublish]) -> None: """Define the published message callback implementation. Expected signature is: @@ -2174,21 +2379,23 @@ def on_publish(self, func): with self._callback_mutex: self._on_publish = func - def publish_callback(self): - def decorator(func): + def publish_callback( + self, + ) -> typing.Callable[[CallbackOnPublish], CallbackOnPublish]: + def decorator(func: CallbackOnPublish) -> CallbackOnPublish: self.on_publish = func return func return decorator @property - def on_unsubscribe(self): + def on_unsubscribe(self) -> typing.Optional[CallbackOnUnsubscribe]: """If implemented, called when the broker responds to an unsubscribe request.""" return self._on_unsubscribe @on_unsubscribe.setter - def on_unsubscribe(self, func): + def on_unsubscribe(self, func: typing.Optional[CallbackOnUnsubscribe]) -> None: """Define the unsubscribe callback implementation. Expected signature for MQTT v3.1.1 and v3.1 is: @@ -2212,20 +2419,22 @@ def on_unsubscribe(self, func): with self._callback_mutex: self._on_unsubscribe = func - def unsubscribe_callback(self): - def decorator(func): + def unsubscribe_callback( + self, + ) -> typing.Callable[[CallbackOnUnsubscribe], CallbackOnUnsubscribe]: + def decorator(func: CallbackOnUnsubscribe) -> CallbackOnUnsubscribe: self.on_unsubscribe = func return func return decorator @property - def on_disconnect(self): + def on_disconnect(self) -> typing.Optional[CallbackOnDisconnect]: """If implemented, called when the client disconnects from the broker.""" return self._on_disconnect @on_disconnect.setter - def on_disconnect(self, func): + def on_disconnect(self, func: typing.Optional[CallbackOnDisconnect]) -> None: """Define the disconnect callback implementation. Expected signature for MQTT v3.1.1 and v3.1 is: @@ -2249,20 +2458,22 @@ def on_disconnect(self, func): with self._callback_mutex: self._on_disconnect = func - def disconnect_callback(self): - def decorator(func): + def disconnect_callback( + self, + ) -> typing.Callable[[CallbackOnDisconnect], CallbackOnDisconnect]: + def decorator(func: CallbackOnDisconnect) -> CallbackOnDisconnect: self.on_disconnect = func return func return decorator @property - def on_socket_open(self): + def on_socket_open(self) -> typing.Optional[CallbackOnSocket]: """If implemented, called just after the socket was opend.""" return self._on_socket_open @on_socket_open.setter - def on_socket_open(self, func): + def on_socket_open(self, func: typing.Optional[CallbackOnSocket]) -> None: """Define the socket_open callback implementation. This should be used to register the socket to an external event loop for reading. @@ -2280,20 +2491,29 @@ def on_socket_open(self, func): with self._callback_mutex: self._on_socket_open = func - def socket_open_callback(self): - def decorator(func): + def socket_open_callback( + self, + ) -> typing.Callable[[CallbackOnSocket], CallbackOnSocket]: + def decorator(func: CallbackOnSocket) -> CallbackOnSocket: self.on_socket_open = func return func return decorator - def _call_socket_open(self): + def _call_socket_open(self) -> None: """Call the socket_open callback with the just-opened socket""" with self._callback_mutex: on_socket_open = self.on_socket_open if on_socket_open: with self._in_callback_mutex: + if self._sock is None: + self._easy_log( + MQTT_LOG_ERR, + "socket() is None in _call_socket_open", + ) + return + try: on_socket_open(self, self._userdata, self._sock) except Exception as err: @@ -2304,12 +2524,12 @@ def _call_socket_open(self): raise @property - def on_socket_close(self): + def on_socket_close(self) -> typing.Optional[CallbackOnSocket]: """If implemented, called just before the socket is closed.""" return self._on_socket_close @on_socket_close.setter - def on_socket_close(self, func): + def on_socket_close(self, func: typing.Optional[CallbackOnSocket]) -> None: """Define the socket_close callback implementation. This should be used to unregister the socket from an external event loop for reading. @@ -2327,14 +2547,16 @@ def on_socket_close(self, func): with self._callback_mutex: self._on_socket_close = func - def socket_close_callback(self): - def decorator(func): + def socket_close_callback( + self, + ) -> typing.Callable[[CallbackOnSocket], CallbackOnSocket]: + def decorator(func: CallbackOnSocket) -> CallbackOnSocket: self.on_socket_close = func return func return decorator - def _call_socket_close(self, sock): + def _call_socket_close(self, sock: SocketLike) -> None: """Call the socket_close callback with the about-to-be-closed socket""" with self._callback_mutex: on_socket_close = self.on_socket_close @@ -2351,12 +2573,12 @@ def _call_socket_close(self, sock): raise @property - def on_socket_register_write(self): + def on_socket_register_write(self) -> typing.Optional[CallbackOnSocket]: """If implemented, called when the socket needs writing but can't.""" return self._on_socket_register_write @on_socket_register_write.setter - def on_socket_register_write(self, func): + def on_socket_register_write(self, func: typing.Optional[CallbackOnSocket]) -> None: """Define the socket_register_write callback implementation. This should be used to register the socket with an external event loop for writing. @@ -2374,14 +2596,16 @@ def on_socket_register_write(self, func): with self._callback_mutex: self._on_socket_register_write = func - def socket_register_write_callback(self): - def decorator(func): + def socket_register_write_callback( + self, + ) -> typing.Callable[[CallbackOnSocket], CallbackOnSocket]: + def decorator(func: CallbackOnSocket) -> CallbackOnSocket: self._on_socket_register_write = func return func return decorator - def _call_socket_register_write(self): + def _call_socket_register_write(self) -> None: """Call the socket_register_write callback with the unwritable socket""" if not self._sock or self._registered_write: return @@ -2402,12 +2626,16 @@ def _call_socket_register_write(self): raise @property - def on_socket_unregister_write(self): + def on_socket_unregister_write( + self, + ) -> typing.Optional[CallbackOnSocket]: """If implemented, called when the socket doesn't need writing anymore.""" return self._on_socket_unregister_write @on_socket_unregister_write.setter - def on_socket_unregister_write(self, func): + def on_socket_unregister_write( + self, func: typing.Optional[CallbackOnSocket] + ) -> None: """Define the socket_unregister_write callback implementation. This should be used to unregister the socket from an external event loop for writing. @@ -2425,14 +2653,20 @@ def on_socket_unregister_write(self, func): with self._callback_mutex: self._on_socket_unregister_write = func - def socket_unregister_write_callback(self): - def decorator(func): + def socket_unregister_write_callback( + self, + ) -> typing.Callable[[CallbackOnSocket], CallbackOnSocket]: + def decorator( + func: CallbackOnSocket, + ) -> CallbackOnSocket: self._on_socket_unregister_write = func return func return decorator - def _call_socket_unregister_write(self, sock=None): + def _call_socket_unregister_write( + self, sock: typing.Optional[SocketLike] = None + ) -> None: """Call the socket_unregister_write callback with the writable socket""" sock = sock or self._sock if not sock or not self._registered_write: @@ -2454,7 +2688,7 @@ def _call_socket_unregister_write(self, sock=None): if not self.suppress_exceptions: raise - def message_callback_add(self, sub, callback): + def message_callback_add(self, sub: str, callback: CallbackOnMessage) -> None: """Register a message callback for a specific topic. Messages that match 'sub' will be passed to 'callback'. Any non-matching messages will be passed to the default on_message @@ -2471,14 +2705,16 @@ def message_callback_add(self, sub, callback): with self._callback_mutex: self._on_message_filtered[sub] = callback - def topic_callback(self, sub): - def decorator(func): + def topic_callback( + self, sub: str + ) -> typing.Callable[[CallbackOnMessage], CallbackOnMessage]: + def decorator(func: CallbackOnMessage) -> CallbackOnMessage: self.message_callback_add(sub, func) return func return decorator - def message_callback_remove(self, sub): + def message_callback_remove(self, sub: str) -> None: """Remove a message callback previously registered with message_callback_add().""" if sub is None: @@ -2494,18 +2730,22 @@ def message_callback_remove(self, sub): # Private functions # ============================================================ - def _loop_rc_handle(self, rc, properties=None): + def _loop_rc_handle( + self, + rc: typing.Union[MQTTErrorCode, ReasonCodes, None], + properties: typing.Optional[Properties] = None, + ) -> typing.Union[MQTTErrorCode, ReasonCodes, None]: if rc: self._sock_close() if self._state == mqtt_cs_disconnecting: - rc = MQTT_ERR_SUCCESS + rc = MQTTErrorCode.MQTT_ERR_SUCCESS self._do_on_disconnect(rc, properties) return rc - def _packet_read(self): + def _packet_read(self) -> MQTTErrorCode: # This gets called if pselect() indicates that there is network data # available - ie. at least one byte. What we do depends on what data we # already have. @@ -2523,18 +2763,18 @@ def _packet_read(self): try: command = self._sock_recv(1) except BlockingIOError: - return MQTT_ERR_AGAIN + return MQTTErrorCode.MQTT_ERR_AGAIN except ConnectionError as err: self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST except TimeoutError as err: self._easy_log(MQTT_LOG_ERR, "timeout on socket: %s", err) - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST else: if len(command) == 0: - return MQTT_ERR_CONN_LOST - (command,) = struct.unpack("!B", command) - self._in_packet["command"] = command + return MQTTErrorCode.MQTT_ERR_CONN_LOST + (command_value,) = struct.unpack("!B", command) + self._in_packet["command"] = command_value if self._in_packet["have_remaining"] == 0: # Read remaining @@ -2544,28 +2784,28 @@ def _packet_read(self): try: byte = self._sock_recv(1) except BlockingIOError: - return MQTT_ERR_AGAIN + return MQTTErrorCode.MQTT_ERR_AGAIN except ConnectionError as err: self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST else: if len(byte) == 0: - return MQTT_ERR_CONN_LOST - (byte,) = struct.unpack("!B", byte) - self._in_packet["remaining_count"].append(byte) + return MQTTErrorCode.MQTT_ERR_CONN_LOST + (byte_value,) = struct.unpack("!B", byte) + self._in_packet["remaining_count"].append(byte_value) # Max 4 bytes length for remaining length as defined by protocol. # Anything more likely means a broken/malicious client. if len(self._in_packet["remaining_count"]) > 4: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL self._in_packet["remaining_length"] += ( - byte & 127 + byte_value & 127 ) * self._in_packet["remaining_mult"] self._in_packet["remaining_mult"] = ( self._in_packet["remaining_mult"] * 128 ) - if (byte & 128) == 0: + if (byte_value & 128) == 0: break self._in_packet["have_remaining"] = 1 @@ -2576,60 +2816,62 @@ def _packet_read(self): try: data = self._sock_recv(self._in_packet["to_process"]) except BlockingIOError: - return MQTT_ERR_AGAIN + return MQTTErrorCode.MQTT_ERR_AGAIN except ConnectionError as err: self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST else: if len(data) == 0: - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST self._in_packet["to_process"] -= len(data) self._in_packet["packet"] += data count -= 1 if count == 0: with self._msgtime_mutex: self._last_msg_in = time_func() - return MQTT_ERR_AGAIN + return MQTTErrorCode.MQTT_ERR_AGAIN # All data for this packet is read. self._in_packet["pos"] = 0 rc = self._packet_handle() # Free data and reset values - self._in_packet = { - "command": 0, - "have_remaining": 0, - "remaining_count": [], - "remaining_mult": 1, - "remaining_length": 0, - "packet": bytearray(b""), - "to_process": 0, - "pos": 0, - } + self._in_packet = _InPacket( + { + "command": 0, + "have_remaining": 0, + "remaining_count": [], + "remaining_mult": 1, + "remaining_length": 0, + "packet": bytearray(b""), + "to_process": 0, + "pos": 0, + } + ) with self._msgtime_mutex: self._last_msg_in = time_func() return rc - def _packet_write(self): + def _packet_write(self) -> MQTTErrorCode: while True: try: packet = self._out_packet.popleft() except IndexError: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS try: write_length = self._sock_send(packet["packet"][packet["pos"] :]) except (AttributeError, ValueError): self._out_packet.appendleft(packet) - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS except BlockingIOError: self._out_packet.appendleft(packet) - return MQTT_ERR_AGAIN + return MQTTErrorCode.MQTT_ERR_AGAIN except ConnectionError as err: self._out_packet.appendleft(packet) self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST if write_length > 0: packet["to_process"] -= write_length @@ -2653,15 +2895,19 @@ def _packet_write(self): if not self.suppress_exceptions: raise - packet["info"]._set_as_published() + # TODO: Something is odd here. I don't see why packet["info"] can't be None. + # A packet could be produced by _handle_connack with qos=0 and no info + # (around line 3645). Ignore the mypy check for now but I feel their is a bug + # somewhere. + packet["info"]._set_as_published() # type: ignore if (packet["command"] & 0xF0) == DISCONNECT: with self._msgtime_mutex: self._last_msg_out = time_func() - self._do_on_disconnect(MQTT_ERR_SUCCESS) + self._do_on_disconnect(MQTTErrorCode.MQTT_ERR_SUCCESS) self._sock_close() - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS else: # We haven't finished with this packet @@ -2672,9 +2918,9 @@ def _packet_write(self): with self._msgtime_mutex: self._last_msg_out = time_func() - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _easy_log(self, level, fmt, *args): + def _easy_log(self, level: int, fmt: str, *args: typing.Any) -> None: if self.on_log is not None: buf = fmt % args try: @@ -2686,9 +2932,9 @@ def _easy_log(self, level, fmt, *args): level_std = LOGGING_LEVEL[level] self._logger.log(level_std, fmt, *args) - def _check_keepalive(self): + def _check_keepalive(self) -> None: if self._keepalive == 0: - return MQTT_ERR_SUCCESS + return now = time_func() @@ -2714,13 +2960,13 @@ def _check_keepalive(self): self._sock_close() if self._state == mqtt_cs_disconnecting: - rc = MQTT_ERR_SUCCESS + rc = MQTTErrorCode.MQTT_ERR_SUCCESS else: - rc = MQTT_ERR_KEEPALIVE + rc = MQTTErrorCode.MQTT_ERR_KEEPALIVE self._do_on_disconnect(rc) - def _mid_generate(self): + def _mid_generate(self) -> int: with self._mid_generate_mutex: self._last_mid += 1 if self._last_mid == 65536: @@ -2728,47 +2974,49 @@ def _mid_generate(self): return self._last_mid @staticmethod - def _topic_wildcard_len_check(topic): + def _topic_wildcard_len_check(topic: bytes) -> MQTTErrorCode: # Search for + or # in a topic. Return MQTT_ERR_INVAL if found. # Also returns MQTT_ERR_INVAL if the topic string is too long. # Returns MQTT_ERR_SUCCESS if everything is fine. if b"+" in topic or b"#" in topic or len(topic) > 65535: - return MQTT_ERR_INVAL + return MQTTErrorCode.MQTT_ERR_INVAL else: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS @staticmethod - def _filter_wildcard_len_check(sub): + def _filter_wildcard_len_check(sub: bytes) -> MQTTErrorCode: if ( len(sub) == 0 or len(sub) > 65535 or any(b"+" in p or b"#" in p for p in sub.split(b"/") if len(p) > 1) or b"#/" in sub ): - return MQTT_ERR_INVAL + return MQTTErrorCode.MQTT_ERR_INVAL else: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _send_pingreq(self): + def _send_pingreq(self) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PINGREQ") rc = self._send_simple_command(PINGREQ) - if rc == MQTT_ERR_SUCCESS: + if rc == MQTTErrorCode.MQTT_ERR_SUCCESS: self._ping_t = time_func() return rc - def _send_pingresp(self): + def _send_pingresp(self) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PINGRESP") return self._send_simple_command(PINGRESP) - def _send_puback(self, mid): + def _send_puback(self, mid: int) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PUBACK (Mid: %d)", mid) return self._send_command_with_mid(PUBACK, mid, False) - def _send_pubcomp(self, mid): + def _send_pubcomp(self, mid: int) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PUBCOMP (Mid: %d)", mid) return self._send_command_with_mid(PUBCOMP, mid, False) - def _pack_remaining_length(self, packet, remaining_length): + def _pack_remaining_length( + self, packet: bytearray, remaining_length: int + ) -> bytearray: remaining_bytes = [] while True: byte = remaining_length % 128 @@ -2783,7 +3031,7 @@ def _pack_remaining_length(self, packet, remaining_length): # FIXME - this doesn't deal with incorrectly large payloads return packet - def _pack_str16(self, packet, data): + def _pack_str16(self, packet: bytearray, data: typing.Union[bytes, str]) -> None: if isinstance(data, str): data = data.encode("utf-8") packet.extend(struct.pack("!H", len(data))) @@ -2791,15 +3039,15 @@ def _pack_str16(self, packet, data): def _send_publish( self, - mid, - topic, - payload=b"", - qos=0, - retain=False, - dup=False, - info=None, - properties=None, - ): + mid: int, + topic: bytes, + payload: bytes = b"", + qos: int = 0, + retain: bool = False, + dup: bool = False, + info: typing.Optional[MQTTMessageInfo] = None, + properties: typing.Optional[Properties] = None, + ) -> MQTTErrorCode: # we assume that topic and payload are already properly encoded if not isinstance(topic, bytes): raise TypeError("topic must be bytes, not str") @@ -2807,7 +3055,7 @@ def _send_publish( raise TypeError("payload must be bytes if set") if self._sock is None: - return MQTT_ERR_NO_CONN + return MQTTErrorCode.MQTT_ERR_NO_CONN command = PUBLISH | ((dup & 0x1) << 3) | (qos << 1) | retain packet = bytearray() @@ -2888,15 +3136,15 @@ def _send_publish( return self._packet_queue(PUBLISH, packet, mid, qos, info) - def _send_pubrec(self, mid): + def _send_pubrec(self, mid: int) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PUBREC (Mid: %d)", mid) return self._send_command_with_mid(PUBREC, mid, False) - def _send_pubrel(self, mid): + def _send_pubrel(self, mid: int) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PUBREL (Mid: %d)", mid) return self._send_command_with_mid(PUBREL | 2, mid, False) - def _send_command_with_mid(self, command, mid, dup): + def _send_command_with_mid(self, command: int, mid: int, dup: int) -> MQTTErrorCode: # For PUBACK, PUBCOMP, PUBREC, and PUBREL if dup: command |= 0x8 @@ -2905,13 +3153,13 @@ def _send_command_with_mid(self, command, mid, dup): packet = struct.pack("!BBH", command, remaining_length, mid) return self._packet_queue(command, packet, mid, 1) - def _send_simple_command(self, command): + def _send_simple_command(self, command: int) -> MQTTErrorCode: # For DISCONNECT, PINGREQ and PINGRESP remaining_length = 0 packet = struct.pack("!BB", command, remaining_length) return self._packet_queue(command, packet, 0, 0) - def _send_connect(self, keepalive): + def _send_connect(self, keepalive: int) -> MQTTErrorCode: proto_ver = self._protocol # hard-coded UTF-8 encoded string protocol = b"MQTT" if proto_ver >= MQTTv311 else b"MQIsdp" @@ -3026,7 +3274,11 @@ def _send_connect(self, keepalive): ) return self._packet_queue(command, packet, 0, 0) - def _send_disconnect(self, reasoncode=None, properties=None): + def _send_disconnect( + self, + reasoncode: typing.Optional[ReasonCodes] = None, + properties: typing.Optional[Properties] = None, + ) -> MQTTErrorCode: if self._protocol == MQTTv5: self._easy_log( MQTT_LOG_DEBUG, @@ -3062,7 +3314,12 @@ def _send_disconnect(self, reasoncode=None, properties=None): return self._packet_queue(command, packet, 0, 0) - def _send_subscribe(self, dup, topics, properties=None): + def _send_subscribe( + self, + dup: int, + topics: typing.Sequence[tuple[bytes, typing.Union[SubscribeOptions, int]]], + properties: typing.Optional[Properties] = None, + ) -> tuple[MQTTErrorCode, int]: remaining_length = 2 if self._protocol == MQTTv5: if properties is None: @@ -3086,9 +3343,9 @@ def _send_subscribe(self, dup, topics, properties=None): for t, q in topics: self._pack_str16(packet, t) if self._protocol == MQTTv5: - packet += q.pack() + packet += q.pack() # type: ignore else: - packet.append(q) + packet.append(q) # type: ignore self._easy_log( MQTT_LOG_DEBUG, @@ -3099,7 +3356,12 @@ def _send_subscribe(self, dup, topics, properties=None): ) return (self._packet_queue(command, packet, local_mid, 1), local_mid) - def _send_unsubscribe(self, dup, topics, properties=None): + def _send_unsubscribe( + self, + dup: int, + topics: list[bytes], + properties: typing.Optional[Properties] = None, + ) -> tuple[MQTTErrorCode, int]: remaining_length = 2 if self._protocol == MQTTv5: if properties is None: @@ -3143,16 +3405,16 @@ def _send_unsubscribe(self, dup, topics, properties=None): ) return (self._packet_queue(command, packet, local_mid, 1), local_mid) - def _check_clean_session(self): + def _check_clean_session(self) -> bool: if self._protocol == MQTTv5: if self._clean_start == MQTT_CLEAN_START_FIRST_ONLY: return self._mqttv5_first_connect else: - return self._clean_start + return self._clean_start # type: ignore else: return self._clean_session - def _messages_reconnect_reset_out(self): + def _messages_reconnect_reset_out(self) -> None: with self._out_message_mutex: self._inflight_messages = 0 for m in self._out_messages.values(): @@ -3184,7 +3446,7 @@ def _messages_reconnect_reset_out(self): else: m.state = mqtt_ms_queued - def _messages_reconnect_reset_in(self): + def _messages_reconnect_reset_in(self) -> None: with self._in_message_mutex: if self._check_clean_session(): self._in_messages = collections.OrderedDict() @@ -3197,20 +3459,29 @@ def _messages_reconnect_reset_in(self): # Preserve current state pass - def _messages_reconnect_reset(self): + def _messages_reconnect_reset(self) -> None: self._messages_reconnect_reset_out() self._messages_reconnect_reset_in() - def _packet_queue(self, command, packet, mid, qos, info=None): - mpkt = { - "command": command, - "mid": mid, - "qos": qos, - "pos": 0, - "to_process": len(packet), - "packet": packet, - "info": info, - } + def _packet_queue( + self, + command: int, + packet: bytes, + mid: int, + qos: int, + info: typing.Optional[MQTTMessageInfo] = None, + ) -> MQTTErrorCode: + mpkt = _OutPacket( + { + "command": command, + "mid": mid, + "qos": qos, + "pos": 0, + "to_process": len(packet), + "packet": packet, + "info": info, + } + ) self._out_packet.append(mpkt) @@ -3231,9 +3502,9 @@ def _packet_queue(self, command, packet, mid, qos, info=None): self._call_socket_register_write() - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _packet_handle(self): + def _packet_handle(self) -> MQTTErrorCode: cmd = self._in_packet["command"] & 0xF0 if cmd == PINGREQ: return self._handle_pingreq() @@ -3260,37 +3531,37 @@ def _packet_handle(self): else: # If we don't recognise the command, return an error straight away. self._easy_log(MQTT_LOG_ERR, "Error: Unrecognised command %s", cmd) - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL - def _handle_pingreq(self): + def _handle_pingreq(self) -> MQTTErrorCode: if self._in_packet["remaining_length"] != 0: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL self._easy_log(MQTT_LOG_DEBUG, "Received PINGREQ") return self._send_pingresp() - def _handle_pingresp(self): + def _handle_pingresp(self) -> MQTTErrorCode: if self._in_packet["remaining_length"] != 0: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL # No longer waiting for a PINGRESP. self._ping_t = 0 self._easy_log(MQTT_LOG_DEBUG, "Received PINGRESP") - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _handle_connack(self): + def _handle_connack(self) -> MQTTErrorCode: if self._protocol == MQTTv5: if self._in_packet["remaining_length"] < 2: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL elif self._in_packet["remaining_length"] != 2: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL if self._protocol == MQTTv5: (flags, result) = struct.unpack("!BB", self._in_packet["packet"][:2]) if result == 1: # This is probably a failure from a broker that doesn't support # MQTT v5. - reason = 132 # Unsupported protocol version + reason = ReasonCodes(CONNACK >> 4, aName="Unsupported protocol version") properties = None else: reason = ReasonCodes(CONNACK >> 4, identifier=result) @@ -3322,7 +3593,7 @@ def _handle_connack(self): flags, result, ) - self._client_id = base62(uuid.uuid4().int, padding=22) + self._client_id = base62(uuid.uuid4().int, padding=22).encode("utf8") return self.reconnect() if result == 0: @@ -3352,9 +3623,9 @@ def _handle_connack(self): with self._in_callback_mutex: try: if self._protocol == MQTTv5: - on_connect(self, self._userdata, flags_dict, reason, properties) + on_connect(self, self._userdata, flags_dict, reason, properties) # type: ignore else: - on_connect(self, self._userdata, flags_dict, result) + on_connect(self, self._userdata, flags_dict, result) # type: ignore except Exception as err: self._easy_log( MQTT_LOG_ERR, "Caught exception in on_connect: %s", err @@ -3363,7 +3634,7 @@ def _handle_connack(self): raise if result == 0: - rc = 0 + rc = MQTTErrorCode.MQTT_ERR_SUCCESS with self._out_message_mutex: for m in self._out_messages.values(): m.timestamp = time_func() @@ -3382,7 +3653,7 @@ def _handle_connack(self): m.dup, properties=m.properties, ) - if rc != 0: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc elif m.qos == 1: if m.state == mqtt_ms_publish: @@ -3398,7 +3669,7 @@ def _handle_connack(self): m.dup, properties=m.properties, ) - if rc != 0: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc elif m.qos == 2: if m.state == mqtt_ms_publish: @@ -3414,24 +3685,24 @@ def _handle_connack(self): m.dup, properties=m.properties, ) - if rc != 0: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc elif m.state == mqtt_ms_resend_pubrel: self._inflight_messages += 1 m.state = mqtt_ms_wait_for_pubcomp with self._in_callback_mutex: # Don't call loop_write after _send_publish() rc = self._send_pubrel(m.mid) - if rc != 0: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc self.loop_write() # Process outgoing messages that have just been queued up return rc elif result > 0 and result < 6: - return MQTT_ERR_CONN_REFUSED + return MQTTErrorCode.MQTT_ERR_CONN_REFUSED else: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL - def _handle_disconnect(self): + def _handle_disconnect(self) -> typing.Literal[MQTTErrorCode.MQTT_ERR_SUCCESS]: packet_type = DISCONNECT >> 4 reasonCode = properties = None if self._in_packet["remaining_length"] > 2: @@ -3446,9 +3717,9 @@ def _handle_disconnect(self): self._loop_rc_handle(reasonCode, properties) - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _handle_suback(self): + def _handle_suback(self) -> typing.Literal[MQTTErrorCode.MQTT_ERR_SUCCESS]: self._easy_log(MQTT_LOG_DEBUG, "Received SUBACK") pack_format = f"!H{len(self._in_packet['packet']) - 2}s" (mid, packet) = struct.unpack(pack_format, self._in_packet["packet"]) @@ -3470,9 +3741,9 @@ def _handle_suback(self): with self._in_callback_mutex: # Don't call loop_write after _send_publish() try: if self._protocol == MQTTv5: - on_subscribe(self, self._userdata, mid, reasoncodes, properties) + on_subscribe(self, self._userdata, mid, reasoncodes, properties) # type: ignore else: - on_subscribe(self, self._userdata, mid, granted_qos) + on_subscribe(self, self._userdata, mid, granted_qos) # type: ignore except Exception as err: self._easy_log( MQTT_LOG_ERR, "Caught exception in on_subscribe: %s", err @@ -3480,16 +3751,14 @@ def _handle_suback(self): if not self.suppress_exceptions: raise - return MQTT_ERR_SUCCESS - - def _handle_publish(self): - rc = 0 + return MQTTErrorCode.MQTT_ERR_SUCCESS + def _handle_publish(self) -> MQTTErrorCode: header = self._in_packet["command"] message = MQTTMessage() - message.dup = (header & 0x08) >> 3 + message.dup = ((header & 0x08) >> 3) != 0 message.qos = (header & 0x06) >> 1 - message.retain = header & 0x01 + message.retain = (header & 0x01) != 0 pack_format = f"!H{len(self._in_packet['packet']) - 2}s" (slen, packet) = struct.unpack(pack_format, self._in_packet["packet"]) @@ -3497,7 +3766,7 @@ def _handle_publish(self): (topic, packet) = struct.unpack(pack_format, packet) if self._protocol != MQTTv5 and len(topic) == 0: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL # Handle topics with invalid UTF-8 # This replaces an invalid topic with a message and the hex @@ -3548,11 +3817,11 @@ def _handle_publish(self): message.timestamp = time_func() if message.qos == 0: self._handle_on_message(message) - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS elif message.qos == 1: self._handle_on_message(message) if self._manual_ack: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS else: return self._send_puback(message.mid) elif message.qos == 2: @@ -3564,9 +3833,9 @@ def _handle_publish(self): return rc else: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL - def ack(self, mid: int, qos: int) -> int: + def ack(self, mid: int, qos: int) -> MQTTErrorCode: """ send an acknowledgement for a given message id. (stored in message.mid ) only useful in QoS=1 and auto_ack=False @@ -3577,9 +3846,9 @@ def ack(self, mid: int, qos: int) -> int: elif qos == 2: return self._send_pubcomp(mid) - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def manual_ack_set(self, on): + def manual_ack_set(self, on: bool) -> None: """ The paho library normally acknowledges messages as soon as they are delivered to the caller. If manual_ack is turned on, then the caller MUST manually acknowledge every message once @@ -3587,12 +3856,12 @@ def manual_ack_set(self, on): """ self._manual_ack = on - def _handle_pubrel(self): + def _handle_pubrel(self) -> MQTTErrorCode: if self._protocol == MQTTv5: if self._in_packet["remaining_length"] < 2: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL elif self._in_packet["remaining_length"] != 2: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL (mid,) = struct.unpack("!H", self._in_packet["packet"]) self._easy_log(MQTT_LOG_DEBUG, "Received PUBREL (Mid: %d)", mid) @@ -3607,7 +3876,7 @@ def _handle_pubrel(self): if self._max_inflight_messages > 0: with self._out_message_mutex: rc = self._update_inflight() - if rc != MQTT_ERR_SUCCESS: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc # FIXME: this should only be done if the message is known @@ -3617,11 +3886,11 @@ def _handle_pubrel(self): # Choose to acknowledge this message (thus losing a message) but # avoid hanging. See #284. if self._manual_ack: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS else: return self._send_pubcomp(mid) - def _update_inflight(self): + def _update_inflight(self) -> MQTTErrorCode: # Dont lock message_mutex here for m in self._out_messages.values(): if self._inflight_messages < self._max_inflight_messages: @@ -3640,18 +3909,18 @@ def _update_inflight(self): m.dup, properties=m.properties, ) - if rc != 0: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc else: - return MQTT_ERR_SUCCESS - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _handle_pubrec(self): + def _handle_pubrec(self) -> MQTTErrorCode: if self._protocol == MQTTv5: if self._in_packet["remaining_length"] < 2: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL elif self._in_packet["remaining_length"] != 2: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) if self._protocol == MQTTv5: @@ -3670,25 +3939,27 @@ def _handle_pubrec(self): msg.timestamp = time_func() return self._send_pubrel(mid) - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _handle_unsuback(self): + def _handle_unsuback(self) -> MQTTErrorCode: if self._protocol == MQTTv5: if self._in_packet["remaining_length"] < 4: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL elif self._in_packet["remaining_length"] != 2: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) if self._protocol == MQTTv5: packet = self._in_packet["packet"][2:] properties = Properties(UNSUBACK >> 4) props, props_len = properties.unpack(packet) - reasoncodes = [] + reasoncodes_list = [] for c in packet[props_len:]: - reasoncodes.append(ReasonCodes(UNSUBACK >> 4, identifier=c)) - if len(reasoncodes) == 1: - reasoncodes = reasoncodes[0] + reasoncodes_list.append(ReasonCodes(UNSUBACK >> 4, identifier=c)) + + reasoncodes: typing.Union[ReasonCodes, list[ReasonCodes]] = reasoncodes_list + if len(reasoncodes_list) == 1: + reasoncodes = reasoncodes_list[0] self._easy_log(MQTT_LOG_DEBUG, "Received UNSUBACK (Mid: %d)", mid) with self._callback_mutex: @@ -3699,10 +3970,10 @@ def _handle_unsuback(self): try: if self._protocol == MQTTv5: on_unsubscribe( - self, self._userdata, mid, properties, reasoncodes + self, self._userdata, mid, properties, reasoncodes # type: ignore ) else: - on_unsubscribe(self, self._userdata, mid) + on_unsubscribe(self, self._userdata, mid) # type: ignore except Exception as err: self._easy_log( MQTT_LOG_ERR, "Caught exception in on_unsubscribe: %s", err @@ -3710,9 +3981,13 @@ def _handle_unsuback(self): if not self.suppress_exceptions: raise - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _do_on_disconnect(self, rc, properties=None): + def _do_on_disconnect( + self, + rc: typing.Union[MQTTErrorCode, ReasonCodes], + properties: typing.Optional[Properties] = None, + ) -> None: with self._callback_mutex: on_disconnect = self.on_disconnect @@ -3720,9 +3995,9 @@ def _do_on_disconnect(self, rc, properties=None): with self._in_callback_mutex: try: if self._protocol == MQTTv5: - on_disconnect(self, self._userdata, rc, properties) + on_disconnect(self, self._userdata, rc, properties) # type: ignore else: - on_disconnect(self, self._userdata, rc) + on_disconnect(self, self._userdata, rc) # type: ignore except Exception as err: self._easy_log( MQTT_LOG_ERR, "Caught exception in on_disconnect: %s", err @@ -3730,7 +4005,7 @@ def _do_on_disconnect(self, rc, properties=None): if not self.suppress_exceptions: raise - def _do_on_publish(self, mid): + def _do_on_publish(self, mid: int) -> MQTTErrorCode: with self._callback_mutex: on_publish = self.on_publish @@ -3751,16 +4026,18 @@ def _do_on_publish(self, mid): self._inflight_messages -= 1 if self._max_inflight_messages > 0: rc = self._update_inflight() - if rc != MQTT_ERR_SUCCESS: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _handle_pubackcomp(self, cmd): + def _handle_pubackcomp( + self, cmd: typing.Union[typing.Literal["PUBACK"], typing.Literal["PUBCOMP"]] + ) -> MQTTErrorCode: if self._protocol == MQTTv5: if self._in_packet["remaining_length"] < 2: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL elif self._in_packet["remaining_length"] != 2: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL packet_type = PUBACK if cmd == "PUBACK" else PUBCOMP packet_type = packet_type >> 4 @@ -3780,9 +4057,9 @@ def _handle_pubackcomp(self, cmd): rc = self._do_on_publish(mid) return rc - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _handle_on_message(self, message): + def _handle_on_message(self, message: MQTTMessage) -> None: try: topic = message.topic except UnicodeDecodeError: @@ -3824,7 +4101,7 @@ def _handle_on_message(self, message): if not self.suppress_exceptions: raise - def _handle_on_connect_fail(self): + def _handle_on_connect_fail(self) -> None: with self._callback_mutex: on_connect_fail = self.on_connect_fail @@ -3837,10 +4114,10 @@ def _handle_on_connect_fail(self): MQTT_LOG_ERR, "Caught exception in on_connect_fail: %s", err ) - def _thread_main(self): + def _thread_main(self) -> None: self.loop_forever(retry_first_connection=True) - def _reconnect_wait(self): + def _reconnect_wait(self) -> None: # See reconnect_delay_set for details now = time_func() with self._reconnect_delay_mutex: @@ -3864,8 +4141,8 @@ def _reconnect_wait(self): remaining = target_time - time_func() @staticmethod - def _proxy_is_valid(p): - def check(t, a): + def _proxy_is_valid(p) -> bool: # type: ignore[no-untyped-def] + def check(t, a) -> bool: # type: ignore[no-untyped-def] return ( socks is not None and t in set([socks.HTTP, socks.SOCKS4, socks.SOCKS5]) @@ -3879,7 +4156,7 @@ def check(t, a): else: return False - def _get_proxy(self): + def _get_proxy(self) -> typing.Optional[dict[str, typing.Any]]: if socks is None: return None @@ -3930,7 +4207,7 @@ def _get_proxy(self): # None to indicate that the connection should be handled normally return None - def _create_socket_connection(self): + def _create_socket_connection(self) -> _socket.socket: proxy = self._get_proxy() addr = (self._host, self._port) source = (self._bind_address, self._bind_port) @@ -3953,7 +4230,15 @@ class WebsocketWrapper: OPCODE_PING = 0x9 OPCODE_PONG = 0xA - def __init__(self, socket, host, port, is_ssl, path, extra_headers): + def __init__( + self, + socket: typing.Union[socket.socket, "ssl.SSLSocket"], + host: str, + port: int, + is_ssl: bool, + path: str, + extra_headers: typing.Optional[WebSocketHeaders], + ): self.connected = False self._ssl = is_ssl @@ -3971,11 +4256,11 @@ def __init__(self, socket, host, port, is_ssl, path, extra_headers): self._do_handshake(extra_headers) - def __del__(self): - self._sendbuffer = None - self._readbuffer = None + def __del__(self) -> None: + self._sendbuffer = bytearray() + self._readbuffer = bytearray() - def _do_handshake(self, extra_headers): + def _do_handshake(self, extra_headers: typing.Optional[WebSocketHeaders]) -> None: sec_websocket_key = uuid.uuid4().bytes sec_websocket_key = base64.b64encode(sec_websocket_key) @@ -4035,15 +4320,17 @@ def _do_handshake(self, extra_headers): ): GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - server_hash = self._readbuffer.decode("utf-8").split(": ", 1)[1] - server_hash = server_hash.strip().encode("utf-8") + server_hash_str = self._readbuffer.decode("utf-8").split( + ": ", 1 + )[1] + server_hash = server_hash_str.strip().encode("utf-8") - client_hash = sec_websocket_key.decode("utf-8") + GUID + client_hash_key = sec_websocket_key.decode("utf-8") + GUID # Use of SHA-1 is OK here; it's according to the Websocket spec. - client_hash = hashlib.sha1( # noqa: S324 - client_hash.encode("utf-8") + client_hash_digest = hashlib.sha1( # noqa: S324 + client_hash_key.encode("utf-8") ) - client_hash = base64.b64encode(client_hash.digest()) + client_hash = base64.b64encode(client_hash_digest.digest()) if server_hash != client_hash: raise WebsocketConnectionError( @@ -4068,7 +4355,9 @@ def _do_handshake(self, extra_headers): self._readbuffer = bytearray() self.connected = True - def _create_frame(self, opcode, data, do_masking=1): + def _create_frame( + self, opcode: int, data: bytearray, do_masking: int = 1 + ) -> bytearray: header = bytearray() length = len(data) @@ -4099,7 +4388,7 @@ def _create_frame(self, opcode, data, do_masking=1): return header + data - def _buffered_read(self, length): + def _buffered_read(self, length: int) -> bytearray: # try to recv and store needed bytes wanted_bytes = length - (len(self._readbuffer) - self._readbuffer_head) if wanted_bytes > 0: @@ -4116,12 +4405,12 @@ def _buffered_read(self, length): self._readbuffer_head += length return self._readbuffer[self._readbuffer_head - length : self._readbuffer_head] - def _recv_impl(self, length): + def _recv_impl(self, length: int) -> bytes: # try to decode websocket payload part from data try: self._readbuffer_head = 0 - result = None + result = b"" chunk_startindex = self._payload_head chunk_endindex = self._payload_head + length @@ -4158,7 +4447,7 @@ def _recv_impl(self, length): payload = self._buffered_read(readindex) # unmask only the needed part - if maskbit: + if mask_key is not None: for index in range(chunk_startindex, readindex): payload[index] ^= mask_key[index % 4] @@ -4197,7 +4486,7 @@ def _recv_impl(self, length): self.connected = False return b"" - def _send_impl(self, data): + def _send_impl(self, data: bytes) -> int: # if previous frame was sent successfully if len(self._sendbuffer) == 0: # create websocket frame @@ -4217,32 +4506,32 @@ def _send_impl(self, data): # couldn't send whole data, request the same data again with 0 as sent length return 0 - def recv(self, length): + def recv(self, length: int) -> bytes: return self._recv_impl(length) - def read(self, length): + def read(self, length: int) -> bytes: return self._recv_impl(length) - def send(self, data): + def send(self, data: bytes) -> int: return self._send_impl(data) - def write(self, data): + def write(self, data: bytes) -> int: return self._send_impl(data) - def close(self): + def close(self) -> None: self._socket.close() - def fileno(self): + def fileno(self) -> int: return self._socket.fileno() - def pending(self): + def pending(self) -> int: # Fix for bug #131: a SSL socket may still have data available # for reading without select() being aware of it. if self._ssl: - return self._socket.pending() + return self._socket.pending() # type: ignore[union-attr] else: # normal socket rely only on select() return 0 - def setblocking(self, flag): + def setblocking(self, flag: bool) -> None: self._socket.setblocking(flag) diff --git a/src/paho/mqtt/publish.py b/src/paho/mqtt/publish.py index 0fcb65d9..9b79988b 100644 --- a/src/paho/mqtt/publish.py +++ b/src/paho/mqtt/publish.py @@ -20,13 +20,46 @@ """ import collections +import typing from collections.abc import Iterable from .. import mqtt from . import client as paho +if typing.TYPE_CHECKING: + try: + from typing import NotRequired, Required + except ImportError: + from typing_extensions import NotRequired, Required -def _do_publish(client): + +class AuthParamater(typing.TypedDict, total=False): + username: "Required[str]" + password: "NotRequired[str]" + + +class TLSParamater(typing.TypedDict, total=False): + ca_certs: "Required[str]" + certfile: "NotRequired[str]" + keyfile: "NotRequired[str]" + tls_version: "NotRequired[int]" + ciphers: "NotRequired[str]" + insecure: "NotRequired[bool]" + + +class MessageDict(typing.TypedDict, total=False): + topic: "Required[str]" + payload: "NotRequired[paho.PayloadType]" + qos: "NotRequired[int]" + retain: "NotRequired[bool]" + + +MessageTuple = typing.Tuple[str, paho.PayloadType, int, bool] + +MessagesList = list[typing.Union[MessageDict, MessageTuple]] + + +def _do_publish(client: paho.Client): """Internal function""" message = client._userdata.popleft() @@ -50,12 +83,14 @@ def _on_connect(client, userdata, flags, rc): raise mqtt.MQTTException(paho.connack_string(rc)) -def _on_connect_v5(client, userdata, flags, rc, properties): +def _on_connect_v5(client: paho.Client, userdata: MessagesList, flags, rc, properties): """Internal v5 callback""" _on_connect(client, userdata, flags, rc) -def _on_publish(client, userdata, mid): +def _on_publish( + client: paho.Client, userdata: collections.deque[MessagesList], mid: int +) -> None: """Internal callback""" # pylint: disable=unused-argument @@ -66,18 +101,18 @@ def _on_publish(client, userdata, mid): def multiple( - msgs, - hostname="localhost", - port=1883, - client_id="", - keepalive=60, - will=None, - auth=None, - tls=None, - protocol=paho.MQTTv311, - transport="tcp", - proxy_args=None, -): + msgs: MessagesList, + hostname: str = "localhost", + port: int = 1883, + client_id: str = "", + keepalive: int = 60, + will: typing.Optional[MessageDict] = None, + auth: typing.Optional[AuthParamater] = None, + tls: typing.Optional[TLSParamater] = None, + protocol: int = paho.MQTTv311, + transport: str = "tcp", + proxy_args: typing.Optional[typing.Any] = None, +) -> None: """Publish multiple messages to a broker, then disconnect cleanly. This function creates an MQTT client, connects to a broker and publishes a @@ -152,9 +187,9 @@ def multiple( client.on_publish = _on_publish if protocol == mqtt.client.MQTTv5: - client.on_connect = _on_connect_v5 + client.on_connect = _on_connect_v5 # type: ignore else: - client.on_connect = _on_connect + client.on_connect = _on_connect # type: ignore if proxy_args is not None: client.proxy_set(**proxy_args) @@ -175,7 +210,8 @@ def multiple( if tls is not None: if isinstance(tls, dict): insecure = tls.pop("insecure", False) - client.tls_set(**tls) + # mypy don't get the tls no longer contains the key insecure + client.tls_set(**tls) # type: ignore[misc] if insecure: # Must be set *after* the `client.tls_set()` call since it sets # up the SSL context that `client.tls_insecure_set` alters. @@ -189,21 +225,21 @@ def multiple( def single( - topic, - payload=None, - qos=0, - retain=False, - hostname="localhost", - port=1883, - client_id="", - keepalive=60, - will=None, - auth=None, - tls=None, - protocol=paho.MQTTv311, - transport="tcp", - proxy_args=None, -): + topic: str, + payload: paho.PayloadType = None, + qos: int = 0, + retain: bool = False, + hostname: str = "localhost", + port: int = 1883, + client_id: str = "", + keepalive: int = 60, + will: typing.Optional[MessageDict] = None, + auth: typing.Optional[AuthParamater] = None, + tls: typing.Optional[TLSParamater] = None, + protocol: int = paho.MQTTv311, + transport: str = "tcp", + proxy_args: typing.Optional[typing.Any] = None, +) -> None: """Publish a single message to a broker, then disconnect cleanly. This function creates an MQTT client, connects to a broker and publishes a @@ -259,7 +295,9 @@ def single( proxy_args: a dictionary that will be given to the client. """ - msg = {"topic": topic, "payload": payload, "qos": qos, "retain": retain} + msg = MessageDict( + {"topic": topic, "payload": payload, "qos": qos, "retain": retain} + ) multiple( [msg], diff --git a/src/paho/mqtt/subscribeoptions.py b/src/paho/mqtt/subscribeoptions.py index 20dad8ca..c4a76ca3 100644 --- a/src/paho/mqtt/subscribeoptions.py +++ b/src/paho/mqtt/subscribeoptions.py @@ -41,10 +41,10 @@ class SubscribeOptions: def __init__( self, - qos=0, - noLocal=False, - retainAsPublished=False, - retainHandling=RETAIN_SEND_ON_SUBSCRIBE, + qos: int = 0, + noLocal: bool = False, + retainAsPublished: bool = False, + retainHandling: int = RETAIN_SEND_ON_SUBSCRIBE, ): """ qos: 0, 1 or 2. 0 is the default. diff --git a/tests/lib/clients/03-publish-b2c-qos1.py b/tests/lib/clients/03-publish-b2c-qos1.py index 9efc928d..880ad2ff 100644 --- a/tests/lib/clients/03-publish-b2c-qos1.py +++ b/tests/lib/clients/03-publish-b2c-qos1.py @@ -10,7 +10,7 @@ def on_message(mqttc, obj, msg): assert msg.topic == "pub/qos1/receive", f"Invalid topic: ({msg.topic})" assert msg.payload == expected_payload, f"Invalid payload: ({msg.payload})" assert msg.qos == 1, f"Invalid qos: ({msg.qos})" - assert msg.retain is not False, f"Invalid retain: ({msg.retain})" + assert not msg.retain, f"Invalid retain: ({msg.retain})" def on_connect(mqttc, obj, flags, rc): diff --git a/tests/test_client.py b/tests/test_client.py index ea0e9a71..dd07b181 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -311,3 +311,44 @@ def callback2(client, userdata, msg): assert userdata["on_message"] == 1 assert userdata["callback1"] == 1 assert userdata["callback2"] == 2 + + +class Test_compatibility: + """ + Few test for backward compatibility + """ + + def test_change_error_code_to_enum(self): + """Make sure code don't break after MQTTErrorCode enum introduction""" + rc_ok = client.MQTTErrorCode.MQTT_ERR_SUCCESS + rc_again = client.MQTTErrorCode.MQTT_ERR_AGAIN + rc_err = client.MQTTErrorCode.MQTT_ERR_NOMEM + + # Access using old name still works + assert rc_ok == client.MQTT_ERR_SUCCESS + + # User might compare to 0 to check for success + assert rc_ok == 0 + assert not rc_err == 0 + assert not rc_again == 0 + assert not rc_ok != 0 + assert rc_err != 0 + assert rc_again != 0 + + # User might compare to specific code + assert rc_again == -1 + assert rc_err == 1 + + # User might just use "if rc:" + assert not rc_ok + assert rc_err + assert rc_again + + # User might do inequality with 0 (like "if rc > 0") + assert not (rc_ok > 0) + assert rc_err > 0 + assert rc_again < 0 + + # This might probably not be done: User might use rc as number in + # operation + assert rc_ok + 1 == 1 diff --git a/tox.ini b/tox.ini index bddabd0b..e5fbf849 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,6 @@ envlist = py{37,38,39,310,311,312} whitelist_externals = echo make deps = -rrequirements.txt - ruff==0.1.8 allowlist_externals = echo make @@ -16,12 +15,13 @@ env = [testenv:lint] deps = - -e . + -e .[proxy] + dnspython mypy pre-commit safety commands = # The "-" in front of command tells tox to ignore errors pre-commit run --all-files - - mypy --ignore-missing-imports src + mypy src safety check