diff --git a/nomenklatura/enrich/__init__.py b/nomenklatura/enrich/__init__.py index e41fc94d..bba9b735 100644 --- a/nomenklatura/enrich/__init__.py +++ b/nomenklatura/enrich/__init__.py @@ -1,12 +1,17 @@ import logging from importlib import import_module -from typing import Iterable, Generator, Optional, Type, cast +from typing import Dict, Iterable, Generator, Optional, Type, cast from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache from nomenklatura.matching import DefaultAlgorithm -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ( + Enricher, + EnricherConfig, + ItemEnricher, + BulkEnricher, +) from nomenklatura.enrich.common import EnrichmentAbort, EnrichmentException from nomenklatura.judgement import Judgement from nomenklatura.resolver import Resolver @@ -16,6 +21,8 @@ "Enricher", "EnrichmentAbort", "EnrichmentException", + "ItemEnricher", + "BulkEnricher", "make_enricher", "enrich", "match", @@ -42,44 +49,128 @@ def make_enricher( # nk dedupe -i entities-with-matches.json -r resolver.json def match( enricher: Enricher[DS], resolver: Resolver[CE], entities: Iterable[CE] +) -> Generator[CE, None, None]: + if isinstance(enricher, BulkEnricher): + yield from get_bulk_matches(enricher, resolver, entities) + elif isinstance(enricher, ItemEnricher): + yield from get_itemwise_matches(enricher, resolver, entities) + else: + raise EnrichmentException("Invalid enricher type: %r" % enricher) + + +def get_itemwise_matches( + enricher: ItemEnricher[DS], resolver: Resolver[CE], entities: Iterable[CE] ) -> Generator[CE, None, None]: for entity in entities: yield entity try: for match in enricher.match_wrapped(entity): - if entity.id is None or match.id is None: - continue - if not resolver.check_candidate(entity.id, match.id): - continue - if not entity.schema.can_match(match.schema): - continue - result = DefaultAlgorithm.compare(entity, match) - log.info("Match [%s]: %.2f -> %s", entity, result.score, match) - resolver.suggest(entity.id, match.id, result.score) - match.datasets.add(enricher.dataset.name) - match = resolver.apply(match) - yield match + match_result = match_item(entity, match, resolver, enricher.dataset) + if match_result is not None: + yield match_result + except EnrichmentException: + log.exception("Failed to match: %r" % entity) + + +def get_bulk_matches( + enricher: BulkEnricher[DS], resolver: Resolver[CE], entities: Iterable[CE] +) -> Generator[CE, None, None]: + entity_lookup: Dict[str, CE] = {} + for entity in entities: + try: + enricher.load_wrapped(entity) + if entity.id is None: + raise EnrichmentException("Entity has no ID: %r" % entity) + if entity.id in entity_lookup: + raise EnrichmentException("Duplicate entity ID: %r" % entity.id) + entity_lookup[entity.id] = entity + except EnrichmentException: + log.exception("Failed to match: %r" % entity) + for entity_id, candidate_set in enricher.candidates(): + entity = entity_lookup[entity_id.id] + try: + for match in enricher.match_candidates(entity, candidate_set): + match_result = match_item(entity, match, resolver, enricher.dataset) + if match_result is not None: + yield match_result except EnrichmentException: log.exception("Failed to match: %r" % entity) +def match_item( + entity: CE, match: CE, resolver: Resolver[CE], dataset: DS +) -> Optional[CE]: + if entity.id is None or match.id is None: + return None + if not resolver.check_candidate(entity.id, match.id): + return None + if not entity.schema.can_match(match.schema): + return None + result = DefaultAlgorithm.compare(entity, match) + log.info("Match [%s]: %.2f -> %s", entity, result.score, match) + resolver.suggest(entity.id, match.id, result.score) + match.datasets.add(dataset.name) + match = resolver.apply(match) + return match + + # nk enrich -i entities.json -r resolver.json -o combined.json def enrich( enricher: Enricher[DS], resolver: Resolver[CE], entities: Iterable[CE] +) -> Generator[CE, None, None]: + if isinstance(enricher, BulkEnricher): + yield from get_bulk_enrichments(enricher, resolver, entities) + elif isinstance(enricher, ItemEnricher): + yield from get_itemwise_enrichments(enricher, resolver, entities) + else: + raise EnrichmentException("Invalid enricher type: %r" % enricher) + + +def get_itemwise_enrichments( + enricher: ItemEnricher[DS], resolver: Resolver[CE], entities: Iterable[CE] ) -> Generator[CE, None, None]: for entity in entities: try: for match in enricher.match_wrapped(entity): - if entity.id is None or match.id is None: - continue - judgement = resolver.get_judgement(match.id, entity.id) - if judgement != Judgement.POSITIVE: - continue - - log.info("Enrich [%s]: %r", entity, match) - for adjacent in enricher.expand_wrapped(entity, match): - adjacent.datasets.add(enricher.dataset.name) - adjacent = resolver.apply(adjacent) - yield adjacent + yield from enrich_item(enricher, entity, match, resolver) + except EnrichmentException: + log.exception("Failed to enrich: %r" % entity) + + +def get_bulk_enrichments( + enricher: BulkEnricher[DS], resolver: Resolver[CE], entities: Iterable[CE] +) -> Generator[CE, None, None]: + entity_lookup: Dict[str, CE] = {} + for entity in entities: + try: + enricher.load_wrapped(entity) + if entity.id is None: + raise EnrichmentException("Entity has no ID: %r" % entity) + if entity.id in entity_lookup: + raise EnrichmentException("Duplicate entity ID: %r" % entity.id) + entity_lookup[entity.id] = entity + except EnrichmentException: + log.exception("Failed to match: %r" % entity) + for entity_id, candidate_set in enricher.candidates(): + entity = entity_lookup[entity_id.id] + try: + for match in enricher.match_candidates(entity, candidate_set): + yield from enrich_item(enricher, entity, match, resolver) except EnrichmentException: log.exception("Failed to enrich: %r" % entity) + + +def enrich_item( + enricher: Enricher[DS], entity: CE, match: CE, resolver: Resolver[CE] +) -> Generator[CE, None, None]: + if entity.id is None or match.id is None: + return None + judgement = resolver.get_judgement(match.id, entity.id) + if judgement != Judgement.POSITIVE: + return None + + log.info("Enrich [%s]: %r", entity, match) + for adjacent in enricher.expand_wrapped(entity, match): + adjacent.datasets.add(enricher.dataset.name) + adjacent = resolver.apply(adjacent) + yield adjacent diff --git a/nomenklatura/enrich/aleph.py b/nomenklatura/enrich/aleph.py index be8971cb..4f4a1d43 100644 --- a/nomenklatura/enrich/aleph.py +++ b/nomenklatura/enrich/aleph.py @@ -12,12 +12,12 @@ from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig log = logging.getLogger(__name__) -class AlephEnricher(Enricher[DS]): +class AlephEnricher(ItemEnricher[DS]): def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): super().__init__(dataset, cache, config) self._host: str = os.environ.get("ALEPH_HOST", "https://aleph.occrp.org/") diff --git a/nomenklatura/enrich/common.py b/nomenklatura/enrich/common.py index c4523227..5197eaa9 100644 --- a/nomenklatura/enrich/common.py +++ b/nomenklatura/enrich/common.py @@ -3,7 +3,7 @@ import logging import time from banal import as_bool -from typing import Union, Any, Dict, Optional, Generator, Generic +from typing import List, Tuple, Union, Any, Dict, Optional, Generator, Generic from abc import ABC, abstractmethod from requests import Session from requests.exceptions import RequestException @@ -18,8 +18,12 @@ from nomenklatura.dataset import DS from nomenklatura.cache import Cache from nomenklatura.util import HeadersType +from nomenklatura.resolver import Identifier EnricherConfig = Dict[str, Any] +MatchCandidates = List[Tuple[Identifier, float]] +"""A list of candidate matches with their scores from a cheaper blocking comparison.""" + log = logging.getLogger(__name__) @@ -183,25 +187,63 @@ def _filter_entity(self, entity: CompositeEntity) -> bool: return False return True + def expand_wrapped(self, entity: CE, match: CE) -> Generator[CE, None, None]: + if not self._filter_entity(entity): + return + yield from self.expand(entity, match) + + @abstractmethod + def expand(self, entity: CE, match: CE) -> Generator[CE, None, None]: + raise NotImplementedError() + + def close(self) -> None: + self.cache.close() + if self._session is not None: + self._session.close() + + +class ItemEnricher(Enricher[DS], ABC): + """ + An enricher which performs matching on individual entities, one at a time. + """ + def match_wrapped(self, entity: CE) -> Generator[CE, None, None]: if not self._filter_entity(entity): return yield from self.match(entity) - def expand_wrapped(self, entity: CE, match: CE) -> Generator[CE, None, None]: + @abstractmethod + def match(self, entity: CE) -> Generator[CE, None, None]: + raise NotImplementedError() + + +class BulkEnricher(Enricher[DS], ABC): + """ + An enricher which performs matching in bulk, requiring all subject entities + to be loaded before matching. + + Once loaded, matching can be done by iterating over the `candidates` method + which provides the subject entity ID and a list of IDs of candidate matches. + + `match_candidates` is then called for each subject entity and its + `MatchCandidates` yielding matching entities. + """ + + def load_wrapped(self, entity: CE) -> None: if not self._filter_entity(entity): return - yield from self.expand(entity, match) + self.load(entity) @abstractmethod - def match(self, entity: CE) -> Generator[CE, None, None]: + def load(self, entity: CE) -> None: raise NotImplementedError() @abstractmethod - def expand(self, entity: CE, match: CE) -> Generator[CE, None, None]: + def candidates(self) -> Generator[Tuple[Identifier, MatchCandidates], None, None]: raise NotImplementedError() - def close(self) -> None: - self.cache.close() - if self._session is not None: - self._session.close() + @abstractmethod + def match_candidates( + self, entity: CE, candidates: MatchCandidates + ) -> Generator[CE, None, None]: + raise NotImplementedError() diff --git a/nomenklatura/enrich/nominatim.py b/nomenklatura/enrich/nominatim.py index a396fe4f..c33028a4 100644 --- a/nomenklatura/enrich/nominatim.py +++ b/nomenklatura/enrich/nominatim.py @@ -5,14 +5,14 @@ from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig log = logging.getLogger(__name__) NOMINATIM = "https://nominatim.openstreetmap.org/search.php" -class NominatimEnricher(Enricher[DS]): +class NominatimEnricher(ItemEnricher[DS]): def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): super().__init__(dataset, cache, config) self.cache.preload(f"{NOMINATIM}%") diff --git a/nomenklatura/enrich/opencorporates.py b/nomenklatura/enrich/opencorporates.py index 4e65ed7e..db2a2c3c 100644 --- a/nomenklatura/enrich/opencorporates.py +++ b/nomenklatura/enrich/opencorporates.py @@ -11,7 +11,7 @@ from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig from nomenklatura.enrich.common import EnrichmentAbort, EnrichmentException @@ -22,7 +22,7 @@ def parse_date(raw: Any) -> Optional[str]: return registry.date.clean(raw) -class OpenCorporatesEnricher(Enricher[DS]): +class OpenCorporatesEnricher(ItemEnricher[DS]): COMPANY_SEARCH_API = "https://api.opencorporates.com/v0.4/companies/search" OFFICER_SEARCH_API = "https://api.opencorporates.com/v0.4/officers/search" UI_PART = "://opencorporates.com/" diff --git a/nomenklatura/enrich/openfigi.py b/nomenklatura/enrich/openfigi.py index 2109e182..f8f3b300 100644 --- a/nomenklatura/enrich/openfigi.py +++ b/nomenklatura/enrich/openfigi.py @@ -6,12 +6,12 @@ from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig log = logging.getLogger(__name__) -class OpenFIGIEnricher(Enricher[DS]): +class OpenFIGIEnricher(ItemEnricher[DS]): """Uses the `OpenFIGI` search API to look up FIGIs by company name.""" SEARCH_URL = "https://api.openfigi.com/v3/search" diff --git a/nomenklatura/enrich/permid.py b/nomenklatura/enrich/permid.py index 327a22d0..b84aa4c5 100644 --- a/nomenklatura/enrich/permid.py +++ b/nomenklatura/enrich/permid.py @@ -14,7 +14,7 @@ from nomenklatura.entity import CE from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig from nomenklatura.enrich.common import EnrichmentAbort from nomenklatura.util import fingerprint_name @@ -28,7 +28,7 @@ } -class PermIDEnricher(Enricher[DS]): +class PermIDEnricher(ItemEnricher[DS]): MATCHING_API = "https://api-eit.refinitiv.com/permid/match" def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): diff --git a/nomenklatura/enrich/wikidata/__init__.py b/nomenklatura/enrich/wikidata/__init__.py index cebc1210..021be194 100644 --- a/nomenklatura/enrich/wikidata/__init__.py +++ b/nomenklatura/enrich/wikidata/__init__.py @@ -18,7 +18,7 @@ PROPS_TOPICS, ) from nomenklatura.enrich.wikidata.model import Claim, Item -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig WD_API = "https://www.wikidata.org/w/api.php" LABEL_PREFIX = "wd:lb:" @@ -29,7 +29,7 @@ def clean_name(name: str) -> str: return clean_brackets(name).strip() -class WikidataEnricher(Enricher[DS]): +class WikidataEnricher(ItemEnricher[DS]): def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): super().__init__(dataset, cache, config) self.depth = self.get_config_int("depth", 1) diff --git a/nomenklatura/enrich/yente.py b/nomenklatura/enrich/yente.py index 7c28aaf4..d42f8289 100644 --- a/nomenklatura/enrich/yente.py +++ b/nomenklatura/enrich/yente.py @@ -11,13 +11,13 @@ from nomenklatura.entity import CE, CompositeEntity from nomenklatura.dataset import DS from nomenklatura.cache import Cache -from nomenklatura.enrich.common import Enricher, EnricherConfig +from nomenklatura.enrich.common import ItemEnricher, EnricherConfig from nomenklatura.enrich.common import EnrichmentException log = logging.getLogger(__name__) -class YenteEnricher(Enricher[DS]): +class YenteEnricher(ItemEnricher[DS]): """Uses the `yente` match API to look up entities in a specific dataset.""" def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): diff --git a/nomenklatura/index/__init__.py b/nomenklatura/index/__init__.py index bf33cfdb..e001835d 100644 --- a/nomenklatura/index/__init__.py +++ b/nomenklatura/index/__init__.py @@ -9,7 +9,7 @@ from nomenklatura.entity import CE log = logging.getLogger(__name__) -INDEX_TYPES = ["tantivy", Index.name] +INDEX_TYPES = ["tantivy", "duckdb", Index.name] def get_index( @@ -24,10 +24,17 @@ def get_index( clazz = TantivyIndex[DS, CE] except ImportError: log.warning("`tantivy` is not available, falling back to in-memory index.") + if type_ == "duckdb": + try: + from nomenklatura.index.duckdb_index import DuckDBIndex + + clazz = DuckDBIndex[DS, CE] + except ImportError: + log.warning("`duckdb` is not available, falling back to in-memory index.") index = clazz(view, path) index.build() return index -__all__ = ["BaseIndex", "Index", "TantivyIndex", "get_index"] +__all__ = ["BaseIndex", "Index", "TantivyIndex", "DuckDBIndex", "get_index"] diff --git a/nomenklatura/index/common.py b/nomenklatura/index/common.py index 8cc75d49..c3f1792e 100644 --- a/nomenklatura/index/common.py +++ b/nomenklatura/index/common.py @@ -1,6 +1,6 @@ from pathlib import Path -from typing import Generic, List, Tuple -from nomenklatura.resolver import Identifier +from typing import Generic, Iterable, List, Tuple +from nomenklatura.resolver import Pair, Identifier from nomenklatura.dataset import DS from nomenklatura.entity import CE from nomenklatura.store import View @@ -16,9 +16,7 @@ def __init__(self, view: View[DS, CE], data_dir: Path) -> None: def build(self) -> None: raise NotImplementedError - def pairs( - self, max_pairs: int = MAX_PAIRS - ) -> List[Tuple[Tuple[Identifier, Identifier], float]]: + def pairs(self, max_pairs: int = MAX_PAIRS) -> Iterable[Tuple[Pair, float]]: raise NotImplementedError def match(self, entity: CE) -> List[Tuple[Identifier, float]]: diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py new file mode 100644 index 00000000..6ade3f33 --- /dev/null +++ b/nomenklatura/index/duckdb_index.py @@ -0,0 +1,241 @@ +from io import TextIOWrapper +from followthemoney.types import registry +from pathlib import Path +from shutil import rmtree +from typing import Any, Dict, Generator, Iterable, Optional, Tuple +import csv +import duckdb +import logging + +from nomenklatura.dataset import DS +from nomenklatura.entity import CE +from nomenklatura.index.common import BaseIndex +from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer +from nomenklatura.resolver import Pair, Identifier +from nomenklatura.store import View +from nomenklatura.enrich.common import MatchCandidates + +log = logging.getLogger(__name__) + +BATCH_SIZE = 1000 + + +class DuckDBIndex(BaseIndex[DS, CE]): + """ + An index using DuckDB for token matching and scoring, keeping data in memory + until it needs to spill to disk as it approaches the configured memory limit. + + Pairs match if they share one or more tokens. A basic similarity score is calculated + cumulatively based on each token's Term Frequency (TF) and the field's boost factor. + """ + + name = "duckdb" + + BOOSTS = { + NAME_PART_FIELD: 2.0, + WORD_FIELD: 0.5, + registry.name.name: 10.0, + # registry.country.name: 1.5, + # registry.date.name: 1.5, + # registry.language: 0.7, + # registry.iban.name: 3.0, + registry.phone.name: 3.0, + registry.email.name: 3.0, + # registry.entity: 0.0, + # registry.topic: 2.1, + registry.address.name: 2.5, + registry.identifier.name: 3.0, + } + + __slots__ = "view", "fields", "tokenizer", "entities" + + def __init__( + self, view: View[DS, CE], data_dir: Path, options: Dict[str, Any] = {} + ): + self.view = view + memory_budget = options.get("memory_budget", None) + self.memory_budget: Optional[int] = ( + int(memory_budget) if memory_budget else None + ) + """Memory budget in megabytes""" + self.max_candidates = int(options.get("max_candidates", 50)) + self.tokenizer = Tokenizer[DS, CE]() + self.data_dir = data_dir + if self.data_dir.exists(): + rmtree(self.data_dir) + self.data_dir.mkdir(parents=True) + self.con = duckdb.connect((self.data_dir / "duckdb_index.db").as_posix()) + self.matching_path = self.data_dir / "matching.csv" + self.matching_path.unlink(missing_ok=True) + self.matching_dump: TextIOWrapper | None = open(self.matching_path, "w") + writer = csv.writer(self.matching_dump) + writer.writerow(["id", "field", "token"]) + + # https://duckdb.org/docs/guides/performance/environment + # > For ideal performance, + # > aggregation-heavy workloads require approx. 5 GB memory per thread and + # > join-heavy workloads require approximately 10 GB memory per thread. + # > Aim for 5-10 GB memory per thread. + if self.memory_budget is not None: + self.con.execute("SET memory_limit = ?;", [f"{self.memory_budget}MB"]) + # > If you have a limited amount of memory, try to limit the number of threads + self.con.execute("SET threads = 1;") + + def build(self) -> None: + """Index all entities in the dataset.""" + log.info("Building index from: %r...", self.view) + self.con.execute("CREATE TABLE boosts (field TEXT, boost FLOAT)") + for field, boost in self.BOOSTS.items(): + self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost]) + + self.con.execute("CREATE TABLE matching (id TEXT, field TEXT, token TEXT)") + self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)") + csv_path = self.data_dir / "mentions.csv" + log.info("Dumping entity tokens to CSV for bulk load into the database...") + with open(csv_path, "w") as fh: + writer = csv.writer(fh) + + # csv.writer type gymnastics + def dump_entity(entity: CE) -> None: + if not entity.schema.matchable or entity.id is None: + return + for field, token in self.tokenizer.entity(entity): + writer.writerow([entity.id, field, token]) + writer.writerow(["id", "field", "token"]) + + for idx, entity in enumerate(self.view.entities()): + dump_entity(entity) + if idx % 50000 == 0: + log.info("Dumped %s entities" % idx) + log.info("Loading data...") + self.con.execute(f"COPY entries from '{csv_path}'") + log.info("Done.") + + self._build_frequencies() + log.info("Index built.") + + def _build_field_len(self) -> None: + self._build_stopwords() + log.info("Calculating field lengths...") + field_len_query = """ + CREATE TABLE IF NOT EXISTS field_len as + SELECT entries.field, entries.id, count(*) as field_len from entries + LEFT OUTER JOIN stopwords + ON stopwords.field = entries.field AND stopwords.token = entries.token + WHERE token_freq is NULL + GROUP BY entries.field, entries.id + """ + self.con.execute(field_len_query) + + def _build_mentions(self) -> None: + self._build_stopwords() + log.info("Calculating mention counts...") + mentions_query = """ + CREATE TABLE IF NOT EXISTS mentions as + SELECT entries.field, entries.id, entries.token, count(*) as mentions + FROM entries + LEFT OUTER JOIN stopwords + ON stopwords.field = entries.field AND stopwords.token = entries.token + WHERE token_freq is NULL + GROUP BY entries.field, entries.id, entries.token + """ + self.con.execute(mentions_query) + + def _build_stopwords(self) -> None: + token_freq_query = """ + SELECT field, token, count(*) as token_freq + FROM entries + GROUP BY field, token + """ + token_freq = self.con.sql(token_freq_query) # noqa + self.con.execute( + """ + CREATE TABLE IF NOT EXISTS stopwords as + SELECT * FROM token_freq where token_freq > 100 + """ + ) + + def _build_frequencies(self) -> None: + self._build_field_len() + self._build_mentions() + log.info("Calculating term frequencies...") + term_frequencies_query = """ + CREATE TABLE IF NOT EXISTS term_frequencies as + SELECT mentions.field, mentions.token, mentions.id, mentions/field_len as tf + FROM field_len + JOIN mentions + ON field_len.field = mentions.field AND field_len.id = mentions.id + """ + self.con.execute(term_frequencies_query) + + def pairs( + self, max_pairs: int = BaseIndex.MAX_PAIRS + ) -> Iterable[Tuple[Pair, float]]: + pairs_query = """ + SELECT "left".id, "right".id, sum(("left".tf + "right".tf) * ifnull(boost, 1)) as score + FROM term_frequencies as "left" + JOIN term_frequencies as "right" + ON "left".field = "right".field AND "left".token = "right".token + LEFT OUTER JOIN boosts + ON "left".field = boosts.field + AND "left".id > "right".id + GROUP BY "left".id, "right".id + ORDER BY score DESC + LIMIT ? + """ + results = self.con.execute(pairs_query, [max_pairs]) + while batch := results.fetchmany(BATCH_SIZE): + for left, right, score in batch: + yield (Identifier.get(left), Identifier.get(right)), score + + def add_matching_subject(self, entity: CE) -> None: + if self.matching_dump is None: + raise Exception("Cannot add matching subject after getting candidates.") + writer = csv.writer(self.matching_dump) + for field, token in self.tokenizer.entity(entity): + writer.writerow([entity.id, field, token]) + + def matches( + self, + ) -> Generator[Tuple[Identifier, MatchCandidates], None, None]: + if self.matching_dump is not None: + self.matching_dump.close() + self.matching_dump = None + log.info("Loading matching subjects...") + self.con.execute(f"COPY matching from '{self.matching_path}'") + log.info("Finished loading matching subjects.") + + match_query = """ + SELECT matching.id, matches.id, sum(matches.tf * ifnull(boost, 1)) as score + FROM term_frequencies as matches + JOIN matching + ON matches.field = matching.field AND matches.token = matching.token + LEFT OUTER JOIN boosts + ON matches.field = boosts.field + GROUP BY matches.id, matching.id + ORDER BY matching.id, score DESC + """ + results = self.con.execute(match_query) + previous_id = None + matches: MatchCandidates = [] + while batch := results.fetchmany(BATCH_SIZE): + for matching_id, match_id, score in batch: + # first row + if previous_id is None: + previous_id = matching_id + # Next pair of subject and candidates + if matching_id != previous_id: + if matches: + yield Identifier.get(previous_id), matches + matches = [] + previous_id = matching_id + matches.append((Identifier.get(match_id), score)) + # Last pair or subject and candidates + if matches and previous_id is not None: + yield Identifier.get(previous_id), matches[: self.max_candidates] + + def __repr__(self) -> str: + return "" % ( + self.view.scope.name, + self.con, + ) diff --git a/nomenklatura/index/index.py b/nomenklatura/index/index.py index fa361475..cddd8e73 100644 --- a/nomenklatura/index/index.py +++ b/nomenklatura/index/index.py @@ -92,7 +92,7 @@ def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float] log.info("Building index blocking pairs...") for field_name, field in self.fields.items(): boost = self.BOOSTS.get(field_name, 1.0) - for idx, entry in enumerate(field.tokens.values()): + for idx, (token, entry) in enumerate(field.tokens.items()): if idx % 10000 == 0: log.info("Pairwise xref [%s]: %d" % (field_name, idx)) diff --git a/setup.py b/setup.py index 861e1ba6..aa787a5a 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ "plyvel < 2.0.0", "redis > 5.0.0, < 6.0.0", "tantivy < 1.0.0", + "duckdb < 2.0.0", ], "leveldb": [ "plyvel < 2.0.0", @@ -69,5 +70,8 @@ "tantivy": [ "tantivy < 1.0.0", ], + "duckdb": [ + "duckdb < 2.0.0", + ], }, ) diff --git a/tests/conftest.py b/tests/conftest.py index ee28af17..c620315c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from tempfile import mkdtemp from nomenklatura import settings +from nomenklatura.index.duckdb_index import DuckDBIndex from nomenklatura.index.tantivy_index import TantivyIndex from nomenklatura.store import load_entity_file_store, SimpleMemoryStore from nomenklatura.kv import get_redis @@ -81,6 +82,13 @@ def tantivy_index(index_path: Path, dstore: SimpleMemoryStore): yield index +@pytest.fixture(scope="function") +def duckdb_index(index_path: Path, dstore: SimpleMemoryStore): + index = DuckDBIndex(dstore.default_view(), index_path) + index.build() + yield index + + @pytest.fixture(scope="function") def index_path(): index_path = Path(mkdtemp()) / "index-dir" diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py new file mode 100644 index 00000000..77e63874 --- /dev/null +++ b/tests/index/test_duckdb_index.py @@ -0,0 +1,169 @@ +from collections import defaultdict +from pathlib import Path + +from nomenklatura.dataset import Dataset +from nomenklatura.entity import CompositeEntity +from nomenklatura.index import get_index +from nomenklatura.index.duckdb_index import DuckDBIndex +from nomenklatura.resolver.identifier import Identifier +from nomenklatura.store import SimpleMemoryStore + +DAIMLER = "66ce9f62af8c7d329506da41cb7c36ba058b3d28" +VERBAND_ID = "62ad0fe6f56dbbf6fee57ce3da76e88c437024d5" +VERBAND_BADEN_ID = "69401823a9f0a97cfdc37afa7c3158374e007669" +VERBAND_BADEN_DATA = { + "id": "bla", + "schema": "Company", + "properties": { + "name": ["VERBAND DER METALL UND ELEKTROINDUSTRIE BADEN WURTTEMBERG"] + }, +} + + +def test_import(dstore: SimpleMemoryStore, index_path: Path): + view = dstore.default_view() + index = get_index(view, index_path, "duckdb") + assert isinstance(index, DuckDBIndex), type(index) + + +def test_field_lengths(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): + field_names = set() + ids = set() + + field_len_rel = duckdb_index.con.sql("SELECT * FROM field_len") + for field_name, id, field_len in field_len_rel.fetchall(): + field_names.add(field_name) + ids.add(id) + + # Expect to see all matchable entities + # jq .schema tests/fixtures/donations.ijson | sort | uniq -c + # Organizations 17 + # Companies 56 + # Persons 22 + # Addresses 89 + assert len(ids) == 184, len(ids) + + # Expect to see all index fields for the matchable prop types and any applicable synthetic fields + # jq '.properties | keys | .[]' tests/fixtures/donations.ijson --raw-output|sort -u + expected_fields = { + "namepart", + "name", + "country", + "word", + } + assert field_names == expected_fields, field_names + + +def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): + ids = set() + field_tokens = defaultdict(set) + + mentions_rel = duckdb_index.con.sql("SELECT * FROM mentions") + for field_name, id, token, count in mentions_rel.fetchall(): + ids.add(id) + field_tokens[field_name].add(token) + + assert len(ids) == 184, len(ids) + assert "verband" in field_tokens["namepart"], field_tokens["namepart"] + assert "gb" in field_tokens["country"], field_tokens["country"] + assert "adolf wurth gmbh" in field_tokens["name"], field_tokens["name"] + assert "dortmund" in field_tokens["word"], field_tokens["word"] + + +def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): + view = dstore.default_view() + pairs = list(duckdb_index.pairs()) + + # At least one pair is found + assert len(pairs) > 0, len(pairs) + + # A pair has tokens which overlap + tokenizer = duckdb_index.tokenizer + pair, score = pairs[0] + entity0 = view.get_entity(str(pair[0])) + tokens0 = set(tokenizer.entity(entity0)) + entity1 = view.get_entity(str(pair[1])) + tokens1 = set(tokenizer.entity(entity1)) + overlap = tokens0.intersection(tokens1) + assert len(overlap) > 0, overlap + + # A pair has non-zero score + assert score > 0 + + # pairs are in descending score order + last_score = pairs[0][1] + for pair in pairs[1:]: + assert pair[1] <= last_score + last_score = pair[1] + + # Johanna Quandt <> Frau Johanna Quandt + jq = ( + Identifier.get("9add84cbb7bb48c7552f8ec7ae54de54eed1e361"), + Identifier.get("2d3e50433e36ebe16f3d906b684c9d5124c46d76"), + ) + jq_score = [score for pair, score in pairs if jq == pair][0] + + # Bayerische Motorenwerke AG <> Bayerische Motorenwerke (BMW) AG + bmw = ( + Identifier.get("21cc81bf3b960d2847b66c6c862e7aa9b5e4f487"), + Identifier.get("12570ee94b8dc23bcc080e887539d3742b2a5237"), + ) + bmw_score = [score for pair, score in pairs if bmw == pair][0] + + # More tokens in BMW means lower TF, reducing the score + assert jq_score > bmw_score, (jq_score, bmw_score) + assert jq_score == 19.0, jq_score + assert 3.3 < bmw_score < 3.4, bmw_score + + # FERRING Arzneimittel GmbH <> Clou Container Leasing GmbH + false_pos = ( + Identifier.get("f8867c433ba247cfab74096c73f6ff5e36db3ffe"), + Identifier.get("a061e760dfcf0d5c774fc37c74937193704807b5"), + ) + false_pos_score = [score for pair, score in pairs if false_pos == pair][0] + assert 1.1 < false_pos_score < 1.2, false_pos_score + assert bmw_score > false_pos_score, (bmw_score, false_pos_score) + + assert len(pairs) >= 428, len(pairs) + + +def test_match_score(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): + """Match an entity that isn't itself in the index""" + dx = Dataset.make({"name": "test", "title": "Test"}) + entity = CompositeEntity.from_data(dx, VERBAND_BADEN_DATA) + duckdb_index.add_matching_subject(entity) + match_sets = list(duckdb_index.matches()) + assert len(match_sets) == 1, match_sets + subject_id, matches = match_sets[0] + assert subject_id == Identifier("bla"), subject_id + + # 9 entities in the index where some token in the query entity matches some + # token in the index. + assert len(matches) == 9, matches + + top_result = matches[0] + assert top_result[0] == Identifier(VERBAND_BADEN_ID), top_result + assert 1.99 < top_result[1] < 2, top_result + + next_result = matches[1] + assert next_result[0] == Identifier(VERBAND_ID), next_result + assert 1.66 < next_result[1] < 1.67, next_result + + #match_identifiers = set(str(m[0]) for m in matches) + + +# def test_top_match_matches_strong_pairs( +# dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex +# ): +# """Pairs with high scores are each others' top matches""" +# +# view = dstore.default_view() +# strong_pairs = [p for p in duckdb_index.pairs() if p[1] > 3.0] +# assert len(strong_pairs) > 4 +# +# for pair, pair_score in strong_pairs: +# entity = view.get_entity(pair[0].id) +# matches = duckdb_index.match(entity) +# # it'll match itself and the other in the pair +# for match, match_score in matches[:2]: +# assert match in pair, (match, pair) diff --git a/tests/index/test_index.py b/tests/index/test_index.py index dc5d6913..ac37d9fb 100644 --- a/tests/index/test_index.py +++ b/tests/index/test_index.py @@ -33,20 +33,24 @@ def test_index_persist(dstore: SimpleMemoryStore, dindex): with NamedTemporaryFile("w") as fh: path = Path(fh.name) dindex.save(path) - loaded = Index.load(dstore.default_view(), path, tmpdir) + loaded = Index.load(dstore.default_view(), path, Path(tmpdir)) assert len(dindex.entities) == len(loaded.entities), (dindex, loaded) assert len(dindex) == len(loaded), (dindex, loaded) path.unlink(missing_ok=True) with TemporaryDirectory() as tmpdir: - empty = Index.load(view, path, tmpdir) + empty = Index.load(view, path, Path(tmpdir)) assert len(empty) == len(loaded), (empty, loaded) def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index): view = dstore.default_view() pairs = dindex.pairs() - assert len(pairs) > 0, pairs + + # At least one pair is found + assert len(pairs) > 0, len(pairs) + + # A pair has tokens which overlap tokenizer = dindex.tokenizer pair, score = pairs[0] entity0 = view.get_entity(str(pair[0])) @@ -55,10 +59,45 @@ def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index): tokens1 = set(tokenizer.entity(entity1)) overlap = tokens0.intersection(tokens1) assert len(overlap) > 0, overlap - # assert "Schnabel" in (overlap, tokens0, tokens1) - # assert "Schnabel" in (entity0.caption, entity1.caption) + + # A pair has non-zero score assert score > 0 - # assert False + + # pairs are in descending score order + last_score = pairs[0][1] + for pair in pairs[1:]: + assert pair[1] <= last_score + last_score = pair[1] + + # Johanna Quandt <> Frau Johanna Quandt + jq = ( + Identifier.get("9add84cbb7bb48c7552f8ec7ae54de54eed1e361"), + Identifier.get("2d3e50433e36ebe16f3d906b684c9d5124c46d76"), + ) + jq_score = [score for pair, score in pairs if jq == pair][0] + + # Bayerische Motorenwerke AG <> Bayerische Motorenwerke (BMW) AG + bmw = ( + Identifier.get("21cc81bf3b960d2847b66c6c862e7aa9b5e4f487"), + Identifier.get("12570ee94b8dc23bcc080e887539d3742b2a5237"), + ) + bmw_score = [score for pair, score in pairs if bmw == pair][0] + + # More tokens in BMW means lower TF, reducing the score + assert jq_score > bmw_score, (jq_score, bmw_score) + assert jq_score == 19.0, jq_score + assert 3.3 < bmw_score < 3.4, bmw_score + + # FERRING Arzneimittel GmbH <> Clou Container Leasing GmbH + false_pos = ( + Identifier.get("f8867c433ba247cfab74096c73f6ff5e36db3ffe"), + Identifier.get("a061e760dfcf0d5c774fc37c74937193704807b5"), + ) + false_pos_score = [score for pair, score in pairs if false_pos == pair][0] + assert 1.1 < false_pos_score < 1.2, false_pos_score + assert bmw_score > false_pos_score, (bmw_score, false_pos_score) + + assert len(pairs) == 428, len(pairs) def test_match_score(dstore: SimpleMemoryStore, dindex: Index): @@ -92,7 +131,7 @@ def test_top_match_matches_strong_pairs(dstore: SimpleMemoryStore, dindex: Index assert len(strong_pairs) > 4 for pair, pair_score in strong_pairs: - entity = view.get_entity(pair[0]) + entity = view.get_entity(pair[0].id) matches = dindex.match(entity) # it'll match itself and the other in the pair for match, match_score in matches[:2]: