Skip to content

Commit

Permalink
Fix failing unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
muralov committed Jan 15, 2025
1 parent 0067066 commit 799bc14
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 247 deletions.
235 changes: 120 additions & 115 deletions poetry.lock

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions src/services/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import requests
from kubernetes import client, dynamic
from requests import Response

from services.data_sanitizer import DataSanitizer
from utils import logging
Expand All @@ -23,7 +22,7 @@ def model_dump(self) -> None:
"""Dump the model without any confidential data."""
...

def execute_get_api_request(self, uri: str) -> Response:
def execute_get_api_request(self, uri: str) -> dict | list[dict]:
"""Execute a GET request to the Kubernetes API."""
...

Expand Down Expand Up @@ -154,7 +153,7 @@ def _get_auth_headers(self) -> dict:
"Content-Type": "application/json",
}

def execute_get_api_request(self, uri: str) -> Response:
def execute_get_api_request(self, uri: str) -> dict | list[dict]:
"""Execute a GET request to the Kubernetes API."""
response = requests.get(
url=f"{self.api_server}/{uri.lstrip('/')}",
Expand All @@ -169,7 +168,7 @@ def execute_get_api_request(self, uri: str) -> Response:

if self.data_sanitizer:
return self.data_sanitizer.sanitize(response.json())
return response
return response.json()

def list_resources(self, api_version: str, kind: str, namespace: str) -> list[dict]:
"""List resources of a specific kind in a namespace.
Expand Down Expand Up @@ -244,9 +243,7 @@ def list_not_running_pods(self, namespace: str) -> list[dict]:

def list_nodes_metrics(self) -> list[dict]:
"""List all nodes metrics."""
result = self.execute_get_api_request(
"apis/metrics.k8s.io/v1beta1/nodes"
).json()
result = self.execute_get_api_request("apis/metrics.k8s.io/v1beta1/nodes")
return list(result["items"])

def list_k8s_events(self, namespace: str) -> list[dict]:
Expand Down Expand Up @@ -303,7 +300,11 @@ def fetch_pod_logs(
if is_terminated:
uri += "&previous=true"

response = self.execute_get_api_request(uri)
response = requests.get(
url=f"{self.api_server}/{uri.lstrip('/')}",
headers=self._get_auth_headers(),
verify=self.ca_temp_filename,
)
if response.status_code != HTTPStatus.OK:
raise ValueError(
f"Failed to fetch logs for pod {name} in namespace {namespace} "
Expand Down
5 changes: 4 additions & 1 deletion src/utils/response.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from typing import Any

from agents.common.constants import PLANNER
from agents.supervisor.agent import SUPERVISOR
from utils.logging import get_logger

Expand All @@ -18,8 +19,10 @@ def process_response(data: dict[str, Any], agent: str) -> dict[str, Any]:
if "messages" in agent_data and agent_data["messages"]:
answer["content"] = agent_data["messages"][-1].get("content")

if agent == SUPERVISOR:
if agent == PLANNER:
answer["subtasks"] = agent_data.get("subtasks")

if agent == SUPERVISOR:
answer["next"] = agent_data.get("next")

return {"agent": agent, "answer": answer}
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/agents/k8s/tools/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def sample_k8s_sanitized_secret():
"v1/secret/my-secret",
sample_k8s_secret(),
None,
sample_k8s_sanitized_secret(),
sample_k8s_secret(),
None,
),
# Test case: the execute_get_api_request returns an exception.
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_k8s_query_tool(
if given_exception:
k8s_client.execute_get_api_request.side_effect = given_exception
else:
k8s_client.execute_get_api_request.return_value.json.return_value = given_object
k8s_client.execute_get_api_request.return_value = given_object

# When: invoke the tool.
result = tool_node.invoke(
Expand Down
1 change: 1 addition & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def mock_config():
ModelConfig(
name=ModelType.GEMINI_10_PRO, deployment_id="dep3", temperature=0
),
ModelConfig(name="unsupported_model", deployment_id="dep4", temperature=0),
]
)

Expand Down
5 changes: 2 additions & 3 deletions tests/unit/routers/test_conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from agents.common.data import Message
from main import app
from routers.conversations import get_conversation_service
from routers.conversations import init_conversation_service
from services.conversation import IService
from services.k8s import IK8sClient

Expand Down Expand Up @@ -59,8 +59,7 @@ def _create_client(expected_error=None):
def get_mock_service():
return mock_service

app.dependency_overrides[get_conversation_service] = get_mock_service

app.dependency_overrides[init_conversation_service] = get_mock_service
test_client = TestClient(app)

return test_client
Expand Down
47 changes: 29 additions & 18 deletions tests/unit/services/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from langchain_core.messages import AIMessage

from agents.common.data import Message
from services.conversation import ConversationService
from services.conversation import TOKEN_LIMIT, ConversationService
from utils.models.factory import ModelType

TIME_STAMP = 1.8
QUESTIONS = ["question1?", "question2?", "question3?"]
Expand Down Expand Up @@ -34,8 +35,9 @@ class TestConversation:
@pytest.fixture
def mock_model_factory(self):
mock_model = Mock()
mock_models = {ModelType.GPT4O_MINI: mock_model}
with patch("services.conversation.ModelFactory") as mock:
mock.return_value.create_model.return_value = mock_model
mock.return_value.create_models.return_value = mock_models
yield mock

@pytest.fixture
Expand All @@ -54,33 +56,31 @@ def mock_companion_graph(self):

@pytest.fixture
def mock_redis_saver(self):
async def async_mock_add_conversation_message(*args, **kwargs):
pass

with patch("services.conversation.RedisSaver") as mock:
mock.return_value.add_conversation_message = AsyncMock(
side_effect=async_mock_add_conversation_message
)
with patch("services.conversation.AsyncRedisSaver") as mock:
mock.from_conn_info.return_value = Mock()
yield mock

@pytest.fixture
def mock_init_pool(self):
with patch("services.conversation.initialize_async_pool") as mock:
yield mock
def mock_config(self):
mock_config = Mock()
mock_config.sanitization_config = Mock()
return mock_config

def test_new_conversation(
self,
mock_model_factory,
mock_companion_graph,
mock_redis_saver,
mock_init_pool,
mock_config,
) -> None:
# Given:
mock_handler = Mock()
mock_handler.fetch_relevant_data_from_k8s_cluster = Mock(return_value=POD_YAML)
mock_handler.apply_token_limit = Mock(return_value=POD_YAML)
mock_handler.generate_questions = Mock(return_value=QUESTIONS)
conversation_service = ConversationService(
initial_questions_handler=mock_handler
config=mock_config,
initial_questions_handler=mock_handler,
)

mock_k8s_client = Mock()
Expand All @@ -92,13 +92,19 @@ def test_new_conversation(

# Then:
assert result == QUESTIONS
mock_handler.fetch_relevant_data_from_k8s_cluster.assert_called_once_with(
message=TEST_MESSAGE, k8s_client=mock_k8s_client
)
mock_handler.apply_token_limit.assert_called_once_with(POD_YAML, TOKEN_LIMIT)
mock_handler.generate_questions.assert_called_once_with(context=POD_YAML)

@pytest.mark.asyncio
async def test_handle_followup_questions(
self,
mock_model_factory,
mock_companion_graph,
mock_init_pool,
mock_redis_saver,
mock_config,
) -> None:
# Given:
dummy_conversation_history = [
Expand All @@ -111,7 +117,8 @@ async def test_handle_followup_questions(
mock_handler.generate_questions = Mock(return_value=QUESTIONS)
# initialize ConversationService instance.
conversation_service = ConversationService(
followup_questions_handler=mock_handler
config=mock_config,
followup_questions_handler=mock_handler,
)
conversation_service._followup_questions_handler = mock_handler
# define mock for CompanionGraph.
Expand All @@ -133,13 +140,17 @@ async def test_handle_followup_questions(

@pytest.mark.asyncio
async def test_handle_request(
self, mock_model_factory, mock_init_pool, mock_redis_saver, mock_companion_graph
self,
mock_model_factory,
mock_redis_saver,
mock_companion_graph,
mock_config,
):
# Given:
mock_k8s_client = Mock()

# When:
messaging_service = ConversationService()
messaging_service = ConversationService(config=mock_config)

# Then:
result = [
Expand Down
110 changes: 10 additions & 100 deletions tests/unit/utils/test_factory.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,21 @@
from unittest.mock import Mock, patch
from unittest.mock import patch

import pytest

from utils.models.exceptions import ModelNotFoundError
from utils.models.exceptions import ModelNotFoundError, UnsupportedModelError
from utils.models.factory import (
GeminiModel,
ModelFactory,
ModelType,
OpenAIModel,
get_model_config,
)

SUPPORTED_MODEL_COUNT = 3


@pytest.fixture
def mock_model_config():
return Mock(name="test_model", deployment_id="test_deployment")


@pytest.mark.parametrize(
"model_name, expected_deployment_id, expected_error",
[
(ModelType.GPT4O, "dep1", None),
(ModelType.GPT35, "dep2", None),
(ModelType.GEMINI_10_PRO, "dep3", None),
("non_existent_model", None, None),
("", None, None),
(None, None, None),
],
)
def test_get_model_config(
mock_get_config, model_name, expected_deployment_id, expected_error
):
if expected_error:
with pytest.raises(expected_error):
get_model_config(model_name)
else:
result = get_model_config(model_name)
if expected_deployment_id:
assert result.name == model_name
assert result.deployment_id == expected_deployment_id
else:
assert result is None
def model_factory(mock_get_proxy_client, mock_config):
return ModelFactory(mock_config)


class TestModelFactory:
Expand All @@ -56,10 +29,6 @@ def mock_gemini_model(self):
with patch("utils.models.factory.GeminiModel") as mock:
yield mock

@pytest.fixture
def model_factory(self, mock_get_proxy_client):
return ModelFactory()

@pytest.mark.parametrize(
"test_case,model_name,expected_model_class,expected_exception",
[
Expand Down Expand Up @@ -87,11 +56,16 @@ def model_factory(self, mock_get_proxy_client):
None,
ModelNotFoundError,
),
(
"should raise error when unsupported model is requested",
"unsupported_model",
None,
UnsupportedModelError,
),
],
)
def test_create_model(
self,
mock_get_config,
mock_openai_model,
mock_gemini_model,
model_factory,
Expand All @@ -112,67 +86,3 @@ def test_create_model(
else:
mock_gemini_model.assert_called_once()
assert model == mock_gemini_model.return_value

def test_create_models_returns_all_supported_models(
self,
mock_get_config,
mock_openai_model,
mock_gemini_model,
model_factory,
):
# When
models = model_factory.create_models()

# Then
assert (
len(models) == SUPPORTED_MODEL_COUNT
) # Verify we get all supported models

# Verify correct model instances were created
assert models[ModelType.GPT4O] == mock_openai_model.return_value
assert models[ModelType.GPT35] == mock_openai_model.return_value
assert models[ModelType.GEMINI_10_PRO] == mock_gemini_model.return_value

# Verify each model has proper configuration
for model in models.values():
assert model.name is not None
assert model.deployment_id is not None

@pytest.mark.parametrize(
"test_case,model_name,expected_model_class",
[
("get gpt4o model should return OpenAIModel", ModelType.GPT4O, OpenAIModel),
(
"get gemini_10_pro model should return GeminiModel",
ModelType.GEMINI_10_PRO,
GeminiModel,
),
(
"get non_existent_model should raise error",
"non_existent_model",
None,
),
],
)
def test_get_model(
self,
mock_get_config,
mock_openai_model,
mock_gemini_model,
model_factory,
test_case,
model_name,
expected_model_class,
):
if expected_model_class is None:
with pytest.raises(ModelNotFoundError):
model_factory.create_model(model_name)
else:
model = model_factory.create_model(model_name)

if expected_model_class == OpenAIModel:
mock_openai_model.assert_called_once()
assert model == mock_openai_model.return_value
else:
mock_gemini_model.assert_called_once()
assert model == mock_gemini_model.return_value

0 comments on commit 799bc14

Please sign in to comment.