diff --git a/README.md b/README.md index f4f6350..bbb933d 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,13 @@ # Neurobagel API -![GitHub branch check runs](https://img.shields.io/github/check-runs/neurobagel/api/main?style=flat-square) -![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/neurobagel/api/test.yaml?branch=main&style=flat-square&label=tests) -![Codecov](https://img.shields.io/codecov/c/github/neurobagel/api?token=ZEOGQFFZMJ&style=flat-square&logo=codecov&link=https%3A%2F%2Fcodecov.io%2Fgh%2Fneurobagel%2Fapi) -![Static Badge](https://img.shields.io/badge/python-3.10-blue?style=flat-square&logo=python) -![GitHub License](https://img.shields.io/github/license/neurobagel/api?style=flat-square&color=purple&link=LICENSE) -![Docker Image Version (tag)](https://img.shields.io/docker/v/neurobagel/api/latest?style=flat-square&logo=docker&link=https%3A%2F%2Fhub.docker.com%2Fr%2Fneurobagel%2Fapi%2Ftags) -![Docker Pulls](https://img.shields.io/docker/pulls/neurobagel/api?style=flat-square&logo=docker&link=https%3A%2F%2Fhub.docker.com%2Fr%2Fneurobagel%2Fapi%2Ftags) +[![Main branch check status](https://img.shields.io/github/check-runs/neurobagel/api/main?style=flat-square&logo=github)](https://github.com/neurobagel/api/actions?query=branch:main) +[![Tests Status](https://img.shields.io/github/actions/workflow/status/neurobagel/api/test.yaml?branch=main&style=flat-square&logo=github&label=tests)](https://github.com/neurobagel/api/actions/workflows/test.yaml) +[![Codecov](https://img.shields.io/codecov/c/github/neurobagel/api?token=ZEOGQFFZMJ&style=flat-square&logo=codecov&link=https%3A%2F%2Fcodecov.io%2Fgh%2Fneurobagel%2Fapi)](https://app.codecov.io/gh/neurobagel/api) +[![Python versions static](https://img.shields.io/badge/python-3.10-blue?style=flat-square&logo=python)](https://www.python.org) +[![License](https://img.shields.io/github/license/neurobagel/api?style=flat-square&color=purple&link=LICENSE)](LICENSE) +[![Docker Image Version (tag)](https://img.shields.io/docker/v/neurobagel/api/latest?style=flat-square&logo=docker&link=https%3A%2F%2Fhub.docker.com%2Fr%2Fneurobagel%2Fapi%2Ftags)](https://hub.docker.com/r/neurobagel/api/tags) +[![Docker Pulls](https://img.shields.io/docker/pulls/neurobagel/api?style=flat-square&logo=docker&link=https%3A%2F%2Fhub.docker.com%2Fr%2Fneurobagel%2Fapi%2Ftags)](https://hub.docker.com/r/neurobagel/api/tags) diff --git a/app/api/config.py b/app/api/config.py new file mode 100644 index 0000000..67d6629 --- /dev/null +++ b/app/api/config.py @@ -0,0 +1,31 @@ +"""Configuration environment variables for the API.""" + +from pydantic import Field, computed_field +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + """Data model for configurable API settings.""" + + # NOTE: Environment variables are case-insensitive by default + # (see https://docs.pydantic.dev/latest/concepts/pydantic_settings/#case-sensitivity) + root_path: str = Field(alias="NB_NAPI_BASE_PATH", default="") + allowed_origins: str = Field(alias="NB_API_ALLOWED_ORIGINS", default="") + graph_username: str | None = Field(alias="NB_GRAPH_USERNAME", default=None) + graph_password: str | None = Field(alias="NB_GRAPH_PASSWORD", default=None) + graph_address: str = Field(alias="NB_GRAPH_ADDRESS", default="127.0.0.1") + graph_db: str = Field(alias="NB_GRAPH_DB", default="repositories/my_db") + graph_port: int = Field(alias="NB_GRAPH_PORT", default=7200) + return_agg: bool = Field(alias="NB_RETURN_AGG", default=True) + min_cell_size: int = Field(alias="NB_MIN_CELL_SIZE", default=0) + auth_enabled: bool = Field(alias="NB_ENABLE_AUTH", default=True) + client_id: str | None = Field(alias="NB_QUERY_CLIENT_ID", default=None) + + @computed_field + @property + def query_url(self) -> str: + """Construct the URL of the graph store to be queried.""" + return f"http://{self.graph_address}:{self.graph_port}/{self.graph_db}" + + +settings = Settings() diff --git a/app/api/crud.py b/app/api/crud.py index 4670926..a003f23 100644 --- a/app/api/crud.py +++ b/app/api/crud.py @@ -1,6 +1,5 @@ """CRUD functions called by path operations.""" -import os import warnings from pathlib import Path @@ -10,9 +9,10 @@ from fastapi import HTTPException, status from . import utility as util +from .config import settings from .models import CohortQueryResponse, SessionResponse, VocabLabelsResponse -ALL_SUBJECT_ATTRIBUTES = list(SessionResponse.__fields__.keys()) + [ +ALL_SUBJECT_ATTRIBUTES = list(SessionResponse.model_fields.keys()) + [ "dataset_uuid", "dataset_name", "dataset_portal_uri", @@ -41,12 +41,12 @@ def post_query_to_graph(query: str, timeout: float = None) -> dict: """ try: response = httpx.post( - url=util.QUERY_URL, + url=settings.query_url, content=query, headers=util.QUERY_HEADER, auth=httpx.BasicAuth( - os.environ.get(util.GRAPH_USERNAME.name), - os.environ.get(util.GRAPH_PASSWORD.name), + settings.graph_username, + settings.graph_password, ), timeout=timeout, ) @@ -141,7 +141,7 @@ async def get( """ results = post_query_to_graph( util.create_query( - return_agg=util.RETURN_AGG.val, + return_agg=settings.return_agg, age=(min_age, max_age), sex=sex, diagnosis=diagnosis, @@ -177,9 +177,9 @@ async def get( # results for datasets with fewer than min_cell_count subjects. But # ideally we would handle this directly inside SPARQL so we don't even # get the results in the first place. See #267 for a solution. - if num_matching_subjects <= util.MIN_CELL_SIZE.val: + if num_matching_subjects <= settings.min_cell_size: continue - if util.RETURN_AGG.val: + if settings.return_agg: subject_data = "protected" else: subject_data = ( @@ -294,7 +294,7 @@ async def get( else None ), num_matching_subjects=num_matching_subjects, - records_protected=util.RETURN_AGG.val, + records_protected=settings.return_agg, subject_data=subject_data, image_modals=list( group["image_modal"][ diff --git a/app/api/routers/query.py b/app/api/routers/query.py index 0d8f1eb..af43c75 100644 --- a/app/api/routers/query.py +++ b/app/api/routers/query.py @@ -5,7 +5,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi.security import OAuth2 -from .. import crud, security +from .. import crud +from ..config import settings from ..models import CohortQueryResponse, QueryModel from ..security import verify_token @@ -36,7 +37,7 @@ async def get_query( token: str | None = Depends(oauth2_scheme), ): """When a GET request is sent, return list of dicts corresponding to subject-level metadata aggregated by dataset.""" - if security.AUTH_ENABLED: + if settings.auth_enabled: if token is None: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/app/api/security.py b/app/api/security.py index 551bf92..90e0768 100644 --- a/app/api/security.py +++ b/app/api/security.py @@ -1,14 +1,11 @@ """Functions for handling authentication. Same ones as used in Neurobagel's federation API.""" -import os - import jwt from fastapi import HTTPException, status from fastapi.security.utils import get_authorization_scheme_param from jwt import PyJWKClient, PyJWTError -AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true" -CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None) +from .config import Settings, settings KEYS_URL = "https://neurobagel.ca.auth0.com/.well-known/jwks.json" ISSUER = "https://neurobagel.ca.auth0.com/" @@ -19,11 +16,11 @@ def check_client_id(): - """Check if the CLIENT_ID environment variable is set.""" - # The CLIENT_ID is needed to verify the audience claim of ID tokens. - if AUTH_ENABLED and CLIENT_ID is None: + """Check if the app client ID environment variable is set.""" + # The client ID is needed to verify the audience claim of ID tokens. + if settings.auth_enabled and settings.client_id is None: raise RuntimeError( - "Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. " + f"Authentication has been enabled ({Settings.model_fields['auth_enabled'].alias}) but the environment variable NB_QUERY_CLIENT_ID is not set. " "Please set NB_QUERY_CLIENT_ID to the client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens." ) @@ -47,7 +44,7 @@ def verify_token(token: str): "verify_signature": True, "require": ["aud", "iss", "exp", "iat"], }, - audience=CLIENT_ID, + audience=settings.client_id, issuer=ISSUER, ) except (PyJWTError, ValueError) as exc: diff --git a/app/api/utility.py b/app/api/utility.py index 2c4d22a..ff34ba2 100644 --- a/app/api/utility.py +++ b/app/api/utility.py @@ -1,47 +1,11 @@ """Constants for graph server connection and utility functions for writing the SPARQL query.""" import json -import os import textwrap from collections import namedtuple from pathlib import Path from typing import Optional -# Request constants -EnvVar = namedtuple("EnvVar", ["name", "val"]) - -ROOT_PATH = EnvVar( - "NB_NAPI_BASE_PATH", os.environ.get("NB_NAPI_BASE_PATH", "") -) - -ALLOWED_ORIGINS = EnvVar( - "NB_API_ALLOWED_ORIGINS", os.environ.get("NB_API_ALLOWED_ORIGINS", "") -) - -GRAPH_USERNAME = EnvVar( - "NB_GRAPH_USERNAME", os.environ.get("NB_GRAPH_USERNAME") -) -GRAPH_PASSWORD = EnvVar( - "NB_GRAPH_PASSWORD", os.environ.get("NB_GRAPH_PASSWORD") -) -GRAPH_ADDRESS = EnvVar( - "NB_GRAPH_ADDRESS", os.environ.get("NB_GRAPH_ADDRESS", "127.0.0.1") -) -GRAPH_DB = EnvVar( - "NB_GRAPH_DB", os.environ.get("NB_GRAPH_DB", "repositories/my_db") -) -GRAPH_PORT = EnvVar("NB_GRAPH_PORT", os.environ.get("NB_GRAPH_PORT", 7200)) -# TODO: Environment variables can't be parsed as bool so this is a workaround but isn't ideal. -# Another option is to switch this to a command-line argument, but that would require changing the -# Dockerfile also since Uvicorn can't accept custom command-line args. -RETURN_AGG = EnvVar( - "NB_RETURN_AGG", os.environ.get("NB_RETURN_AGG", "True").lower() == "true" -) -MIN_CELL_SIZE = EnvVar( - "NB_MIN_CELL_SIZE", int(os.environ.get("NB_MIN_CELL_SIZE", 0)) -) - -QUERY_URL = f"http://{GRAPH_ADDRESS.val}:{GRAPH_PORT.val}/{GRAPH_DB.val}" QUERY_HEADER = { "Content-Type": "application/sparql-query", "Accept": "application/sparql-results+json", @@ -303,7 +267,7 @@ def create_query( """ ) - # The query defined above will return all subject-level attributes from the graph. If RETURN_AGG variable has been set to true, + # The query defined above will return all subject-level attributes from the graph. If aggregate results have been enabled, # wrap query in an aggregating statement so data returned from graph include only attributes needed for dataset-level aggregate metadata. if return_agg: query_string = ( diff --git a/app/main.py b/app/main.py index 1273115..6c31276 100644 --- a/app/main.py +++ b/app/main.py @@ -1,8 +1,9 @@ """Main app.""" -import os import warnings from contextlib import asynccontextmanager + +# from functools import lru_cache from pathlib import Path from tempfile import TemporaryDirectory @@ -13,9 +14,14 @@ from fastapi.responses import HTMLResponse, ORJSONResponse, RedirectResponse from .api import utility as util +from .api.config import Settings, settings from .api.routers import assessments, attributes, diagnoses, pipelines, query from .api.security import check_client_id +# @lru_cache +# def get_settings(): +# return Settings() + def validate_environment_variables(): """ @@ -24,24 +30,21 @@ def validate_environment_variables(): Ensures that the username and password for the graph database are provided. If not, raises a RuntimeError to prevent the application from running without valid credentials. - Also checks that ALLOWED_ORIGINS is properly set. If missing, a warning is issued, but the app continues running. + Also checks that allowed origins are properly set. If missing, a warning is issued, but the app continues running. """ - if ( - os.environ.get(util.GRAPH_USERNAME.name) is None - or os.environ.get(util.GRAPH_PASSWORD.name) is None - ): + if settings.graph_username is None or settings.graph_password is None: raise RuntimeError( - f"The application was launched but could not find the {util.GRAPH_USERNAME.name} and / or {util.GRAPH_PASSWORD.name} environment variables." + f"The application was launched but could not find the {Settings.model_fields['graph_username'].alias} and / or {Settings.model_fields['graph_password'].alias} environment variables." ) - if os.environ.get(util.ALLOWED_ORIGINS.name, "") == "": + if settings.allowed_origins == "": warnings.warn( - f"The API was launched without providing any values for the {util.ALLOWED_ORIGINS.name} environment " + f"The API was launched without providing any values for the {Settings.model_fields['allowed_origins'].alias} environment " f"variable." "This means that the API will only be accessible from the same origin it is hosted from: " "https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy." f"If you want to access the API from tools hosted at other origins such as the Neurobagel query tool, " - f"explicitly set the value of {util.ALLOWED_ORIGINS.name} to the origin(s) of these tools (e.g. " + f"explicitly set the value of {Settings.model_fields['allowed_origins'].alias} to the origin(s) of these tools (e.g. " f"http://localhost:3000)." "Multiple allowed origins should be separated with spaces in a single string enclosed in quotes." ) @@ -103,7 +106,7 @@ async def lifespan(app: FastAPI): app = FastAPI( - root_path=util.ROOT_PATH.val, + root_path=settings.root_path, lifespan=lifespan, default_response_class=ORJSONResponse, docs_url=None, @@ -115,7 +118,7 @@ async def lifespan(app: FastAPI): app.add_middleware( CORSMiddleware, - allow_origins=util.parse_origins_as_list(util.ALLOWED_ORIGINS.val), + allow_origins=util.parse_origins_as_list(settings.allowed_origins), allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/pytest.ini b/pytest.ini index 7eefaf5..8994749 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,3 +3,18 @@ markers = integration: mark integration tests that need the test graph to run ; Default to not running tests with the integration marker addopts = -m "not integration" +; NOTE: The following environment variables are set to non-default values before any tests are run, +; allowing us to actually test that user-provided values are parsed accurately from the environment +; while avoiding issues related to import order in tests. +; In individual tests, these values may then be overridden as needed to test downstream logic +; by monkeypatching attributes of the global settings object +env = + NB_API_ALLOWED_ORIGINS=* + NB_GRAPH_USERNAME=DBUSER + NB_GRAPH_PASSWORD=DBPASSWORD + NB_GRAPH_PORT=7201 + NB_RETURN_AGG=False + ; NOTE: We set the minimum cell size to a different value than the default (0) here + ; to confirm that it is still read in correctly as an int + ; but we choose 1 so as to avoid filtering out any results being returned as part of the tests + NB_MIN_CELL_SIZE=1 diff --git a/requirements.txt b/requirements.txt index fd6ee0d..cf8e66b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ orjson==3.9.15 packaging==21.3 pandas==1.5.2 platformdirs==2.5.4 -pluggy==1.0.0 +pluggy==1.5.0 pre-commit==3.6.0 pyasn1==0.6.0 pyasn1_modules==0.4.0 @@ -36,7 +36,8 @@ pydantic-settings==2.7.1 pydantic_core==2.27.2 PyJWT==2.10.1 pyparsing==3.0.9 -pytest==7.2.0 +pytest==8.3.4 +pytest-env==1.1.5 python-dateutil==2.8.2 python-dotenv==1.0.1 pytz==2022.7 diff --git a/tests/conftest.py b/tests/conftest.py index 3dd63d3..46bd14e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,7 @@ import pytest from starlette.testclient import TestClient -from app.api import utility as util -from app.main import app +from app.main import app, settings @pytest.fixture(scope="module") @@ -14,7 +13,7 @@ def test_app(): @pytest.fixture def enable_auth(monkeypatch): """Enable the authentication requirement for the API.""" - monkeypatch.setattr("app.api.security.AUTH_ENABLED", True) + monkeypatch.setattr(settings, "auth_enabled", True) @pytest.fixture @@ -23,14 +22,7 @@ def disable_auth(monkeypatch): Disable the authentication requirement for the API to skip startup checks (for when the tested route does not require authentication). """ - monkeypatch.setattr("app.api.security.AUTH_ENABLED", False) - - -@pytest.fixture(scope="function") -def set_test_credentials(monkeypatch): - """Set random username and password to avoid error from startup check for set credentials.""" - monkeypatch.setenv(util.GRAPH_USERNAME.name, "DBUSER") - monkeypatch.setenv(util.GRAPH_PASSWORD.name, "DBPASSWORD") + monkeypatch.setattr(settings, "auth_enabled", False) @pytest.fixture() @@ -57,6 +49,19 @@ def mock_auth_header() -> dict: return {"Authorization": "Bearer foo"} +@pytest.fixture() +def set_graph_url_vars_for_integration_tests(monkeypatch): + """ + Set the graph URL to the default value for integration tests. + + NOTE: These should correspond to the default configuration values, but are set explicitly here for clarity and + to override any environment defined in pytest.ini. + """ + monkeypatch.setattr(settings, "graph_address", "localhost") + monkeypatch.setattr(settings, "graph_port", 7200) + monkeypatch.setattr(settings, "graph_db", "repositories/my_db") + + @pytest.fixture() def test_data(): """Create valid aggregate response data for two toy datasets for testing.""" diff --git a/tests/test_app_events.py b/tests/test_app_events.py index 5595a9f..36f3277 100644 --- a/tests/test_app_events.py +++ b/tests/test_app_events.py @@ -1,66 +1,46 @@ """Test events occurring on app startup or shutdown.""" -import os import warnings -import httpx import pytest from app.api import utility as util +from app.main import settings @pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") def test_start_app_without_environment_vars_fails( - test_app, monkeypatch, disable_auth + test_app, disable_auth, monkeypatch ): """Given non-existing username and password environment variables, raises an informative RuntimeError.""" - monkeypatch.delenv(util.GRAPH_USERNAME.name, raising=False) - monkeypatch.delenv(util.GRAPH_PASSWORD.name, raising=False) + monkeypatch.setattr(settings, "graph_username", None) + monkeypatch.setattr(settings, "graph_password", None) with pytest.raises(RuntimeError) as e_info: with test_app: pass assert ( - f"could not find the {util.GRAPH_USERNAME.name} and / or {util.GRAPH_PASSWORD.name} environment variables" + "could not find the NB_GRAPH_USERNAME and / or NB_GRAPH_PASSWORD environment variables" in str(e_info.value) ) -@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") -def test_app_with_invalid_environment_vars( - test_app, monkeypatch, mock_auth_header, set_mock_verify_token -): - """Given invalid environment variables for the graph, returns a 401 status code.""" - monkeypatch.setenv(util.GRAPH_USERNAME.name, "something") - monkeypatch.setenv(util.GRAPH_PASSWORD.name, "cool") - - def mock_httpx_post(**kwargs): - return httpx.Response(status_code=401) - - monkeypatch.setattr(httpx, "post", mock_httpx_post) - response = test_app.get("/query", headers=mock_auth_header) - assert response.status_code == 401 - - def test_app_with_unset_allowed_origins( test_app, - monkeypatch, - set_test_credentials, disable_auth, + monkeypatch, ): """Tests that when the environment variable for allowed origins has not been set, a warning is raised and the app uses a default value.""" - monkeypatch.delenv(util.ALLOWED_ORIGINS.name, raising=False) + monkeypatch.setattr(settings, "allowed_origins", "") with pytest.warns( UserWarning, - match=f"API was launched without providing any values for the {util.ALLOWED_ORIGINS.name} environment variable", + match="API was launched without providing any values for the NB_API_ALLOWED_ORIGINS environment variable", ): with test_app: pass - assert util.parse_origins_as_list( - os.environ.get(util.ALLOWED_ORIGINS.name, "") - ) == [""] + assert util.parse_origins_as_list(settings.allowed_origins) == [""] @pytest.mark.parametrize( @@ -71,7 +51,7 @@ def test_app_with_unset_allowed_origins( [""], pytest.warns( UserWarning, - match=f"API was launched without providing any values for the {util.ALLOWED_ORIGINS.name} environment variable", + match="API was launched without providing any values for the NB_API_ALLOWED_ORIGINS environment variable", ), ), ( @@ -94,7 +74,6 @@ def test_app_with_unset_allowed_origins( def test_app_with_set_allowed_origins( test_app, monkeypatch, - set_test_credentials, allowed_origins, parsed_origins, expectation, @@ -104,16 +83,14 @@ def test_app_with_set_allowed_origins( Test that when the environment variable for allowed origins has been explicitly set, the app correctly parses it into a list and raises a warning if the value is an empty string. """ - monkeypatch.setenv(util.ALLOWED_ORIGINS.name, allowed_origins) + monkeypatch.setattr(settings, "allowed_origins", allowed_origins) with expectation: with test_app: pass assert set(parsed_origins).issubset( - util.parse_origins_as_list( - os.environ.get(util.ALLOWED_ORIGINS.name, "") - ) + util.parse_origins_as_list(settings.allowed_origins) ) @@ -124,7 +101,6 @@ def test_app_with_set_allowed_origins( @pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") def test_stored_vocab_lookup_file_created_on_startup( test_app, - set_test_credentials, disable_auth, lookup_file, ): diff --git a/tests/test_attribute_factory_routes.py b/tests/test_attribute_factory_routes.py index d2f3a41..6dccffc 100644 --- a/tests/test_attribute_factory_routes.py +++ b/tests/test_attribute_factory_routes.py @@ -7,7 +7,6 @@ def test_get_instances_endpoint_with_vocab_lookup( test_app, monkeypatch, - set_test_credentials, # Since this test runs the API startup events to fetch the vocabularies used in the test, # we need to disable auth to avoid startup errors about unset auth-related environment variables. disable_auth, @@ -68,9 +67,7 @@ def mock_httpx_post(**kwargs): } -def test_get_instances_endpoint_without_vocab_lookup( - test_app, monkeypatch, set_test_credentials -): +def test_get_instances_endpoint_without_vocab_lookup(test_app, monkeypatch): """ Given a GET request to /pipelines/ (attribute without a vocabulary lookup file available), test that the endpoint correctly returns the found graph instances as prefixed term URIs with empty label fields. @@ -130,7 +127,6 @@ def mock_httpx_post(**kwargs): def test_get_vocab_endpoint( test_app, monkeypatch, - set_test_credentials, attribute, expected_vocab_name, expected_namespace_pfx, diff --git a/tests/test_attributes.py b/tests/test_attributes.py index e671fda..f442504 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -6,7 +6,6 @@ def test_get_attributes( test_app, monkeypatch, - set_test_credentials, ): """Given a GET request to the /attributes endpoint, successfully returns controlled term attributes with namespaces abbrieviated and as a list.""" mock_response_json = { diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index ac436e0..a4fbe9f 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -3,9 +3,7 @@ BASE_ROUTE = "/pipelines" -def test_get_pipeline_versions_response( - test_app, monkeypatch, set_test_credentials -): +def test_get_pipeline_versions_response(test_app, monkeypatch): """ Given a request to /pipelines/{pipeline_term}/versions with a valid pipeline name, returns a dict where the key is the pipeline resource and the value is a list of pipeline versions. diff --git a/tests/test_query.py b/tests/test_query.py index 60d0218..dbad33e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -3,9 +3,9 @@ import pytest from fastapi import HTTPException -import app.api.utility as util from app.api import crud from app.api.models import QueryModel +from app.main import settings ROUTE = "/query" @@ -65,9 +65,7 @@ def test_null_modalities( set_mock_verify_token, ): """Given a response containing a dataset with no recorded modalities, returns an empty list for the imaging modalities.""" - monkeypatch.setattr( - util, "RETURN_AGG", util.EnvVar(util.RETURN_AGG.name, True) - ) + monkeypatch.setattr(settings, "return_agg", True) monkeypatch.setattr( crud, "post_query_to_graph", mock_post_agg_query_to_graph ) @@ -617,7 +615,6 @@ def test_get_valid_pipeline_name_version( def test_aggregate_query_response_structure( test_app, - set_test_credentials, mock_post_agg_query_to_graph, mock_query_matching_dataset_sizes, monkeypatch, @@ -625,9 +622,7 @@ def test_aggregate_query_response_structure( set_mock_verify_token, ): """Test that when aggregate results are enabled, a cohort query response has the expected structure.""" - monkeypatch.setattr( - util, "RETURN_AGG", util.EnvVar(util.RETURN_AGG.name, True) - ) + monkeypatch.setattr(settings, "return_agg", True) monkeypatch.setattr( crud, "post_query_to_graph", mock_post_agg_query_to_graph ) @@ -642,11 +637,7 @@ def test_aggregate_query_response_structure( def test_query_without_token_succeeds_when_auth_disabled( - test_app, - mock_successful_get, - monkeypatch, - disable_auth, - set_test_credentials, + test_app, mock_successful_get, monkeypatch, disable_auth ): """ Test that when authentication is disabled, a request to the /query route without a token succeeds. @@ -656,20 +647,33 @@ def test_query_without_token_succeeds_when_auth_disabled( assert response.status_code == 200 +@pytest.mark.integration +@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") +def test_app_with_invalid_environment_vars( + test_app, + monkeypatch, + disable_auth, + set_graph_url_vars_for_integration_tests, +): + """Given invalid credentials for the graph, returns a 401 status code.""" + monkeypatch.setattr(settings, "graph_username", "wrong_username") + monkeypatch.setattr(settings, "graph_password", "wrong_password") + + response = test_app.get("/query") + assert response.status_code == 401 + + @pytest.mark.integration def test_integration_query_without_auth_succeeds( - test_app, monkeypatch, disable_auth, set_test_credentials + test_app, + monkeypatch, + disable_auth, + set_graph_url_vars_for_integration_tests, ): """ Running a test against a real local test graph should succeed when authentication is disabled. """ - # Patching the QUERY_URL directly means we don't need to worry about the constituent - # graph environment variables - monkeypatch.setattr( - util, "QUERY_URL", "http://localhost:7200/repositories/my_db" - ) - response = test_app.get(ROUTE) assert response.status_code == 200 @@ -686,9 +690,7 @@ def test_derivatives_info_handled_by_agg_api_response( Test that in the aggregated API mode, pipeline information for matching subjects is correctly parsed and formatted in the API response. """ - monkeypatch.setattr( - util, "RETURN_AGG", util.EnvVar(util.RETURN_AGG.name, True) - ) + monkeypatch.setattr(settings, "return_agg", True) monkeypatch.setattr( crud, "post_query_to_graph", mock_post_agg_query_to_graph ) @@ -719,9 +721,7 @@ def test_missing_derivatives_info_handled_by_nonagg_api_response( Test that in the non-aggregated API mode, when all matching subjects lack pipeline information, the API does not error out and pipeline variables in the API response still have the expected structure. """ - monkeypatch.setattr( - util, "RETURN_AGG", util.EnvVar(util.RETURN_AGG.name, False) - ) + monkeypatch.setattr(settings, "return_agg", False) monkeypatch.setattr( crud, "post_query_to_graph", mock_post_nonagg_query_to_graph ) @@ -740,17 +740,15 @@ def test_missing_derivatives_info_handled_by_nonagg_api_response( @pytest.mark.integration def test_only_imaging_and_phenotypic_sessions_returned_in_query_response( - test_app, monkeypatch, disable_auth, set_test_credentials + test_app, + monkeypatch, + disable_auth, + set_graph_url_vars_for_integration_tests, ): """ Test that only sessions of type PhenotypicSession and ImagingSession are returned in an unaggregated query response. """ - monkeypatch.setattr( - util, "RETURN_AGG", util.EnvVar(util.RETURN_AGG.name, False) - ) - monkeypatch.setattr( - util, "QUERY_URL", "http://localhost:7200/repositories/my_db" - ) + monkeypatch.setattr(settings, "return_agg", False) response = test_app.get(ROUTE) assert response.status_code == 200 @@ -773,17 +771,15 @@ def test_only_imaging_and_phenotypic_sessions_returned_in_query_response( @pytest.mark.integration def test_min_cell_size_removes_results( - test_app, monkeypatch, disable_auth, set_test_credentials + test_app, + monkeypatch, + disable_auth, + set_graph_url_vars_for_integration_tests, ): """ - If MIN_CELL_SIZE is high enough, all results should be filtered out + If the minimum cell size is large enough, all results should be filtered out """ - monkeypatch.setattr( - util, "MIN_CELL_SIZE", util.EnvVar(util.MIN_CELL_SIZE.name, 100) - ) - monkeypatch.setattr( - util, "QUERY_URL", "http://localhost:7200/repositories/my_db" - ) + monkeypatch.setattr(settings, "min_cell_size", 100) response = test_app.get(ROUTE) assert response.status_code == 200 diff --git a/tests/test_security.py b/tests/test_security.py index cd06de1..26c393f 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,18 +1,19 @@ import pytest from fastapi import HTTPException +from app.api.config import settings from app.api.security import verify_token @pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") def test_missing_client_id_raises_error_when_auth_enabled( - monkeypatch, test_app, set_test_credentials, enable_auth + monkeypatch, test_app, enable_auth ): """Test that a missing client ID raises an error on startup when authentication is enabled.""" - # We're using what should be default values of CLIENT_ID and AUTH_ENABLED here + # We're using what should be default values of client_id and auth_enabled here # (if the corresponding environment variables are unset), # but we set the values explicitly here for clarity - monkeypatch.setattr("app.api.security.CLIENT_ID", None) + monkeypatch.setattr(settings, "client_id", None) with pytest.raises(RuntimeError) as exc_info: with test_app: @@ -22,12 +23,10 @@ def test_missing_client_id_raises_error_when_auth_enabled( @pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") -def test_missing_client_id_ignored_when_auth_disabled( - monkeypatch, test_app, set_test_credentials -): +def test_missing_client_id_ignored_when_auth_disabled(monkeypatch, test_app): """Test that a missing client ID does not raise an error when authentication is disabled.""" - monkeypatch.setattr("app.api.security.CLIENT_ID", None) - monkeypatch.setattr("app.api.security.AUTH_ENABLED", False) + monkeypatch.setattr(settings, "client_id", None) + monkeypatch.setattr(settings, "auth_enabled", False) with test_app: pass @@ -61,7 +60,7 @@ def test_query_with_malformed_auth_header_fails( Test that when authentication is enabled, a request to the /query route with a missing or malformed authorization header fails. """ - monkeypatch.setattr("app.api.security.CLIENT_ID", "foo.id") + monkeypatch.setattr(settings, "client_id", "foo.id") response = test_app.get( "/query", diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 0000000..9bf3381 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,21 @@ +from app.api.config import Settings + + +def test_settings_read_correctly(): + """Ensure that settings are read correctly from environment variables, with correct types and default values.""" + settings = Settings() + + # Check that defaults are applied correctly for undefined environment variables + assert settings.root_path == "" + assert settings.graph_address == "127.0.0.1" + assert settings.graph_db == "repositories/my_db" + assert settings.auth_enabled is True + assert settings.client_id is None + + # Check that set environment variables are read and typed correctly + assert settings.allowed_origins == "*" + assert settings.graph_username == "DBUSER" + assert settings.graph_password == "DBPASSWORD" + assert settings.graph_port == 7201 + assert settings.return_agg is False + assert settings.min_cell_size == 1