Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF] Refactor startup/shutdown events to lifespan events #407

Merged
merged 11 commits into from
Feb 20, 2025
2 changes: 1 addition & 1 deletion app/api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def check_client_id():
"""Check if the CLIENT_ID environment variable is set."""
# The CLIENT_ID is needed to verify the audience claim of ID tokens.
if AUTH_ENABLED and CLIENT_ID is None:
raise ValueError(
raise RuntimeError(
"Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. "
"Please set NB_QUERY_CLIENT_ID to the client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens."
)
Expand Down
104 changes: 76 additions & 28 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import os
import warnings
from contextlib import asynccontextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
from contextlib import asynccontextmanager

import uvicorn
from fastapi import FastAPI, Request
Expand All @@ -18,10 +18,20 @@


def validate_environment_variables():
"""Validate required environment variables."""
if os.environ.get(util.GRAPH_USERNAME.name) is None or os.environ.get(util.GRAPH_PASSWORD.name) is None:
"""
Check that all required environment variables are set.

Ensures that the username and password for the graph database are provided.
If not, raises a RuntimeError to prevent the application from running without valid credentials.

Also checks that ALLOWED_ORIGINS is properly set. If missing, a warning is issued, but the app continues running.
"""
if (
os.environ.get(util.GRAPH_USERNAME.name) is None
or os.environ.get(util.GRAPH_PASSWORD.name) is None
):
raise RuntimeError(
f"The application was launched but could not find the {util.GRAPH_USERNAME.name} and/or {util.GRAPH_PASSWORD.name} environment variables."
f"The application was launched but could not find the {util.GRAPH_USERNAME.name} and / or {util.GRAPH_PASSWORD.name} environment variables."
)

if os.environ.get(util.ALLOWED_ORIGINS.name, "") == "":
Expand All @@ -37,39 +47,60 @@ def validate_environment_variables():
)


def initialize_vocabularies(app):
"""
Create and store on the app instance a temporary directory for vocabulary term lookup JSON files
(each of which contain key-value pairings of IDs to human-readable names of terms),
and then fetch vocabularies using their respective native APIs and save them to the temporary directory for reuse.
"""
# We use Starlette's ability (FastAPI is Starlette underneath) to store arbitrary state on the app instance (https://www.starlette.io/applications/#storing-state-on-the-app-instance)
# to store a temporary directory object and its corresponding path. These data are local to the instance and will be recreated on every app launch (i.e. not persisted).

app.state.vocab_dir = TemporaryDirectory()
app.state.vocab_dir_path = Path(app.state.vocab_dir.name)

app.state.vocab_lookup_paths = {
"snomed_assessment": app.state.vocab_dir_path
/ "snomedct_assessment_term_labels.json",
"snomed_disorder": app.state.vocab_dir_path
/ "snomedct_disorder_term_labels.json",
}

util.create_snomed_assessment_lookup(
app.state.vocab_lookup_paths["snomed_assessment"]
)
util.create_snomed_disorder_lookup(
app.state.vocab_lookup_paths["snomed_disorder"]
)



@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan event handler for startup and shutdown logic."""
try:
# Validate environment variables
validate_environment_variables()

# Authentication check
check_client_id()
"""
Handles application startup and shutdown events.

# Create and store temporary directories
app.state.vocab_dir = TemporaryDirectory()
app.state.vocab_dir_path = Path(app.state.vocab_dir.name)
On startup:
- Validates required environment variables.
- Performs authentication checks.
- Initializes temporary directories for vocabulary lookups.

app.state.vocab_lookup_paths = {
"snomed_assessment": app.state.vocab_dir_path / "snomedct_assessment_term_labels.json",
"snomed_disorder": app.state.vocab_dir_path / "snomedct_disorder_term_labels.json"
}
On shutdown:
- Cleans up temporary directories to free resources.
"""
# Validate environment variables
validate_environment_variables()

# Create vocabulary lookups
util.create_snomed_assessment_lookup(app.state.vocab_lookup_paths["snomed_assessment"])
util.create_snomed_disorder_lookup(app.state.vocab_lookup_paths["snomed_disorder"])
# Authentication check
check_client_id()

except Exception as e:
raise RuntimeError(f"Startup failed: {str(e)}")
# Initialize vocabularies
initialize_vocabularies(app)

yield

# Shutdown logic
try:
app.state.vocab_dir.cleanup()
except Exception as e:
warnings.warn(f"Failed to clean up temporary directory: {str(e)}")
app.state.vocab_dir.cleanup()


app = FastAPI(
Expand All @@ -94,21 +125,34 @@ async def lifespan(app: FastAPI):

@app.get("/", response_class=HTMLResponse)
def root(request: Request):
"""
Display a welcome message and a link to the API documentation.
"""
return f"""
<html>
<body>
<h1>Welcome to the Neurobagel REST API!</h1>
<p>Please visit the <a href="{request.scope.get('root_path', '')}/docs">API documentation</a> to view
<p>Please visit the <a href="{request.scope.get('root_path', '')}/docs">API documentation</a> to view
available API endpoints.</p> </body> </html>"""


@app.get("/favicon.ico", include_in_schema=False)
async def favicon():
"""
Overrides the default favicon with a custom one.

NOTE: When the API is behind a reverse proxy that has a stripped path prefix (and root_path is defined),
the custom favicon doesn't appear to work correctly for any API paths other than the docs,
as the path in the favicon request isn't automatically adjusted to include the root path prefix.
"""
return RedirectResponse(url=favicon_url)


@app.get("/docs", include_in_schema=False)
def overridden_swagger(request: Request):
"""
Overrides the Swagger UI HTML for the "/docs" endpoint.
"""
return get_swagger_ui_html(
openapi_url=f"{request.scope.get('root_path', '')}/openapi.json",
title="Neurobagel API",
Expand All @@ -118,6 +162,9 @@ def overridden_swagger(request: Request):

@app.get("/redoc", include_in_schema=False)
def overridden_redoc(request: Request):
"""
Overrides the Redoc HTML for the "/redoc" endpoint.
"""
return get_redoc_html(
openapi_url=f"{request.scope.get('root_path', '')}/openapi.json",
title="Neurobagel API",
Expand All @@ -131,5 +178,6 @@ def overridden_redoc(request: Request):
app.include_router(diagnoses.router)
app.include_router(pipelines.router)

# Automatically start uvicorn server on execution of main.py
if __name__ == "__main__":
uvicorn.run("app.main:app", port=8000, reload=True)
4 changes: 2 additions & 2 deletions tests/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ def test_missing_client_id_raises_error_when_auth_enabled(
# but we set the values explicitly here for clarity
monkeypatch.setattr("app.api.security.CLIENT_ID", None)

with pytest.raises(ValueError) as exc_info:
with pytest.raises(RuntimeError) as exc_info:
with test_app:
pass

assert "NB_QUERY_CLIENT_ID is not set" in str(exc_info.value)
assert "Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set." in str(exc_info.value)


@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
Expand Down