From 0118c9fa13eabb60f3193a2caaaf1a1dc85e6ba7 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Thu, 31 Oct 2024 16:41:01 +0000 Subject: [PATCH 01/23] memory-index-docs --- nomenklatura/index/entry.py | 14 ++++++-------- nomenklatura/index/index.py | 24 +++++++++++++++++------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/nomenklatura/index/entry.py b/nomenklatura/index/entry.py index f5b9e299..787348f1 100644 --- a/nomenklatura/index/entry.py +++ b/nomenklatura/index/entry.py @@ -10,7 +10,6 @@ class Entry(object): __slots__ = "idf", "entities" def __init__(self) -> None: - self.idf: float = 0.0 self.entities: Dict[Identifier, int] = dict() def add(self, entity_id: Identifier) -> None: @@ -21,13 +20,15 @@ def add(self, entity_id: Identifier) -> None: except KeyError: self.entities[entity_id] = 1 - def compute(self, field: "Field") -> None: - """Compute weighted term frequency for scoring.""" - self.idf = math.log(field.len / len(self.entities)) - def frequencies( self, field: "Field" ) -> Generator[Tuple[Identifier, float], None, None]: + """ + Term Frequency (TF) for each entity in this entry. + + TF being the number of occurrences of this token in the entity divided + by the total number of tokens in the entity (scoped to this field). + """ for entity_id, mentions in self.entities.items(): field_len = max(1, field.entities[entity_id]) yield entity_id, (mentions / field_len) @@ -69,9 +70,6 @@ def compute(self) -> None: self.len = max(1, len(self.entities)) self.avg_len = sum(self.entities.values()) / self.len - for entry in self.tokens.values(): - entry.compute(self) - def to_dict(self) -> Dict[str, Any]: return { "tokens": {t: e.to_dict() for t, e in self.tokens.items()}, diff --git a/nomenklatura/index/index.py b/nomenklatura/index/index.py index 4c9a8fc3..fa361475 100644 --- a/nomenklatura/index/index.py +++ b/nomenklatura/index/index.py @@ -18,7 +18,12 @@ class Index(BaseIndex[DS, CE]): - """An in-memory search index to match entities against a given dataset.""" + """ + An in-memory search index to match entities against a given dataset. + + For each field in the dataset, the index stores the IDs which contains each + token, along with the absolute frequency of each token in the document. + """ name = "memory" @@ -73,9 +78,16 @@ def commit(self) -> None: field.compute() def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float]]: - """A second method of doing xref: summing up the pairwise match value - for all entities lineraly. This uses a lot of memory but is really - fast.""" + """ + A second method of doing xref: summing up the pairwise match value + for all entities linearly. This uses a lot of memory but is really + fast. + + The score of each pair is the the sum of the product of term frequencies for + each co-occurring token in each field of the pair. + + We skip any tokens with more than 100 entities. + """ pairs: Dict[Pair, float] = {} log.info("Building index blocking pairs...") for field_name, field in self.fields.items(): @@ -86,9 +98,7 @@ def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float] if len(entry.entities) == 1 or len(entry.entities) > 100: continue - entities = sorted( - entry.frequencies(field), key=lambda f: f[1], reverse=True - ) + entities = entry.frequencies(field) for (left, lw), (right, rw) in combinations(entities, 2): if lw == 0.0 or rw == 0.0: continue From af42ebe02cd7514adcc5afad68cb7d4f0320b8af Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 1 Nov 2024 08:58:27 +0000 Subject: [PATCH 02/23] Start adding duckdb again --- nomenklatura/index/__init__.py | 11 ++- nomenklatura/index/duckdb_index.py | 112 +++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 nomenklatura/index/duckdb_index.py diff --git a/nomenklatura/index/__init__.py b/nomenklatura/index/__init__.py index bf33cfdb..81c662ad 100644 --- a/nomenklatura/index/__init__.py +++ b/nomenklatura/index/__init__.py @@ -9,7 +9,7 @@ from nomenklatura.entity import CE log = logging.getLogger(__name__) -INDEX_TYPES = ["tantivy", Index.name] +INDEX_TYPES = ["tantivy", "duckdb", Index.name] def get_index( @@ -24,10 +24,17 @@ def get_index( clazz = TantivyIndex[DS, CE] except ImportError: log.warning("`tantivy` is not available, falling back to in-memory index.") + if type_ == "duckdb": + try: + from nomenklatura.index.duckdb_index import DuckDBIndex + + clazz = DuckDBIndex[DS, CE] + except ImportError: + log.warning("`tantivy` is not available, falling back to in-memory index.") index = clazz(view, path) index.build() return index -__all__ = ["BaseIndex", "Index", "TantivyIndex", "get_index"] +__all__ = ["BaseIndex", "Index", "TantivyIndex", "DuckDBIndex", "get_index"] diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py new file mode 100644 index 00000000..70098968 --- /dev/null +++ b/nomenklatura/index/duckdb_index.py @@ -0,0 +1,112 @@ +import csv +from pathlib import Path +import logging +from itertools import combinations +from tempfile import mkdtemp +from typing import Any, Dict, List, Set, Tuple +from followthemoney.types import registry +import duckdb + +from nomenklatura.util import PathLike +from nomenklatura.resolver import Pair, Identifier +from nomenklatura.dataset import DS +from nomenklatura.entity import CE +from nomenklatura.store import View +from nomenklatura.index.entry import Field +from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer +from nomenklatura.index.common import BaseIndex + +log = logging.getLogger(__name__) + + + +class DuckDBIndex(BaseIndex[DS, CE]): + """ + An in-memory search index to match entities against a given dataset. + + For each field in the dataset, the index stores the IDs which contains each + token, along with the absolute frequency of each token in the document. + """ + + name = "duckdb" + + BOOSTS = { + NAME_PART_FIELD: 2.0, + WORD_FIELD: 0.5, + registry.name.name: 10.0, + # registry.country.name: 1.5, + # registry.date.name: 1.5, + # registry.language: 0.7, + # registry.iban.name: 3.0, + registry.phone.name: 3.0, + registry.email.name: 3.0, + # registry.entity: 0.0, + # registry.topic: 2.1, + registry.address.name: 2.5, + registry.identifier.name: 3.0, + } + + __slots__ = "view", "fields", "tokenizer", "entities" + + def __init__(self, view: View[DS, CE], data_dir: Path): + self.view = view + self.tokenizer = Tokenizer[DS, CE]() + self.path = Path(mkdtemp()) + self.con = duckdb.connect((self.path / "duckdb_index.db").as_posix()) + self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)") + + def dump(self, writer, entity: CE) -> None: + + if not entity.schema.matchable or entity.id is None: + return + + for field, token in self.tokenizer.entity(entity): + writer.writerow([entity.id, field, token]) + + def build(self) -> None: + """Index all entities in the dataset.""" + log.info("Building index from: %r...", self.view) + csv_path = self.path / "mentions.csv" + with open(csv_path, "w") as fh: + writer = csv.writer(fh) + writer.writerow(["id", "field", "token"]) + for idx, entity in enumerate(self.view.entities()): + self.dump(writer, entity) + if idx % 10000 == 0: + log.info("Dumped %s entities" % idx) + + log.info("Loading data...") + self.con.execute(f"COPY entries from '{csv_path}'") + log.info("Index built.") + + def frequencies(self, field: str, token: str) -> List[Tuple[str, float]]: + """ + """ + + mentions_query = """ + SELECT id, count(*) as mentions + FROM entries + WHERE field = ? AND token = ? + GROUP BY id + """ + mentions_rel = self.con.sql( + mentions_query, alias="mentions", params=[field, token] + ) + field_len_query = """ + SELECT id, count(*) as field_len from entries + WHERE field = ? + GROUP BY id + """ + field_len_rel = self.con.sql(field_len_query, alias="field_len", params=[field]) + joined = mentions_rel.join( + field_len_rel, "mentions.id = field_len.id" + ).set_alias("joined") + # TODO: Do I really need the max(1, field_len) here? + weights = self.con.sql("SELECT id, mentions / field_len from joined") + return list(weights.fetchall()) + + def __repr__(self) -> str: + return "" % ( + self.view.scope.name, + self.con, + ) From 41454b62a6b73b981556b3a030f1bcd0af117da7 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 1 Nov 2024 11:15:01 +0000 Subject: [PATCH 03/23] Lots of code, not quite working --- nomenklatura/index/__init__.py | 2 +- nomenklatura/index/duckdb_index.py | 117 +++++++++++++++++++++++------ nomenklatura/index/index.py | 6 +- 3 files changed, 102 insertions(+), 23 deletions(-) diff --git a/nomenklatura/index/__init__.py b/nomenklatura/index/__init__.py index 81c662ad..e001835d 100644 --- a/nomenklatura/index/__init__.py +++ b/nomenklatura/index/__init__.py @@ -30,7 +30,7 @@ def get_index( clazz = DuckDBIndex[DS, CE] except ImportError: - log.warning("`tantivy` is not available, falling back to in-memory index.") + log.warning("`duckdb` is not available, falling back to in-memory index.") index = clazz(view, path) index.build() diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index 70098968..21708969 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -3,7 +3,7 @@ import logging from itertools import combinations from tempfile import mkdtemp -from typing import Any, Dict, List, Set, Tuple +from typing import Any, Dict, Generator, List, Set, Tuple from followthemoney.types import registry import duckdb @@ -19,7 +19,6 @@ log = logging.getLogger(__name__) - class DuckDBIndex(BaseIndex[DS, CE]): """ An in-memory search index to match entities against a given dataset. @@ -77,33 +76,109 @@ def build(self) -> None: log.info("Loading data...") self.con.execute(f"COPY entries from '{csv_path}'") + + self.calculate_frequencies() log.info("Index built.") - - def frequencies(self, field: str, token: str) -> List[Tuple[str, float]]: - """ + + def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS): + pairs: Dict[Pair, float] = {} + for field_name, token, entities in self.frequencies(): + boost = self.BOOSTS.get(field_name, 1.0) + for (left, lw), (right, rw) in combinations(entities, 2): + if lw == 0.0 or rw == 0.0: + continue + pair = (max(left, right), min(left, right)) + if pair not in pairs: + pairs[pair] = 0 + score = (lw + rw) * boost + pairs[pair] += score + return sorted(pairs.items(), key=lambda p: p[1], reverse=True)[:max_pairs] + + def field_lengths(self): + field_len_query = """ + SELECT field, id, count(*) as field_len from entries + GROUP BY field, id + ORDER by field, id """ + field_len_rel = self.con.sql(field_len_query, alias="field_len") + row = field_len_rel.fetchone() + while row is not None: + yield row + row = field_len_rel.fetchone() + def mentions(self): mentions_query = """ - SELECT id, count(*) as mentions + SELECT field, id, token, count(*) as mentions FROM entries - WHERE field = ? AND token = ? - GROUP BY id + GROUP BY field, id, token + ORDER by field, id, token """ - mentions_rel = self.con.sql( - mentions_query, alias="mentions", params=[field, token] + mentions_rel = self.con.sql(mentions_query, alias="mentions") + row = mentions_rel.fetchone() + while row is not None: + yield row + row = mentions_rel.fetchone() + + def calculate_frequencies(self) -> None: + csv_path = self.path / "frequencies.csv" + with open(csv_path, "w") as fh: + writer = csv.writer(fh) + writer.writerow(["field", "id", "token", "frequency"]) + + mentions_gen = self.mentions() + mention_row = None + for field_name, id, field_len in self.field_lengths(): + if mention_row is None: # first iteration + mention_row = next(mentions_gen) + if mention_row is None: + # If there's at least one field length, there should be at least one mention + raise Exception("Unexpected empty mentions.") + frequencies = [] + (mention_field_name, mention_id, token, mention_count) = mention_row + + # For all the tokens in this field for this entity ID + while mention_field_name == field_name and mention_id == id: + frequencies.append((token, mention_count / field_len)) + mention_row = next(mentions_gen) + if mention_row is None: + break + (mention_field_name, mention_id, token, mention_count) = mention_row + + for token, freq in frequencies: + writer.writerow([field_name, id, token, freq]) + + log.info(f"Loading frequencies data... ({csv_path})") + self.con.execute( + "CREATE TABLE frequencies (field TEXT, id TEXT, token TEXT, frequency FLOAT)" ) - field_len_query = """ - SELECT id, count(*) as field_len from entries - WHERE field = ? - GROUP BY id + self.con.execute(f"COPY frequencies from '{csv_path}'") + log.info("Frequencies are loaded") + + def frequencies( + self, + ) -> Generator[Tuple[str, str, List[Tuple[Identifier, float]]], None, None]: + query = """ + SELECT field, token, id, frequency + FROM frequencies + ORDER by field, token """ - field_len_rel = self.con.sql(field_len_query, alias="field_len", params=[field]) - joined = mentions_rel.join( - field_len_rel, "mentions.id = field_len.id" - ).set_alias("joined") - # TODO: Do I really need the max(1, field_len) here? - weights = self.con.sql("SELECT id, mentions / field_len from joined") - return list(weights.fetchall()) + rel = self.con.sql(query, alias="mentions") + row = rel.fetchone() + entities = [] # the entities in this field, token group + field_name = None + token = None + while row is not None: + field_name, token, id, freq = row + entities.append((Identifier.get(id), freq)) + + row = rel.fetchone() + if row is None: + yield field_name, token, entities + break + new_field_name, new_token, _, _ = row + if new_field_name != field_name or new_token != token: + yield field_name, token, entities + entities = [] def __repr__(self) -> str: return "" % ( diff --git a/nomenklatura/index/index.py b/nomenklatura/index/index.py index fa361475..816c1dec 100644 --- a/nomenklatura/index/index.py +++ b/nomenklatura/index/index.py @@ -92,13 +92,17 @@ def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float] log.info("Building index blocking pairs...") for field_name, field in self.fields.items(): boost = self.BOOSTS.get(field_name, 1.0) - for idx, entry in enumerate(field.tokens.values()): + for idx, (token, entry) in enumerate(field.tokens.items()): if idx % 10000 == 0: log.info("Pairwise xref [%s]: %d" % (field_name, idx)) if len(entry.entities) == 1 or len(entry.entities) > 100: continue entities = entry.frequencies(field) + if field_name == "country": + for id, freq in entities: + if id.id == "NK-cVfXUNMeCpGWyQVFLkQCe7": + print(id, token, freq) for (left, lw), (right, rw) in combinations(entities, 2): if lw == 0.0 or rw == 0.0: continue From 69adca5b27ed016dc46a445c2fc968c72b6ed885 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 1 Nov 2024 13:18:54 +0000 Subject: [PATCH 04/23] Some sanity tests --- nomenklatura/index/duckdb_index.py | 5 ++- tests/index/test_duckdb_index.py | 70 ++++++++++++++++++++++++++++++ tests/index/test_index.py | 11 +++++ 3 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 tests/index/test_duckdb_index.py diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index 21708969..4856415a 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -106,7 +106,9 @@ def field_lengths(self): yield row row = field_len_rel.fetchone() - def mentions(self): + def mentions(self) -> Generator[Tuple[str, str, str, int], None, None]: + """Yields tuples of [field_name, entity_id, token, mention_count]""" + mentions_query = """ SELECT field, id, token, count(*) as mentions FROM entries @@ -128,6 +130,7 @@ def calculate_frequencies(self) -> None: mentions_gen = self.mentions() mention_row = None for field_name, id, field_len in self.field_lengths(): + print(field_name, id, field_len) if mention_row is None: # first iteration mention_row = next(mentions_gen) if mention_row is None: diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py new file mode 100644 index 00000000..7917b468 --- /dev/null +++ b/tests/index/test_duckdb_index.py @@ -0,0 +1,70 @@ +from collections import defaultdict +from pathlib import Path +from tempfile import NamedTemporaryFile, TemporaryDirectory + +from nomenklatura.dataset import Dataset +from nomenklatura.entity import CompositeEntity +from nomenklatura.index import Index +from nomenklatura.index.duckdb_index import DuckDBIndex +from nomenklatura.resolver.identifier import Identifier +from nomenklatura.store import SimpleMemoryStore + +DAIMLER = "66ce9f62af8c7d329506da41cb7c36ba058b3d28" +VERBAND_ID = "62ad0fe6f56dbbf6fee57ce3da76e88c437024d5" +VERBAND_BADEN_ID = "69401823a9f0a97cfdc37afa7c3158374e007669" +VERBAND_BADEN_DATA = { + "id": "bla", + "schema": "Company", + "properties": { + "name": ["VERBAND DER METALL UND ELEKTROINDUSTRIE BADEN WURTTEMBERG"] + }, +} + + +def test_field_lengths(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): + field_names = set() + ids = set() + for field_name, id, field_len in duckdb_index.field_lengths(): + field_names.add(field_name) + ids.add(id) + + # Expect to see all matchable entities + # jq .schema tests/fixtures/donations.ijson | sort | uniq -c + # Organizations 17 + # Companies 56 + # Persons 22 + # Addresses 89 + assert len(ids) == 184, len(ids) + + # Expect to see all index fields for the matchable prop types and any applicable synthetic fields + # jq '.properties | keys | .[]' tests/fixtures/donations.ijson --raw-output|sort -u + expected_fields = { + "namepart", + "name", + "country", + "word", + } + assert field_names == expected_fields, field_names + + +def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): + ids = set() + field_tokens = defaultdict(set) + + for field_name, id, token, count in duckdb_index.mentions(): + ids.add(id) + field_tokens[field_name].add(token) + + assert len(ids) == 184, len(ids) + assert "verband" in field_tokens["namepart"], field_tokens["namepart"] + assert "de" in field_tokens["country"], field_tokens["country"] + assert "adolf wurth gmbh" in field_tokens["name"], field_tokens["name"] + + +def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): + view = dstore.default_view() + for field_name, token, entities in duckdb_index.frequencies(): + print(field_name, token) + for entity_id, tf in entities: + print(" ", entity_id, tf) + assert False diff --git a/tests/index/test_index.py b/tests/index/test_index.py index dc5d6913..621b6385 100644 --- a/tests/index/test_index.py +++ b/tests/index/test_index.py @@ -27,6 +27,17 @@ def test_index_build(index_path: Path, dstore: SimpleMemoryStore): assert len(index) == 184, len(index) +def test_frequencies(dstore: SimpleMemoryStore, dindex: Index): + view = dstore.default_view() + + for field_name, field in dindex.fields.items(): + for token, entry in field.tokens.items(): + print(field_name, token) + for ident, tf in entry.frequencies(field): + print(" ", ident.id, tf) + assert False + + def test_index_persist(dstore: SimpleMemoryStore, dindex): view = dstore.default_view() with TemporaryDirectory() as tmpdir: From 2fc7cb34ad38cc6398dfb9de235a414d2f05244c Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 1 Nov 2024 14:08:33 +0000 Subject: [PATCH 05/23] Break up frequency calculation --- nomenklatura/index/duckdb_index.py | 55 +++++++++++++++++------------- tests/conftest.py | 8 +++++ tests/index/test_duckdb_index.py | 16 +++++++++ 3 files changed, 56 insertions(+), 23 deletions(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index 4856415a..7fcceff5 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -107,7 +107,7 @@ def field_lengths(self): row = field_len_rel.fetchone() def mentions(self) -> Generator[Tuple[str, str, str, int], None, None]: - """Yields tuples of [field_name, entity_id, token, mention_count]""" + """Yields tuples of (field_name, entity_id, token, mention_count)""" mentions_query = """ SELECT field, id, token, count(*) as mentions @@ -121,34 +121,43 @@ def mentions(self) -> Generator[Tuple[str, str, str, int], None, None]: yield row row = mentions_rel.fetchone() + def id_grouped_mentions( + self, + ) -> Generator[Tuple[str, str, int, List[Tuple[str, int]]], None, None]: + """ + Yields tuples of (field_name, entity_id, field_len, [(token, mention_count)]) + """ + mentions_gen = self.mentions() + mention_row = None + # Read all field lengths into memory because the concurrent iteration + # sees to be exiting the outer loop early and giving partial results. + for field_name, id, field_len in list(self.field_lengths()): + mentions = [] + try: + if mention_row is None: # first iteration + mention_field_name, mention_id, token, mention_count = next( + mentions_gen + ) + + while mention_field_name == field_name and mention_id == id: + mentions.append((token, mention_count)) + mention_field_name, mention_id, token, mention_count = next( + mentions_gen + ) + yield field_name, id, field_len, mentions + except StopIteration: + yield field_name, id, field_len, mentions + break + def calculate_frequencies(self) -> None: csv_path = self.path / "frequencies.csv" with open(csv_path, "w") as fh: writer = csv.writer(fh) writer.writerow(["field", "id", "token", "frequency"]) - mentions_gen = self.mentions() - mention_row = None - for field_name, id, field_len in self.field_lengths(): - print(field_name, id, field_len) - if mention_row is None: # first iteration - mention_row = next(mentions_gen) - if mention_row is None: - # If there's at least one field length, there should be at least one mention - raise Exception("Unexpected empty mentions.") - frequencies = [] - (mention_field_name, mention_id, token, mention_count) = mention_row - - # For all the tokens in this field for this entity ID - while mention_field_name == field_name and mention_id == id: - frequencies.append((token, mention_count / field_len)) - mention_row = next(mentions_gen) - if mention_row is None: - break - (mention_field_name, mention_id, token, mention_count) = mention_row - - for token, freq in frequencies: - writer.writerow([field_name, id, token, freq]) + for field_name, id, field_len, mentions in self.id_grouped_mentions(): + for token, freq in mentions: + writer.writerow([field_name, id, token, freq / field_len]) log.info(f"Loading frequencies data... ({csv_path})") self.con.execute( diff --git a/tests/conftest.py b/tests/conftest.py index ee28af17..c620315c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from tempfile import mkdtemp from nomenklatura import settings +from nomenklatura.index.duckdb_index import DuckDBIndex from nomenklatura.index.tantivy_index import TantivyIndex from nomenklatura.store import load_entity_file_store, SimpleMemoryStore from nomenklatura.kv import get_redis @@ -81,6 +82,13 @@ def tantivy_index(index_path: Path, dstore: SimpleMemoryStore): yield index +@pytest.fixture(scope="function") +def duckdb_index(index_path: Path, dstore: SimpleMemoryStore): + index = DuckDBIndex(dstore.default_view(), index_path) + index.build() + yield index + + @pytest.fixture(scope="function") def index_path(): index_path = Path(mkdtemp()) / "index-dir" diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py index 7917b468..e57a1ea7 100644 --- a/tests/index/test_duckdb_index.py +++ b/tests/index/test_duckdb_index.py @@ -59,6 +59,22 @@ def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): assert "verband" in field_tokens["namepart"], field_tokens["namepart"] assert "de" in field_tokens["country"], field_tokens["country"] assert "adolf wurth gmbh" in field_tokens["name"], field_tokens["name"] + assert "word" in field_tokens["word"], field_tokens["word"] + + +def test_id_grouped_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): + ids = set() + field_tokens = defaultdict(set) + for field_name, id, field_len, mentions in duckdb_index.id_grouped_mentions(): + ids.add(id) + for token, count in mentions: + field_tokens[field_name].add(token) + + assert len(ids) == 184, len(ids) + assert "verband" in field_tokens["namepart"], field_tokens["namepart"] + assert "de" in field_tokens["country"], field_tokens["country"] + print(field_tokens["name"]) + assert "adolf wurth gmbh" in field_tokens["name"], field_tokens["name"] def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): From eabb22c9716fdefcd554ceebf6c1ca8d1e14ffbc Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 1 Nov 2024 15:54:34 +0000 Subject: [PATCH 06/23] Skip tokens occurring in more than 100 entities --- nomenklatura/index/duckdb_index.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index 7fcceff5..86611784 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -121,12 +121,27 @@ def mentions(self) -> Generator[Tuple[str, str, str, int], None, None]: yield row row = mentions_rel.fetchone() + def common_tokens(self) -> Set[Tuple[str, str]]: + """Yields tuples of (field_name, token)""" + query = """ + SELECT field, token, count(*) as frequency + FROM entries + GROUP BY field, token + """ + token_counts_rel = self.con.sql(query) + common_tokens_rel = self.con.sql("SELECT * from token_counts_rel where frequency > 100") + tokens: Set[Tuple[str, str]] = set() + for (field_name, token, freq) in common_tokens_rel.fetchall(): + tokens.add((field_name, token)) + return tokens + def id_grouped_mentions( self, ) -> Generator[Tuple[str, str, int, List[Tuple[str, int]]], None, None]: """ Yields tuples of (field_name, entity_id, field_len, [(token, mention_count)]) """ + common_tokens = self.common_tokens() mentions_gen = self.mentions() mention_row = None # Read all field lengths into memory because the concurrent iteration @@ -140,7 +155,8 @@ def id_grouped_mentions( ) while mention_field_name == field_name and mention_id == id: - mentions.append((token, mention_count)) + if (mention_field_name, token) not in common_tokens: + mentions.append((token, mention_count)) mention_field_name, mention_id, token, mention_count = next( mentions_gen ) From b4103190dfe1b377598cd0168e2fac4b7ad5afa0 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 1 Nov 2024 18:11:52 +0000 Subject: [PATCH 07/23] Largely working --- nomenklatura/index/duckdb_index.py | 22 ++++++++++++---------- tests/index/test_duckdb_index.py | 23 ++++++++++++++++------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index 86611784..3641d872 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -100,7 +100,7 @@ def field_lengths(self): GROUP BY field, id ORDER by field, id """ - field_len_rel = self.con.sql(field_len_query, alias="field_len") + field_len_rel = self.con.sql(field_len_query) row = field_len_rel.fetchone() while row is not None: yield row @@ -115,7 +115,7 @@ def mentions(self) -> Generator[Tuple[str, str, str, int], None, None]: GROUP BY field, id, token ORDER by field, id, token """ - mentions_rel = self.con.sql(mentions_query, alias="mentions") + mentions_rel = self.con.sql(mentions_query) row = mentions_rel.fetchone() while row is not None: yield row @@ -129,9 +129,10 @@ def common_tokens(self) -> Set[Tuple[str, str]]: GROUP BY field, token """ token_counts_rel = self.con.sql(query) - common_tokens_rel = self.con.sql("SELECT * from token_counts_rel where frequency > 100") + filter_query = "SELECT * from token_counts_rel where frequency > 100 and field != 'country'" + common_tokens_rel = self.con.sql(filter_query) tokens: Set[Tuple[str, str]] = set() - for (field_name, token, freq) in common_tokens_rel.fetchall(): + for field_name, token, freq in common_tokens_rel.fetchall(): tokens.add((field_name, token)) return tokens @@ -150,17 +151,18 @@ def id_grouped_mentions( mentions = [] try: if mention_row is None: # first iteration - mention_field_name, mention_id, token, mention_count = next( - mentions_gen - ) + mention_row = next(mentions_gen) + mention_field_name, mention_id, token, mention_count = mention_row while mention_field_name == field_name and mention_id == id: if (mention_field_name, token) not in common_tokens: mentions.append((token, mention_count)) - mention_field_name, mention_id, token, mention_count = next( - mentions_gen - ) + + mention_row = next(mentions_gen) + mention_field_name, mention_id, token, mention_count = mention_row + yield field_name, id, field_len, mentions + except StopIteration: yield field_name, id, field_len, mentions break diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py index e57a1ea7..ee4b1d08 100644 --- a/tests/index/test_duckdb_index.py +++ b/tests/index/test_duckdb_index.py @@ -59,7 +59,7 @@ def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): assert "verband" in field_tokens["namepart"], field_tokens["namepart"] assert "de" in field_tokens["country"], field_tokens["country"] assert "adolf wurth gmbh" in field_tokens["name"], field_tokens["name"] - assert "word" in field_tokens["word"], field_tokens["word"] + assert "dortmund" in field_tokens["word"], field_tokens["word"] def test_id_grouped_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): @@ -73,14 +73,23 @@ def test_id_grouped_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBInde assert len(ids) == 184, len(ids) assert "verband" in field_tokens["namepart"], field_tokens["namepart"] assert "de" in field_tokens["country"], field_tokens["country"] - print(field_tokens["name"]) assert "adolf wurth gmbh" in field_tokens["name"], field_tokens["name"] + assert "dortmund" in field_tokens["word"], field_tokens["word"] def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): view = dstore.default_view() - for field_name, token, entities in duckdb_index.frequencies(): - print(field_name, token) - for entity_id, tf in entities: - print(" ", entity_id, tf) - assert False + pairs = duckdb_index.pairs() + assert len(pairs) > 0, pairs + tokenizer = duckdb_index.tokenizer + pair, score = pairs[0] + entity0 = view.get_entity(str(pair[0])) + tokens0 = set(tokenizer.entity(entity0)) + entity1 = view.get_entity(str(pair[1])) + tokens1 = set(tokenizer.entity(entity1)) + overlap = tokens0.intersection(tokens1) + assert len(overlap) > 0, overlap + # assert "Schnabel" in (overlap, tokens0, tokens1) + # assert "Schnabel" in (entity0.caption, entity1.caption) + assert score > 0 + # assert False From b45aa32040528f2f41518e88ff4f57692401fade Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 1 Nov 2024 18:16:54 +0000 Subject: [PATCH 08/23] pairing on country is too expensive --- nomenklatura/index/duckdb_index.py | 2 +- tests/index/test_duckdb_index.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index 3641d872..557141f0 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -129,7 +129,7 @@ def common_tokens(self) -> Set[Tuple[str, str]]: GROUP BY field, token """ token_counts_rel = self.con.sql(query) - filter_query = "SELECT * from token_counts_rel where frequency > 100 and field != 'country'" + filter_query = "SELECT * from token_counts_rel where frequency > 100" common_tokens_rel = self.con.sql(filter_query) tokens: Set[Tuple[str, str]] = set() for field_name, token, freq in common_tokens_rel.fetchall(): diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py index ee4b1d08..3d0ef005 100644 --- a/tests/index/test_duckdb_index.py +++ b/tests/index/test_duckdb_index.py @@ -57,7 +57,7 @@ def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): assert len(ids) == 184, len(ids) assert "verband" in field_tokens["namepart"], field_tokens["namepart"] - assert "de" in field_tokens["country"], field_tokens["country"] + assert "gb" in field_tokens["country"], field_tokens["country"] assert "adolf wurth gmbh" in field_tokens["name"], field_tokens["name"] assert "dortmund" in field_tokens["word"], field_tokens["word"] @@ -72,7 +72,7 @@ def test_id_grouped_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBInde assert len(ids) == 184, len(ids) assert "verband" in field_tokens["namepart"], field_tokens["namepart"] - assert "de" in field_tokens["country"], field_tokens["country"] + assert "gb" in field_tokens["country"], field_tokens["country"] assert "adolf wurth gmbh" in field_tokens["name"], field_tokens["name"] assert "dortmund" in field_tokens["word"], field_tokens["word"] From c2fbc2934102ed8f95db521b3dae02273a6125d7 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 1 Nov 2024 18:20:25 +0000 Subject: [PATCH 09/23] Tidy --- nomenklatura/index/index.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/nomenklatura/index/index.py b/nomenklatura/index/index.py index 816c1dec..c44a2b06 100644 --- a/nomenklatura/index/index.py +++ b/nomenklatura/index/index.py @@ -98,12 +98,7 @@ def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float] if len(entry.entities) == 1 or len(entry.entities) > 100: continue - entities = entry.frequencies(field) - if field_name == "country": - for id, freq in entities: - if id.id == "NK-cVfXUNMeCpGWyQVFLkQCe7": - print(id, token, freq) - for (left, lw), (right, rw) in combinations(entities, 2): + for (left, lw), (right, rw) in combinations(entry.frequencies(field), 2): if lw == 0.0 or rw == 0.0: continue pair = (max(left, right), min(left, right)) From 88e43988841ae500e3ecfc3ba8a9cff173ccc871 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Wed, 6 Nov 2024 16:50:34 +0000 Subject: [PATCH 10/23] Move more into the db --- nomenklatura/index/duckdb_index.py | 142 ++++++++++------------------- 1 file changed, 50 insertions(+), 92 deletions(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index 557141f0..e8aa60c9 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -1,3 +1,4 @@ +from collections import defaultdict import csv from pathlib import Path import logging @@ -52,6 +53,9 @@ def __init__(self, view: View[DS, CE], data_dir: Path): self.tokenizer = Tokenizer[DS, CE]() self.path = Path(mkdtemp()) self.con = duckdb.connect((self.path / "duckdb_index.db").as_posix()) + self.con.execute("SET memory_limit = '2GB';") + self.con.execute("SET max_memory = '2GB';") + self.con.execute("SET threads = 1;") self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)") def dump(self, writer, entity: CE) -> None: @@ -77,122 +81,76 @@ def build(self) -> None: log.info("Loading data...") self.con.execute(f"COPY entries from '{csv_path}'") - self.calculate_frequencies() log.info("Index built.") - def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS): - pairs: Dict[Pair, float] = {} + def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float]]: + csv_path = self.path / "cooccurrences.csv" + with open(csv_path, "w") as fh: + writer = csv.writer(fh) + writer.writerow(["left", "right", "score"]) + for pair, score in self.cooccurring_tokens(): + writer.writerow([pair[0], pair[1], score]) + log.info("Loading co-occurrences...") + self.con.execute('CREATE TABLE cooccurrences ("left" TEXT, "right" TEXT, score FLOAT)') + self.con.execute(f"COPY cooccurrences from '{csv_path}'") + pairs_query = """ + SELECT "left", "right", sum(score) as score + FROM cooccurrences + GROUP BY "left", "right" + ORDER BY score DESC + LIMIT ? + """ + pairs_rel = self.con.execute(pairs_query, [max_pairs]) + pairs: List[Tuple[Pair, float]] = [] + for left, right, score in pairs_rel.fetchall(): + pairs.append(((Identifier.get(left), Identifier.get(right)), score)) + return pairs + + def cooccurring_tokens(self): + logged = defaultdict(int) for field_name, token, entities in self.frequencies(): + logged[field_name] += 1 + if logged[field_name] % 10000 == 0: + log.info("Pairwise xref [%s]: %d" % (field_name, logged[field_name])) boost = self.BOOSTS.get(field_name, 1.0) for (left, lw), (right, rw) in combinations(entities, 2): if lw == 0.0 or rw == 0.0: continue pair = (max(left, right), min(left, right)) - if pair not in pairs: - pairs[pair] = 0 score = (lw + rw) * boost - pairs[pair] += score - return sorted(pairs.items(), key=lambda p: p[1], reverse=True)[:max_pairs] + yield pair, score - def field_lengths(self): + def frequencies( + self, + ) -> Generator[Tuple[str, str, List[Tuple[Identifier, float]]], None, None]: field_len_query = """ SELECT field, id, count(*) as field_len from entries GROUP BY field, id - ORDER by field, id """ - field_len_rel = self.con.sql(field_len_query) - row = field_len_rel.fetchone() - while row is not None: - yield row - row = field_len_rel.fetchone() - - def mentions(self) -> Generator[Tuple[str, str, str, int], None, None]: - """Yields tuples of (field_name, entity_id, token, mention_count)""" - + field_len = self.con.sql(field_len_query) mentions_query = """ SELECT field, id, token, count(*) as mentions FROM entries GROUP BY field, id, token - ORDER by field, id, token """ - mentions_rel = self.con.sql(mentions_query) - row = mentions_rel.fetchone() - while row is not None: - yield row - row = mentions_rel.fetchone() - - def common_tokens(self) -> Set[Tuple[str, str]]: - """Yields tuples of (field_name, token)""" - query = """ - SELECT field, token, count(*) as frequency + mentions = self.con.sql(mentions_query) + token_freq_query = """ + SELECT field, token, count(*) as token_freq FROM entries GROUP BY field, token """ - token_counts_rel = self.con.sql(query) - filter_query = "SELECT * from token_counts_rel where frequency > 100" - common_tokens_rel = self.con.sql(filter_query) - tokens: Set[Tuple[str, str]] = set() - for field_name, token, freq in common_tokens_rel.fetchall(): - tokens.add((field_name, token)) - return tokens - - def id_grouped_mentions( - self, - ) -> Generator[Tuple[str, str, int, List[Tuple[str, int]]], None, None]: - """ - Yields tuples of (field_name, entity_id, field_len, [(token, mention_count)]) - """ - common_tokens = self.common_tokens() - mentions_gen = self.mentions() - mention_row = None - # Read all field lengths into memory because the concurrent iteration - # sees to be exiting the outer loop early and giving partial results. - for field_name, id, field_len in list(self.field_lengths()): - mentions = [] - try: - if mention_row is None: # first iteration - mention_row = next(mentions_gen) - mention_field_name, mention_id, token, mention_count = mention_row - - while mention_field_name == field_name and mention_id == id: - if (mention_field_name, token) not in common_tokens: - mentions.append((token, mention_count)) - - mention_row = next(mentions_gen) - mention_field_name, mention_id, token, mention_count = mention_row - - yield field_name, id, field_len, mentions - - except StopIteration: - yield field_name, id, field_len, mentions - break - - def calculate_frequencies(self) -> None: - csv_path = self.path / "frequencies.csv" - with open(csv_path, "w") as fh: - writer = csv.writer(fh) - writer.writerow(["field", "id", "token", "frequency"]) - - for field_name, id, field_len, mentions in self.id_grouped_mentions(): - for token, freq in mentions: - writer.writerow([field_name, id, token, freq / field_len]) - - log.info(f"Loading frequencies data... ({csv_path})") - self.con.execute( - "CREATE TABLE frequencies (field TEXT, id TEXT, token TEXT, frequency FLOAT)" - ) - self.con.execute(f"COPY frequencies from '{csv_path}'") - log.info("Frequencies are loaded") - - def frequencies( - self, - ) -> Generator[Tuple[str, str, List[Tuple[Identifier, float]]], None, None]: + token_freq = self.con.sql(token_freq_query) query = """ - SELECT field, token, id, frequency - FROM frequencies - ORDER by field, token + SELECT mentions.field, mentions.token, mentions.id, mentions/field_len + FROM field_len + JOIN mentions + ON field_len.field = mentions.field AND field_len.id = mentions.id + JOIN token_freq + ON token_freq.field = mentions.field AND token_freq.token = mentions.token + where token_freq < 100 + ORDER BY mentions.field, mentions.token """ - rel = self.con.sql(query, alias="mentions") + rel = self.con.sql(query) row = rel.fetchone() entities = [] # the entities in this field, token group field_name = None From 684c5c05e0a8acb7a600552b285e0f9dde12ab7b Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Wed, 6 Nov 2024 16:51:06 +0000 Subject: [PATCH 11/23] Move even more into the db --- nomenklatura/index/duckdb_index.py | 89 ++++++++++-------------------- 1 file changed, 28 insertions(+), 61 deletions(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index e8aa60c9..e4dfb9df 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -4,7 +4,7 @@ import logging from itertools import combinations from tempfile import mkdtemp -from typing import Any, Dict, Generator, List, Set, Tuple +from typing import Any, Dict, Generator, Iterable, List, Set, Tuple from followthemoney.types import registry import duckdb @@ -19,6 +19,8 @@ log = logging.getLogger(__name__) +BATCH_SIZE = 1000 + class DuckDBIndex(BaseIndex[DS, CE]): """ @@ -57,6 +59,7 @@ def __init__(self, view: View[DS, CE], data_dir: Path): self.con.execute("SET max_memory = '2GB';") self.con.execute("SET threads = 1;") self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)") + self.con.execute("CREATE TABLE boosts (field TEXT, boost FLOAT)") def dump(self, writer, entity: CE) -> None: @@ -77,52 +80,17 @@ def build(self) -> None: self.dump(writer, entity) if idx % 10000 == 0: log.info("Dumped %s entities" % idx) + for field, boost in self.BOOSTS.items(): + self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost]) log.info("Loading data...") self.con.execute(f"COPY entries from '{csv_path}'") log.info("Index built.") - def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float]]: - csv_path = self.path / "cooccurrences.csv" - with open(csv_path, "w") as fh: - writer = csv.writer(fh) - writer.writerow(["left", "right", "score"]) - for pair, score in self.cooccurring_tokens(): - writer.writerow([pair[0], pair[1], score]) - log.info("Loading co-occurrences...") - self.con.execute('CREATE TABLE cooccurrences ("left" TEXT, "right" TEXT, score FLOAT)') - self.con.execute(f"COPY cooccurrences from '{csv_path}'") - pairs_query = """ - SELECT "left", "right", sum(score) as score - FROM cooccurrences - GROUP BY "left", "right" - ORDER BY score DESC - LIMIT ? - """ - pairs_rel = self.con.execute(pairs_query, [max_pairs]) - pairs: List[Tuple[Pair, float]] = [] - for left, right, score in pairs_rel.fetchall(): - pairs.append(((Identifier.get(left), Identifier.get(right)), score)) - return pairs - - def cooccurring_tokens(self): - logged = defaultdict(int) - for field_name, token, entities in self.frequencies(): - logged[field_name] += 1 - if logged[field_name] % 10000 == 0: - log.info("Pairwise xref [%s]: %d" % (field_name, logged[field_name])) - boost = self.BOOSTS.get(field_name, 1.0) - for (left, lw), (right, rw) in combinations(entities, 2): - if lw == 0.0 or rw == 0.0: - continue - pair = (max(left, right), min(left, right)) - score = (lw + rw) * boost - yield pair, score - - def frequencies( - self, - ) -> Generator[Tuple[str, str, List[Tuple[Identifier, float]]], None, None]: + def pairs( + self, max_pairs: int = BaseIndex.MAX_PAIRS + ) -> Iterable[Tuple[Pair, float]]: field_len_query = """ SELECT field, id, count(*) as field_len from entries GROUP BY field, id @@ -140,33 +108,32 @@ def frequencies( GROUP BY field, token """ token_freq = self.con.sql(token_freq_query) - query = """ - SELECT mentions.field, mentions.token, mentions.id, mentions/field_len + term_frequencies_query = """ + SELECT mentions.field, mentions.token, mentions.id, mentions/field_len as tf FROM field_len JOIN mentions ON field_len.field = mentions.field AND field_len.id = mentions.id JOIN token_freq ON token_freq.field = mentions.field AND token_freq.token = mentions.token where token_freq < 100 - ORDER BY mentions.field, mentions.token """ - rel = self.con.sql(query) - row = rel.fetchone() - entities = [] # the entities in this field, token group - field_name = None - token = None - while row is not None: - field_name, token, id, freq = row - entities.append((Identifier.get(id), freq)) - - row = rel.fetchone() - if row is None: - yield field_name, token, entities - break - new_field_name, new_token, _, _ = row - if new_field_name != field_name or new_token != token: - yield field_name, token, entities - entities = [] + term_frequencies = self.con.sql(term_frequencies_query) + pairs_query = """ + SELECT "left".id, "right".id, sum(("left".tf + "right".tf) * boost) as score + FROM term_frequencies as "left" + JOIN term_frequencies as "right" + ON "left".field = "right".field AND "left".token = "right".token + JOIN boosts + ON "left".field = boosts.field + WHERE "left".id > "right".id + GROUP BY "left".id, "right".id + ORDER BY score DESC + LIMIT ? + """ + results = self.con.execute(pairs_query, [max_pairs]) + while batch := results.fetchmany(BATCH_SIZE): + for left, right, score in batch: + yield (Identifier.get(left), Identifier.get(right)), score def __repr__(self) -> str: return "" % ( From 84d47e3d4311d13ae0f5fb391199e54527841d0f Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Thu, 7 Nov 2024 15:42:28 +0000 Subject: [PATCH 12/23] Unit test subqueries, typecheck --- nomenklatura/index/common.py | 8 ++- nomenklatura/index/duckdb_index.py | 81 ++++++++++++++++++------------ tests/index/test_duckdb_index.py | 71 ++++++++++++++++---------- tests/index/test_index.py | 56 +++++++++++++++------ 4 files changed, 137 insertions(+), 79 deletions(-) diff --git a/nomenklatura/index/common.py b/nomenklatura/index/common.py index 8cc75d49..c3f1792e 100644 --- a/nomenklatura/index/common.py +++ b/nomenklatura/index/common.py @@ -1,6 +1,6 @@ from pathlib import Path -from typing import Generic, List, Tuple -from nomenklatura.resolver import Identifier +from typing import Generic, Iterable, List, Tuple +from nomenklatura.resolver import Pair, Identifier from nomenklatura.dataset import DS from nomenklatura.entity import CE from nomenklatura.store import View @@ -16,9 +16,7 @@ def __init__(self, view: View[DS, CE], data_dir: Path) -> None: def build(self) -> None: raise NotImplementedError - def pairs( - self, max_pairs: int = MAX_PAIRS - ) -> List[Tuple[Tuple[Identifier, Identifier], float]]: + def pairs(self, max_pairs: int = MAX_PAIRS) -> Iterable[Tuple[Pair, float]]: raise NotImplementedError def match(self, entity: CE) -> List[Tuple[Identifier, float]]: diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index e4dfb9df..d3664d4a 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -1,21 +1,18 @@ -from collections import defaultdict -import csv +from duckdb import DuckDBPyRelation +from followthemoney.types import registry from pathlib import Path -import logging -from itertools import combinations from tempfile import mkdtemp -from typing import Any, Dict, Generator, Iterable, List, Set, Tuple -from followthemoney.types import registry +from typing import Iterable, Tuple +import csv import duckdb +import logging -from nomenklatura.util import PathLike -from nomenklatura.resolver import Pair, Identifier from nomenklatura.dataset import DS from nomenklatura.entity import CE -from nomenklatura.store import View -from nomenklatura.index.entry import Field -from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer from nomenklatura.index.common import BaseIndex +from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer +from nomenklatura.resolver import Pair, Identifier +from nomenklatura.store import View log = logging.getLogger(__name__) @@ -55,59 +52,74 @@ def __init__(self, view: View[DS, CE], data_dir: Path): self.tokenizer = Tokenizer[DS, CE]() self.path = Path(mkdtemp()) self.con = duckdb.connect((self.path / "duckdb_index.db").as_posix()) + + # https://duckdb.org/docs/guides/performance/environment + # > For ideal performance, aggregation-heavy workloads require approx. + # > 5 GB memory per thread and join-heavy workloads require approximately + # > 10 GB memory per thread. + # > Aim for 5-10 GB memory per thread. self.con.execute("SET memory_limit = '2GB';") self.con.execute("SET max_memory = '2GB';") + # > If you have a limited amount of memory, try to limit the number of threads self.con.execute("SET threads = 1;") - self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)") - self.con.execute("CREATE TABLE boosts (field TEXT, boost FLOAT)") - - def dump(self, writer, entity: CE) -> None: - - if not entity.schema.matchable or entity.id is None: - return - - for field, token in self.tokenizer.entity(entity): - writer.writerow([entity.id, field, token]) def build(self) -> None: """Index all entities in the dataset.""" log.info("Building index from: %r...", self.view) + self.con.execute("CREATE TABLE boosts (field TEXT, boost FLOAT)") + for field, boost in self.BOOSTS.items(): + self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost]) + + self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)") csv_path = self.path / "mentions.csv" + log.info("Dumping entity tokens to CSV for bulk load into the database...") with open(csv_path, "w") as fh: writer = csv.writer(fh) + + # csv.writer type gymnastics + def dump_entity(entity: CE) -> None: + if not entity.schema.matchable or entity.id is None: + return + for field, token in self.tokenizer.entity(entity): + writer.writerow([entity.id, field, token]) + writer.writerow(["id", "field", "token"]) for idx, entity in enumerate(self.view.entities()): - self.dump(writer, entity) + dump_entity(entity) if idx % 10000 == 0: log.info("Dumped %s entities" % idx) - for field, boost in self.BOOSTS.items(): - self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost]) log.info("Loading data...") self.con.execute(f"COPY entries from '{csv_path}'") - log.info("Index built.") - def pairs( - self, max_pairs: int = BaseIndex.MAX_PAIRS - ) -> Iterable[Tuple[Pair, float]]: + def field_len_rel(self) -> DuckDBPyRelation: field_len_query = """ SELECT field, id, count(*) as field_len from entries GROUP BY field, id """ - field_len = self.con.sql(field_len_query) + return self.con.sql(field_len_query) + + def mentions_rel(self) -> DuckDBPyRelation: mentions_query = """ SELECT field, id, token, count(*) as mentions FROM entries GROUP BY field, id, token """ - mentions = self.con.sql(mentions_query) + return self.con.sql(mentions_query) + + def token_freq_rel(self) -> DuckDBPyRelation: token_freq_query = """ SELECT field, token, count(*) as token_freq FROM entries GROUP BY field, token """ - token_freq = self.con.sql(token_freq_query) + return self.con.sql(token_freq_query) + + def frequencies_rel(self) -> DuckDBPyRelation: + field_len = self.field_len_rel() # noqa + mentions = self.mentions_rel() # noqa + token_freq = self.token_freq_rel() # noqa term_frequencies_query = """ SELECT mentions.field, mentions.token, mentions.id, mentions/field_len as tf FROM field_len @@ -117,7 +129,12 @@ def pairs( ON token_freq.field = mentions.field AND token_freq.token = mentions.token where token_freq < 100 """ - term_frequencies = self.con.sql(term_frequencies_query) + return self.con.sql(term_frequencies_query) + + def pairs( + self, max_pairs: int = BaseIndex.MAX_PAIRS + ) -> Iterable[Tuple[Pair, float]]: + term_frequencies = self.frequencies_rel() # noqa pairs_query = """ SELECT "left".id, "right".id, sum(("left".tf + "right".tf) * boost) as score FROM term_frequencies as "left" diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py index 3d0ef005..04581f77 100644 --- a/tests/index/test_duckdb_index.py +++ b/tests/index/test_duckdb_index.py @@ -1,10 +1,5 @@ from collections import defaultdict -from pathlib import Path -from tempfile import NamedTemporaryFile, TemporaryDirectory -from nomenklatura.dataset import Dataset -from nomenklatura.entity import CompositeEntity -from nomenklatura.index import Index from nomenklatura.index.duckdb_index import DuckDBIndex from nomenklatura.resolver.identifier import Identifier from nomenklatura.store import SimpleMemoryStore @@ -24,7 +19,7 @@ def test_field_lengths(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): field_names = set() ids = set() - for field_name, id, field_len in duckdb_index.field_lengths(): + for field_name, id, field_len in duckdb_index.field_len_rel().fetchall(): field_names.add(field_name) ids.add(id) @@ -51,7 +46,7 @@ def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): ids = set() field_tokens = defaultdict(set) - for field_name, id, token, count in duckdb_index.mentions(): + for field_name, id, token, count in duckdb_index.mentions_rel().fetchall(): ids.add(id) field_tokens[field_name].add(token) @@ -62,25 +57,14 @@ def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): assert "dortmund" in field_tokens["word"], field_tokens["word"] -def test_id_grouped_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): - ids = set() - field_tokens = defaultdict(set) - for field_name, id, field_len, mentions in duckdb_index.id_grouped_mentions(): - ids.add(id) - for token, count in mentions: - field_tokens[field_name].add(token) - - assert len(ids) == 184, len(ids) - assert "verband" in field_tokens["namepart"], field_tokens["namepart"] - assert "gb" in field_tokens["country"], field_tokens["country"] - assert "adolf wurth gmbh" in field_tokens["name"], field_tokens["name"] - assert "dortmund" in field_tokens["word"], field_tokens["word"] - - def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): view = dstore.default_view() - pairs = duckdb_index.pairs() - assert len(pairs) > 0, pairs + pairs = list(duckdb_index.pairs()) + + # At least one pair is found + assert len(pairs) > 0, len(pairs) + + # A pair has tokens which overlap tokenizer = duckdb_index.tokenizer pair, score = pairs[0] entity0 = view.get_entity(str(pair[0])) @@ -89,7 +73,40 @@ def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): tokens1 = set(tokenizer.entity(entity1)) overlap = tokens0.intersection(tokens1) assert len(overlap) > 0, overlap - # assert "Schnabel" in (overlap, tokens0, tokens1) - # assert "Schnabel" in (entity0.caption, entity1.caption) + + # A pair has non-zero score assert score > 0 - # assert False + + # pairs are in descending score order + last_score = pairs[0][1] + for pair in pairs[1:]: + assert pair[1] <= last_score + last_score = pair[1] + + # Johanna Quandt <> Frau Johanna Quandt + jq = ( + Identifier.get("9add84cbb7bb48c7552f8ec7ae54de54eed1e361"), + Identifier.get("2d3e50433e36ebe16f3d906b684c9d5124c46d76"), + ) + jq_score = [score for pair, score in pairs if jq == pair][0] + + # Bayerische Motorenwerke AG <> Bayerische Motorenwerke (BMW) AG + bmw = ( + Identifier.get("21cc81bf3b960d2847b66c6c862e7aa9b5e4f487"), + Identifier.get("12570ee94b8dc23bcc080e887539d3742b2a5237"), + ) + bmw_score = [score for pair, score in pairs if bmw == pair][0] + + # More tokens in BMW means lower TF, reducing the score + assert jq_score > bmw_score, (jq_score, bmw_score) + assert jq_score == 19.0, jq_score + assert 3.3 < bmw_score < 3.4, bmw_score + + # FERRING Arzneimittel GmbH <> Clou Container Leasing GmbH + false_pos = ( + Identifier.get("f8867c433ba247cfab74096c73f6ff5e36db3ffe"), + Identifier.get("a061e760dfcf0d5c774fc37c74937193704807b5"), + ) + false_pos_score = [score for pair, score in pairs if false_pos == pair][0] + assert 1.1 < false_pos_score < 1.2, false_pos_score + assert bmw_score > false_pos_score, (bmw_score, false_pos_score) diff --git a/tests/index/test_index.py b/tests/index/test_index.py index 621b6385..c5e9fdac 100644 --- a/tests/index/test_index.py +++ b/tests/index/test_index.py @@ -27,17 +27,6 @@ def test_index_build(index_path: Path, dstore: SimpleMemoryStore): assert len(index) == 184, len(index) -def test_frequencies(dstore: SimpleMemoryStore, dindex: Index): - view = dstore.default_view() - - for field_name, field in dindex.fields.items(): - for token, entry in field.tokens.items(): - print(field_name, token) - for ident, tf in entry.frequencies(field): - print(" ", ident.id, tf) - assert False - - def test_index_persist(dstore: SimpleMemoryStore, dindex): view = dstore.default_view() with TemporaryDirectory() as tmpdir: @@ -57,7 +46,11 @@ def test_index_persist(dstore: SimpleMemoryStore, dindex): def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index): view = dstore.default_view() pairs = dindex.pairs() - assert len(pairs) > 0, pairs + + # At least one pair is found + assert len(pairs) > 0, len(pairs) + + # A pair has tokens which overlap tokenizer = dindex.tokenizer pair, score = pairs[0] entity0 = view.get_entity(str(pair[0])) @@ -66,10 +59,43 @@ def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index): tokens1 = set(tokenizer.entity(entity1)) overlap = tokens0.intersection(tokens1) assert len(overlap) > 0, overlap - # assert "Schnabel" in (overlap, tokens0, tokens1) - # assert "Schnabel" in (entity0.caption, entity1.caption) + + # A pair has non-zero score assert score > 0 - # assert False + + # pairs are in descending score order + last_score = pairs[0][1] + for pair in pairs[1:]: + assert pair[1] <= last_score + last_score = pair[1] + + # Johanna Quandt <> Frau Johanna Quandt + jq = ( + Identifier.get("9add84cbb7bb48c7552f8ec7ae54de54eed1e361"), + Identifier.get("2d3e50433e36ebe16f3d906b684c9d5124c46d76"), + ) + jq_score = [score for pair, score in pairs if jq == pair][0] + + # Bayerische Motorenwerke AG <> Bayerische Motorenwerke (BMW) AG + bmw = ( + Identifier.get("21cc81bf3b960d2847b66c6c862e7aa9b5e4f487"), + Identifier.get("12570ee94b8dc23bcc080e887539d3742b2a5237"), + ) + bmw_score = [score for pair, score in pairs if bmw == pair][0] + + # More tokens in BMW means lower TF, reducing the score + assert jq_score > bmw_score, (jq_score, bmw_score) + assert jq_score == 19.0, jq_score + assert 3.3 < bmw_score < 3.4, bmw_score + + # FERRING Arzneimittel GmbH <> Clou Container Leasing GmbH + false_pos = ( + Identifier.get("f8867c433ba247cfab74096c73f6ff5e36db3ffe"), + Identifier.get("a061e760dfcf0d5c774fc37c74937193704807b5"), + ) + false_pos_score = [score for pair, score in pairs if false_pos == pair][0] + assert 1.1 < false_pos_score < 1.2, false_pos_score + assert bmw_score > false_pos_score, (bmw_score, false_pos_score) def test_match_score(dstore: SimpleMemoryStore, dindex: Index): From 136a064683fe0d97916d8f372bb723d2e830e8b3 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Thu, 7 Nov 2024 15:54:53 +0000 Subject: [PATCH 13/23] Add duckdb as dep --- setup.py | 4 ++++ tests/index/test_duckdb_index.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/setup.py b/setup.py index 4e1008d6..7492c27d 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ "plyvel < 2.0.0", "redis > 5.0.0, < 6.0.0", "tantivy < 1.0.0", + "duckdb < 2.0.0", ], "leveldb": [ "plyvel < 2.0.0", @@ -69,5 +70,8 @@ "tantivy": [ "tantivy < 1.0.0", ], + "duckdb": [ + "duckdb < 2.0.0", + ], }, ) diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py index 04581f77..745e753d 100644 --- a/tests/index/test_duckdb_index.py +++ b/tests/index/test_duckdb_index.py @@ -1,5 +1,7 @@ from collections import defaultdict +from pathlib import Path +from nomenklatura.index import get_index from nomenklatura.index.duckdb_index import DuckDBIndex from nomenklatura.resolver.identifier import Identifier from nomenklatura.store import SimpleMemoryStore @@ -16,6 +18,12 @@ } +def test_import(dstore: SimpleMemoryStore, index_path: Path): + view = dstore.default_view() + index = get_index(view, index_path, "duckdb") + assert isinstance(index, DuckDBIndex), type(index) + + def test_field_lengths(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): field_names = set() ids = set() From 655241f0ca9156eba10aca10b5923b1262328617 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Thu, 7 Nov 2024 17:55:28 +0000 Subject: [PATCH 14/23] Use provided index directory --- nomenklatura/index/duckdb_index.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index d3664d4a..5592f2c8 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -1,7 +1,7 @@ from duckdb import DuckDBPyRelation from followthemoney.types import registry from pathlib import Path -from tempfile import mkdtemp +from shutil import rmtree from typing import Iterable, Tuple import csv import duckdb @@ -50,8 +50,11 @@ class DuckDBIndex(BaseIndex[DS, CE]): def __init__(self, view: View[DS, CE], data_dir: Path): self.view = view self.tokenizer = Tokenizer[DS, CE]() - self.path = Path(mkdtemp()) - self.con = duckdb.connect((self.path / "duckdb_index.db").as_posix()) + self.data_dir = data_dir + if self.data_dir.exists(): + rmtree(self.data_dir.as_posix()) + self.data_dir.mkdir(parents=True) + self.con = duckdb.connect((self.data_dir / "duckdb_index.db").as_posix()) # https://duckdb.org/docs/guides/performance/environment # > For ideal performance, aggregation-heavy workloads require approx. @@ -71,7 +74,7 @@ def build(self) -> None: self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost]) self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)") - csv_path = self.path / "mentions.csv" + csv_path = self.data_dir / "mentions.csv" log.info("Dumping entity tokens to CSV for bulk load into the database...") with open(csv_path, "w") as fh: writer = csv.writer(fh) From 0ed8da8397c9c67395ece4dae9314beb65f2a10f Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Mon, 11 Nov 2024 15:01:20 +0000 Subject: [PATCH 15/23] Add basic matching, but it's 3 times slower than tantivy --- nomenklatura/index/duckdb_index.py | 67 +++++++++++++++++++++++------- tests/index/test_duckdb_index.py | 41 ++++++++++++++++++ tests/index/test_index.py | 8 ++-- 3 files changed, 99 insertions(+), 17 deletions(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index 5592f2c8..f4b84916 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -2,7 +2,7 @@ from followthemoney.types import registry from pathlib import Path from shutil import rmtree -from typing import Iterable, Tuple +from typing import Any, Dict, Iterable, List, Tuple import csv import duckdb import logging @@ -21,10 +21,11 @@ class DuckDBIndex(BaseIndex[DS, CE]): """ - An in-memory search index to match entities against a given dataset. + An index using DuckDB for token matching and scoring, keeping data in memory + until it needs to spill to disk as it approaches the configured memory limit. - For each field in the dataset, the index stores the IDs which contains each - token, along with the absolute frequency of each token in the document. + Pairs match if they share one or more tokens. A basic similarity score is calculated + cumulatively based on each token's Term Frequency (TF) and the field's boost factor. """ name = "duckdb" @@ -47,8 +48,12 @@ class DuckDBIndex(BaseIndex[DS, CE]): __slots__ = "view", "fields", "tokenizer", "entities" - def __init__(self, view: View[DS, CE], data_dir: Path): + def __init__( + self, view: View[DS, CE], data_dir: Path, options: Dict[str, Any] = {} + ): self.view = view + # self.memory_budget = int(options.get("memory_budget", 500) * 1024 * 1024) + self.max_candidates = int(options.get("max_candidates", 50)) self.tokenizer = Tokenizer[DS, CE]() self.data_dir = data_dir if self.data_dir.exists(): @@ -89,11 +94,23 @@ def dump_entity(entity: CE) -> None: writer.writerow(["id", "field", "token"]) for idx, entity in enumerate(self.view.entities()): dump_entity(entity) - if idx % 10000 == 0: + if idx % 50000 == 0: log.info("Dumped %s entities" % idx) log.info("Loading data...") self.con.execute(f"COPY entries from '{csv_path}'") + + log.info("Calculating term frequencies...") + frequencies = self.frequencies_rel() # noqa + self.con.execute("CREATE TABLE term_frequencies as SELECT * FROM frequencies") + + log.info("Calculating stopwords...") + token_freq = self.token_freq_rel() # noqa + self.con.execute( + "CREATE TABLE stopwords as SELECT * FROM token_freq where token_freq > 100" + ) + + self.con.execute("CREATE TEMPORARY TABLE matching (field TEXT, token TEXT)") log.info("Index built.") def field_len_rel(self) -> DuckDBPyRelation: @@ -122,30 +139,28 @@ def token_freq_rel(self) -> DuckDBPyRelation: def frequencies_rel(self) -> DuckDBPyRelation: field_len = self.field_len_rel() # noqa mentions = self.mentions_rel() # noqa - token_freq = self.token_freq_rel() # noqa term_frequencies_query = """ SELECT mentions.field, mentions.token, mentions.id, mentions/field_len as tf FROM field_len JOIN mentions ON field_len.field = mentions.field AND field_len.id = mentions.id - JOIN token_freq - ON token_freq.field = mentions.field AND token_freq.token = mentions.token - where token_freq < 100 """ return self.con.sql(term_frequencies_query) def pairs( self, max_pairs: int = BaseIndex.MAX_PAIRS ) -> Iterable[Tuple[Pair, float]]: - term_frequencies = self.frequencies_rel() # noqa pairs_query = """ - SELECT "left".id, "right".id, sum(("left".tf + "right".tf) * boost) as score + SELECT "left".id, "right".id, sum(("left".tf + "right".tf) * ifnull(boost, 1)) as score FROM term_frequencies as "left" JOIN term_frequencies as "right" ON "left".field = "right".field AND "left".token = "right".token - JOIN boosts + LEFT OUTER JOIN boosts ON "left".field = boosts.field - WHERE "left".id > "right".id + LEFT OUTER JOIN stopwords + ON stopwords.field = "left".field AND stopwords.token = "left".token + WHERE token_freq is NULL + AND "left".id > "right".id GROUP BY "left".id, "right".id ORDER BY score DESC LIMIT ? @@ -155,6 +170,30 @@ def pairs( for left, right, score in batch: yield (Identifier.get(left), Identifier.get(right)), score + def match(self, entity: CE) -> List[Tuple[Identifier, float]]: + """Match an entity against the index, returning a list of + (entity_id, score) pairs.""" + rows = list(self.tokenizer.entity(entity)) + + if rows: + self.con.executemany("INSERT INTO matching VALUES (?, ?)", rows) + + match_query = """ + SELECT id, sum(tf * ifnull(boost, 1)) as score + FROM term_frequencies + JOIN matching + ON term_frequencies.field = matching.field AND term_frequencies.token = matching.token + LEFT OUTER JOIN boosts + ON term_frequencies.field = boosts.field + GROUP BY id + ORDER BY score DESC + LIMIT ? + """ + results = self.con.execute(match_query, [self.max_candidates]) + matches = [(Identifier.get(id), score) for id, score in results.fetchall()] + self.con.execute("DELETE FROM matching") + return matches + def __repr__(self) -> str: return "" % ( self.view.scope.name, diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py index 745e753d..cc989337 100644 --- a/tests/index/test_duckdb_index.py +++ b/tests/index/test_duckdb_index.py @@ -1,6 +1,8 @@ from collections import defaultdict from pathlib import Path +from nomenklatura.dataset import Dataset +from nomenklatura.entity import CompositeEntity from nomenklatura.index import get_index from nomenklatura.index.duckdb_index import DuckDBIndex from nomenklatura.resolver.identifier import Identifier @@ -118,3 +120,42 @@ def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): false_pos_score = [score for pair, score in pairs if false_pos == pair][0] assert 1.1 < false_pos_score < 1.2, false_pos_score assert bmw_score > false_pos_score, (bmw_score, false_pos_score) + + assert len(pairs) == 428, len(pairs) + + +def test_match_score(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): + """Match an entity that isn't itself in the index""" + dx = Dataset.make({"name": "test", "title": "Test"}) + entity = CompositeEntity.from_data(dx, VERBAND_BADEN_DATA) + matches = duckdb_index.match(entity) + # 9 entities in the index where some token in the query entity matches some + # token in the index. + assert len(matches) == 9, matches + + top_result = matches[0] + assert top_result[0] == Identifier(VERBAND_BADEN_ID), top_result + assert 1.99 < top_result[1] < 2, top_result + + next_result = matches[1] + assert next_result[0] == Identifier(VERBAND_ID), next_result + assert 1.66 < next_result[1] < 1.67, next_result + + match_identifiers = set(str(m[0]) for m in matches) + + +def test_top_match_matches_strong_pairs( + dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex +): + """Pairs with high scores are each others' top matches""" + + view = dstore.default_view() + strong_pairs = [p for p in duckdb_index.pairs() if p[1] > 3.0] + assert len(strong_pairs) > 4 + + for pair, pair_score in strong_pairs: + entity = view.get_entity(pair[0].id) + matches = duckdb_index.match(entity) + # it'll match itself and the other in the pair + for match, match_score in matches[:2]: + assert match in pair, (match, pair) diff --git a/tests/index/test_index.py b/tests/index/test_index.py index c5e9fdac..ac37d9fb 100644 --- a/tests/index/test_index.py +++ b/tests/index/test_index.py @@ -33,13 +33,13 @@ def test_index_persist(dstore: SimpleMemoryStore, dindex): with NamedTemporaryFile("w") as fh: path = Path(fh.name) dindex.save(path) - loaded = Index.load(dstore.default_view(), path, tmpdir) + loaded = Index.load(dstore.default_view(), path, Path(tmpdir)) assert len(dindex.entities) == len(loaded.entities), (dindex, loaded) assert len(dindex) == len(loaded), (dindex, loaded) path.unlink(missing_ok=True) with TemporaryDirectory() as tmpdir: - empty = Index.load(view, path, tmpdir) + empty = Index.load(view, path, Path(tmpdir)) assert len(empty) == len(loaded), (empty, loaded) @@ -97,6 +97,8 @@ def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index): assert 1.1 < false_pos_score < 1.2, false_pos_score assert bmw_score > false_pos_score, (bmw_score, false_pos_score) + assert len(pairs) == 428, len(pairs) + def test_match_score(dstore: SimpleMemoryStore, dindex: Index): """Match an entity that isn't itself in the index""" @@ -129,7 +131,7 @@ def test_top_match_matches_strong_pairs(dstore: SimpleMemoryStore, dindex: Index assert len(strong_pairs) > 4 for pair, pair_score in strong_pairs: - entity = view.get_entity(pair[0]) + entity = view.get_entity(pair[0].id) matches = dindex.match(entity) # it'll match itself and the other in the pair for match, match_score in matches[:2]: From ed1e28218441e021ca785e89f96cc40946f7c67a Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Mon, 11 Nov 2024 17:45:12 +0000 Subject: [PATCH 16/23] WIP horrid interface --- nomenklatura/enrich/__init__.py | 3 +- nomenklatura/enrich/common.py | 23 ++++++++++- nomenklatura/index/duckdb_index.py | 63 ++++++++++++++++++++---------- tests/conftest.py | 2 +- tests/index/test_duckdb_index.py | 37 ++++++++++-------- 5 files changed, 88 insertions(+), 40 deletions(-) diff --git a/nomenklatura/enrich/__init__.py b/nomenklatura/enrich/__init__.py index e41fc94d..78175828 100644 --- a/nomenklatura/enrich/__init__.py +++ b/nomenklatura/enrich/__init__.py @@ -6,7 +6,7 @@ from nomenklatura.dataset import DS from nomenklatura.cache import Cache from nomenklatura.matching import DefaultAlgorithm -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import Enricher, EnricherConfig, BulkEnricher from nomenklatura.enrich.common import EnrichmentAbort, EnrichmentException from nomenklatura.judgement import Judgement from nomenklatura.resolver import Resolver @@ -16,6 +16,7 @@ "Enricher", "EnrichmentAbort", "EnrichmentException", + "BulkEnricher", "make_enricher", "enrich", "match", diff --git a/nomenklatura/enrich/common.py b/nomenklatura/enrich/common.py index c4523227..3b3825c1 100644 --- a/nomenklatura/enrich/common.py +++ b/nomenklatura/enrich/common.py @@ -3,7 +3,7 @@ import logging import time from banal import as_bool -from typing import Union, Any, Dict, Optional, Generator, Generic +from typing import List, Tuple, Union, Any, Dict, Optional, Generator, Generic from abc import ABC, abstractmethod from requests import Session from requests.exceptions import RequestException @@ -18,6 +18,7 @@ from nomenklatura.dataset import DS from nomenklatura.cache import Cache from nomenklatura.util import HeadersType +from nomenklatura.resolver import Identifier EnricherConfig = Dict[str, Any] log = logging.getLogger(__name__) @@ -205,3 +206,23 @@ def close(self) -> None: self.cache.close() if self._session is not None: self._session.close() + + +class BulkEnricher(Enricher[DS], ABC): + """ + An enricher which performs matching in bulk, requiring all subject entities + to be loaded before matching. + """ + def load_wrapped(self, entity: CE) -> None: + if not self._filter_entity(entity): + return + self.load(entity) + + def load(self, entity: CE) -> None: + raise NotImplementedError() + + def candidates(self) -> Generator[Tuple[Identifier, List[Tuple[Identifier, float]]], None, None]: + raise NotImplementedError() + + def match_candidates(self, entity: CE, candidates: List[Tuple[Identifier, float]]) -> Generator[CE, None, None]: + raise NotImplementedError() diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index f4b84916..bcfd97ab 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -1,8 +1,9 @@ +from io import TextIOWrapper from duckdb import DuckDBPyRelation from followthemoney.types import registry from pathlib import Path from shutil import rmtree -from typing import Any, Dict, Iterable, List, Tuple +from typing import Any, Dict, Generator, Iterable, List, Tuple import csv import duckdb import logging @@ -60,6 +61,10 @@ def __init__( rmtree(self.data_dir.as_posix()) self.data_dir.mkdir(parents=True) self.con = duckdb.connect((self.data_dir / "duckdb_index.db").as_posix()) + self.matching_path = self.data_dir / "matching.csv" + self.matching_dump: TextIOWrapper | None = open(self.matching_path, "w") + writer = csv.writer(self.matching_dump) + writer.writerow(["id", "field", "token"]) # https://duckdb.org/docs/guides/performance/environment # > For ideal performance, aggregation-heavy workloads require approx. @@ -78,6 +83,7 @@ def build(self) -> None: for field, boost in self.BOOSTS.items(): self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost]) + self.con.execute("CREATE TABLE matching (id TEXT, field TEXT, token TEXT)") self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)") csv_path = self.data_dir / "mentions.csv" log.info("Dumping entity tokens to CSV for bulk load into the database...") @@ -110,7 +116,6 @@ def dump_entity(entity: CE) -> None: "CREATE TABLE stopwords as SELECT * FROM token_freq where token_freq > 100" ) - self.con.execute("CREATE TEMPORARY TABLE matching (field TEXT, token TEXT)") log.info("Index built.") def field_len_rel(self) -> DuckDBPyRelation: @@ -170,29 +175,45 @@ def pairs( for left, right, score in batch: yield (Identifier.get(left), Identifier.get(right)), score - def match(self, entity: CE) -> List[Tuple[Identifier, float]]: - """Match an entity against the index, returning a list of - (entity_id, score) pairs.""" - rows = list(self.tokenizer.entity(entity)) + def add_matching_subject(self, entity: CE) -> None: + print("adding", entity) + writer = csv.writer(self.matching_dump) + for field, token in self.tokenizer.entity(entity): + writer.writerow([entity.id, field, token]) + + def matches( + self, + ) -> Generator[Tuple[Identifier, List[Tuple[Identifier, float]]], None, None]: + if self.matching_dump is not None: + self.matching_dump.close() + self.matching_dump = None + log.info("Loading matching subjects...") + print(self.con.execute(f"COPY entries from '{self.matching_path}'").fetchall()) + log.info("Finished loading matching subjects.") - if rows: - self.con.executemany("INSERT INTO matching VALUES (?, ?)", rows) - - match_query = """ - SELECT id, sum(tf * ifnull(boost, 1)) as score - FROM term_frequencies + pairs_query = """ + SELECT matching.id, matches.id, sum(matches.tf * ifnull(boost, 1)) as score + FROM term_frequencies as matches JOIN matching - ON term_frequencies.field = matching.field AND term_frequencies.token = matching.token + ON matches.field = matching.field AND matches.token = matching.token LEFT OUTER JOIN boosts - ON term_frequencies.field = boosts.field - GROUP BY id - ORDER BY score DESC - LIMIT ? + ON matches.field = boosts.field + GROUP BY matches.id, matching.id + ORDER BY matching.id, score DESC """ - results = self.con.execute(match_query, [self.max_candidates]) - matches = [(Identifier.get(id), score) for id, score in results.fetchall()] - self.con.execute("DELETE FROM matching") - return matches + results = self.con.execute(pairs_query) + print("results", results.fetchall) + previous_id = None + matches: List[Tuple[Identifier, float]] = [] + while batch := results.fetchmany(BATCH_SIZE): + print("batch") + for matching_id, match_id, score in batch: + if previous_id is not None and matching_id != previous_id: + if matches: + yield Identifier.get(previous_id), matches + matches = [] + previous_id = matching_id + matches.append((Identifier.get(match_id), score)) def __repr__(self) -> str: return "" % ( diff --git a/tests/conftest.py b/tests/conftest.py index c620315c..4897a63e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -93,4 +93,4 @@ def duckdb_index(index_path: Path, dstore: SimpleMemoryStore): def index_path(): index_path = Path(mkdtemp()) / "index-dir" yield index_path - shutil.rmtree(index_path, ignore_errors=True) + #shutil.rmtree(index_path, ignore_errors=True) diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py index cc989337..b528e6b5 100644 --- a/tests/index/test_duckdb_index.py +++ b/tests/index/test_duckdb_index.py @@ -125,10 +125,15 @@ def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): def test_match_score(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): + print(duckdb_index.data_dir) """Match an entity that isn't itself in the index""" dx = Dataset.make({"name": "test", "title": "Test"}) entity = CompositeEntity.from_data(dx, VERBAND_BADEN_DATA) - matches = duckdb_index.match(entity) + duckdb_index.add_matching_subject(entity) + match_sets = list(duckdb_index.matches()) + assert len(match_sets) == 1, match_sets + subject_id, matches = match_sets[0] + # 9 entities in the index where some token in the query entity matches some # token in the index. assert len(matches) == 9, matches @@ -144,18 +149,18 @@ def test_match_score(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): match_identifiers = set(str(m[0]) for m in matches) -def test_top_match_matches_strong_pairs( - dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex -): - """Pairs with high scores are each others' top matches""" - - view = dstore.default_view() - strong_pairs = [p for p in duckdb_index.pairs() if p[1] > 3.0] - assert len(strong_pairs) > 4 - - for pair, pair_score in strong_pairs: - entity = view.get_entity(pair[0].id) - matches = duckdb_index.match(entity) - # it'll match itself and the other in the pair - for match, match_score in matches[:2]: - assert match in pair, (match, pair) +#def test_top_match_matches_strong_pairs( +# dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex +#): +# """Pairs with high scores are each others' top matches""" +# +# view = dstore.default_view() +# strong_pairs = [p for p in duckdb_index.pairs() if p[1] > 3.0] +# assert len(strong_pairs) > 4 +# +# for pair, pair_score in strong_pairs: +# entity = view.get_entity(pair[0].id) +# matches = duckdb_index.match(entity) +# # it'll match itself and the other in the pair +# for match, match_score in matches[:2]: +# assert match in pair, (match, pair) From d5f4d17e3481acbb2ba56995ebdc39aab9f9180f Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Mon, 11 Nov 2024 17:59:05 +0000 Subject: [PATCH 17/23] Fixy --- nomenklatura/index/duckdb_index.py | 12 +++++++----- tests/conftest.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index bcfd97ab..d5bcc8b2 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -104,7 +104,8 @@ def dump_entity(entity: CE) -> None: log.info("Dumped %s entities" % idx) log.info("Loading data...") - self.con.execute(f"COPY entries from '{csv_path}'") + result = self.con.execute(f"COPY entries from '{csv_path}'").fetchall() + log.info("Loaded %r rows", len(result)) log.info("Calculating term frequencies...") frequencies = self.frequencies_rel() # noqa @@ -188,10 +189,10 @@ def matches( self.matching_dump.close() self.matching_dump = None log.info("Loading matching subjects...") - print(self.con.execute(f"COPY entries from '{self.matching_path}'").fetchall()) + print(self.con.execute(f"COPY matching from '{self.matching_path}'").fetchall()) log.info("Finished loading matching subjects.") - pairs_query = """ + match_query = """ SELECT matching.id, matches.id, sum(matches.tf * ifnull(boost, 1)) as score FROM term_frequencies as matches JOIN matching @@ -201,8 +202,8 @@ def matches( GROUP BY matches.id, matching.id ORDER BY matching.id, score DESC """ - results = self.con.execute(pairs_query) - print("results", results.fetchall) + print(self.con.execute(match_query).fetchall()) + results = self.con.execute(match_query) previous_id = None matches: List[Tuple[Identifier, float]] = [] while batch := results.fetchmany(BATCH_SIZE): @@ -214,6 +215,7 @@ def matches( matches = [] previous_id = matching_id matches.append((Identifier.get(match_id), score)) + yield Identifier.get(previous_id), matches def __repr__(self) -> str: return "" % ( diff --git a/tests/conftest.py b/tests/conftest.py index 4897a63e..c620315c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -93,4 +93,4 @@ def duckdb_index(index_path: Path, dstore: SimpleMemoryStore): def index_path(): index_path = Path(mkdtemp()) / "index-dir" yield index_path - #shutil.rmtree(index_path, ignore_errors=True) + shutil.rmtree(index_path, ignore_errors=True) From 287d70a0b04e2c796979d1f77e5f418bc015d22e Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Mon, 11 Nov 2024 21:09:45 +0000 Subject: [PATCH 18/23] WIP --- nomenklatura/index/duckdb_index.py | 16 +++++++++------- tests/index/test_duckdb_index.py | 1 + 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index d5bcc8b2..1538730a 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -71,8 +71,8 @@ def __init__( # > 5 GB memory per thread and join-heavy workloads require approximately # > 10 GB memory per thread. # > Aim for 5-10 GB memory per thread. - self.con.execute("SET memory_limit = '2GB';") - self.con.execute("SET max_memory = '2GB';") + self.con.execute("SET memory_limit = '3GB';") + self.con.execute("SET max_memory = '3GB';") # > If you have a limited amount of memory, try to limit the number of threads self.con.execute("SET threads = 1;") @@ -177,7 +177,6 @@ def pairs( yield (Identifier.get(left), Identifier.get(right)), score def add_matching_subject(self, entity: CE) -> None: - print("adding", entity) writer = csv.writer(self.matching_dump) for field, token in self.tokenizer.entity(entity): writer.writerow([entity.id, field, token]) @@ -189,27 +188,30 @@ def matches( self.matching_dump.close() self.matching_dump = None log.info("Loading matching subjects...") - print(self.con.execute(f"COPY matching from '{self.matching_path}'").fetchall()) + self.con.execute(f"COPY matching from '{self.matching_path}'") log.info("Finished loading matching subjects.") match_query = """ SELECT matching.id, matches.id, sum(matches.tf * ifnull(boost, 1)) as score FROM term_frequencies as matches + LEFT OUTER JOIN stopwords + ON stopwords.field = matches.field AND stopwords.token = matches.token JOIN matching ON matches.field = matching.field AND matches.token = matching.token LEFT OUTER JOIN boosts ON matches.field = boosts.field + WHERE token_freq is NULL GROUP BY matches.id, matching.id ORDER BY matching.id, score DESC """ - print(self.con.execute(match_query).fetchall()) results = self.con.execute(match_query) previous_id = None matches: List[Tuple[Identifier, float]] = [] while batch := results.fetchmany(BATCH_SIZE): - print("batch") for matching_id, match_id, score in batch: - if previous_id is not None and matching_id != previous_id: + if previous_id is None: + previous_id = matching_id + if matching_id != previous_id: if matches: yield Identifier.get(previous_id), matches matches = [] diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py index b528e6b5..2ef739d0 100644 --- a/tests/index/test_duckdb_index.py +++ b/tests/index/test_duckdb_index.py @@ -133,6 +133,7 @@ def test_match_score(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): match_sets = list(duckdb_index.matches()) assert len(match_sets) == 1, match_sets subject_id, matches = match_sets[0] + assert subject_id == Identifier("bla"), subject_id # 9 entities in the index where some token in the query entity matches some # token in the index. From 178a941e39952d13916bd6d370c14b8493da3bbf Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Fri, 15 Nov 2024 15:40:10 +0000 Subject: [PATCH 19/23] Reduce memory consumption by By letting it materialise intermediate results more explicitly instead of doing multiple joins concurrently --- nomenklatura/index/duckdb_index.py | 95 ++++++++++++++++-------------- tests/index/test_duckdb_index.py | 16 ++--- 2 files changed, 61 insertions(+), 50 deletions(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index 1538730a..d3547ef4 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -3,7 +3,7 @@ from followthemoney.types import registry from pathlib import Path from shutil import rmtree -from typing import Any, Dict, Generator, Iterable, List, Tuple +from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple import csv import duckdb import logging @@ -53,26 +53,31 @@ def __init__( self, view: View[DS, CE], data_dir: Path, options: Dict[str, Any] = {} ): self.view = view - # self.memory_budget = int(options.get("memory_budget", 500) * 1024 * 1024) + memory_budget = options.get("memory_budget", None) + self.memory_budget: Optional[int] = ( + (int(memory_budget) * 1024) if memory_budget else None + ) + """Memory budget in megabytes""" self.max_candidates = int(options.get("max_candidates", 50)) self.tokenizer = Tokenizer[DS, CE]() self.data_dir = data_dir if self.data_dir.exists(): - rmtree(self.data_dir.as_posix()) + rmtree(self.data_dir) self.data_dir.mkdir(parents=True) self.con = duckdb.connect((self.data_dir / "duckdb_index.db").as_posix()) self.matching_path = self.data_dir / "matching.csv" + self.matching_path.unlink(missing_ok=True) self.matching_dump: TextIOWrapper | None = open(self.matching_path, "w") writer = csv.writer(self.matching_dump) writer.writerow(["id", "field", "token"]) # https://duckdb.org/docs/guides/performance/environment - # > For ideal performance, aggregation-heavy workloads require approx. - # > 5 GB memory per thread and join-heavy workloads require approximately - # > 10 GB memory per thread. + # > For ideal performance, + # > aggregation-heavy workloads require approx. 5 GB memory per thread and + # > join-heavy workloads require approximately 10 GB memory per thread. # > Aim for 5-10 GB memory per thread. - self.con.execute("SET memory_limit = '3GB';") - self.con.execute("SET max_memory = '3GB';") + if self.memory_budget is not None: + self.con.execute("SET memory_limit = ?;", [f"{self.memory_budget}MB"]) # > If you have a limited amount of memory, try to limit the number of threads self.con.execute("SET threads = 1;") @@ -96,62 +101,72 @@ def dump_entity(entity: CE) -> None: return for field, token in self.tokenizer.entity(entity): writer.writerow([entity.id, field, token]) + writer.writerow(["id", "field", "token"]) - writer.writerow(["id", "field", "token"]) for idx, entity in enumerate(self.view.entities()): dump_entity(entity) if idx % 50000 == 0: log.info("Dumped %s entities" % idx) - log.info("Loading data...") - result = self.con.execute(f"COPY entries from '{csv_path}'").fetchall() - log.info("Loaded %r rows", len(result)) - - log.info("Calculating term frequencies...") - frequencies = self.frequencies_rel() # noqa - self.con.execute("CREATE TABLE term_frequencies as SELECT * FROM frequencies") - - log.info("Calculating stopwords...") - token_freq = self.token_freq_rel() # noqa - self.con.execute( - "CREATE TABLE stopwords as SELECT * FROM token_freq where token_freq > 100" - ) + self.con.execute(f"COPY entries from '{csv_path}'") + log.info("Done.") + self._build_frequencies() log.info("Index built.") - def field_len_rel(self) -> DuckDBPyRelation: + def _build_field_len(self) -> None: + self._build_stopwords() + log.info("Calculating field lengths...") field_len_query = """ - SELECT field, id, count(*) as field_len from entries - GROUP BY field, id + CREATE TABLE IF NOT EXISTS field_len as + SELECT entries.field, entries.id, count(*) as field_len from entries + LEFT OUTER JOIN stopwords + ON stopwords.field = entries.field AND stopwords.token = entries.token + WHERE token_freq is NULL + GROUP BY entries.field, entries.id """ - return self.con.sql(field_len_query) + self.con.execute(field_len_query) - def mentions_rel(self) -> DuckDBPyRelation: + def _build_mentions(self) -> None: + self._build_stopwords() + log.info("Calculating mention counts...") mentions_query = """ - SELECT field, id, token, count(*) as mentions + CREATE TABLE IF NOT EXISTS mentions as + SELECT entries.field, entries.id, entries.token, count(*) as mentions FROM entries - GROUP BY field, id, token + LEFT OUTER JOIN stopwords + ON stopwords.field = entries.field AND stopwords.token = entries.token + WHERE token_freq is NULL + GROUP BY entries.field, entries.id, entries.token """ - return self.con.sql(mentions_query) + self.con.execute(mentions_query) - def token_freq_rel(self) -> DuckDBPyRelation: + def _build_stopwords(self) -> None: token_freq_query = """ SELECT field, token, count(*) as token_freq FROM entries GROUP BY field, token """ - return self.con.sql(token_freq_query) + token_freq = self.con.sql(token_freq_query) # noqa + self.con.execute( + """ + CREATE TABLE IF NOT EXISTS stopwords as + SELECT * FROM token_freq where token_freq > 100 + """ + ) - def frequencies_rel(self) -> DuckDBPyRelation: - field_len = self.field_len_rel() # noqa - mentions = self.mentions_rel() # noqa + def _build_frequencies(self) -> None: + self._build_field_len() + self._build_mentions() + log.info("Calculating term frequencies...") term_frequencies_query = """ + CREATE TABLE IF NOT EXISTS term_frequencies as SELECT mentions.field, mentions.token, mentions.id, mentions/field_len as tf FROM field_len JOIN mentions ON field_len.field = mentions.field AND field_len.id = mentions.id """ - return self.con.sql(term_frequencies_query) + self.con.execute(term_frequencies_query) def pairs( self, max_pairs: int = BaseIndex.MAX_PAIRS @@ -163,9 +178,6 @@ def pairs( ON "left".field = "right".field AND "left".token = "right".token LEFT OUTER JOIN boosts ON "left".field = boosts.field - LEFT OUTER JOIN stopwords - ON stopwords.field = "left".field AND stopwords.token = "left".token - WHERE token_freq is NULL AND "left".id > "right".id GROUP BY "left".id, "right".id ORDER BY score DESC @@ -194,13 +206,10 @@ def matches( match_query = """ SELECT matching.id, matches.id, sum(matches.tf * ifnull(boost, 1)) as score FROM term_frequencies as matches - LEFT OUTER JOIN stopwords - ON stopwords.field = matches.field AND stopwords.token = matches.token JOIN matching ON matches.field = matching.field AND matches.token = matching.token LEFT OUTER JOIN boosts ON matches.field = boosts.field - WHERE token_freq is NULL GROUP BY matches.id, matching.id ORDER BY matching.id, score DESC """ @@ -217,7 +226,7 @@ def matches( matches = [] previous_id = matching_id matches.append((Identifier.get(match_id), score)) - yield Identifier.get(previous_id), matches + yield Identifier.get(previous_id), matches[: self.max_candidates] def __repr__(self) -> str: return "" % ( diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py index 2ef739d0..77e63874 100644 --- a/tests/index/test_duckdb_index.py +++ b/tests/index/test_duckdb_index.py @@ -29,7 +29,9 @@ def test_import(dstore: SimpleMemoryStore, index_path: Path): def test_field_lengths(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): field_names = set() ids = set() - for field_name, id, field_len in duckdb_index.field_len_rel().fetchall(): + + field_len_rel = duckdb_index.con.sql("SELECT * FROM field_len") + for field_name, id, field_len in field_len_rel.fetchall(): field_names.add(field_name) ids.add(id) @@ -56,7 +58,8 @@ def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): ids = set() field_tokens = defaultdict(set) - for field_name, id, token, count in duckdb_index.mentions_rel().fetchall(): + mentions_rel = duckdb_index.con.sql("SELECT * FROM mentions") + for field_name, id, token, count in mentions_rel.fetchall(): ids.add(id) field_tokens[field_name].add(token) @@ -121,11 +124,10 @@ def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): assert 1.1 < false_pos_score < 1.2, false_pos_score assert bmw_score > false_pos_score, (bmw_score, false_pos_score) - assert len(pairs) == 428, len(pairs) + assert len(pairs) >= 428, len(pairs) def test_match_score(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): - print(duckdb_index.data_dir) """Match an entity that isn't itself in the index""" dx = Dataset.make({"name": "test", "title": "Test"}) entity = CompositeEntity.from_data(dx, VERBAND_BADEN_DATA) @@ -147,12 +149,12 @@ def test_match_score(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): assert next_result[0] == Identifier(VERBAND_ID), next_result assert 1.66 < next_result[1] < 1.67, next_result - match_identifiers = set(str(m[0]) for m in matches) + #match_identifiers = set(str(m[0]) for m in matches) -#def test_top_match_matches_strong_pairs( +# def test_top_match_matches_strong_pairs( # dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex -#): +# ): # """Pairs with high scores are each others' top matches""" # # view = dstore.default_view() From 4d5abe2e0dbfa694ddaf9c38bc33acbeb5e2ac06 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Tue, 19 Nov 2024 10:57:07 +0000 Subject: [PATCH 20/23] It's already megabytes --- nomenklatura/index/duckdb_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index d3547ef4..a1b8088c 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -55,7 +55,7 @@ def __init__( self.view = view memory_budget = options.get("memory_budget", None) self.memory_budget: Optional[int] = ( - (int(memory_budget) * 1024) if memory_budget else None + int(memory_budget) if memory_budget else None ) """Memory budget in megabytes""" self.max_candidates = int(options.get("max_candidates", 50)) From 4c682de336891ae3890a2495dafeebb07af40876 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Tue, 19 Nov 2024 12:29:33 +0000 Subject: [PATCH 21/23] Split enricher types for different interfaces --- nomenklatura/enrich/aleph.py | 4 +- nomenklatura/enrich/common.py | 47 +++++++++++++++++------- nomenklatura/enrich/nominatim.py | 4 +- nomenklatura/enrich/opencorporates.py | 4 +- nomenklatura/enrich/openfigi.py | 4 +- nomenklatura/enrich/permid.py | 4 +- nomenklatura/enrich/wikidata/__init__.py | 4 +- nomenklatura/enrich/yente.py | 4 +- nomenklatura/index/duckdb_index.py | 16 +++++--- 9 files changed, 59 insertions(+), 32 deletions(-) diff --git a/nomenklatura/enrich/aleph.py b/nomenklatura/enrich/aleph.py index be8971cb..4f4a1d43 100644 --- a/nomenklatura/enrich/aleph.py +++ b/nomenklatura/enrich/aleph.py @@ -12,12 +12,12 @@ from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig log = logging.getLogger(__name__) -class AlephEnricher(Enricher[DS]): +class AlephEnricher(ItemEnricher[DS]): def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): super().__init__(dataset, cache, config) self._host: str = os.environ.get("ALEPH_HOST", "https://aleph.occrp.org/") diff --git a/nomenklatura/enrich/common.py b/nomenklatura/enrich/common.py index 3b3825c1..5197eaa9 100644 --- a/nomenklatura/enrich/common.py +++ b/nomenklatura/enrich/common.py @@ -21,6 +21,9 @@ from nomenklatura.resolver import Identifier EnricherConfig = Dict[str, Any] +MatchCandidates = List[Tuple[Identifier, float]] +"""A list of candidate matches with their scores from a cheaper blocking comparison.""" + log = logging.getLogger(__name__) @@ -184,20 +187,11 @@ def _filter_entity(self, entity: CompositeEntity) -> bool: return False return True - def match_wrapped(self, entity: CE) -> Generator[CE, None, None]: - if not self._filter_entity(entity): - return - yield from self.match(entity) - def expand_wrapped(self, entity: CE, match: CE) -> Generator[CE, None, None]: if not self._filter_entity(entity): return yield from self.expand(entity, match) - @abstractmethod - def match(self, entity: CE) -> Generator[CE, None, None]: - raise NotImplementedError() - @abstractmethod def expand(self, entity: CE, match: CE) -> Generator[CE, None, None]: raise NotImplementedError() @@ -208,21 +202,48 @@ def close(self) -> None: self._session.close() +class ItemEnricher(Enricher[DS], ABC): + """ + An enricher which performs matching on individual entities, one at a time. + """ + + def match_wrapped(self, entity: CE) -> Generator[CE, None, None]: + if not self._filter_entity(entity): + return + yield from self.match(entity) + + @abstractmethod + def match(self, entity: CE) -> Generator[CE, None, None]: + raise NotImplementedError() + + class BulkEnricher(Enricher[DS], ABC): """ An enricher which performs matching in bulk, requiring all subject entities - to be loaded before matching. + to be loaded before matching. + + Once loaded, matching can be done by iterating over the `candidates` method + which provides the subject entity ID and a list of IDs of candidate matches. + + `match_candidates` is then called for each subject entity and its + `MatchCandidates` yielding matching entities. """ + def load_wrapped(self, entity: CE) -> None: if not self._filter_entity(entity): return self.load(entity) + @abstractmethod def load(self, entity: CE) -> None: raise NotImplementedError() - - def candidates(self) -> Generator[Tuple[Identifier, List[Tuple[Identifier, float]]], None, None]: + + @abstractmethod + def candidates(self) -> Generator[Tuple[Identifier, MatchCandidates], None, None]: raise NotImplementedError() - def match_candidates(self, entity: CE, candidates: List[Tuple[Identifier, float]]) -> Generator[CE, None, None]: + @abstractmethod + def match_candidates( + self, entity: CE, candidates: MatchCandidates + ) -> Generator[CE, None, None]: raise NotImplementedError() diff --git a/nomenklatura/enrich/nominatim.py b/nomenklatura/enrich/nominatim.py index a396fe4f..c33028a4 100644 --- a/nomenklatura/enrich/nominatim.py +++ b/nomenklatura/enrich/nominatim.py @@ -5,14 +5,14 @@ from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig log = logging.getLogger(__name__) NOMINATIM = "https://nominatim.openstreetmap.org/search.php" -class NominatimEnricher(Enricher[DS]): +class NominatimEnricher(ItemEnricher[DS]): def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): super().__init__(dataset, cache, config) self.cache.preload(f"{NOMINATIM}%") diff --git a/nomenklatura/enrich/opencorporates.py b/nomenklatura/enrich/opencorporates.py index 4e65ed7e..db2a2c3c 100644 --- a/nomenklatura/enrich/opencorporates.py +++ b/nomenklatura/enrich/opencorporates.py @@ -11,7 +11,7 @@ from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig from nomenklatura.enrich.common import EnrichmentAbort, EnrichmentException @@ -22,7 +22,7 @@ def parse_date(raw: Any) -> Optional[str]: return registry.date.clean(raw) -class OpenCorporatesEnricher(Enricher[DS]): +class OpenCorporatesEnricher(ItemEnricher[DS]): COMPANY_SEARCH_API = "https://api.opencorporates.com/v0.4/companies/search" OFFICER_SEARCH_API = "https://api.opencorporates.com/v0.4/officers/search" UI_PART = "://opencorporates.com/" diff --git a/nomenklatura/enrich/openfigi.py b/nomenklatura/enrich/openfigi.py index 2109e182..f8f3b300 100644 --- a/nomenklatura/enrich/openfigi.py +++ b/nomenklatura/enrich/openfigi.py @@ -6,12 +6,12 @@ from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig log = logging.getLogger(__name__) -class OpenFIGIEnricher(Enricher[DS]): +class OpenFIGIEnricher(ItemEnricher[DS]): """Uses the `OpenFIGI` search API to look up FIGIs by company name.""" SEARCH_URL = "https://api.openfigi.com/v3/search" diff --git a/nomenklatura/enrich/permid.py b/nomenklatura/enrich/permid.py index 327a22d0..b84aa4c5 100644 --- a/nomenklatura/enrich/permid.py +++ b/nomenklatura/enrich/permid.py @@ -14,7 +14,7 @@ from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig from nomenklatura.enrich.common import EnrichmentAbort from nomenklatura.util import fingerprint_name @@ -28,7 +28,7 @@ } -class PermIDEnricher(Enricher[DS]): +class PermIDEnricher(ItemEnricher[DS]): MATCHING_API = "https://api-eit.refinitiv.com/permid/match" def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): diff --git a/nomenklatura/enrich/wikidata/__init__.py b/nomenklatura/enrich/wikidata/__init__.py index cebc1210..021be194 100644 --- a/nomenklatura/enrich/wikidata/__init__.py +++ b/nomenklatura/enrich/wikidata/__init__.py @@ -18,7 +18,7 @@ PROPS_TOPICS, ) from nomenklatura.enrich.wikidata.model import Claim, Item -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig WD_API = "https://www.wikidata.org/w/api.php" LABEL_PREFIX = "wd:lb:" @@ -29,7 +29,7 @@ def clean_name(name: str) -> str: return clean_brackets(name).strip() -class WikidataEnricher(Enricher[DS]): +class WikidataEnricher(ItemEnricher[DS]): def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): super().__init__(dataset, cache, config) self.depth = self.get_config_int("depth", 1) diff --git a/nomenklatura/enrich/yente.py b/nomenklatura/enrich/yente.py index 7c28aaf4..d42f8289 100644 --- a/nomenklatura/enrich/yente.py +++ b/nomenklatura/enrich/yente.py @@ -11,13 +11,13 @@ from nomenklatura.entity import CE, CompositeEntity from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig from nomenklatura.enrich.common import EnrichmentException log = logging.getLogger(__name__) -class YenteEnricher(Enricher[DS]): +class YenteEnricher(ItemEnricher[DS]): """Uses the `yente` match API to look up entities in a specific dataset.""" def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index a1b8088c..6ade3f33 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -1,9 +1,8 @@ from io import TextIOWrapper -from duckdb import DuckDBPyRelation from followthemoney.types import registry from pathlib import Path from shutil import rmtree -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple +from typing import Any, Dict, Generator, Iterable, Optional, Tuple import csv import duckdb import logging @@ -14,6 +13,7 @@ from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer from nomenklatura.resolver import Pair, Identifier from nomenklatura.store import View +from nomenklatura.enrich.common import MatchCandidates log = logging.getLogger(__name__) @@ -189,13 +189,15 @@ def pairs( yield (Identifier.get(left), Identifier.get(right)), score def add_matching_subject(self, entity: CE) -> None: + if self.matching_dump is None: + raise Exception("Cannot add matching subject after getting candidates.") writer = csv.writer(self.matching_dump) for field, token in self.tokenizer.entity(entity): writer.writerow([entity.id, field, token]) def matches( self, - ) -> Generator[Tuple[Identifier, List[Tuple[Identifier, float]]], None, None]: + ) -> Generator[Tuple[Identifier, MatchCandidates], None, None]: if self.matching_dump is not None: self.matching_dump.close() self.matching_dump = None @@ -215,18 +217,22 @@ def matches( """ results = self.con.execute(match_query) previous_id = None - matches: List[Tuple[Identifier, float]] = [] + matches: MatchCandidates = [] while batch := results.fetchmany(BATCH_SIZE): for matching_id, match_id, score in batch: + # first row if previous_id is None: previous_id = matching_id + # Next pair of subject and candidates if matching_id != previous_id: if matches: yield Identifier.get(previous_id), matches matches = [] previous_id = matching_id matches.append((Identifier.get(match_id), score)) - yield Identifier.get(previous_id), matches[: self.max_candidates] + # Last pair or subject and candidates + if matches and previous_id is not None: + yield Identifier.get(previous_id), matches[: self.max_candidates] def __repr__(self) -> str: return "" % ( From fe1af3a3cae5314cc06be415e4ccb44b2c5ef14b Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Tue, 19 Nov 2024 12:29:54 +0000 Subject: [PATCH 22/23] Handle split enricher interfaces in nomenklatura --- nomenklatura/enrich/__init__.py | 139 ++++++++++++++++++++++++++------ 1 file changed, 114 insertions(+), 25 deletions(-) diff --git a/nomenklatura/enrich/__init__.py b/nomenklatura/enrich/__init__.py index 78175828..b5ffae65 100644 --- a/nomenklatura/enrich/__init__.py +++ b/nomenklatura/enrich/__init__.py @@ -1,12 +1,17 @@ import logging from importlib import import_module -from typing import Iterable, Generator, Optional, Type, cast +from typing import Dict, Iterable, Generator, Optional, Type, cast from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache from nomenklatura.matching import DefaultAlgorithm -from nomenklatura.enrich.common import Enricher, EnricherConfig, BulkEnricher +from nomenklatura.enrich.common import ( + Enricher, + EnricherConfig, + ItemEnricher, + BulkEnricher, +) from nomenklatura.enrich.common import EnrichmentAbort, EnrichmentException from nomenklatura.judgement import Judgement from nomenklatura.resolver import Resolver @@ -43,44 +48,128 @@ def make_enricher( # nk dedupe -i entities-with-matches.json -r resolver.json def match( enricher: Enricher[DS], resolver: Resolver[CE], entities: Iterable[CE] +) -> Generator[CE, None, None]: + if isinstance(enricher, BulkEnricher): + yield from get_bulk_matches(enricher, resolver, entities) + elif isinstance(enricher, ItemEnricher): + yield from get_itemwise_matches(enricher, resolver, entities) + else: + raise EnrichmentException("Invalid enricher type: %r" % enricher) + + +def get_itemwise_matches( + enricher: ItemEnricher[DS], resolver: Resolver[CE], entities: Iterable[CE] ) -> Generator[CE, None, None]: for entity in entities: yield entity try: for match in enricher.match_wrapped(entity): - if entity.id is None or match.id is None: - continue - if not resolver.check_candidate(entity.id, match.id): - continue - if not entity.schema.can_match(match.schema): - continue - result = DefaultAlgorithm.compare(entity, match) - log.info("Match [%s]: %.2f -> %s", entity, result.score, match) - resolver.suggest(entity.id, match.id, result.score) - match.datasets.add(enricher.dataset.name) - match = resolver.apply(match) - yield match + match_result = match_item(entity, match, resolver, enricher.dataset) + if match_result is not None: + yield match_result + except EnrichmentException: + log.exception("Failed to match: %r" % entity) + + +def get_bulk_matches( + enricher: BulkEnricher[DS], resolver: Resolver[CE], entities: Iterable[CE] +) -> Generator[CE, None, None]: + entity_lookup: Dict[str, CE] = {} + for entity in entities: + try: + enricher.load_wrapped(entity) + if entity.id is None: + raise EnrichmentException("Entity has no ID: %r" % entity) + if entity.id in entity_lookup: + raise EnrichmentException("Duplicate entity ID: %r" % entity.id) + entity_lookup[entity.id] = entity + except EnrichmentException: + log.exception("Failed to match: %r" % entity) + for entity_id, candidate_set in enricher.candidates(): + entity = entity_lookup[entity_id.id] + try: + for match in enricher.match_candidates(entity, candidate_set): + match_result = match_item(entity, match, resolver, enricher.dataset) + if match_result is not None: + yield match_result except EnrichmentException: log.exception("Failed to match: %r" % entity) +def match_item( + entity: CE, match: CE, resolver: Resolver[CE], dataset: DS +) -> Optional[CE]: + if entity.id is None or match.id is None: + return None + if not resolver.check_candidate(entity.id, match.id): + return None + if not entity.schema.can_match(match.schema): + return None + result = DefaultAlgorithm.compare(entity, match) + log.info("Match [%s]: %.2f -> %s", entity, result.score, match) + resolver.suggest(entity.id, match.id, result.score) + match.datasets.add(dataset.name) + match = resolver.apply(match) + return match + + # nk enrich -i entities.json -r resolver.json -o combined.json def enrich( enricher: Enricher[DS], resolver: Resolver[CE], entities: Iterable[CE] +) -> Generator[CE, None, None]: + if isinstance(enricher, BulkEnricher): + yield from get_bulk_enrichments(enricher, resolver, entities) + elif isinstance(enricher, ItemEnricher): + yield from get_itemwise_enrichments(enricher, resolver, entities) + else: + raise EnrichmentException("Invalid enricher type: %r" % enricher) + + +def get_itemwise_enrichments( + enricher: ItemEnricher[DS], resolver: Resolver[CE], entities: Iterable[CE] ) -> Generator[CE, None, None]: for entity in entities: try: for match in enricher.match_wrapped(entity): - if entity.id is None or match.id is None: - continue - judgement = resolver.get_judgement(match.id, entity.id) - if judgement != Judgement.POSITIVE: - continue - - log.info("Enrich [%s]: %r", entity, match) - for adjacent in enricher.expand_wrapped(entity, match): - adjacent.datasets.add(enricher.dataset.name) - adjacent = resolver.apply(adjacent) - yield adjacent + yield from enrich_item(enricher, entity, match, resolver) + except EnrichmentException: + log.exception("Failed to enrich: %r" % entity) + + +def get_bulk_enrichments( + enricher: BulkEnricher[DS], resolver: Resolver[CE], entities: Iterable[CE] +) -> Generator[CE, None, None]: + entity_lookup: Dict[str, CE] = {} + for entity in entities: + try: + enricher.load_wrapped(entity) + if entity.id is None: + raise EnrichmentException("Entity has no ID: %r" % entity) + if entity.id in entity_lookup: + raise EnrichmentException("Duplicate entity ID: %r" % entity.id) + entity_lookup[entity.id] = entity + except EnrichmentException: + log.exception("Failed to match: %r" % entity) + for entity_id, candidate_set in enricher.candidates(): + entity = entity_lookup[entity_id.id] + try: + for match in enricher.match_candidates(entity, candidate_set): + yield from enrich_item(enricher, entity, match, resolver) except EnrichmentException: log.exception("Failed to enrich: %r" % entity) + + +def enrich_item( + enricher: Enricher[DS], entity: CE, match: CE, resolver: Resolver[CE] +) -> Generator[CE, None, None]: + if entity.id is None or match.id is None: + return None + judgement = resolver.get_judgement(match.id, entity.id) + if judgement != Judgement.POSITIVE: + return None + + log.info("Enrich [%s]: %r", entity, match) + for adjacent in enricher.expand_wrapped(entity, match): + adjacent.datasets.add(enricher.dataset.name) + adjacent = resolver.apply(adjacent) + yield adjacent From ad6ebdc2694d9b15e48c598dad7aa11ce45f3d35 Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Tue, 19 Nov 2024 16:43:11 +0000 Subject: [PATCH 23/23] Fix import --- nomenklatura/enrich/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nomenklatura/enrich/__init__.py b/nomenklatura/enrich/__init__.py index b5ffae65..bba9b735 100644 --- a/nomenklatura/enrich/__init__.py +++ b/nomenklatura/enrich/__init__.py @@ -21,6 +21,7 @@ "Enricher", "EnrichmentAbort", "EnrichmentException", + "ItemEnricher", "BulkEnricher", "make_enricher", "enrich",