diff --git a/.bumpversion.cfg b/.bumpversion.cfg index ec3e8384..d4b63425 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 3.13.1 +current_version = 3.13.2 tag_name = {new_version} commit = True tag = True diff --git a/nomenklatura/__init__.py b/nomenklatura/__init__.py index 042e6c49..9837db7b 100644 --- a/nomenklatura/__init__.py +++ b/nomenklatura/__init__.py @@ -4,7 +4,7 @@ from nomenklatura.store import Store, View from nomenklatura.index import Index -__version__ = "3.13.1" +__version__ = "3.13.2" __all__ = [ "Dataset", "CompositeEntity", diff --git a/nomenklatura/index/index.py b/nomenklatura/index/index.py index c44a2b06..cddd8e73 100644 --- a/nomenklatura/index/index.py +++ b/nomenklatura/index/index.py @@ -98,7 +98,8 @@ def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float] if len(entry.entities) == 1 or len(entry.entities) > 100: continue - for (left, lw), (right, rw) in combinations(entry.frequencies(field), 2): + entities = entry.frequencies(field) + for (left, lw), (right, rw) in combinations(entities, 2): if lw == 0.0 or rw == 0.0: continue pair = (max(left, right), min(left, right)) diff --git a/nomenklatura/statement/serialize.py b/nomenklatura/statement/serialize.py index 7d8899de..d4ecb448 100644 --- a/nomenklatura/statement/serialize.py +++ b/nomenklatura/statement/serialize.py @@ -2,7 +2,7 @@ from io import TextIOWrapper from pathlib import Path from types import TracebackType -from typing import BinaryIO, Generator, Iterable, List, Optional, Type +from typing import BinaryIO, Generator, Iterable, List, Optional, TextIO, Type import click import orjson @@ -16,6 +16,7 @@ PACK = "pack" FORMATS = [JSON, CSV, PACK] +CSV_BATCH = 5000 CSV_COLUMNS = [ "canonical_id", "entity_id", @@ -108,9 +109,11 @@ def read_path_statements( def get_statement_writer(fh: BinaryIO, format: str) -> "StatementWriter": if format == CSV: - return CSVStatementWriter(fh) + wrapped = TextIOWrapper(fh, encoding="utf-8") + return CSVStatementWriter(wrapped) elif format == PACK: - return PackStatementWriter(fh) + wrapped = TextIOWrapper(fh, encoding="utf-8") + return PackStatementWriter(wrapped) elif format == JSON: return JSONStatementWriter(fh) raise RuntimeError("Unknown statement format: %s" % format) @@ -124,14 +127,11 @@ def write_statements(fh: BinaryIO, format: str, statements: Iterable[S]) -> None class StatementWriter(object): - def __init__(self, fh: BinaryIO) -> None: - self.fh = fh - def write(self, stmt: S) -> None: raise NotImplementedError() def close(self) -> None: - self.fh.close() + raise NotImplementedError() def __enter__(self) -> "StatementWriter": return self @@ -146,37 +146,47 @@ def __exit__( class JSONStatementWriter(StatementWriter): + def __init__(self, fh: BinaryIO) -> None: + self.fh = fh + def write(self, stmt: S) -> None: data = stmt.to_dict() out = orjson.dumps(data, option=orjson.OPT_APPEND_NEWLINE) self.fh.write(out) + def close(self) -> None: + self.fh.close() + class CSVStatementWriter(StatementWriter): - def __init__(self, fh: BinaryIO) -> None: - super().__init__(fh) - self.wrapper = TextIOWrapper(fh, encoding="utf-8") - self.writer = csv.writer(self.wrapper, dialect=csv.unix_dialect) + def __init__(self, fh: TextIO) -> None: + self.fh = fh + self.writer = csv.writer(self.fh, dialect=csv.unix_dialect) self.writer.writerow(CSV_COLUMNS) + self._batch: List[List[Optional[str]]] = [] def write(self, stmt: S) -> None: row = stmt.to_csv_row() - self.writer.writerow([row.get(c) for c in CSV_COLUMNS]) + self._batch.append([row.get(c) for c in CSV_COLUMNS]) + if len(self._batch) >= CSV_BATCH: + self.writer.writerows(self._batch) + self._batch.clear() def close(self) -> None: - self.wrapper.close() - super().close() + if len(self._batch) > 0: + self.writer.writerows(self._batch) + self.fh.close() class PackStatementWriter(StatementWriter): - def __init__(self, fh: BinaryIO) -> None: - super().__init__(fh) - self.wrapper = TextIOWrapper(fh, encoding="utf-8") + def __init__(self, fh: TextIO) -> None: + self.fh = fh self.writer = csv.writer( - self.wrapper, + self.fh, dialect=csv.unix_dialect, quoting=csv.QUOTE_MINIMAL, ) + self._batch: List[List[Optional[str]]] = [] def write(self, stmt: S) -> None: row = stmt.to_csv_row() @@ -185,8 +195,12 @@ def write(self, stmt: S) -> None: if prop is None or schema is None: raise ValueError("Cannot pack statement without prop and schema") row["prop"] = pack_prop(schema, prop) - self.writer.writerow([row.get(c) for c in PACK_COLUMNS]) + self._batch.append([row.get(c) for c in PACK_COLUMNS]) + if len(self._batch) >= CSV_BATCH: + self.writer.writerows(self._batch) + self._batch.clear() def close(self) -> None: - self.wrapper.close() - super().close() + if len(self._batch) > 0: + self.writer.writerows(self._batch) + self.fh.close() diff --git a/nomenklatura/store/level.py b/nomenklatura/store/level.py index 494385ec..a9deeaa0 100644 --- a/nomenklatura/store/level.py +++ b/nomenklatura/store/level.py @@ -158,7 +158,7 @@ def __init__( ) -> None: super().__init__(store, scope, external=external) self.store: LevelDBStore[DS, CE] = store - self.last_seens: Dict[str, str] = {} + self.last_seens: Dict[str, Optional[str]] = {} def has_entity(self, id: str) -> bool: prefix = b(f"s:{id}:") @@ -188,7 +188,10 @@ def get_entity(self, id: str) -> Optional[CE]: for v in it: statements.append(unpack_statement(v, id, True)) for stmt in statements: - if stmt.dataset not in self.last_seens: + if ( + stmt.dataset not in self.last_seens + or self.last_seens[stmt.dataset] is None + ): ls_val = self.store.db.get(b(f"ls:{stmt.dataset}")) ls = ls_val.decode("utf-8") if ls_val is not None else None self.last_seens[stmt.dataset] = ls diff --git a/setup.py b/setup.py index 7492c27d..aa787a5a 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="nomenklatura", - version="3.13.1", + version="3.13.2", description="Make record linkages in followthemoney data.", long_description=long_description, long_description_content_type="text/markdown",