From 76ebb21395bc527b089a78298df9b72e82f01c00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20B=C3=B6hm?= Date: Wed, 27 Dec 2023 00:00:41 +0100 Subject: [PATCH] refactor: Isolate connection state --- aiomqtt/client.py | 60 +++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/aiomqtt/client.py b/aiomqtt/client.py index f4fb3b2..c7c2403 100644 --- a/aiomqtt/client.py +++ b/aiomqtt/client.py @@ -37,6 +37,7 @@ from typing_extensions import Concatenate, ParamSpec, TypeAlias +MAX_TOPIC_LENGTH = 65535 MQTT_LOGGER = logging.getLogger("mqtt") MQTT_LOGGER.setLevel(logging.WARNING) @@ -122,9 +123,6 @@ async def decorated(self: ClientT, /, *args: P.args, **kwargs: P.kwargs) -> T: return decorated -MAX_TOPIC_LENGTH = 65535 - - @dataclass(frozen=True) class Wildcard: """MQTT wildcard that can be subscribed to, but not published to. @@ -147,7 +145,7 @@ def __str__(self) -> str: def __post_init__(self) -> None: """Validate the wildcard.""" if not isinstance(self.value, str): - msg = "wildcard must be a string" + msg = "Wildcard must be of type str" raise TypeError(msg) if ( len(self.value) == 0 @@ -180,7 +178,7 @@ class Topic(Wildcard): def __post_init__(self) -> None: """Validate the topic.""" if not isinstance(self.value, str): - msg = "topic must be a string" + msg = "Topic must be of type str" raise TypeError(msg) if ( len(self.value) == 0 @@ -209,17 +207,18 @@ def matches(self, wildcard: WildcardLike) -> bool: # Shared subscriptions use the topic structure: $share// wildcard_levels = wildcard_levels[2:] - def recurse(x: list[str], y: list[str]) -> bool: - if not x: - if not y or y[0] == "#": + def recurse(tl: list[str], wl: list[str]) -> bool: + """Recursively match topic levels with wildcard levels.""" + if not tl: + if not wl or wl[0] == "#": return True return False - if not y: + if not wl: return False - if y[0] == "#": + if wl[0] == "#": return True - if x[0] == y[0] or y[0] == "+": - return recurse(x[1:], y[1:]) + if tl[0] == wl[0] or wl[0] == "+": + return recurse(tl[1:], wl[1:]) return False return recurse(topic_levels, wildcard_levels) @@ -368,12 +367,12 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 self._clean_start = clean_start self._properties = properties self._loop = asyncio.get_event_loop() - self._connected: asyncio.Future[ - int | mqtt.ReasonCodes | None - ] = asyncio.Future() - self._disconnected: asyncio.Future[ - int | mqtt.ReasonCodes | None - ] = asyncio.Future() + + # Connection state + self._connected: asyncio.Future[None] = asyncio.Future() + self._disconnected: asyncio.Future[None] = asyncio.Future() + self._lock: asyncio.Lock = asyncio.Lock() + # Pending subscribe, unsubscribe, and publish calls self._pending_subscribes: dict[ int, asyncio.Future[tuple[int] | list[mqtt.ReasonCodes]] @@ -462,7 +461,6 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 if socket_options is None: socket_options = () self._socket_options = tuple(socket_options) - self._lock: asyncio.Lock = asyncio.Lock() if timeout is None: timeout = 10 @@ -717,7 +715,7 @@ async def unfiltered_messages( ) # Early out if self._unfiltered_messages_callback is not None: - msg = "Only a single unfiltered_messages generator can be used at a time." + msg = "Only a single unfiltered_messages generator can be used at a time" raise RuntimeError(msg) callback, generator = self._deprecated_callback_and_generator( log_context="unfiltered", queue_maxsize=queue_maxsize @@ -906,15 +904,15 @@ def _on_connect( # noqa: PLR0913 rc: int | mqtt.ReasonCodes, properties: mqtt.Properties | None = None, ) -> None: + """Called when we receive a CONNACK message from the broker.""" # Return early if already connected. Sometimes, paho-mqtt calls _on_connect # multiple times. Maybe because we receive multiple CONNACK messages # from the server. In any case, we return early so that we don't set # self._connected twice (as it raises an asyncio.InvalidStateError). if self._connected.done(): return - if rc == mqtt.CONNACK_ACCEPTED: - self._connected.set_result(rc) + self._connected.set_result(None) else: self._connected.set_exception(MqttConnectError(rc)) @@ -943,9 +941,11 @@ def _on_disconnect( if not self._connected.done() or self._connected.exception() is not None: return if rc == mqtt.MQTT_ERR_SUCCESS: - self._disconnected.set_result(rc) + self._disconnected.set_result(None) else: - self._disconnected.set_exception(MqttCodeError(rc, "Unexpected disconnect")) + self._disconnected.set_exception( + MqttCodeError(rc, "Unexpected disconnection") + ) def _on_subscribe( # noqa: PLR0913 self, @@ -1060,7 +1060,7 @@ async def _misc_loop(self) -> None: async def __aenter__(self) -> Client: """Connect to the broker.""" if self._lock.locked(): - msg = "The client context manager is reusable, but not reentrant." + msg = "The client context manager is reusable, but not reentrant" raise MqttReentrantError(msg) await self._lock.acquire() try: @@ -1078,14 +1078,12 @@ async def __aenter__(self) -> Client: self._clean_start, self._properties, ) - client_socket = self._client.socket() - _set_client_socket_defaults(client_socket, self._socket_options) - # paho.mqtt.Client.connect may raise one of several exceptions. - # We convert all of them to the common MqttError for user convenience. + _set_client_socket_defaults(self._client.socket(), self._socket_options) + # Convert all possible paho-mqtt Client.connect exceptions to our MqttError # See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770 - except (OSError, mqtt.WebsocketConnectionError) as error: + except (OSError, mqtt.WebsocketConnectionError) as exc: self._lock.release() - raise MqttError(str(error)) from None + raise MqttError(str(exc)) from None await self._wait_for(self._connected, timeout=None) # Reset `_disconnected` if it's already in completed state after connecting if self._disconnected.done():