Skip to content

Commit

Permalink
Pass complete settings object
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Feb 19, 2025
1 parent 5fba427 commit c9b1667
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 58 deletions.
5 changes: 3 additions & 2 deletions tiled/client/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,9 @@ def from_app(
# Extract the API key from the app and set it.
from ..server.settings import get_settings

settings = app.dependency_overrides[get_settings]()
api_key = settings.single_user_api_key or None
api_key = (
app.state.settings or get_settings()
).single_user_api_key or None
else:
# This is a multi-user server but no API key was passed,
# so we will leave it as None on the Context.
Expand Down
110 changes: 54 additions & 56 deletions tiled/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import urllib.parse
import warnings
from contextlib import asynccontextmanager
from functools import cache, partial
from functools import partial
from pathlib import Path
from typing import Any, Mapping, Optional

Expand Down Expand Up @@ -51,7 +51,7 @@
from ..validation_registration import ValidationRegistry, default_validation_registry
from .compression import CompressionMiddleware
from .router import get_router
from .settings import get_settings
from .settings import Settings, get_settings
from .utils import API_KEY_COOKIE_NAME, CSRF_COOKIE_NAME, get_root_url, record_timing

SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
Expand Down Expand Up @@ -106,7 +106,7 @@ def build_app(
tree: Mapping[str, Any],
*,
authentication=None,
server_settings=None,
server_settings: Optional[Settings | dict[str, Any]] = None,
query_registry: Optional[QueryRegistry] = None,
serialization_registry: Optional[SerializationRegistry] = None,
deserialization_registry: Optional[SerializationRegistry] = None,
Expand Down Expand Up @@ -207,6 +207,44 @@ def build_app(
)
# If we reach here, the no configuration problems were found.

complete_settings = get_settings()
for item in [
"allow_anonymous_access",
"secret_keys",
"single_user_api_key",
"access_token_max_age",
"refresh_token_max_age",
"session_max_age",
]:
if authentication.get(item) is not None:
setattr(complete_settings, item, authentication[item])
if authentication.get("single_user_api_key") is not None:
complete_settings.single_user_api_key_generated = False
for item in [
"allow_origins",
"response_bytesize_limit",
"reject_undeclared_specs",
"expose_raw_assets",
]:
if server_settings.get(item) is not None:
setattr(complete_settings, item, server_settings[item])
database = server_settings.get("database", {})
if database.get("uri"):
complete_settings.database_uri = database["uri"]
if database.get("pool_size"):
complete_settings.database_pool_size = database["pool_size"]
if database.get("pool_pre_ping"):
complete_settings.database_pool_pre_ping = database["pool_pre_ping"]
if database.get("max_overflow"):
complete_settings.database_max_overflow = database["max_overflow"]
if database.get("init_if_not_exists"):
complete_settings.database_init_if_not_exists = database["init_if_not_exists"]
if authentication.get("providers"):
# If we support authentication providers, we need a database, so if one is
# not set, use a SQLite database in memory. Horizontally scaled deployments
# must specify a persistent database.
complete_settings.database_uri = complete_settings.database_uri or "sqlite://"

@asynccontextmanager
async def lifespan(app: FastAPI):
"Manage lifespan events for each event loop that the app runs in"
Expand Down Expand Up @@ -376,54 +414,12 @@ async def unhandled_exception_handler(

app.include_router(authentication_router)

@cache
def override_get_settings():
settings = get_settings()
for item in [
"allow_anonymous_access",
"secret_keys",
"single_user_api_key",
"access_token_max_age",
"refresh_token_max_age",
"session_max_age",
]:
if authentication.get(item) is not None:
setattr(settings, item, authentication[item])
if authentication.get("single_user_api_key") is not None:
settings.single_user_api_key_generated = False
for item in [
"allow_origins",
"response_bytesize_limit",
"reject_undeclared_specs",
"expose_raw_assets",
]:
if server_settings.get(item) is not None:
setattr(settings, item, server_settings[item])
database = server_settings.get("database", {})
if database.get("uri"):
settings.database_uri = database["uri"]
if database.get("pool_size"):
settings.database_pool_size = database["pool_size"]
if database.get("pool_pre_ping"):
settings.database_pool_pre_ping = database["pool_pre_ping"]
if database.get("max_overflow"):
settings.database_max_overflow = database["max_overflow"]
if database.get("init_if_not_exists"):
settings.database_init_if_not_exists = database["init_if_not_exists"]
if authentication.get("providers"):
# If we support authentication providers, we need a database, so if one is
# not set, use a SQLite database in memory. Horizontally scaled deployments
# must specify a persistent database.
settings.database_uri = settings.database_uri or "sqlite://"
return settings

async def startup_event():
from .. import __version__

logger.info(f"Tiled version {__version__}")
# Validate the single-user API key.
settings = app.dependency_overrides[get_settings]()
single_user_api_key = settings.single_user_api_key
single_user_api_key = complete_settings.single_user_api_key
API_KEY_MSG = """
Here are two ways to generate a good API key:
Expand Down Expand Up @@ -468,13 +464,13 @@ async def startup_event():
asyncio_task = asyncio.create_task(task())
app.state.tasks.append(asyncio_task)

app.state.allow_origins.extend(settings.allow_origins)
app.state.allow_origins.extend(complete_settings.allow_origins)
# Expose the root_tree here to make it easier to access it from tests,
# in usages like:
# client.context.app.state.root_tree
app.state.root_tree = tree

if settings.database_uri is not None:
if complete_settings.database_uri is not None:
from sqlalchemy.ext.asyncio import AsyncSession

from ..alembic_utils import (
Expand All @@ -495,7 +491,7 @@ async def startup_event():
# This creates a connection pool and stashes it in a module-global
# registry, keyed on database_settings, where can be retrieved by
# the Dependency get_database_session.
engine = open_database_connection_pool(settings.database_settings)
engine = open_database_connection_pool(complete_settings.database_settings)
if not engine.url.database:
# Special-case for in-memory SQLite: Because it is transient we can
# skip over anything related to migrations.
Expand All @@ -506,7 +502,7 @@ async def startup_event():
try:
await check_database(engine, REQUIRED_REVISION, ALL_REVISIONS)
except UninitializedDatabase:
if settings.database_init_if_not_exists:
if complete_settings.database_init_if_not_exists:
# The alembic stamping can only be does synchronously.
# The cleanest option available is to start a subprocess
# because SQLite is allergic to threads.
Expand Down Expand Up @@ -599,13 +595,12 @@ async def shutdown_event():
for task in tasks.get("shutdown", []):
await task()

settings = app.dependency_overrides[get_settings]()
if settings.database_uri is not None:
if complete_settings.database_uri is not None:
from ..authn_database.connection_pool import close_database_connection_pool

for task in app.state.tasks:
task.cancel()
await close_database_connection_pool(settings.database_settings)
await close_database_connection_pool(complete_settings.database_settings)

app.add_middleware(
CompressionMiddleware,
Expand Down Expand Up @@ -697,7 +692,6 @@ async def set_cookies(request: Request, call_next):
return response

app.openapi = partial(custom_openapi, app)
app.dependency_overrides[get_settings] = override_get_settings

@app.middleware("http")
async def capture_metrics(request: Request, call_next):
Expand Down Expand Up @@ -838,12 +832,16 @@ def __getattr__(name):


def print_admin_api_key_if_generated(
web_app: FastAPI, host: str, port: int, force: bool = False
web_app: FastAPI,
host: str,
port: int,
force: bool = False,
):
"Print message to stderr with API key if server-generated (or force=True)."
host = host or "127.0.0.1"
port = port or 8000
settings = web_app.dependency_overrides.get(get_settings, get_settings)()

settings = web_app.state.settings or get_settings

if settings.allow_anonymous_access:
print(
Expand Down

0 comments on commit c9b1667

Please sign in to comment.