Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF] Move environment variable handling to Pydantic BaseSettings class #416

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
0002024
create Pydantic class for environment variables
alyssadai Feb 20, 2025
cb863c1
Merge remote-tracking branch 'origin/main' into refactor-env-vars
alyssadai Feb 20, 2025
39f4958
add aliases and defaults for env vars
alyssadai Feb 20, 2025
05b7237
update env var references in startup checks
alyssadai Feb 20, 2025
5741cad
update leftover deprecated pydantic 1 syntax
alyssadai Feb 20, 2025
f9ec034
test vars are read in using new settings class using pytest-env
alyssadai Feb 20, 2025
6c20845
monkeypatch settings class attributes in tests
alyssadai Feb 21, 2025
f912918
remove set_test_credentials fixture in favour of pytest-env
alyssadai Feb 21, 2025
f8371c6
update value of NB_MIN_CELL_SIZE set for tests
alyssadai Feb 21, 2025
2112208
remove leftover set_test_credentials calls
alyssadai Feb 21, 2025
392e92d
update type hints & restore default value validation for settings
alyssadai Feb 21, 2025
145b2f8
replace global env var variables with settings
alyssadai Feb 21, 2025
c051c85
use computed field for the query URL
alyssadai Feb 21, 2025
1a10bee
update docstrings
alyssadai Feb 21, 2025
9166c1a
update env var values for tests
alyssadai Feb 21, 2025
7fd4eb3
create fixture for graph URL vars for integration tests
alyssadai Feb 21, 2025
1b734ad
change test_app_with_invalid_environment_vars to integration test
alyssadai Feb 21, 2025
2a8d6d7
fix test
alyssadai Feb 21, 2025
cd43265
add links to README badges
alyssadai Feb 21, 2025
0659c2b
add github logo to README badges
alyssadai Feb 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

# Neurobagel API

