diff --git a/setup.py b/setup.py index 1f0d1299..f97ccb4a 100644 --- a/setup.py +++ b/setup.py @@ -38,8 +38,7 @@ def main() -> None: 'tzlocal', 'more_itertools', 'pytz', - 'sqlalchemy', # DB api - 'cachew>=0.8.0', # caching with type hints + 'sqlalchemy>=2.0', # DB api *DEPS_INDEXER, *DEPS_SERVER, @@ -84,8 +83,6 @@ def main() -> None: ] DEPS_SOURCES = { - # TODO make cachew optional? - # althrough server uses it so not sure... ('optional', 'dependencies that bring some bells & whistles'): [ 'logzero', # pretty colored logging 'python-magic', # better mimetype decetion diff --git a/src/promnesia/__main__.py b/src/promnesia/__main__.py index 6401cb76..1bad03b7 100644 --- a/src/promnesia/__main__.py +++ b/src/promnesia/__main__.py @@ -18,7 +18,7 @@ from .misc import install_server from .common import Extractor, PathIsh, logger, get_tmpdir, DbVisit, Res from .common import Source, get_system_tz, user_config_file, default_config_path -from .dump import visits_to_sqlite +from .database.dump import visits_to_sqlite from .extract import extract_visits diff --git a/src/promnesia/compare.py b/src/promnesia/compare.py index 382b9ad6..ce5dd1c9 100755 --- a/src/promnesia/compare.py +++ b/src/promnesia/compare.py @@ -8,6 +8,7 @@ from .common import DbVisit, Url, PathWithMtime # TODO ugh. figure out pythonpath +from .database.load import row_to_db_visit # TODO include latest too? # from cconfig import ignore, filtered @@ -139,10 +140,10 @@ def compare_files(*files: Path, log=True) -> Iterator[Tuple[str, DbVisit]]: this_dts = name[0: name.index('.')] # can't use stem due to multiple extensions.. from promnesia.server import _get_stuff # TODO ugh - engine, binder, table = _get_stuff(PathWithMtime.make(f)) + engine, table = _get_stuff(PathWithMtime.make(f)) with engine.connect() as conn: - vis = [binder.from_row(row) for row in conn.execute(table.select())] # type: ignore[var-annotated] + vis = [row_to_db_visit(row) for row in conn.execute(table.select())] # type: ignore[var-annotated] if last is not None: between = f'{last_dts}:{this_dts}' diff --git a/src/promnesia/config.py b/src/promnesia/config.py index 8164cae5..9e4903e0 100644 --- a/src/promnesia/config.py +++ b/src/promnesia/config.py @@ -69,6 +69,8 @@ def sources(self) -> Iterable[Res[Source]]: @property def cache_dir(self) -> Optional[Path]: + # TODO we used to use this for cachew, but it's best to rely on HPI modules etc to cofigure this + # keeping just in case for now cd = self.CACHE_DIR cpath: Optional[Path] if cd is None: diff --git a/src/promnesia/database/common.py b/src/promnesia/database/common.py new file mode 100644 index 00000000..09e30ed2 --- /dev/null +++ b/src/promnesia/database/common.py @@ -0,0 +1,66 @@ +from datetime import datetime +from typing import Sequence, Tuple + +from sqlalchemy import ( + Column, + Integer, + Row, + String, +) + +# TODO maybe later move DbVisit here completely? +# kinda an issue that it's technically an "api" because hook in config can patch up DbVisit +from ..common import DbVisit, Loc + + +def get_columns() -> Sequence[Column]: + # fmt: off + res: Sequence[Column] = [ + Column('norm_url' , String()), + Column('orig_url' , String()), + Column('dt' , String()), + Column('locator_title', String()), + Column('locator_href' , String()), + Column('src' , String()), + Column('context' , String()), + Column('duration' , Integer()) + ] + # fmt: on + assert len(res) == len(DbVisit._fields) + 1 # +1 because Locator is 'flattened' + return res + + +def db_visit_to_row(v: DbVisit) -> Tuple: + # ugh, very hacky... + # we want to make sure the resulting tuple only consists of simple types + # so we can use dbengine directly + dt_s = v.dt.isoformat() + row = ( + v.norm_url, + v.orig_url, + dt_s, + v.locator.title, + v.locator.href, + v.src, + v.context, + v.duration, + ) + return row + + +def row_to_db_visit(row: Sequence) -> DbVisit: + (norm_url, orig_url, dt_s, locator_title, locator_href, src, context, duration) = row + dt_s = dt_s.split()[0] # backwards compatibility: previously it could be a string separated with tz name + dt = datetime.fromisoformat(dt_s) + return DbVisit( + norm_url=norm_url, + orig_url=orig_url, + dt=dt, + locator=Loc( + title=locator_title, + href=locator_href, + ), + src=src, + context=context, + duration=duration, + ) diff --git a/src/promnesia/dump.py b/src/promnesia/database/dump.py similarity index 85% rename from src/promnesia/dump.py rename to src/promnesia/database/dump.py index 6eb688dc..3d6178f0 100644 --- a/src/promnesia/dump.py +++ b/src/promnesia/database/dump.py @@ -1,7 +1,7 @@ from pathlib import Path import shutil import sqlite3 -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Dict, Iterable, List, Optional, Set from more_itertools import chunked @@ -17,9 +17,7 @@ ) from sqlalchemy.dialects import sqlite as dialect_sqlite -from cachew import NTBinder # TODO need to get rid of this - -from .common import ( +from ..common import ( DbVisit, Loc, Res, @@ -28,7 +26,8 @@ get_tmpdir, now_tz, ) -from . import config +from .common import get_columns, db_visit_to_row +from .. import config # NOTE: I guess the main performance benefit from this is not creating too many tmp lists and avoiding overhead @@ -54,10 +53,10 @@ def enable_wal(dbapi_con, con_record) -> None: # returns critical warnings def visits_to_sqlite( - vit: Iterable[Res[DbVisit]], - *, - overwrite_db: bool, - _db_path: Optional[Path] = None, # only used in tests + vit: Iterable[Res[DbVisit]], + *, + overwrite_db: bool, + _db_path: Optional[Path] = None, # only used in tests ) -> List[Exception]: if _db_path is None: db_path = config.get().db @@ -89,28 +88,8 @@ def vit_ok() -> Iterable[DbVisit]: index_stats[ev.src] = index_stats.get(ev.src, 0) + 1 yield ev - binder = NTBinder.make(DbVisit) meta = MetaData() - # TODO is it ok to reuse meta/table?? - table = Table('visits', meta, *binder.columns) - - def db_visit_to_row(v: DbVisit) -> Tuple: - # ugh, very hacky... - # we want to make sure the resulting tuple only consists of simple types - # so we can use dbengine directly - dt = v.dt - dt_s = None if dt is None else dt.isoformat() - row = ( - v.norm_url, - v.orig_url, - dt_s, - v.locator.title, - v.locator.href, - v.src, - v.context, - v.duration, - ) - return row + table = Table('visits', meta, *get_columns()) def query_total_stats(conn) -> Stats: query = select(table.c.src, func.count(table.c.src)).select_from(table).group_by(table.c.src) diff --git a/src/promnesia/read_db.py b/src/promnesia/database/load.py similarity index 69% rename from src/promnesia/read_db.py rename to src/promnesia/database/load.py index 5ec72abb..8819175e 100644 --- a/src/promnesia/read_db.py +++ b/src/promnesia/database/load.py @@ -1,32 +1,29 @@ from pathlib import Path from typing import Tuple, List -from cachew import NTBinder from sqlalchemy import ( create_engine, exc, + Engine, MetaData, Index, Table, ) -from sqlalchemy.engine import Engine -from .common import DbVisit +from .common import DbVisit, get_columns, row_to_db_visit -DbStuff = Tuple[Engine, NTBinder, Table] +DbStuff = Tuple[Engine, Table] def get_db_stuff(db_path: Path) -> DbStuff: assert db_path.exists(), db_path # todo how to open read only? # actually not sure if we can since we are creating an index here - engine = create_engine(f'sqlite:///{db_path}') # , echo=True) - - binder = NTBinder.make(DbVisit) + engine = create_engine(f'sqlite:///{db_path}') # , echo=True) meta = MetaData() - table = Table('visits', meta, *binder.columns) + table = Table('visits', meta, *get_columns()) idx = Index('index_norm_url', table.c.norm_url) try: @@ -39,13 +36,13 @@ def get_db_stuff(db_path: Path) -> DbStuff: raise e # NOTE: apparently it's ok to open connection on every request? at least my comparisons didn't show anything - return engine, binder, table + return engine, table def get_all_db_visits(db_path: Path) -> List[DbVisit]: # NOTE: this is pretty inefficient if the DB is huge # mostly intended for tests - engine, binder, table = get_db_stuff(db_path) + engine, table = get_db_stuff(db_path) query = table.select() with engine.connect() as conn: - return [binder.from_row(row) for row in conn.execute(query)] + return [row_to_db_visit(row) for row in conn.execute(query)] diff --git a/src/promnesia/server.py b/src/promnesia/server.py index 17671305..707691e5 100644 --- a/src/promnesia/server.py +++ b/src/promnesia/server.py @@ -18,7 +18,7 @@ import fastapi -from sqlalchemy import MetaData, exists, literal, between, or_, and_, exc, select +from sqlalchemy import literal, between, or_, and_, exc, select from sqlalchemy import Column, Table, func, types from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql import text @@ -26,6 +26,7 @@ from .common import PathWithMtime, DbVisit, Url, setup_logger, default_output_dir, get_system_tz from .cannon import canonify +from .database.load import DbStuff, get_db_stuff, row_to_db_visit Json = Dict[str, Any] @@ -117,8 +118,6 @@ def get_db_path(check: bool=True) -> Path: return db -from .read_db import DbStuff, get_db_stuff - @lru_cache(1) # PathWithMtime aids lru_cache in reloading the sqlalchemy binder def _get_stuff(db_path: PathWithMtime) -> DbStuff: @@ -134,7 +133,7 @@ def get_stuff(db_path: Optional[Path]=None) -> DbStuff: # TODO better name def db_stats(db_path: Path) -> Json: - engine, binder, table = get_stuff(db_path) + engine, table = get_stuff(db_path) query = select(func.count()).select_from(table) with engine.connect() as conn: total = list(conn.execute(query))[0][0] @@ -165,7 +164,7 @@ def search_common(url: str, where: Where) -> VisitsResponse: url = original_url logger.info('normalised url: %s', url) - engine, binder, table = get_stuff() + engine, table = get_stuff() query = table.select().where(where(table=table, url=url)) logger.debug('query: %s', query) @@ -173,7 +172,7 @@ def search_common(url: str, where: Where) -> VisitsResponse: with engine.connect() as conn: try: # TODO make more defensive here - visits: List[DbVisit] = [binder.from_row(row) for row in conn.execute(query)] + visits: List[DbVisit] = [row_to_db_visit(row) for row in conn.execute(query)] except exc.OperationalError as e: if getattr(e, 'msg', None) == 'no such table: visits': logger.warn('you may have to run indexer first!') @@ -361,7 +360,7 @@ def visited(request: VisitedRequest) -> VisitedResponse: if len(snurls) == 0: return [] - engine, binder, table = get_stuff() + engine, table = get_stuff() # sqlalchemy doesn't seem to support SELECT FROM (VALUES (...)) in its api # also doesn't support array binding... @@ -389,7 +388,7 @@ def visited(request: VisitedRequest) -> VisitedResponse: # brings down large queries to 50ms... with engine.connect() as conn: res = list(conn.execute(query)) - present: Dict[str, Any] = {row[0]: binder.from_row(row[1:]) for row in res} + present: Dict[str, Any] = {row[0]: row_to_db_visit(row[1:]) for row in res} results = [] for nu in nurls: r = present.get(nu, None) diff --git a/src/promnesia/sources/browser_legacy.py b/src/promnesia/sources/browser_legacy.py index 1bfbd51e..1b343bba 100644 --- a/src/promnesia/sources/browser_legacy.py +++ b/src/promnesia/sources/browser_legacy.py @@ -9,8 +9,14 @@ from ..common import PathIsh, Results, Visit, Loc, logger, Second, is_sqlite_db from .. import config -# todo mcachew? -from cachew import cachew +try: + from cachew import cachew # type: ignore[import-not-found] +except ModuleNotFoundError as me: + if me.name != 'cachew': + raise me + # this module is legacy anyway, so just make it defensive + def cachew(*args, **kwargs): # type: ignore[no-redef] + return lambda f: f def index(p: PathIsh) -> Results: diff --git a/src/promnesia/sources/takeout_legacy.py b/src/promnesia/sources/takeout_legacy.py index da1043cf..2bf7ec85 100644 --- a/src/promnesia/sources/takeout_legacy.py +++ b/src/promnesia/sources/takeout_legacy.py @@ -34,7 +34,15 @@ def index() -> Results: from more_itertools import unique_everseen -from cachew import cachew + +try: + from cachew import cachew # type: ignore[import-not-found] +except ModuleNotFoundError as me: + if me.name != 'cachew': + raise me + # this module is legacy anyway, so just make it defensive + def cachew(*args, **kwargs): # type: ignore[no-redef] + return lambda f: f # TODO use CPath? Could encapsulate a path within an archive *or* within a directory diff --git a/src/promnesia/tests/test_db_dump.py b/src/promnesia/tests/test_db_dump.py index ccba7bad..50c5db84 100644 --- a/src/promnesia/tests/test_db_dump.py +++ b/src/promnesia/tests/test_db_dump.py @@ -13,9 +13,10 @@ import pytz -from ..common import DbVisit, Loc, Res -from ..dump import visits_to_sqlite -from ..read_db import get_all_db_visits +from ..common import Loc, Res +from ..database.common import DbVisit +from ..database.dump import visits_to_sqlite +from ..database.load import get_all_db_visits from ..sqlite import sqlite_connection from .common import gc_control, running_on_ci @@ -73,9 +74,6 @@ def test_one_visit(tmp_path: Path) -> None: assert sqlite_visit == { 'context': None, - # NOTE: at the moment date is dumped like this because of cachew NTBinder - # however it's not really necessary for promnesia (and possibly results in a bit of performance hit) - # I think we could just convert to a format sqlite supports, just need to make it backwards compatible 'dt': '2023-11-14T23:11:01+01:00', 'duration': 123, 'locator_href': 'https://whatever.com', @@ -109,7 +107,8 @@ def test_read_db_visits(tmp_path: Path) -> None: ); ''' ) - # this tz format might occur in databases that were created when promnesia was using cachew NTBinder + # this dt format (zone name after iso timestap) might occur in legacy databases + # (that were created when promnesia was using cachew NTBinder) conn.execute( ''' INSERT INTO visits VALUES( diff --git a/tests/integration_test.py b/tests/integration_test.py index 8bc5c59c..a85f2543 100644 --- a/tests/integration_test.py +++ b/tests/integration_test.py @@ -10,7 +10,7 @@ from common import under_ci, DATA, GIT_ROOT, promnesia_bin from promnesia.common import _is_windows, DbVisit -from promnesia.read_db import get_all_db_visits +from promnesia.database.load import get_all_db_visits def run_index(cfg: Path, *, update=False) -> None: diff --git a/tests/server_test.py b/tests/server_test.py index 5b175908..92f9c45c 100644 --- a/tests/server_test.py +++ b/tests/server_test.py @@ -64,6 +64,7 @@ def wserver(db: Optional[PathIsh]=None): # TODO err not sure what type should it @contextmanager def _test_helper(tmp_path): tdir = Path(tmp_path) + # TODO probably don't need this anymore? cache_dir = tdir / 'cache' cache_dir.mkdir()