diff --git a/brewblox_devcon_spark/connection/mqtt_connection.py b/brewblox_devcon_spark/connection/mqtt_connection.py index 9918af8f..104ff39e 100644 --- a/brewblox_devcon_spark/connection/mqtt_connection.py +++ b/brewblox_devcon_spark/connection/mqtt_connection.py @@ -35,14 +35,44 @@ def __init__(self, self._handshake_topic = HANDSHAKE_TOPIC + device_id self._log_topic = LOG_TOPIC + device_id - async def _handshake_cb(self, client, topic, payload, qos, properties): + self._recv_msg_id: int = 0 + self._recv_chunks: list[str] + + def _reset_recv_buffer(self, msg_id: int): + self._recv_msg_id = msg_id + self._recv_chunks = [] + + async def _handshake_cb(self, client, topic, payload: bytes, qos, properties): if not payload: self.disconnected.set() - async def _resp_cb(self, client, topic, payload, qos, properties): - await self.on_response(payload.decode()) - - async def _log_cb(self, client, topic, payload, qos, properties): + async def _resp_cb(self, client, topic, payload: bytes, qos, properties): + try: + (msg_id, chunk_idx, chunk) = payload.decode().split(':') + msg_id = int(msg_id) + chunk_idx = int(chunk_idx) + except ValueError as ex: + LOGGER.error(f'Failed to parse MQTT payload "{payload}" with error {utils.strex(ex)}') + self._reset_recv_buffer(0) + return + + if msg_id != self._recv_msg_id: + self._reset_recv_buffer(msg_id) + + if len(self._recv_chunks) != chunk_idx: + LOGGER.error(f'Received unexpected MQTT message chunk with idx {chunk_idx}') + self._reset_recv_buffer(0) + return + + self._recv_chunks.append(chunk) + + # we found a message separator - message is done + if '\n' in chunk: + msg = ''.join(self._recv_chunks).rstrip() + self._reset_recv_buffer(0) + await self.on_response(msg) + + async def _log_cb(self, client, topic, payload: bytes, qos, properties): await self.on_event(payload.decode()) async def send_request(self, msg: str): diff --git a/firmware.ini b/firmware.ini index ba46103f..14144bf4 100644 --- a/firmware.ini +++ b/firmware.ini @@ -1,7 +1,7 @@ [FIRMWARE] -firmware_version=29c23dd7 -firmware_date=2024-05-10 -firmware_sha=29c23dd7f2568cb0e472d8c214ac6b85751e4d7a +firmware_version=5d1e86a3 +firmware_date=2024-05-13 +firmware_sha=5d1e86a31a3082dcf45eb908470710423d162ef4 proto_version=06abbc28 proto_date=2024-05-10 proto_sha=06abbc281f0919bfb728ca699e916b08b40f0e8f diff --git a/test/test_connection_mqtt_connection.py b/test/test_connection_mqtt_connection.py index 7e270af4..be1254e6 100644 --- a/test/test_connection_mqtt_connection.py +++ b/test/test_connection_mqtt_connection.py @@ -1,7 +1,7 @@ import asyncio from contextlib import asynccontextmanager from datetime import timedelta -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, call import pytest from asgi_lifespan import LifespanManager @@ -82,9 +82,9 @@ async def on_handshake(client, topic, payload, qos, properties): recv_handshake.set() @mqtt_client.subscribe('brewcast/cbox/req/+') - async def on_request(client, topic, payload, qos, properties): + async def on_request(client, topic, payload: bytes, qos, properties): resp_topic = topic.replace('/req/', '/resp/') - mqtt_client.publish(resp_topic, payload[::-1]) + mqtt_client.publish(resp_topic, f'1:0:{payload.decode()[::-1]}\n'.encode()) recv_req.set() @mqtt_client.subscribe('brewcast/cbox/resp/+') @@ -123,3 +123,20 @@ async def on_log(client, topic, payload, qos, properties): # Can safely be called await impl.close() assert impl.disconnected.is_set() + + +async def test_mqtt_message_handling(): + callbacks = AsyncMock(spec=connection_handler.ConnectionHandler) + impl = mqtt_connection.MqttConnection('1234', callbacks) + + await impl._resp_cb(None, None, '1:0:first,'.encode(), 0, None) + await impl._resp_cb(None, None, '2:1:second,'.encode(), 0, None) + await impl._resp_cb(None, None, '3:1:third\n'.encode(), 0, None) + await impl._resp_cb(None, None, '4:0:fourth\n'.encode(), 0, None) + await impl._resp_cb(None, None, '5:0:fifth,'.encode(), 0, None) + await impl._resp_cb(None, None, '5:1:fifth-second\n'.encode(), 0, None) + await impl._resp_cb(None, None, '6:0:sixth,'.encode(), 0, None) + await impl._resp_cb(None, None, 'garbled'.encode(), 0, None) + await impl._resp_cb(None, None, '6:1:sixth-second\n'.encode(), 0, None) + + assert callbacks.on_response.await_args_list == [call('fourth'), call('fifth,fifth-second')]