Skip to content

Commit

Permalink
backend: get rid of cachew dependency for marshalling DbVisit into sq…
Browse files Browse the repository at this point in the history
…lite

we don't really benefit from cachew much here, and cachew.NTBinder has performance overhead/extra dependency
small amount of manual sqlite binding is a relatively small price to pay

plus some minor refactoring -- start moving database related stuff to promnesia.database
  • Loading branch information
karlicoss committed Nov 17, 2023
1 parent 060afde commit ddbb65f
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 67 deletions.
5 changes: 1 addition & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def main() -> None:
'tzlocal',
'more_itertools',
'pytz',
'sqlalchemy', # DB api
'cachew>=0.8.0', # caching with type hints
'sqlalchemy>=2.0', # DB api

*DEPS_INDEXER,
*DEPS_SERVER,
Expand Down Expand Up @@ -84,8 +83,6 @@ def main() -> None:
]

DEPS_SOURCES = {
# TODO make cachew optional?
# althrough server uses it so not sure...
('optional', 'dependencies that bring some bells & whistles'): [
'logzero', # pretty colored logging
'python-magic', # better mimetype decetion
Expand Down
2 changes: 1 addition & 1 deletion src/promnesia/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .misc import install_server
from .common import Extractor, PathIsh, logger, get_tmpdir, DbVisit, Res
from .common import Source, get_system_tz, user_config_file, default_config_path
from .dump import visits_to_sqlite
from .database.dump import visits_to_sqlite
from .extract import extract_visits


Expand Down
5 changes: 3 additions & 2 deletions src/promnesia/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


from .common import DbVisit, Url, PathWithMtime # TODO ugh. figure out pythonpath
from .database.load import row_to_db_visit

# TODO include latest too?
# from cconfig import ignore, filtered
Expand Down Expand Up @@ -139,10 +140,10 @@ def compare_files(*files: Path, log=True) -> Iterator[Tuple[str, DbVisit]]:
this_dts = name[0: name.index('.')] # can't use stem due to multiple extensions..

from promnesia.server import _get_stuff # TODO ugh
engine, binder, table = _get_stuff(PathWithMtime.make(f))
engine, table = _get_stuff(PathWithMtime.make(f))

with engine.connect() as conn:
vis = [binder.from_row(row) for row in conn.execute(table.select())] # type: ignore[var-annotated]
vis = [row_to_db_visit(row) for row in conn.execute(table.select())] # type: ignore[var-annotated]

if last is not None:
between = f'{last_dts}:{this_dts}'
Expand Down
2 changes: 2 additions & 0 deletions src/promnesia/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def sources(self) -> Iterable[Res[Source]]:

@property
def cache_dir(self) -> Optional[Path]:
# TODO we used to use this for cachew, but it's best to rely on HPI modules etc to cofigure this
# keeping just in case for now
cd = self.CACHE_DIR
cpath: Optional[Path]
if cd is None:
Expand Down
66 changes: 66 additions & 0 deletions src/promnesia/database/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from datetime import datetime
from typing import Sequence, Tuple

from sqlalchemy import (
Column,
Integer,
Row,
String,
)

# TODO maybe later move DbVisit here completely?
# kinda an issue that it's technically an "api" because hook in config can patch up DbVisit
from ..common import DbVisit, Loc


def get_columns() -> Sequence[Column]:
# fmt: off
res: Sequence[Column] = [
Column('norm_url' , String()),
Column('orig_url' , String()),
Column('dt' , String()),
Column('locator_title', String()),
Column('locator_href' , String()),
Column('src' , String()),
Column('context' , String()),
Column('duration' , Integer())
]
# fmt: on
assert len(res) == len(DbVisit._fields) + 1 # +1 because Locator is 'flattened'
return res


def db_visit_to_row(v: DbVisit) -> Tuple:
# ugh, very hacky...
# we want to make sure the resulting tuple only consists of simple types
# so we can use dbengine directly
dt_s = v.dt.isoformat()
row = (
v.norm_url,
v.orig_url,
dt_s,
v.locator.title,
v.locator.href,
v.src,
v.context,
v.duration,
)
return row


def row_to_db_visit(row: Sequence) -> DbVisit:
(norm_url, orig_url, dt_s, locator_title, locator_href, src, context, duration) = row
dt_s = dt_s.split()[0] # backwards compatibility: previously it could be a string separated with tz name
dt = datetime.fromisoformat(dt_s)
return DbVisit(
norm_url=norm_url,
orig_url=orig_url,
dt=dt,
locator=Loc(
title=locator_title,
href=locator_href,
),
src=src,
context=context,
duration=duration,
)
39 changes: 9 additions & 30 deletions src/promnesia/dump.py → src/promnesia/database/dump.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
import shutil
import sqlite3
from typing import Dict, Iterable, List, Optional, Set, Tuple
from typing import Dict, Iterable, List, Optional, Set

from more_itertools import chunked

Expand All @@ -17,9 +17,7 @@
)
from sqlalchemy.dialects import sqlite as dialect_sqlite

from cachew import NTBinder # TODO need to get rid of this

from .common import (
from ..common import (
DbVisit,
Loc,
Res,
Expand All @@ -28,7 +26,8 @@
get_tmpdir,
now_tz,
)
from . import config
from .common import get_columns, db_visit_to_row
from .. import config


# NOTE: I guess the main performance benefit from this is not creating too many tmp lists and avoiding overhead
Expand All @@ -54,10 +53,10 @@ def enable_wal(dbapi_con, con_record) -> None:

# returns critical warnings
def visits_to_sqlite(
vit: Iterable[Res[DbVisit]],
*,
overwrite_db: bool,
_db_path: Optional[Path] = None, # only used in tests
vit: Iterable[Res[DbVisit]],
*,
overwrite_db: bool,
_db_path: Optional[Path] = None, # only used in tests
) -> List[Exception]:
if _db_path is None:
db_path = config.get().db
Expand Down Expand Up @@ -89,28 +88,8 @@ def vit_ok() -> Iterable[DbVisit]:
index_stats[ev.src] = index_stats.get(ev.src, 0) + 1
yield ev

binder = NTBinder.make(DbVisit)
meta = MetaData()
# TODO is it ok to reuse meta/table??
table = Table('visits', meta, *binder.columns)

def db_visit_to_row(v: DbVisit) -> Tuple:
# ugh, very hacky...
# we want to make sure the resulting tuple only consists of simple types
# so we can use dbengine directly
dt = v.dt
dt_s = None if dt is None else dt.isoformat()
row = (
v.norm_url,
v.orig_url,
dt_s,
v.locator.title,
v.locator.href,
v.src,
v.context,
v.duration,
)
return row
table = Table('visits', meta, *get_columns())

def query_total_stats(conn) -> Stats:
query = select(table.c.src, func.count(table.c.src)).select_from(table).group_by(table.c.src)
Expand Down
19 changes: 8 additions & 11 deletions src/promnesia/read_db.py → src/promnesia/database/load.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
from pathlib import Path
from typing import Tuple, List

from cachew import NTBinder
from sqlalchemy import (
create_engine,
exc,
Engine,
MetaData,
Index,
Table,
)
from sqlalchemy.engine import Engine

from .common import DbVisit
from .common import DbVisit, get_columns, row_to_db_visit


DbStuff = Tuple[Engine, NTBinder, Table]
DbStuff = Tuple[Engine, Table]


def get_db_stuff(db_path: Path) -> DbStuff:
assert db_path.exists(), db_path
# todo how to open read only?
# actually not sure if we can since we are creating an index here
engine = create_engine(f'sqlite:///{db_path}') # , echo=True)

binder = NTBinder.make(DbVisit)
engine = create_engine(f'sqlite:///{db_path}') # , echo=True)

meta = MetaData()
table = Table('visits', meta, *binder.columns)
table = Table('visits', meta, *get_columns())

idx = Index('index_norm_url', table.c.norm_url)
try:
Expand All @@ -39,13 +36,13 @@ def get_db_stuff(db_path: Path) -> DbStuff:
raise e

# NOTE: apparently it's ok to open connection on every request? at least my comparisons didn't show anything
return engine, binder, table
return engine, table


def get_all_db_visits(db_path: Path) -> List[DbVisit]:
# NOTE: this is pretty inefficient if the DB is huge
# mostly intended for tests
engine, binder, table = get_db_stuff(db_path)
engine, table = get_db_stuff(db_path)
query = table.select()
with engine.connect() as conn:
return [binder.from_row(row) for row in conn.execute(query)]
return [row_to_db_visit(row) for row in conn.execute(query)]
15 changes: 7 additions & 8 deletions src/promnesia/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

import fastapi

from sqlalchemy import MetaData, exists, literal, between, or_, and_, exc, select
from sqlalchemy import literal, between, or_, and_, exc, select
from sqlalchemy import Column, Table, func, types
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql import text


from .common import PathWithMtime, DbVisit, Url, setup_logger, default_output_dir, get_system_tz
from .cannon import canonify
from .database.load import DbStuff, get_db_stuff, row_to_db_visit


Json = Dict[str, Any]
Expand Down Expand Up @@ -117,8 +118,6 @@ def get_db_path(check: bool=True) -> Path:
return db


from .read_db import DbStuff, get_db_stuff

@lru_cache(1)
# PathWithMtime aids lru_cache in reloading the sqlalchemy binder
def _get_stuff(db_path: PathWithMtime) -> DbStuff:
Expand All @@ -134,7 +133,7 @@ def get_stuff(db_path: Optional[Path]=None) -> DbStuff: # TODO better name


def db_stats(db_path: Path) -> Json:
engine, binder, table = get_stuff(db_path)
engine, table = get_stuff(db_path)
query = select(func.count()).select_from(table)
with engine.connect() as conn:
total = list(conn.execute(query))[0][0]
Expand Down Expand Up @@ -165,15 +164,15 @@ def search_common(url: str, where: Where) -> VisitsResponse:
url = original_url
logger.info('normalised url: %s', url)

engine, binder, table = get_stuff()
engine, table = get_stuff()

query = table.select().where(where(table=table, url=url))
logger.debug('query: %s', query)

with engine.connect() as conn:
try:
# TODO make more defensive here
visits: List[DbVisit] = [binder.from_row(row) for row in conn.execute(query)]
visits: List[DbVisit] = [row_to_db_visit(row) for row in conn.execute(query)]
except exc.OperationalError as e:
if getattr(e, 'msg', None) == 'no such table: visits':
logger.warn('you may have to run indexer first!')
Expand Down Expand Up @@ -361,7 +360,7 @@ def visited(request: VisitedRequest) -> VisitedResponse:
if len(snurls) == 0:
return []

engine, binder, table = get_stuff()
engine, table = get_stuff()

# sqlalchemy doesn't seem to support SELECT FROM (VALUES (...)) in its api
# also doesn't support array binding...
Expand Down Expand Up @@ -389,7 +388,7 @@ def visited(request: VisitedRequest) -> VisitedResponse:
# brings down large queries to 50ms...
with engine.connect() as conn:
res = list(conn.execute(query))
present: Dict[str, Any] = {row[0]: binder.from_row(row[1:]) for row in res}
present: Dict[str, Any] = {row[0]: row_to_db_visit(row[1:]) for row in res}
results = []
for nu in nurls:
r = present.get(nu, None)
Expand Down
10 changes: 8 additions & 2 deletions src/promnesia/sources/browser_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@
from ..common import PathIsh, Results, Visit, Loc, logger, Second, is_sqlite_db
from .. import config

# todo mcachew?
from cachew import cachew
try:
from cachew import cachew # type: ignore[import-not-found]
except ModuleNotFoundError as me:
if me.name != 'cachew':
raise me
# this module is legacy anyway, so just make it defensive
def cachew(*args, **kwargs): # type: ignore[no-redef]
return lambda f: f


def index(p: PathIsh) -> Results:
Expand Down
10 changes: 9 additions & 1 deletion src/promnesia/sources/takeout_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ def index() -> Results:


from more_itertools import unique_everseen
from cachew import cachew

try:
from cachew import cachew # type: ignore[import-not-found]
except ModuleNotFoundError as me:
if me.name != 'cachew':
raise me
# this module is legacy anyway, so just make it defensive
def cachew(*args, **kwargs): # type: ignore[no-redef]
return lambda f: f


# TODO use CPath? Could encapsulate a path within an archive *or* within a directory
Expand Down
13 changes: 6 additions & 7 deletions src/promnesia/tests/test_db_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import pytz


from ..common import DbVisit, Loc, Res
from ..dump import visits_to_sqlite
from ..read_db import get_all_db_visits
from ..common import Loc, Res
from ..database.common import DbVisit
from ..database.dump import visits_to_sqlite
from ..database.load import get_all_db_visits
from ..sqlite import sqlite_connection

from .common import gc_control, running_on_ci
Expand Down Expand Up @@ -73,9 +74,6 @@ def test_one_visit(tmp_path: Path) -> None:

assert sqlite_visit == {
'context': None,
# NOTE: at the moment date is dumped like this because of cachew NTBinder
# however it's not really necessary for promnesia (and possibly results in a bit of performance hit)
# I think we could just convert to a format sqlite supports, just need to make it backwards compatible
'dt': '2023-11-14T23:11:01+01:00',
'duration': 123,
'locator_href': 'https://whatever.com',
Expand Down Expand Up @@ -109,7 +107,8 @@ def test_read_db_visits(tmp_path: Path) -> None:
);
'''
)
# this tz format might occur in databases that were created when promnesia was using cachew NTBinder
# this dt format (zone name after iso timestap) might occur in legacy databases
# (that were created when promnesia was using cachew NTBinder)
conn.execute(
'''
INSERT INTO visits VALUES(
Expand Down
Loading

0 comments on commit ddbb65f

Please sign in to comment.