Skip to content

Commit

Permalink
refactor: Remove deprecated connect and disconnect methods
Browse files Browse the repository at this point in the history
  • Loading branch information
empicano committed Dec 26, 2023
1 parent 76ebb21 commit 7c95341
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 197 deletions.
83 changes: 0 additions & 83 deletions aiomqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,89 +485,6 @@ def _pending_calls(self) -> Generator[int, None, None]:
yield from self._pending_unsubscribes.keys()
yield from self._pending_publishes.keys()

async def connect(self, *, timeout: float | None = None) -> None:
self._logger.warning(
"The manual `connect` and `disconnect` methods are deprecated and will be"
" removed in a future version. The preferred way to connect and disconnect"
" the client is to use the context manager interface via `async with`. In"
" case your use case needs to connect and disconnect manually, you can call"
" the context manager's `__aenter__` and `__aexit__` methods as an escape"
" hatch instead. `__aenter__` is equivalent to `connect`. `__aexit__` is"
" equivalent to `disconnect` except that it forces disconnection instead"
" of throwing an exception in case the client cannot disconnect cleanly."
" `__aexit__` expects three arguments: `exc_type`, `exc`, and `tb`. These"
" arguments describe the exception that caused the context manager to exit,"
" if any. You can pass `None` to all of these arguments in a manual call to"
" `__aexit__`."
)
try:
loop = asyncio.get_running_loop()

# [3] Run connect() within an executor thread, since it blocks on socket
# connection for up to `keepalive` seconds: https://git.io/Jt5Yc
await loop.run_in_executor(
None,
self._client.connect,
self._hostname,
self._port,
self._keepalive,
self._bind_address,
self._bind_port,
self._clean_start,
self._properties,
)
client_socket = self._client.socket()
_set_client_socket_defaults(client_socket, self._socket_options)
# paho.mqtt.Client.connect may raise one of several exceptions.
# We convert all of them to the common MqttError for user convenience.
# See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770
except (OSError, mqtt.WebsocketConnectionError) as error:
raise MqttError(str(error)) from None
await self._wait_for(self._connected, timeout=timeout)
# If _disconnected is already completed after connecting, reset it.
if self._disconnected.done():
self._disconnected = asyncio.Future()

def _early_out_on_disconnected(self) -> bool:
# Early out if already disconnected...
if self._disconnected.done():
disc_exc = self._disconnected.exception()
if disc_exc is not None:
# ...by raising the error that caused the disconnect
raise disc_exc
# ...by returning since the disconnect was intentional
return True
return False

async def disconnect(self, *, timeout: float | None = None) -> None:
"""Disconnect from the broker."""
self._logger.warning(
"The manual `connect` and `disconnect` methods are deprecated and will be"
" removed in a future version. The preferred way to connect and disconnect"
" the client is to use the context manager interface via `async with`. In"
" case your use case needs to connect and disconnect manually, you can call"
" the context manager's `__aenter__` and `__aexit__` methods as an escape"
" hatch instead. `__aenter__` is equivalent to `connect`. `__aexit__` is"
" equivalent to `disconnect` except that it forces disconnection instead"
" of throwing an exception in case the client cannot disconnect cleanly."
" `__aexit__` expects three arguments: `exc_type`, `exc`, and `tb`. These"
" arguments describe the exception that caused the context manager to exit,"
" if any. You can pass `None` to all of these arguments in a manual call to"
" `__aexit__`."
)
if self._early_out_on_disconnected():
return
# Try to gracefully disconnect from the broker
rc = self._client.disconnect()
# Early out on error
if rc != mqtt.MQTT_ERR_SUCCESS:
raise MqttCodeError(rc, "Could not disconnect")
# Wait for acknowledgement
await self._wait_for(self._disconnected, timeout=timeout)
# If _connected is still in the completed state after disconnection, reset it
if self._connected.done():
self._connected = asyncio.Future()

