From f45cafbbb4e97e4fce6818defda391e5829270d0 Mon Sep 17 00:00:00 2001 From: Niels Provos <provos@gmail.com> Date: Tue, 28 Jan 2025 10:13:46 -0800 Subject: [PATCH] test: enhance stream data handling and verify received traces in TestWebInterface --- tests/planai/test_web_interface.py | 50 +++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/tests/planai/test_web_interface.py b/tests/planai/test_web_interface.py index d8510bc..e91b08c 100644 --- a/tests/planai/test_web_interface.py +++ b/tests/planai/test_web_interface.py @@ -1,7 +1,6 @@ import io import json import threading -import time import unittest from typing import List, Type from unittest.mock import MagicMock, Mock, patch @@ -121,6 +120,9 @@ def test_stream_route(self, mock_memory_stats): "peak": 120.0, } + mock_traces = {(1, "task1"): "trace1", (2, "task2"): "trace2"} + self.mock_dispatcher.get_traces.return_value = mock_traces + self.mock_dispatcher.get_queued_tasks.return_value = [ { "id": "task1", @@ -129,7 +131,6 @@ def test_stream_route(self, mock_memory_stats): "provenance": ["TestWorker_1"], } ] - self.mock_dispatcher.get_active_tasks.return_value = [] self.mock_dispatcher.get_completed_tasks.return_value = [] self.mock_dispatcher.get_failed_tasks.return_value = [] @@ -137,20 +138,54 @@ def test_stream_route(self, mock_memory_stats): "TestWorker": {"completed": 0, "failed": 0, "active": 0} } self.mock_dispatcher.get_user_input_requests.return_value = [] - self.mock_dispatcher.get_traces.return_value = {} self.mock_dispatcher.get_logs.return_value = [] + # Create an Event to signal when we've received data + data_received = threading.Event() + received_data = [] + stream_error = None + def mock_stream(): - response = self.client.get("/stream") - return response.get_data(as_text=True) + nonlocal stream_error + try: + with self.client.get("/stream") as response: + # Read the first chunk of data + for line in response.response: + try: + decoded_line = line.decode("utf-8") + if decoded_line.startswith("data: "): + json_data = json.loads(decoded_line[6:]) + received_data.append(json_data) + data_received.set() + return # Exit after first data received + except Exception as e: + stream_error = e + break + except Exception as e: + stream_error = e + finally: + data_received.set() # Always set the event # Start stream in a separate thread stream_thread = threading.Thread(target=mock_stream) stream_thread.daemon = True stream_thread.start() - # Wait a bit for data - time.sleep(0.5) + # Wait for data with timeout + if not data_received.wait(timeout=5.0): + self.fail("Timeout waiting for stream data") + + # Check if there was an error in the stream thread + if stream_error: + self.fail(f"Error in stream thread: {stream_error}") + + # Verify that we received the expected data + self.assertTrue(len(received_data) > 0, "No data received from stream") + if received_data: + data = received_data[0] + self.assertIn("trace", data) + expected_trace = {"1_task1": "trace1", "2_task2": "trace2"} + self.assertEqual(data["trace"], expected_trace) # Verify that dispatcher methods were called self.mock_dispatcher.get_queued_tasks.assert_called() @@ -158,6 +193,7 @@ def mock_stream(): self.mock_dispatcher.get_completed_tasks.assert_called() self.mock_dispatcher.get_failed_tasks.assert_called() self.mock_dispatcher.get_execution_statistics.assert_called() + self.mock_dispatcher.get_traces.assert_called() self.mock_dispatcher.get_user_input_requests.assert_called()