diff --git a/app/api/crud.py b/app/api/crud.py index 2eecc30..9ad0422 100644 --- a/app/api/crud.py +++ b/app/api/crud.py @@ -5,11 +5,18 @@ from pathlib import Path import httpx +import numpy as np import pandas as pd from fastapi import HTTPException, status from . import utility as util -from .models import CohortQueryResponse, VocabLabelsResponse +from .models import CohortQueryResponse, SessionResponse, VocabLabelsResponse + +ALL_SUBJECT_ATTRIBUTES = list(SessionResponse.__fields__.keys()) + [ + "dataset_uuid", + "dataset_name", + "dataset_portal_uri", +] def post_query_to_graph(query: str, timeout: float = 30.0) -> dict: @@ -137,7 +144,13 @@ async def get( image_modal=image_modal, ) ) - results_df = pd.DataFrame(util.unpack_http_response_json_to_dicts(results)) + + # Reindexing is needed here because when a certain attribute is missing from all matching sessions, + # the attribute does not end up in the graph API response or the below resulting processed dataframe. + # Conforming the columns to a list of expected attributes ensures every subject-session has the same response shape from the node API. + results_df = pd.DataFrame( + util.unpack_http_response_json_to_dicts(results) + ).reindex(columns=ALL_SUBJECT_ATTRIBUTES) matching_dataset_sizes = query_matching_dataset_sizes( results_df["dataset_uuid"].unique() @@ -175,6 +188,20 @@ async def get( } ) ) + + # TODO: Revisit this as there may be a more elegant solution. + # The following code replaces columns with all NaN values with values of None, to ensure they show up in the final JSON as `null`. + # This is needed as the above .agg() seems to turn NaN into None for object-type columns (which have some non-missing values) + # but not for columns with all NaN, which end up with a column type of float64. This is a problem because + # if the column corresponds to a SessionResponse attribute with an expected str type, then the column values will be converted + # to the string "nan" in the response JSON, which we don't want. + all_nan_columns = subject_data.columns[ + subject_data.isna().all() + ] + subject_data[all_nan_columns] = subject_data[ + all_nan_columns + ].replace({np.nan: None}) + subject_data = list(subject_data.to_dict("records")) response_obj.append(