Skip to content

Commit

Permalink
refactor: Isolate connection state
Browse files Browse the repository at this point in the history
  • Loading branch information
empicano committed Dec 26, 2023
1 parent db67c14 commit 76ebb21
Showing 1 changed file with 29 additions and 31 deletions.
60 changes: 29 additions & 31 deletions aiomqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from typing_extensions import Concatenate, ParamSpec, TypeAlias


MAX_TOPIC_LENGTH = 65535
MQTT_LOGGER = logging.getLogger("mqtt")
MQTT_LOGGER.setLevel(logging.WARNING)

Expand Down Expand Up @@ -122,9 +123,6 @@ async def decorated(self: ClientT, /, *args: P.args, **kwargs: P.kwargs) -> T:
return decorated


MAX_TOPIC_LENGTH = 65535


@dataclass(frozen=True)
class Wildcard:
"""MQTT wildcard that can be subscribed to, but not published to.
Expand All @@ -147,7 +145,7 @@ def __str__(self) -> str:
def __post_init__(self) -> None:
"""Validate the wildcard."""
if not isinstance(self.value, str):
msg = "wildcard must be a string"
msg = "Wildcard must be of type str"
raise TypeError(msg)
if (
len(self.value) == 0
Expand Down Expand Up @@ -180,7 +178,7 @@ class Topic(Wildcard):
def __post_init__(self) -> None:
"""Validate the topic."""
if not isinstance(self.value, str):
msg = "topic must be a string"
msg = "Topic must be of type str"
raise TypeError(msg)
if (
len(self.value) == 0
Expand Down Expand Up @@ -209,17 +207,18 @@ def matches(self, wildcard: WildcardLike) -> bool:
# Shared subscriptions use the topic structure: $share/<group_id>/<topic>
wildcard_levels = wildcard_levels[2:]

def recurse(x: list[str], y: list[str]) -> bool:
if not x:
if not y or y[0] == "#":
def recurse(tl: list[str], wl: list[str]) -> bool:
"""Recursively match topic levels with wildcard levels."""
if not tl:
if not wl or wl[0] == "#":
return True
return False
if not y:
if not wl:
return False
if y[0] == "#":
if wl[0] == "#":
return True
if x[0] == y[0] or y[0] == "+":
return recurse(x[1:], y[1:])
if tl[0] == wl[0] or wl[0] == "+":
return recurse(tl[1:], wl[1:])
return False

return recurse(topic_levels, wildcard_levels)
Expand Down Expand Up @@ -368,12 +367,12 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
self._clean_start = clean_start
self._properties = properties
self._loop = asyncio.get_event_loop()
self._connected: asyncio.Future[
int | mqtt.ReasonCodes | None
] = asyncio.Future()
self._disconnected: asyncio.Future[
int | mqtt.ReasonCodes | None
] = asyncio.Future()

# Connection state
self._connected: asyncio.Future[None] = asyncio.Future()
self._disconnected: asyncio.Future[None] = asyncio.Future()
self._lock: asyncio.Lock = asyncio.Lock()

# Pending subscribe, unsubscribe, and publish calls
self._pending_subscribes: dict[
int, asyncio.Future[tuple[int] | list[mqtt.ReasonCodes]]
Expand Down Expand Up @@ -462,7 +461,6 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
if socket_options is None:
socket_options = ()
self._socket_options = tuple(socket_options)
self._lock: asyncio.Lock = asyncio.Lock()

if timeout is None:
timeout = 10
Expand Down Expand Up @@ -717,7 +715,7 @@ async def unfiltered_messages(
)
# Early out
if self._unfiltered_messages_callback is not None:
msg = "Only a single unfiltered_messages generator can be used at a time."
msg = "Only a single unfiltered_messages generator can be used at a time"
raise RuntimeError(msg)
callback, generator = self._deprecated_callback_and_generator(
log_context="unfiltered", queue_maxsize=queue_maxsize
Expand Down Expand Up @@ -906,15 +904,15 @@ def _on_connect( # noqa: PLR0913
rc: int | mqtt.ReasonCodes,
properties: mqtt.Properties | None = None,
) -> None:
"""Called when we receive a CONNACK message from the broker."""
# Return early if already connected. Sometimes, paho-mqtt calls _on_connect
# multiple times. Maybe because we receive multiple CONNACK messages
# from the server. In any case, we return early so that we don't set
# self._connected twice (as it raises an asyncio.InvalidStateError).
if self._connected.done():
return

if rc == mqtt.CONNACK_ACCEPTED:
self._connected.set_result(rc)
self._connected.set_result(None)
else:
self._connected.set_exception(MqttConnectError(rc))

Expand Down Expand Up @@ -943,9 +941,11 @@ def _on_disconnect(
if not self._connected.done() or self._connected.exception() is not None:
return
if rc == mqtt.MQTT_ERR_SUCCESS:
self._disconnected.set_result(rc)
self._disconnected.set_result(None)
else:
self._disconnected.set_exception(MqttCodeError(rc, "Unexpected disconnect"))
self._disconnected.set_exception(
MqttCodeError(rc, "Unexpected disconnection")
)

def _on_subscribe( # noqa: PLR0913
self,
Expand Down Expand Up @@ -1060,7 +1060,7 @@ async def _misc_loop(self) -> None:
async def __aenter__(self) -> Client:
"""Connect to the broker."""
if self._lock.locked():
msg = "The client context manager is reusable, but not reentrant."
msg = "The client context manager is reusable, but not reentrant"
raise MqttReentrantError(msg)
await self._lock.acquire()
try:
Expand All @@ -1078,14 +1078,12 @@ async def __aenter__(self) -> Client:
self._clean_start,
self._properties,
)
client_socket = self._client.socket()
_set_client_socket_defaults(client_socket, self._socket_options)
# paho.mqtt.Client.connect may raise one of several exceptions.
# We convert all of them to the common MqttError for user convenience.
_set_client_socket_defaults(self._client.socket(), self._socket_options)
# Convert all possible paho-mqtt Client.connect exceptions to our MqttError
# See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770
except (OSError, mqtt.WebsocketConnectionError) as error:
except (OSError, mqtt.WebsocketConnectionError) as exc:
self._lock.release()
raise MqttError(str(error)) from None
raise MqttError(str(exc)) from None
await self._wait_for(self._connected, timeout=None)
# Reset `_disconnected` if it's already in completed state after connecting
if self._disconnected.done():
Expand Down

0 comments on commit 76ebb21

Please sign in to comment.