![GitHub branch check runs](https://img.shields.io/github/check-runs/neurobagel/api/main?style=flat-square)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/neurobagel/api/test.yaml?branch=main&style=flat-square&label=tests)
![Codecov](https://img.shields.io/codecov/c/github/neurobagel/api?token=ZEOGQFFZMJ&style=flat-square&logo=codecov&link=https%3A%2F%2Fcodecov.io%2Fgh%2Fneurobagel%2Fapi)
![Static Badge](https://img.shields.io/badge/python-3.10-blue?style=flat-square&logo=python)
![GitHub License](https://img.shields.io/github/license/neurobagel/api?style=flat-square&color=purple&link=LICENSE)
![Docker Image Version (tag)](https://img.shields.io/docker/v/neurobagel/api/latest?style=flat-square&logo=docker&link=https%3A%2F%2Fhub.docker.com%2Fr%2Fneurobagel%2Fapi%2Ftags)
![Docker Pulls](https://img.shields.io/docker/pulls/neurobagel/api?style=flat-square&logo=docker&link=https%3A%2F%2Fhub.docker.com%2Fr%2Fneurobagel%2Fapi%2Ftags)
[![Main branch check status](https://img.shields.io/github/check-runs/neurobagel/api/main?style=flat-square&logo=github)](https://github.com/neurobagel/api/actions?query=branch:main)
[![Tests Status](https://img.shields.io/github/actions/workflow/status/neurobagel/api/test.yaml?branch=main&style=flat-square&logo=github&label=tests)](https://github.com/neurobagel/api/actions/workflows/test.yaml)
[![Codecov](https://img.shields.io/codecov/c/github/neurobagel/api?token=ZEOGQFFZMJ&style=flat-square&logo=codecov&link=https%3A%2F%2Fcodecov.io%2Fgh%2Fneurobagel%2Fapi)](https://app.codecov.io/gh/neurobagel/api)
[![Python versions static](https://img.shields.io/badge/python-3.10-blue?style=flat-square&logo=python)](https://www.python.org)
[![License](https://img.shields.io/github/license/neurobagel/api?style=flat-square&color=purple&link=LICENSE)](LICENSE)
[![Docker Image Version (tag)](https://img.shields.io/docker/v/neurobagel/api/latest?style=flat-square&logo=docker&link=https%3A%2F%2Fhub.docker.com%2Fr%2Fneurobagel%2Fapi%2Ftags)](https://hub.docker.com/r/neurobagel/api/tags)
[![Docker Pulls](https://img.shields.io/docker/pulls/neurobagel/api?style=flat-square&logo=docker&link=https%3A%2F%2Fhub.docker.com%2Fr%2Fneurobagel%2Fapi%2Ftags)](https://hub.docker.com/r/neurobagel/api/tags)

</div>

Expand Down
31 changes: 31 additions & 0 deletions app/api/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Configuration environment variables for the API."""

from pydantic import Field, computed_field
from pydantic_settings import BaseSettings


class Settings(BaseSettings):
"""Data model for configurable API settings."""

# NOTE: Environment variables are case-insensitive by default
# (see https://docs.pydantic.dev/latest/concepts/pydantic_settings/#case-sensitivity)
root_path: str = Field(alias="NB_NAPI_BASE_PATH", default="")
allowed_origins: str = Field(alias="NB_API_ALLOWED_ORIGINS", default="")
graph_username: str | None = Field(alias="NB_GRAPH_USERNAME", default=None)
graph_password: str | None = Field(alias="NB_GRAPH_PASSWORD", default=None)
graph_address: str = Field(alias="NB_GRAPH_ADDRESS", default="127.0.0.1")
graph_db: str = Field(alias="NB_GRAPH_DB", default="repositories/my_db")
graph_port: int = Field(alias="NB_GRAPH_PORT", default=7200)
return_agg: bool = Field(alias="NB_RETURN_AGG", default=True)
min_cell_size: int = Field(alias="NB_MIN_CELL_SIZE", default=0)
auth_enabled: bool = Field(alias="NB_ENABLE_AUTH", default=True)
client_id: str | None = Field(alias="NB_QUERY_CLIENT_ID", default=None)

@computed_field
@property
def query_url(self) -> str:
"""Construct the URL of the graph store to be queried."""
return f"http://{self.graph_address}:{self.graph_port}/{self.graph_db}"


settings = Settings()
18 changes: 9 additions & 9 deletions app/api/crud.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""CRUD functions called by path operations."""

import os
import warnings
from pathlib import Path

Expand All @@ -10,9 +9,10 @@
from fastapi import HTTPException, status

from . import utility as util
from .config import settings
from .models import CohortQueryResponse, SessionResponse, VocabLabelsResponse

ALL_SUBJECT_ATTRIBUTES = list(SessionResponse.__fields__.keys()) + [
ALL_SUBJECT_ATTRIBUTES = list(SessionResponse.model_fields.keys()) + [
"dataset_uuid",
"dataset_name",
"dataset_portal_uri",
Expand Down Expand Up @@ -41,12 +41,12 @@ def post_query_to_graph(query: str, timeout: float = None) -> dict:
"""
try:
response = httpx.post(
url=util.QUERY_URL,
url=settings.query_url,
content=query,
headers=util.QUERY_HEADER,
auth=httpx.BasicAuth(
os.environ.get(util.GRAPH_USERNAME.name),
os.environ.get(util.GRAPH_PASSWORD.name),
settings.graph_username,
settings.graph_password,
),
timeout=timeout,
)
Expand Down Expand Up @@ -141,7 +141,7 @@ async def get(
"""
results = post_query_to_graph(
util.create_query(
return_agg=util.RETURN_AGG.val,
return_agg=settings.return_agg,
age=(min_age, max_age),
sex=sex,
diagnosis=diagnosis,
Expand Down Expand Up @@ -177,9 +177,9 @@ async def get(
# results for datasets with fewer than min_cell_count subjects. But
# ideally we would handle this directly inside SPARQL so we don't even
# get the results in the first place. See #267 for a solution.
if num_matching_subjects <= util.MIN_CELL_SIZE.val:
if num_matching_subjects <= settings.min_cell_size:
continue
if util.RETURN_AGG.val:
if settings.return_agg:
subject_data = "protected"
else:
subject_data = (
Expand Down Expand Up @@ -294,7 +294,7 @@ async def get(
else None
),
num_matching_subjects=num_matching_subjects,
records_protected=util.RETURN_AGG.val,
records_protected=settings.return_agg,
subject_data=subject_data,
image_modals=list(
group["image_modal"][
Expand Down
5 changes: 3 additions & 2 deletions app/api/routers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.security import OAuth2

from .. import crud, security
from .. import crud
from ..config import settings
from ..models import CohortQueryResponse, QueryModel
from ..security import verify_token

Expand Down Expand Up @@ -36,7 +37,7 @@ async def get_query(
token: str | None = Depends(oauth2_scheme),
):
"""When a GET request is sent, return list of dicts corresponding to subject-level metadata aggregated by dataset."""
if security.AUTH_ENABLED:
if settings.auth_enabled:
if token is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
Expand Down
15 changes: 6 additions & 9 deletions app/api/security.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
"""Functions for handling authentication. Same ones as used in Neurobagel's federation API."""

import os

import jwt
from fastapi import HTTPException, status
from fastapi.security.utils import get_authorization_scheme_param
from jwt import PyJWKClient, PyJWTError

AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true"
CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None)
from .config import Settings, settings

KEYS_URL = "https://neurobagel.ca.auth0.com/.well-known/jwks.json"
ISSUER = "https://neurobagel.ca.auth0.com/"
Expand All @@ -19,11 +16,11 @@


def check_client_id():
"""Check if the CLIENT_ID environment variable is set."""
# The CLIENT_ID is needed to verify the audience claim of ID tokens.
if AUTH_ENABLED and CLIENT_ID is None:
"""Check if the app client ID environment variable is set."""
# The client ID is needed to verify the audience claim of ID tokens.
if settings.auth_enabled and settings.client_id is None:
raise RuntimeError(
"Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. "
f"Authentication has been enabled ({Settings.model_fields['auth_enabled'].alias}) but the environment variable NB_QUERY_CLIENT_ID is not set. "
"Please set NB_QUERY_CLIENT_ID to the client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens."
)

Expand All @@ -47,7 +44,7 @@ def verify_token(token: str):
"verify_signature": True,
"require": ["aud", "iss", "exp", "iat"],
},
audience=CLIENT_ID,
audience=settings.client_id,
issuer=ISSUER,
)
except (PyJWTError, ValueError) as exc:
Expand Down
38 changes: 1 addition & 37 deletions app/api/utility.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,11 @@
"""Constants for graph server connection and utility functions for writing the SPARQL query."""

import json
import os
import textwrap
from collections import namedtuple
from pathlib import Path
from typing import Optional

# Request constants
EnvVar = namedtuple("EnvVar", ["name", "val"])

ROOT_PATH = EnvVar(
"NB_NAPI_BASE_PATH", os.environ.get("NB_NAPI_BASE_PATH", "")
)

ALLOWED_ORIGINS = EnvVar(
"NB_API_ALLOWED_ORIGINS", os.environ.get("NB_API_ALLOWED_ORIGINS", "")
)

GRAPH_USERNAME = EnvVar(
"NB_GRAPH_USERNAME", os.environ.get("NB_GRAPH_USERNAME")
)
GRAPH_PASSWORD = EnvVar(
"NB_GRAPH_PASSWORD", os.environ.get("NB_GRAPH_PASSWORD")
)
GRAPH_ADDRESS = EnvVar(
"NB_GRAPH_ADDRESS", os.environ.get("NB_GRAPH_ADDRESS", "127.0.0.1")
)
GRAPH_DB = EnvVar(
"NB_GRAPH_DB", os.environ.get("NB_GRAPH_DB", "repositories/my_db")
)
GRAPH_PORT = EnvVar("NB_GRAPH_PORT", os.environ.get("NB_GRAPH_PORT", 7200))
# TODO: Environment variables can't be parsed as bool so this is a workaround but isn't ideal.
# Another option is to switch this to a command-line argument, but that would require changing the
# Dockerfile also since Uvicorn can't accept custom command-line args.
RETURN_AGG = EnvVar(
"NB_RETURN_AGG", os.environ.get("NB_RETURN_AGG", "True").lower() == "true"
)
MIN_CELL_SIZE = EnvVar(
"NB_MIN_CELL_SIZE", int(os.environ.get("NB_MIN_CELL_SIZE", 0))
)

QUERY_URL = f"http://{GRAPH_ADDRESS.val}:{GRAPH_PORT.val}/{GRAPH_DB.val}"
QUERY_HEADER = {
"Content-Type": "application/sparql-query",
"Accept": "application/sparql-results+json",
Expand Down Expand Up @@ -303,7 +267,7 @@ def create_query(
"""
)

# The query defined above will return all subject-level attributes from the graph. If RETURN_AGG variable has been set to true,
# The query defined above will return all subject-level attributes from the graph. If aggregate results have been enabled,
# wrap query in an aggregating statement so data returned from graph include only attributes needed for dataset-level aggregate metadata.
if return_agg:
query_string = (
Expand Down
27 changes: 15 additions & 12 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Main app."""

import os
import warnings
from contextlib import asynccontextmanager

# from functools import lru_cache
from pathlib import Path
from tempfile import TemporaryDirectory

Expand All @@ -13,9 +14,14 @@
from fastapi.responses import HTMLResponse, ORJSONResponse, RedirectResponse

from .api import utility as util
from .api.config import Settings, settings
from .api.routers import assessments, attributes, diagnoses, pipelines, query
from .api.security import check_client_id

# @lru_cache
# def get_settings():
# return Settings()


def validate_environment_variables():
"""
Expand All @@ -24,24 +30,21 @@ def validate_environment_variables():
Ensures that the username and password for the graph database are provided.
If not, raises a RuntimeError to prevent the application from running without valid credentials.

Also checks that ALLOWED_ORIGINS is properly set. If missing, a warning is issued, but the app continues running.
Also checks that allowed origins are properly set. If missing, a warning is issued, but the app continues running.
"""
if (
os.environ.get(util.GRAPH_USERNAME.name) is None
or os.environ.get(util.GRAPH_PASSWORD.name) is None
):
if settings.graph_username is None or settings.graph_password is None:
raise RuntimeError(
f"The application was launched but could not find the {util.GRAPH_USERNAME.name} and / or {util.GRAPH_PASSWORD.name} environment variables."
f"The application was launched but could not find the {Settings.model_fields['graph_username'].alias} and / or {Settings.model_fields['graph_password'].alias} environment variables."
)

if os.environ.get(util.ALLOWED_ORIGINS.name, "") == "":
if settings.allowed_origins == "":
warnings.warn(
f"The API was launched without providing any values for the {util.ALLOWED_ORIGINS.name} environment "
f"The API was launched without providing any values for the {Settings.model_fields['allowed_origins'].alias} environment "
f"variable."
"This means that the API will only be accessible from the same origin it is hosted from: "
"https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy."
f"If you want to access the API from tools hosted at other origins such as the Neurobagel query tool, "
f"explicitly set the value of {util.ALLOWED_ORIGINS.name} to the origin(s) of these tools (e.g. "
f"explicitly set the value of {Settings.model_fields['allowed_origins'].alias} to the origin(s) of these tools (e.g. "
f"http://localhost:3000)."
"Multiple allowed origins should be separated with spaces in a single string enclosed in quotes."
)
Expand Down Expand Up @@ -103,7 +106,7 @@ async def lifespan(app: FastAPI):


app = FastAPI(
root_path=util.ROOT_PATH.val,
root_path=settings.root_path,
lifespan=lifespan,
default_response_class=ORJSONResponse,
docs_url=None,
Expand All @@ -115,7 +118,7 @@ async def lifespan(app: FastAPI):

app.add_middleware(
CORSMiddleware,
allow_origins=util.parse_origins_as_list(util.ALLOWED_ORIGINS.val),
allow_origins=util.parse_origins_as_list(settings.allowed_origins),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
Expand Down
15 changes: 15 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,18 @@ markers =
integration: mark integration tests that need the test graph to run
; Default to not running tests with the integration marker
addopts = -m "not integration"
; NOTE: The following environment variables are set to non-default values before any tests are run,
; allowing us to actually test that user-provided values are parsed accurately from the environment
; while avoiding issues related to import order in tests.
; In individual tests, these values may then be overridden as needed to test downstream logic
; by monkeypatching attributes of the global settings object
env =
NB_API_ALLOWED_ORIGINS=*
NB_GRAPH_USERNAME=DBUSER
NB_GRAPH_PASSWORD=DBPASSWORD
NB_GRAPH_PORT=7201
NB_RETURN_AGG=False
; NOTE: We set the minimum cell size to a different value than the default (0) here
; to confirm that it is still read in correctly as an int
; but we choose 1 so as to avoid filtering out any results being returned as part of the tests
NB_MIN_CELL_SIZE=1
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ orjson==3.9.15
packaging==21.3
pandas==1.5.2
platformdirs==2.5.4
pluggy==1.0.0
pluggy==1.5.0
pre-commit==3.6.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
Expand All @@ -36,7 +36,8 @@ pydantic-settings==2.7.1
pydantic_core==2.27.2
PyJWT==2.10.1
pyparsing==3.0.9
pytest==7.2.0
pytest==8.3.4
pytest-env==1.1.5
python-dateutil==2.8.2
python-dotenv==1.0.1
pytz==2022.7
Expand Down
Loading