Skip to content

Commit

Permalink
[ENH] Support partial term fetching failures (#65)
Browse files Browse the repository at this point in the history
* add test for partial node failure in terms fetching

* assert over full mocked partial success attribute response

* add response model for combined attribute instances

* test when terms fetching fails for all nodes

* handle node errors when fetching terms and do so async

* remove duplicate test

* refactor out mocked httpx.get raising a ConnectError

* refactor out combined response processing + use generic summary console message

* update comments

Co-authored-by: Sebastian Urchs <[email protected]>

* improve test logic

* refactor func for building combined response from nodes

* set status code in path operation function

* refactor setting of test federation nodes into fixture

* more informative docstrings and comments

---------

Co-authored-by: Sebastian Urchs <[email protected]>
  • Loading branch information
alyssadai and surchs authored Feb 6, 2024
1 parent 9d685db commit b33c34c
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 103 deletions.
111 changes: 63 additions & 48 deletions app/api/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,36 @@
import asyncio
import warnings

from fastapi import HTTPException, status
from fastapi.responses import JSONResponse
from fastapi import HTTPException

from . import utility as util


def build_combined_response(
total_nodes: int, cross_node_results: list | dict, node_errors: list
) -> dict:
"""Return a combined response containing all the nodes' responses and errors. Logs to console a summary of the federated request."""
content = {"errors": node_errors, "responses": cross_node_results}

if node_errors:
# TODO: Use logger instead of print. For example of how to set this up for FastAPI, see https://github.com/tiangolo/fastapi/discussions/8517
print(
f"Requests to {len(node_errors)}/{total_nodes} nodes failed: {[node_error['node_name'] for node_error in node_errors]}."
)
if len(node_errors) == total_nodes:
# See https://fastapi.tiangolo.com/advanced/additional-responses/ for more info
content["nodes_response_status"] = "fail"
else:
content["nodes_response_status"] = "partial success"
else:
print(
f"Requests to all nodes succeeded ({total_nodes}/{total_nodes})."
)
content["nodes_response_status"] = "success"

return content


async def get(
min_age: float,
max_age: float,
Expand Down Expand Up @@ -57,7 +81,6 @@ async def get(
node_errors = []

node_urls = util.validate_query_node_url_list(node_urls)
total_nodes = len(node_urls)

# Node API query parameters
params = {}
Expand Down Expand Up @@ -92,51 +115,25 @@ async def get(
node_errors.append(
{"node_name": node_name, "error": response.detail}
)
# TODO: Replace with logger
warnings.warn(
f"Query to node {node_name} ({node_url}) did not succeed: {response.detail}"
f"Request to node {node_name} ({node_url}) did not succeed: {response.detail}"
)
else:
for result in response:
result["node_name"] = node_name
cross_node_results.extend(response)

if node_errors:
# TODO: Use logger instead of print, see https://github.com/tiangolo/fastapi/issues/5003
print(
f"Queries to {len(node_errors)}/{total_nodes} nodes failed: {[node_error['node_name'] for node_error in node_errors]}."
)

if len(node_errors) == total_nodes:
# See https://fastapi.tiangolo.com/advanced/additional-responses/ for more info
return JSONResponse(
status_code=status.HTTP_207_MULTI_STATUS,
content={
"errors": node_errors,
"responses": cross_node_results,
"nodes_response_status": "fail",
},
)
return JSONResponse(
status_code=status.HTTP_207_MULTI_STATUS,
content={
"errors": node_errors,
"responses": cross_node_results,
"nodes_response_status": "partial success",
},
)

print(f"All nodes queried successfully ({total_nodes}/{total_nodes}).")
return {
"errors": node_errors,
"responses": cross_node_results,
"nodes_response_status": "success",
}
return build_combined_response(
total_nodes=len(node_urls),
cross_node_results=cross_node_results,
node_errors=node_errors,
)


async def get_terms(data_element_URI: str):
# TODO: Make this path able to handle partial successes as well
"""
Makes a GET request to one or more Neurobagel node APIs using send_get_request utility function where the only parameter is a data element URI.
Makes a GET request to all available Neurobagel node APIs using send_get_request utility function where the only parameter is a data element URI.
Parameters
----------
Expand All @@ -148,20 +145,38 @@ async def get_terms(data_element_URI: str):
dict
Dictionary where the key is the Neurobagel class and values correspond to all the unique terms representing available (i.e. used) instances of that class.
"""
cross_node_results = []
params = {data_element_URI: data_element_URI}
node_errors = []
unique_terms_dict = {}

for node_url in util.FEDERATION_NODES:
response = util.send_get_request(
params = {data_element_URI: data_element_URI}
tasks = [
util.send_get_request(
node_url + "attributes/" + data_element_URI, params
)
for node_url in util.FEDERATION_NODES
]
responses = await asyncio.gather(*tasks, return_exceptions=True)

cross_node_results.append(response)

unique_terms_dict = {}
for (node_url, node_name), response in zip(
util.FEDERATION_NODES.items(), responses
):
if isinstance(response, HTTPException):
node_errors.append(
{"node_name": node_name, "error": response.detail}
)
# TODO: Replace with logger
warnings.warn(
f"Request to node {node_name} ({node_url}) did not succeed: {response.detail}"
)
else:
# Build the dictionary of unique term-label pairings from all nodes
for term_dict in response[data_element_URI]:
unique_terms_dict[term_dict["TermURL"]] = term_dict

for list_of_terms in cross_node_results:
for term in list_of_terms[data_element_URI]:
unique_terms_dict[term["TermURL"]] = term
cross_node_results = {data_element_URI: list(unique_terms_dict.values())}

return {data_element_URI: list(unique_terms_dict.values())}
return build_combined_response(
total_nodes=len(util.FEDERATION_NODES),
cross_node_results=cross_node_results,
node_errors=node_errors,
)
8 changes: 8 additions & 0 deletions app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,11 @@ class CombinedQueryResponse(BaseModel):
errors: list[NodeError]
responses: list[CohortQueryResponse]
nodes_response_status: NodesResponseStatus


class CombinedAttributeResponse(BaseModel):
"""Data model for the combined available terms for a given Neurobagel attribute/variable across all available nodes."""

errors: list[NodeError]
responses: dict
nodes_response_status: NodesResponseStatus
23 changes: 17 additions & 6 deletions app/api/routers/attributes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from fastapi import APIRouter
from fastapi import APIRouter, Response, status
from pydantic import constr

from .. import crud
from ..models import CONTROLLED_TERM_REGEX
from ..models import CONTROLLED_TERM_REGEX, CombinedAttributeResponse

router = APIRouter(prefix="/attributes", tags=["attributes"])


@router.get("/{data_element_URI}")
async def get_terms(data_element_URI: constr(regex=CONTROLLED_TERM_REGEX)):
# We use the Response parameter below to change the status code of the response while still being able to validate the returned data using the response model.
# (see https://fastapi.tiangolo.com/advanced/response-change-status-code/ for more info).
#
# TODO: if our response model for fully successful vs. not fully successful responses grows more complex in the future,
# consider additionally using https://fastapi.tiangolo.com/advanced/additional-responses/#additional-response-with-model to document
# example responses for different status codes in the OpenAPI docs (less relevant for now since there is only one response model).
@router.get("/{data_element_URI}", response_model=CombinedAttributeResponse)
async def get_terms(
data_element_URI: constr(regex=CONTROLLED_TERM_REGEX), response: Response
):
"""When a GET request is sent, return a list dicts with the only key corresponding to controlled term of a neurobagel class and value corresponding to all the available terms."""
response = await crud.get_terms(data_element_URI)
response_dict = await crud.get_terms(data_element_URI)

return response
if response_dict["errors"]:
response.status_code = status.HTTP_207_MULTI_STATUS

return response_dict
19 changes: 15 additions & 4 deletions app/api/routers/query.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
"""Router for query path operations."""

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Response, status

from .. import crud
from ..models import CombinedQueryResponse, QueryModel

router = APIRouter(prefix="/query", tags=["query"])


# We use the Response parameter below to change the status code of the response while still being able to validate the returned data using the response model.
# (see https://fastapi.tiangolo.com/advanced/response-change-status-code/ for more info).
#
# TODO: if our response model for fully successful vs. not fully successful responses grows more complex in the future,
# consider additionally using https://fastapi.tiangolo.com/advanced/additional-responses/#additional-response-with-model to document
# example responses for different status codes in the OpenAPI docs (less relevant for now since there is only one response model).
@router.get("/", response_model=CombinedQueryResponse)
async def get_query(query: QueryModel = Depends(QueryModel)):
async def get_query(
response: Response, query: QueryModel = Depends(QueryModel)
):
"""When a GET request is sent, return list of dicts corresponding to subject-level metadata aggregated by dataset."""
response = await crud.get(
response_dict = await crud.get(
query.min_age,
query.max_age,
query.sex,
Expand All @@ -24,4 +32,7 @@ async def get_query(query: QueryModel = Depends(QueryModel)):
query.node_url,
)

return response
if response_dict["errors"]:
response.status_code = status.HTTP_207_MULTI_STATUS

return response_dict
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,37 @@
import httpx
import pytest
from starlette.testclient import TestClient

from app.api import utility as util
from app.main import app


@pytest.fixture(scope="module")
def test_app():
client = TestClient(app)
yield client


@pytest.fixture(scope="function")
def set_valid_test_federation_nodes(monkeypatch):
"""Set two correctly formatted federation nodes for a test function (mocks the result of reading/parsing available public and local nodes on startup)."""
monkeypatch.setattr(
util,
"FEDERATION_NODES",
{
"https://firstpublicnode.org/": "First Public Node",
"https://secondpublicnode.org/": "Second Public Node",
},
)


@pytest.fixture()
def mock_failed_connection_httpx_get():
"""Return a mock for the httpx.AsyncClient.get method that raises a ConnectError when called."""

async def _mock_httpx_get_with_connect_error(self, **kwargs):
# The self parameter is necessary to match the signature of the method being mocked,
# which is a class method of the httpx.AsyncClient class (see https://www.python-httpx.org/api/#asyncclient).
raise httpx.ConnectError("Some connection error")

return _mock_httpx_get_with_connect_error
83 changes: 83 additions & 0 deletions tests/test_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import httpx
import pytest
from fastapi import status


def test_partially_failed_terms_fetching_handled_gracefully(
test_app,
monkeypatch,
set_valid_test_federation_nodes,
):
"""
When some nodes fail while getting term instances for an attribute (/attribute/{data_element_URI}),
the overall API get request still succeeds, and the response includes a list of the encountered errors along with the successfully fetched terms.
"""
mocked_node_attribute_response = {
"nb:Assessment": [
{
"TermURL": "cogatlas:trm_56a9137d9dce1",
"Label": "behavioral approach/inhibition systems",
},
{
"TermURL": "cogatlas:trm_55a6a8e81b7f4",
"Label": "Barratt Impulsiveness Scale",
},
]
}

async def mock_httpx_get(self, **kwargs):
# The self parameter is necessary to match the signature of the method being mocked,
# which is a class method of the httpx.AsyncClient class (see https://www.python-httpx.org/api/#asyncclient).
if (
kwargs["url"]
== "https://secondpublicnode.org/attributes/nb:Assessment"
):
return httpx.Response(
status_code=500, json={}, text="Some internal server error"
)
return httpx.Response(
status_code=200,
json=mocked_node_attribute_response,
)

monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)

with pytest.warns(UserWarning):
response = test_app.get("/attributes/nb:Assessment")

assert response.status_code == status.HTTP_207_MULTI_STATUS

response_object = response.json()
assert response_object["errors"] == [
{
"node_name": "Second Public Node",
"error": "Internal Server Error: Some internal server error",
}
]
assert response_object["responses"] == mocked_node_attribute_response
assert response_object["nodes_response_status"] == "partial success"


def test_fully_failed_terms_fetching_handled_gracefully(
test_app,
monkeypatch,
mock_failed_connection_httpx_get,
set_valid_test_federation_nodes,
):
"""
When *all* nodes fail while getting term instances for an attribute (/attribute/{data_element_URI}),
the overall API get request still succeeds, but includes an overall failure status and all encountered errors in the response.
"""
monkeypatch.setattr(
httpx.AsyncClient, "get", mock_failed_connection_httpx_get
)

with pytest.warns(UserWarning):
response = test_app.get("/attributes/nb:Assessment")

assert response.status_code == status.HTTP_207_MULTI_STATUS

response = response.json()
assert response["nodes_response_status"] == "fail"
assert len(response["errors"]) == 2
assert response["responses"] == {"nb:Assessment": []}
Loading

0 comments on commit b33c34c

Please sign in to comment.