Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF] Refactor startup/shutdown events to lifespan events #407

Merged
merged 11 commits into from
Feb 20, 2025
5 changes: 4 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ repos:
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
args:
- "--profile=black"
- "--filter-files"
- "--line-length=79"

- repo: https://github.com/codespell-project/codespell
rev: v2.4.0
Expand Down
2 changes: 1 addition & 1 deletion app/api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def check_client_id():
"""Check if the CLIENT_ID environment variable is set."""
# The CLIENT_ID is needed to verify the audience claim of ID tokens.
if AUTH_ENABLED and CLIENT_ID is None:
raise ValueError(
raise RuntimeError(
"Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. "
"Please set NB_QUERY_CLIENT_ID to the client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens."
)
Expand Down
161 changes: 94 additions & 67 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import warnings
from contextlib import asynccontextmanager
from pathlib import Path
from tempfile import TemporaryDirectory

Expand All @@ -15,13 +16,101 @@
from .api.routers import assessments, attributes, diagnoses, pipelines, query
from .api.security import check_client_id


def validate_environment_variables():
"""
Check that all required environment variables are set.

Ensures that the username and password for the graph database are provided.
If not, raises a RuntimeError to prevent the application from running without valid credentials.

Also checks that ALLOWED_ORIGINS is properly set. If missing, a warning is issued, but the app continues running.
"""
if (
os.environ.get(util.GRAPH_USERNAME.name) is None
or os.environ.get(util.GRAPH_PASSWORD.name) is None
):
raise RuntimeError(
f"The application was launched but could not find the {util.GRAPH_USERNAME.name} and / or {util.GRAPH_PASSWORD.name} environment variables."
)

if os.environ.get(util.ALLOWED_ORIGINS.name, "") == "":
warnings.warn(
f"The API was launched without providing any values for the {util.ALLOWED_ORIGINS.name} environment "
f"variable."
"This means that the API will only be accessible from the same origin it is hosted from: "
"https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy."
f"If you want to access the API from tools hosted at other origins such as the Neurobagel query tool, "
f"explicitly set the value of {util.ALLOWED_ORIGINS.name} to the origin(s) of these tools (e.g. "
f"http://localhost:3000)."
"Multiple allowed origins should be separated with spaces in a single string enclosed in quotes."
)


def initialize_vocabularies():
"""
Create and store on the app instance a temporary directory for vocabulary term lookup JSON files
(each of which contain key-value pairings of IDs to human-readable names of terms),
and then fetch vocabularies using their respective native APIs and save them to the temporary directory for reuse.
"""
# We use Starlette's ability (FastAPI is Starlette underneath) to store arbitrary state on the app instance (https://www.starlette.io/applications/#storing-state-on-the-app-instance)
# to store a temporary directory object and its corresponding path. These data are local to the instance and will be recreated on every app launch (i.e. not persisted).

app.state.vocab_dir = TemporaryDirectory()
app.state.vocab_dir_path = Path(app.state.vocab_dir.name)

app.state.vocab_lookup_paths = {
"snomed_assessment": app.state.vocab_dir_path
/ "snomedct_assessment_term_labels.json",
"snomed_disorder": app.state.vocab_dir_path
/ "snomedct_disorder_term_labels.json",
}

util.create_snomed_assessment_lookup(
app.state.vocab_lookup_paths["snomed_assessment"]
)
util.create_snomed_disorder_lookup(
app.state.vocab_lookup_paths["snomed_disorder"]
)


@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Handles application startup and shutdown events.

On startup:
- Validates required environment variables.
- Performs authentication checks.
- Initializes temporary directories for vocabulary lookups.

On shutdown:
- Cleans up temporary directories to free resources.
"""
# Validate environment variables
validate_environment_variables()

# Authentication check
check_client_id()

# Initialize vocabularies
initialize_vocabularies()

yield

# Shutdown logic
app.state.vocab_dir.cleanup()


app = FastAPI(
root_path=util.ROOT_PATH.val,
lifespan=lifespan,
default_response_class=ORJSONResponse,
docs_url=None,
redoc_url=None,
redirect_slashes=False,
)

favicon_url = "https://raw.githubusercontent.com/neurobagel/documentation/main/docs/imgs/logo/neurobagel_favicon.png"

app.add_middleware(
Expand All @@ -42,7 +131,7 @@ def root(request: Request):
<html>
<body>
<h1>Welcome to the Neurobagel REST API!</h1>
<p>Please visit the <a href="{request.scope.get("root_path", "")}/docs">API documentation</a> to view available API endpoints.</p>
<p>Please visit the <a href="{request.scope.get('root_path', '')}/docs">API documentation</a> to view available API endpoints.</p>
</body>
</html>
"""
Expand All @@ -52,6 +141,10 @@ def root(request: Request):
async def favicon():
"""
Overrides the default favicon with a custom one.

NOTE: When the API is behind a reverse proxy that has a stripped path prefix (and root_path is defined),
the custom favicon doesn't appear to work correctly for any API paths other than the docs,
as the path in the favicon request isn't automatically adjusted to include the root path prefix.
"""
return RedirectResponse(url=favicon_url)

Expand Down Expand Up @@ -80,72 +173,6 @@ def overridden_redoc(request: Request):
)


