Skip to content

Commit

Permalink
test: enhance stream data handling and verify received traces in Test…
Browse files Browse the repository at this point in the history
…WebInterface
  • Loading branch information
provos committed Jan 28, 2025
1 parent b37081d commit f45cafb
Showing 1 changed file with 43 additions and 7 deletions.
50 changes: 43 additions & 7 deletions tests/planai/test_web_interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import io
import json
import threading
import time
import unittest
from typing import List, Type
from unittest.mock import MagicMock, Mock, patch
Expand Down Expand Up @@ -121,6 +120,9 @@ def test_stream_route(self, mock_memory_stats):
"peak": 120.0,
}

mock_traces = {(1, "task1"): "trace1", (2, "task2"): "trace2"}
self.mock_dispatcher.get_traces.return_value = mock_traces

self.mock_dispatcher.get_queued_tasks.return_value = [
{
"id": "task1",
Expand All @@ -129,35 +131,69 @@ def test_stream_route(self, mock_memory_stats):
"provenance": ["TestWorker_1"],
}
]

self.mock_dispatcher.get_active_tasks.return_value = []
self.mock_dispatcher.get_completed_tasks.return_value = []
self.mock_dispatcher.get_failed_tasks.return_value = []
self.mock_dispatcher.get_execution_statistics.return_value = {
"TestWorker": {"completed": 0, "failed": 0, "active": 0}
}
self.mock_dispatcher.get_user_input_requests.return_value = []
self.mock_dispatcher.get_traces.return_value = {}
self.mock_dispatcher.get_logs.return_value = []

# Create an Event to signal when we've received data
data_received = threading.Event()
received_data = []
stream_error = None

def mock_stream():
response = self.client.get("/stream")
return response.get_data(as_text=True)
nonlocal stream_error
try:
with self.client.get("/stream") as response:
# Read the first chunk of data
for line in response.response:
try:
decoded_line = line.decode("utf-8")
if decoded_line.startswith("data: "):
json_data = json.loads(decoded_line[6:])
received_data.append(json_data)
data_received.set()
return # Exit after first data received
except Exception as e:
stream_error = e
break
except Exception as e:
stream_error = e
finally:
data_received.set() # Always set the event

# Start stream in a separate thread
stream_thread = threading.Thread(target=mock_stream)
stream_thread.daemon = True
stream_thread.start()

# Wait a bit for data
time.sleep(0.5)
# Wait for data with timeout
if not data_received.wait(timeout=5.0):
self.fail("Timeout waiting for stream data")

# Check if there was an error in the stream thread
if stream_error:
self.fail(f"Error in stream thread: {stream_error}")

# Verify that we received the expected data
self.assertTrue(len(received_data) > 0, "No data received from stream")
if received_data:
data = received_data[0]
self.assertIn("trace", data)
expected_trace = {"1_task1": "trace1", "2_task2": "trace2"}
self.assertEqual(data["trace"], expected_trace)

# Verify that dispatcher methods were called
self.mock_dispatcher.get_queued_tasks.assert_called()
self.mock_dispatcher.get_active_tasks.assert_called()
self.mock_dispatcher.get_completed_tasks.assert_called()
self.mock_dispatcher.get_failed_tasks.assert_called()
self.mock_dispatcher.get_execution_statistics.assert_called()
self.mock_dispatcher.get_traces.assert_called()
self.mock_dispatcher.get_user_input_requests.assert_called()


Expand Down

0 comments on commit f45cafb

Please sign in to comment.