@_outgoing_call
async def subscribe( # noqa: PLR0913
self,
Expand Down
146 changes: 32 additions & 114 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

HOSTNAME = "test.mosquitto.org"
OS_PY_VERSION = sys.platform + "_" + ".".join(map(str, sys.version_info[:2]))
TOPIC_HEADER = OS_PY_VERSION + "/tests/aiomqtt/"
TOPIC_PREFIX = OS_PY_VERSION + "/tests/aiomqtt/"


async def test_topic_validation() -> None:
Expand Down Expand Up @@ -91,7 +91,7 @@ async def test_topic_matches() -> None:
@pytest.mark.network
async def test_multiple_messages_generators() -> None:
"""Test that multiple Client.messages() generators can be used at the same time."""
topic = TOPIC_HEADER + "multiple_messages_generators"
topic = TOPIC_PREFIX + "multiple_messages_generators"

async def handler(tg: anyio.abc.TaskGroup) -> None:
async with client.messages() as messages:
Expand All @@ -110,7 +110,7 @@ async def handler(tg: anyio.abc.TaskGroup) -> None:

@pytest.mark.network
async def test_client_filtered_messages() -> None:
topic_header = TOPIC_HEADER + "filtered_messages/"
topic_header = TOPIC_PREFIX + "filtered_messages/"
good_topic = topic_header + "good"
bad_topic = topic_header + "bad"

Expand All @@ -131,7 +131,7 @@ async def handle_messages(tg: anyio.abc.TaskGroup) -> None:

@pytest.mark.network
async def test_client_unfiltered_messages() -> None:
topic_header = TOPIC_HEADER + "unfiltered_messages/"
topic_header = TOPIC_PREFIX + "unfiltered_messages/"
topic_filtered = topic_header + "filtered"
topic_unfiltered = topic_header + "unfiltered"

Expand All @@ -158,7 +158,7 @@ async def handle_filtered_messages() -> None:

@pytest.mark.network
async def test_client_unsubscribe() -> None:
topic_header = TOPIC_HEADER + "unsubscribe/"
topic_header = TOPIC_PREFIX + "unsubscribe/"
topic1 = topic_header + "1"
topic2 = topic_header + "2"

Expand Down Expand Up @@ -196,7 +196,7 @@ async def test_client_id(protocol: ProtocolVersion, length: int) -> None:

@pytest.mark.network
async def test_client_will() -> None:
topic = TOPIC_HEADER + "will"
topic = TOPIC_PREFIX + "will"
event = anyio.Event()

async def launch_client() -> None:
Expand All @@ -218,7 +218,7 @@ async def launch_client() -> None:

@pytest.mark.network
async def test_client_tls_context() -> None:
topic = TOPIC_HEADER + "tls_context"
topic = TOPIC_PREFIX + "tls_context"

async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
async with client.filtered_messages(topic) as messages:
Expand All @@ -240,7 +240,7 @@ async def handle_messages(tg: anyio.abc.TaskGroup) -> None:

@pytest.mark.network
async def test_client_tls_params() -> None:
topic = TOPIC_HEADER + "tls_params"
topic = TOPIC_PREFIX + "tls_params"

async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
async with client.filtered_messages(topic) as messages:
Expand All @@ -264,7 +264,7 @@ async def handle_messages(tg: anyio.abc.TaskGroup) -> None:

@pytest.mark.network
async def test_client_username_password() -> None:
topic = TOPIC_HEADER + "username_password"
topic = TOPIC_PREFIX + "username_password"

async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
async with client.filtered_messages(topic) as messages:
Expand All @@ -291,7 +291,7 @@ async def test_client_logger() -> None:
async def test_client_max_concurrent_outgoing_calls(
monkeypatch: pytest.MonkeyPatch,
) -> None:
topic = TOPIC_HEADER + "max_concurrent_outgoing_calls"
topic = TOPIC_PREFIX + "max_concurrent_outgoing_calls"

class MockPahoClient(mqtt.Client):
def subscribe(
Expand Down Expand Up @@ -337,7 +337,7 @@ def publish( # noqa: PLR0913

@pytest.mark.network
async def test_client_websockets() -> None:
topic = TOPIC_HEADER + "websockets"
topic = TOPIC_PREFIX + "websockets"

async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
async with client.filtered_messages(topic) as messages:
Expand All @@ -364,7 +364,7 @@ async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
async def test_client_pending_calls_threshold(
pending_calls_threshold: int, caplog: pytest.LogCaptureFixture
) -> None:
topic = TOPIC_HEADER + "pending_calls_threshold"
topic = TOPIC_PREFIX + "pending_calls_threshold"

async with Client(HOSTNAME) as client:
client.pending_calls_threshold = pending_calls_threshold
Expand All @@ -390,7 +390,7 @@ async def test_client_no_pending_calls_warnings_with_max_concurrent_outgoing_cal
caplog: pytest.LogCaptureFixture,
) -> None:
topic = (
TOPIC_HEADER + "no_pending_calls_warnings_with_max_concurrent_outgoing_calls"
TOPIC_PREFIX + "no_pending_calls_warnings_with_max_concurrent_outgoing_calls"
)

async with Client(HOSTNAME, max_concurrent_outgoing_calls=1) as client:
Expand All @@ -405,32 +405,24 @@ async def test_client_no_pending_calls_warnings_with_max_concurrent_outgoing_cal


@pytest.mark.network
async def test_client_not_reentrant() -> None:
"""Test that the client raises an error when we try to reenter."""
client = Client(HOSTNAME)
with pytest.raises(MqttReentrantError): # noqa: PT012
async with client:
async with client:
pass


@pytest.mark.network
async def test_client_reusable() -> None:
"""Test that an instance of the client context manager can be reused."""
async def test_client_is_reusable() -> None:
"""Test that a client context manager instance is reusable."""
topic = TOPIC_PREFIX + "test_client_is_reusable"
client = Client(HOSTNAME)
async with client:
await client.publish("task/a", "task_a")
await client.publish(topic, "foo")
async with client:
await client.publish("task/b", "task_b")
await client.publish(topic, "bar")


@pytest.mark.network
async def test_client_connect_disconnect() -> None:
async def test_client_is_not_reentrant() -> None:
"""Test that a client context manager instance is not reentrant."""
client = Client(HOSTNAME)

await client.connect()
await client.publish("connect", "connect")
await client.disconnect()
async with client:
with pytest.raises(MqttReentrantError):
async with client:
pass


@pytest.mark.network
Expand Down Expand Up @@ -470,105 +462,31 @@ async def task_a_publisher() -> None:


@pytest.mark.network
async def test_client_use_connect_disconnect_multiple_message() -> None:
custom_client = Client(HOSTNAME)
publish_client = Client(HOSTNAME)

topic_a = TOPIC_HEADER + "task/a"
topic_b = TOPIC_HEADER + "task/b"

await custom_client.connect()
await publish_client.connect()

async def task_a_customer(
task_status: TaskStatus[None] = TASK_STATUS_IGNORED,
) -> None:
await custom_client.subscribe(topic_a)
async with custom_client.messages() as messages:
task_status.started()
async for message in messages:
assert message.payload == b"task_a"
return

async def task_b_customer(
task_status: TaskStatus[None] = TASK_STATUS_IGNORED,
) -> None:
num = 0
await custom_client.subscribe(topic_b)
async with custom_client.messages() as messages:
task_status.started()
async for message in messages:
assert message.payload in [b"task_a", b"task_b"]
num += 1
if num == 2: # noqa: PLR2004
return

async def task_publisher(topic: str, payload: PayloadType) -> None:
await publish_client.publish(topic, payload)

async with anyio.create_task_group() as tg:
await tg.start(task_a_customer)
await tg.start(task_b_customer)
tg.start_soon(task_publisher, topic_a, "task_a")
tg.start_soon(task_publisher, topic_b, "task_b")

await custom_client.disconnect()
await publish_client.disconnect()


@pytest.mark.network
async def test_client_disconnected_exception() -> None:
client = Client(HOSTNAME)
await client.connect()
client._disconnected.set_exception(RuntimeError)
with pytest.raises(RuntimeError):
await client.disconnect()


@pytest.mark.network
async def test_client_disconnected_done() -> None:
client = Client(HOSTNAME)
await client.connect()
client._disconnected.set_result(None)
await client.disconnect()


@pytest.mark.network
async def test_client_connecting_disconnected_done() -> None:
client = Client(HOSTNAME)
client._disconnected.set_result(None)
await client.connect()
await client.disconnect()


@pytest.mark.network
async def test_client_aenter_error_lock_release() -> None:
"""Test that the client's reusability lock is released on error in __aenter__."""
client = Client(hostname="aenter_connect_error_lock_release")
async def test_aenter_error_lock_release() -> None:
"""Test that the client's reusability lock is released on error in ``aenter``."""
client = Client(hostname="invalid")
with pytest.raises(MqttError):
await client.__aenter__()
assert not client._lock.locked()


@pytest.mark.network
async def test_aexit_without_prior_aenter() -> None:
"""Test that __aexit__ without prior (or unsuccessful) __aenter__ runs cleanly."""
"""Test that ``aexit`` without prior (or unsuccessful) ``aenter`` runs cleanly."""
client = Client(HOSTNAME)
await client.__aexit__(None, None, None)


@pytest.mark.network
async def test_aexit_client_is_already_disconnected_sucess() -> None:
"""Test that __aexit__ exits cleanly if client is already cleanly disconnected."""
client = Client(HOSTNAME)
await client.__aenter__()
client._disconnected.set_result(None)
await client.__aexit__(None, None, None)
"""Test that ``aexit`` exits cleanly if client is already cleanly disconnected."""
async with Client(HOSTNAME) as client:
client._disconnected.set_result(None)


@pytest.mark.network
async def test_aexit_client_is_already_disconnected_failure() -> None:
"""Test that __aexit__ reraises if client is already disconnected with an error."""
"""Test that ``aexit`` reraises if client is already disconnected with an error."""
client = Client(HOSTNAME)
await client.__aenter__()
client._disconnected.set_exception(RuntimeError)
Expand Down

0 comments on commit 7c95341

Please sign in to comment.