From f45cafbbb4e97e4fce6818defda391e5829270d0 Mon Sep 17 00:00:00 2001
From: Niels Provos <provos@gmail.com>
Date: Tue, 28 Jan 2025 10:13:46 -0800
Subject: [PATCH] test: enhance stream data handling and verify received traces
 in TestWebInterface

---
 tests/planai/test_web_interface.py | 50 +++++++++++++++++++++++++-----
 1 file changed, 43 insertions(+), 7 deletions(-)

diff --git a/tests/planai/test_web_interface.py b/tests/planai/test_web_interface.py
index d8510bc..e91b08c 100644
--- a/tests/planai/test_web_interface.py
+++ b/tests/planai/test_web_interface.py
@@ -1,7 +1,6 @@
 import io
 import json
 import threading
-import time
 import unittest
 from typing import List, Type
 from unittest.mock import MagicMock, Mock, patch
@@ -121,6 +120,9 @@ def test_stream_route(self, mock_memory_stats):
             "peak": 120.0,
         }
 
+        mock_traces = {(1, "task1"): "trace1", (2, "task2"): "trace2"}
+        self.mock_dispatcher.get_traces.return_value = mock_traces
+
         self.mock_dispatcher.get_queued_tasks.return_value = [
             {
                 "id": "task1",
@@ -129,7 +131,6 @@ def test_stream_route(self, mock_memory_stats):
                 "provenance": ["TestWorker_1"],
             }
         ]
-
         self.mock_dispatcher.get_active_tasks.return_value = []
         self.mock_dispatcher.get_completed_tasks.return_value = []
         self.mock_dispatcher.get_failed_tasks.return_value = []
@@ -137,20 +138,54 @@ def test_stream_route(self, mock_memory_stats):
             "TestWorker": {"completed": 0, "failed": 0, "active": 0}
         }
         self.mock_dispatcher.get_user_input_requests.return_value = []
-        self.mock_dispatcher.get_traces.return_value = {}
         self.mock_dispatcher.get_logs.return_value = []
 
+        # Create an Event to signal when we've received data
+        data_received = threading.Event()
+        received_data = []
+        stream_error = None
+
         def mock_stream():
-            response = self.client.get("/stream")
-            return response.get_data(as_text=True)
+            nonlocal stream_error
+            try:
+                with self.client.get("/stream") as response:
+                    # Read the first chunk of data
+                    for line in response.response:
+                        try:
+                            decoded_line = line.decode("utf-8")
+                            if decoded_line.startswith("data: "):
+                                json_data = json.loads(decoded_line[6:])
+                                received_data.append(json_data)
+                                data_received.set()
+                                return  # Exit after first data received
+                        except Exception as e:
+                            stream_error = e
+                            break
+            except Exception as e:
+                stream_error = e
+            finally:
+                data_received.set()  # Always set the event
 
         # Start stream in a separate thread
         stream_thread = threading.Thread(target=mock_stream)
         stream_thread.daemon = True
         stream_thread.start()
 
-        # Wait a bit for data
-        time.sleep(0.5)
+        # Wait for data with timeout
+        if not data_received.wait(timeout=5.0):
+            self.fail("Timeout waiting for stream data")
+
+        # Check if there was an error in the stream thread
+        if stream_error:
+            self.fail(f"Error in stream thread: {stream_error}")
+
+        # Verify that we received the expected data
+        self.assertTrue(len(received_data) > 0, "No data received from stream")
+        if received_data:
+            data = received_data[0]
+            self.assertIn("trace", data)
+            expected_trace = {"1_task1": "trace1", "2_task2": "trace2"}
+            self.assertEqual(data["trace"], expected_trace)
 
         # Verify that dispatcher methods were called
         self.mock_dispatcher.get_queued_tasks.assert_called()
@@ -158,6 +193,7 @@ def mock_stream():
         self.mock_dispatcher.get_completed_tasks.assert_called()
         self.mock_dispatcher.get_failed_tasks.assert_called()
         self.mock_dispatcher.get_execution_statistics.assert_called()
+        self.mock_dispatcher.get_traces.assert_called()
         self.mock_dispatcher.get_user_input_requests.assert_called()