Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not require SQL URIs to be prefixed with SQLAlchemy driver #810

Merged
merged 9 commits into from
Nov 5, 2024
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Write the date in place of the "Unreleased" in the case a new version is release

- Drop support for Python 3.8, which is reached end of life
upstream on 7 October 2024.
- Do not require SQL database URIs to specify a "driver" (Python
library to be used for connecting).

## v0.1.0b10 (2024-10-11)

Expand Down
3 changes: 2 additions & 1 deletion tiled/_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..catalog import from_uri, in_memory
from ..client.base import BaseClient
from ..server.settings import get_settings
from ..utils import ensure_specified_sql_driver
from .utils import enter_username_password as utils_enter_uname_passwd
from .utils import temp_postgres

Expand Down Expand Up @@ -152,7 +153,7 @@ async def postgresql_with_example_data_adapter(request, tmpdir):
if uri.endswith("/"):
uri = uri[:-1]
uri_with_database_name = f"{uri}/{DATABASE_NAME}"
engine = create_async_engine(uri_with_database_name)
engine = create_async_engine(ensure_specified_sql_driver(uri_with_database_name))
try:
async with engine.connect():
pass
Expand Down
75 changes: 75 additions & 0 deletions tiled/_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from pathlib import Path

from ..utils import ensure_specified_sql_driver


def test_ensure_specified_sql_driver():
# Postgres
# Default driver is added if missing.
assert (
ensure_specified_sql_driver(
"postgresql://user:password@localhost:5432/database"
)
== "postgresql+asyncpg://user:password@localhost:5432/database"
)
# Default driver passes through if specified.
assert (
ensure_specified_sql_driver(
"postgresql+asyncpg://user:password@localhost:5432/database"
)
== "postgresql+asyncpg://user:password@localhost:5432/database"
)
# Do not override user-provided.
assert (
ensure_specified_sql_driver(
"postgresql+custom://user:password@localhost:5432/database"
)
== "postgresql+custom://user:password@localhost:5432/database"
)

# SQLite
# Default driver is added if missing.
assert (
ensure_specified_sql_driver("sqlite:////test.db")
== "sqlite+aiosqlite:////test.db"
)
# Default driver passes through if specified.
assert (
ensure_specified_sql_driver("sqlite+aiosqlite:////test.db")
== "sqlite+aiosqlite:////test.db"
)
# Do not override user-provided.
assert (
ensure_specified_sql_driver("sqlite+custom:////test.db")
== "sqlite+custom:////test.db"
)
# Handle SQLite :memory: URIs
assert (
ensure_specified_sql_driver("sqlite+aiosqlite://:memory:")
== "sqlite+aiosqlite://:memory:"
)
assert (
ensure_specified_sql_driver("sqlite://:memory:")
== "sqlite+aiosqlite://:memory:"
)
# Handle SQLite relative URIs
assert (
ensure_specified_sql_driver("sqlite+aiosqlite:///test.db")
== "sqlite+aiosqlite:///test.db"
)
assert (
ensure_specified_sql_driver("sqlite:///test.db")
== "sqlite+aiosqlite:///test.db"
)
# Filepaths are implicitly SQLite databases.
# Relative path
assert ensure_specified_sql_driver("test.db") == "sqlite+aiosqlite:///test.db"
# Path object
assert ensure_specified_sql_driver(Path("test.db")) == "sqlite+aiosqlite:///test.db"
# Relative path anchored to .
assert ensure_specified_sql_driver("./test.db") == "sqlite+aiosqlite:///test.db"
# Absolute path
assert (
ensure_specified_sql_driver(Path("/tmp/test.db"))
== f"sqlite+aiosqlite:///{Path('/tmp/test.db')}"
)
3 changes: 2 additions & 1 deletion tiled/_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..client import context
from ..client.base import BaseClient
from ..utils import ensure_specified_sql_driver