@app.on_event("startup")
async def auth_check():
"""
Checks whether authentication has been enabled for API queries and whether the
username and password environment variables for the graph backend have been set.

TODO: Refactor once startup events have been replaced by lifespan event
"""
check_client_id()

if (
# TODO: Check if this error is still raised when variables are empty strings
os.environ.get(util.GRAPH_USERNAME.name) is None
or os.environ.get(util.GRAPH_PASSWORD.name) is None
):
raise RuntimeError(
f"The application was launched but could not find the {util.GRAPH_USERNAME.name} and / or {util.GRAPH_PASSWORD.name} environment variables."
)


@app.on_event("startup")
async def allowed_origins_check():
"""Raises warning if allowed origins environment variable has not been set or is an empty string."""
if os.environ.get(util.ALLOWED_ORIGINS.name, "") == "":
warnings.warn(
f"The API was launched without providing any values for the {util.ALLOWED_ORIGINS.name} environment variable. "
"This means that the API will only be accessible from the same origin it is hosted from: https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy. "
f"If you want to access the API from tools hosted at other origins such as the Neurobagel query tool, explicitly set the value of {util.ALLOWED_ORIGINS.name} to the origin(s) of these tools (e.g. http://localhost:3000). "
"Multiple allowed origins should be separated with spaces in a single string enclosed in quotes. "
)


@app.on_event("startup")
async def fetch_vocabularies_to_temp_dir():
"""
Create and store on the app instance a temporary directory for vocabulary term lookup JSON files
(each of which contain key-value pairings of IDs to human-readable names of terms),
and then fetch vocabularies using their respective native APIs and save them to the temporary directory for reuse.
"""
# We use Starlette's ability (FastAPI is Starlette underneath) to store arbitrary state on the app instance (https://www.starlette.io/applications/#storing-state-on-the-app-instance)
# to store a temporary directory object and its corresponding path. These data are local to the instance and will be recreated on every app launch (i.e. not persisted).
app.state.vocab_dir = TemporaryDirectory()
app.state.vocab_dir_path = Path(app.state.vocab_dir.name)

app.state.vocab_lookup_paths = {}
app.state.vocab_lookup_paths["snomed_assessment"] = (
app.state.vocab_dir_path / "snomedct_assessment_term_labels.json"
)
app.state.vocab_lookup_paths["snomed_disorder"] = (
app.state.vocab_dir_path / "snomedct_disorder_term_labels.json"
)

util.create_snomed_assessment_lookup(
app.state.vocab_lookup_paths["snomed_assessment"]
)
util.create_snomed_disorder_lookup(
app.state.vocab_lookup_paths["snomed_disorder"]
)


@app.on_event("shutdown")
async def cleanup_temp_vocab_dir():
"""Clean up the temporary directory created on startup."""
app.state.vocab_dir.cleanup()


app.include_router(query.router)
app.include_router(attributes.router)
app.include_router(assessments.router)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_missing_client_id_raises_error_when_auth_enabled(
# but we set the values explicitly here for clarity
monkeypatch.setattr("app.api.security.CLIENT_ID", None)

with pytest.raises(ValueError) as exc_info:
with pytest.raises(RuntimeError) as exc_info:
with test_app:
pass

Expand Down