if sys.version_info < (3, 9):
import importlib_resources as resources
Expand All @@ -33,7 +34,7 @@ async def temp_postgres(uri):
if uri.endswith("/"):
uri = uri[:-1]
# Create a fresh database.
engine = create_async_engine(uri)
engine = create_async_engine(ensure_specified_sql_driver(uri))
database_name = f"tiled_test_disposable_{uuid.uuid4().hex}"
async with engine.connect() as connection:
await connection.execute(
Expand Down
5 changes: 4 additions & 1 deletion tiled/authn_database/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine

from ..server.settings import get_settings
from ..utils import ensure_specified_sql_driver

# A given process probably only has one of these at a time, but we
# key on database_settings just case in some testing context or something
Expand All @@ -16,7 +17,9 @@ def open_database_connection_pool(database_settings):
# kwargs["pool_pre_ping"] = database_settings.pool_pre_ping
# kwargs["max_overflow"] = database_settings.max_overflow
engine = create_async_engine(
database_settings.uri, connect_args=connect_args, **kwargs
ensure_specified_sql_driver(database_settings.uri),
connect_args=connect_args,
**kwargs,
)
_connection_pools[database_settings] = engine
return engine
Expand Down
12 changes: 6 additions & 6 deletions tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@
from ..server.schemas import Asset, DataSource, Management, Revision, Spec
from ..structures.core import StructureFamily
from ..utils import (
SCHEME_PATTERN,
UNCHANGED,
Conflicts,
OneShotCachedMap,
UnsupportedQueryType,
ensure_awaitable,
ensure_specified_sql_driver,
ensure_uri,
import_object,
path_from_uri,
Expand Down Expand Up @@ -1347,7 +1347,7 @@ def from_uri(
echo=DEFAULT_ECHO,
adapters_by_mimetype=None,
):
uri = str(uri)
uri = ensure_specified_sql_driver(uri)
if init_if_not_exists:
# The alembic stamping can only be does synchronously.
# The cleanest option available is to start a subprocess
Expand All @@ -1366,9 +1366,6 @@ def from_uri(
stderr = process.stderr.decode()
logging.info(f"Subprocess stdout: {stdout}")
logging.error(f"Subprocess stderr: {stderr}")
if not SCHEME_PATTERN.match(uri):
# Interpret URI as filepath.
uri = f"sqlite+aiosqlite:///{uri}"

parsed_url = make_url(uri)
if (parsed_url.get_dialect().name == "sqlite") and (
Expand All @@ -1381,7 +1378,10 @@ def from_uri(
else:
poolclass = None # defer to sqlalchemy default
engine = create_async_engine(
uri, echo=echo, json_serializer=json_serializer, poolclass=poolclass
uri,
echo=echo,
json_serializer=json_serializer,
poolclass=poolclass,
)
if engine.dialect.name == "sqlite":
event.listens_for(engine.sync_engine, "connect")(_set_sqlite_pragma)
Expand Down
9 changes: 6 additions & 3 deletions tiled/commandline/_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ def initialize_database(database_uri: str):
REQUIRED_REVISION,
initialize_database,
)
from ..utils import ensure_specified_sql_driver

async def do_setup():
engine = create_async_engine(database_uri)
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
redacted_url = engine.url._replace(password="[redacted]")
try:
await check_database(engine, REQUIRED_REVISION, ALL_REVISIONS)
Expand Down Expand Up @@ -71,9 +72,10 @@ def upgrade_database(
ALEMBIC_INI_TEMPLATE_PATH,
)
from ..authn_database.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

async def do_setup():
engine = create_async_engine(database_uri)
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
redacted_url = engine.url._replace(password="[redacted]")
current_revision = await get_current_revision(engine, ALL_REVISIONS)
await engine.dispose()
Expand Down Expand Up @@ -107,9 +109,10 @@ def downgrade_database(
ALEMBIC_INI_TEMPLATE_PATH,
)
from ..authn_database.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

async def do_setup():
engine = create_async_engine(database_uri)
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
redacted_url = engine.url._replace(password="[redacted]")
current_revision = await get_current_revision(engine, ALL_REVISIONS)
if current_revision is None:
Expand Down
12 changes: 8 additions & 4 deletions tiled/commandline/_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@ def init(
from ..alembic_utils import UninitializedDatabase, check_database, stamp_head
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import ALL_REVISIONS, REQUIRED_REVISION, initialize_database
from ..utils import SCHEME_PATTERN
from ..utils import ensure_specified_sql_driver

if not SCHEME_PATTERN.match(database):
# Interpret URI as filepath.
database = f"sqlite+aiosqlite:///{database}"
database = ensure_specified_sql_driver(database)

async def do_setup():
engine = create_async_engine(database)
Expand Down Expand Up @@ -94,6 +92,9 @@ def upgrade_database(
from ..alembic_utils import get_current_revision, upgrade
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

database_uri = ensure_specified_sql_driver(database_uri)

async def do_setup():
engine = create_async_engine(database_uri)
Expand Down Expand Up @@ -127,6 +128,9 @@ def downgrade_database(
from ..alembic_utils import downgrade, get_current_revision
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import ALL_REVISIONS
from ..utils import ensure_specified_sql_driver

database_uri = ensure_specified_sql_driver(database_uri)

async def do_setup():
engine = create_async_engine(database_uri)
Expand Down
6 changes: 4 additions & 2 deletions tiled/commandline/_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def serve_directory(
from ..alembic_utils import stamp_head
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import initialize_database
from ..utils import ensure_specified_sql_driver

engine = create_async_engine(database)
engine = create_async_engine(ensure_specified_sql_driver(database))
asyncio.run(initialize_database(engine))
stamp_head(ALEMBIC_INI_TEMPLATE_PATH, ALEMBIC_DIR, database)

Expand Down Expand Up @@ -389,8 +390,9 @@ def serve_catalog(
from ..alembic_utils import stamp_head
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
from ..catalog.core import initialize_database
from ..utils import ensure_specified_sql_driver

engine = create_async_engine(database)
engine = create_async_engine(ensure_specified_sql_driver(database))
asyncio.run(initialize_database(engine))
stamp_head(ALEMBIC_INI_TEMPLATE_PATH, ALEMBIC_DIR, database)

Expand Down
27 changes: 27 additions & 0 deletions tiled/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,33 @@ def ensure_uri(uri_or_path) -> str:
return str(uri_str)


SCHEME_TO_SCHEME_PLUS_DRIVER = {
"postgresql": "postgresql+asyncpg",
"sqlite": "sqlite+aiosqlite",
}


def ensure_specified_sql_driver(uri: str) -> str:
"""
Given a URI without a driver in the scheme, add Tiled's preferred driver.

If a driver is already specified, the specified one will be used; it
will NOT be overriden by this function.

'postgresql://...' -> 'postgresql+asynpg://...'
'sqlite://...' -> 'sqlite+aiosqlite://...'
'postgresql+asyncpg://...' -> 'postgresql+asynpg://...'
'postgresql+my_custom_driver://...' -> 'postgresql+my_custom_driver://...'
'/path/to/file.db' -> 'sqlite+aiosqlite:////path/to/file.db'
"""
if not SCHEME_PATTERN.match(str(uri)):
# Interpret URI as filepath.
uri = f"sqlite+aiosqlite:///{Path(uri)}"
scheme, rest = uri.split(":", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you not use urllib to parse because in theory this could be URI that's not a URL?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember reaching for urllib first. I believe the issue was that certain SQLite URIs do not round-trip quite right.

new_scheme = SCHEME_TO_SCHEME_PLUS_DRIVER.get(scheme, scheme)
return ":".join([new_scheme, rest])


class catch_warning_msg(warnings.catch_warnings):
"""Backward compatible version of catch_warnings for python <3.11.

Expand Down
Loading