Skip to content

Commit

Permalink
Split enricher types for different interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
jbothma committed Nov 19, 2024
1 parent 4d5abe2 commit 4c682de
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 32 deletions.
4 changes: 2 additions & 2 deletions nomenklatura/enrich/aleph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from nomenklatura.entity import CE
from nomenklatura.dataset import DS
from nomenklatura.cache import Cache
from nomenklatura.enrich.common import Enricher, EnricherConfig
from nomenklatura.enrich.common import ItemEnricher, EnricherConfig

log = logging.getLogger(__name__)


class AlephEnricher(Enricher[DS]):
class AlephEnricher(ItemEnricher[DS]):
def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig):
super().__init__(dataset, cache, config)
self._host: str = os.environ.get("ALEPH_HOST", "https://aleph.occrp.org/")
Expand Down
47 changes: 34 additions & 13 deletions nomenklatura/enrich/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from nomenklatura.resolver import Identifier

EnricherConfig = Dict[str, Any]
MatchCandidates = List[Tuple[Identifier, float]]
"""A list of candidate matches with their scores from a cheaper blocking comparison."""

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -184,20 +187,11 @@ def _filter_entity(self, entity: CompositeEntity) -> bool:
return False
return True

def match_wrapped(self, entity: CE) -> Generator[CE, None, None]:
if not self._filter_entity(entity):
return
yield from self.match(entity)

def expand_wrapped(self, entity: CE, match: CE) -> Generator[CE, None, None]:
if not self._filter_entity(entity):
return
yield from self.expand(entity, match)

@abstractmethod
def match(self, entity: CE) -> Generator[CE, None, None]:
raise NotImplementedError()

@abstractmethod
def expand(self, entity: CE, match: CE) -> Generator[CE, None, None]:
raise NotImplementedError()
Expand All @@ -208,21 +202,48 @@ def close(self) -> None:
self._session.close()


class ItemEnricher(Enricher[DS], ABC):
"""
An enricher which performs matching on individual entities, one at a time.
"""

def match_wrapped(self, entity: CE) -> Generator[CE, None, None]:
if not self._filter_entity(entity):
return
yield from self.match(entity)

@abstractmethod
def match(self, entity: CE) -> Generator[CE, None, None]:
raise NotImplementedError()


class BulkEnricher(Enricher[DS], ABC):
"""
An enricher which performs matching in bulk, requiring all subject entities
to be loaded before matching.
to be loaded before matching.
Once loaded, matching can be done by iterating over the `candidates` method
which provides the subject entity ID and a list of IDs of candidate matches.
`match_candidates` is then called for each subject entity and its
`MatchCandidates` yielding matching entities.
"""

def load_wrapped(self, entity: CE) -> None:
if not self._filter_entity(entity):
return
self.load(entity)

@abstractmethod
def load(self, entity: CE) -> None:
raise NotImplementedError()

def candidates(self) -> Generator[Tuple[Identifier, List[Tuple[Identifier, float]]], None, None]:

@abstractmethod
def candidates(self) -> Generator[Tuple[Identifier, MatchCandidates], None, None]:
raise NotImplementedError()

def match_candidates(self, entity: CE, candidates: List[Tuple[Identifier, float]]) -> Generator[CE, None, None]:
@abstractmethod
def match_candidates(
self, entity: CE, candidates: MatchCandidates
) -> Generator[CE, None, None]:
raise NotImplementedError()
4 changes: 2 additions & 2 deletions nomenklatura/enrich/nominatim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from nomenklatura.entity import CE
from nomenklatura.dataset import DS
from nomenklatura.cache import Cache
from nomenklatura.enrich.common import Enricher, EnricherConfig
from nomenklatura.enrich.common import ItemEnricher, EnricherConfig


log = logging.getLogger(__name__)
NOMINATIM = "https://nominatim.openstreetmap.org/search.php"


class NominatimEnricher(Enricher[DS]):
class NominatimEnricher(ItemEnricher[DS]):
def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig):
super().__init__(dataset, cache, config)
self.cache.preload(f"{NOMINATIM}%")
Expand Down
4 changes: 2 additions & 2 deletions nomenklatura/enrich/opencorporates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from nomenklatura.entity import CE
from nomenklatura.dataset import DS
from nomenklatura.cache import Cache
from nomenklatura.enrich.common import Enricher, EnricherConfig
from nomenklatura.enrich.common import ItemEnricher, EnricherConfig
from nomenklatura.enrich.common import EnrichmentAbort, EnrichmentException


Expand All @@ -22,7 +22,7 @@ def parse_date(raw: Any) -> Optional[str]:
return registry.date.clean(raw)


class OpenCorporatesEnricher(Enricher[DS]):
class OpenCorporatesEnricher(ItemEnricher[DS]):
COMPANY_SEARCH_API = "https://api.opencorporates.com/v0.4/companies/search"
OFFICER_SEARCH_API = "https://api.opencorporates.com/v0.4/officers/search"
UI_PART = "://opencorporates.com/"
Expand Down
4 changes: 2 additions & 2 deletions nomenklatura/enrich/openfigi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from nomenklatura.entity import CE
from nomenklatura.dataset import DS
from nomenklatura.cache import Cache
from nomenklatura.enrich.common import Enricher, EnricherConfig
from nomenklatura.enrich.common import ItemEnricher, EnricherConfig

log = logging.getLogger(__name__)


class OpenFIGIEnricher(Enricher[DS]):
class OpenFIGIEnricher(ItemEnricher[DS]):
"""Uses the `OpenFIGI` search API to look up FIGIs by company name."""

SEARCH_URL = "https://api.openfigi.com/v3/search"
Expand Down
4 changes: 2 additions & 2 deletions nomenklatura/enrich/permid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from nomenklatura.entity import CE
from nomenklatura.dataset import DS
from nomenklatura.cache import Cache
from nomenklatura.enrich.common import Enricher, EnricherConfig
from nomenklatura.enrich.common import ItemEnricher, EnricherConfig
from nomenklatura.enrich.common import EnrichmentAbort
from nomenklatura.util import fingerprint_name

Expand All @@ -28,7 +28,7 @@
}


class PermIDEnricher(Enricher[DS]):
class PermIDEnricher(ItemEnricher[DS]):
MATCHING_API = "https://api-eit.refinitiv.com/permid/match"

def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig):
Expand Down
4 changes: 2 additions & 2 deletions nomenklatura/enrich/wikidata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
PROPS_TOPICS,
)
from nomenklatura.enrich.wikidata.model import Claim, Item
from nomenklatura.enrich.common import Enricher, EnricherConfig
from nomenklatura.enrich.common import ItemEnricher, EnricherConfig

WD_API = "https://www.wikidata.org/w/api.php"
LABEL_PREFIX = "wd:lb:"
Expand All @@ -29,7 +29,7 @@ def clean_name(name: str) -> str:
return clean_brackets(name).strip()


class WikidataEnricher(Enricher[DS]):
class WikidataEnricher(ItemEnricher[DS]):
def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig):
super().__init__(dataset, cache, config)
self.depth = self.get_config_int("depth", 1)
Expand Down
4 changes: 2 additions & 2 deletions nomenklatura/enrich/yente.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from nomenklatura.entity import CE, CompositeEntity
from nomenklatura.dataset import DS
from nomenklatura.cache import Cache
from nomenklatura.enrich.common import Enricher, EnricherConfig
from nomenklatura.enrich.common import ItemEnricher, EnricherConfig
from nomenklatura.enrich.common import EnrichmentException

log = logging.getLogger(__name__)


class YenteEnricher(Enricher[DS]):
class YenteEnricher(ItemEnricher[DS]):
"""Uses the `yente` match API to look up entities in a specific dataset."""

def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig):
Expand Down
16 changes: 11 additions & 5 deletions nomenklatura/index/duckdb_index.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from io import TextIOWrapper
from duckdb import DuckDBPyRelation
from followthemoney.types import registry
from pathlib import Path
from shutil import rmtree
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
from typing import Any, Dict, Generator, Iterable, Optional, Tuple
import csv
import duckdb
import logging
Expand All @@ -14,6 +13,7 @@
from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer
from nomenklatura.resolver import Pair, Identifier
from nomenklatura.store import View
from nomenklatura.enrich.common import MatchCandidates

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -189,13 +189,15 @@ def pairs(
yield (Identifier.get(left), Identifier.get(right)), score

def add_matching_subject(self, entity: CE) -> None:
if self.matching_dump is None:
raise Exception("Cannot add matching subject after getting candidates.")
writer = csv.writer(self.matching_dump)
for field, token in self.tokenizer.entity(entity):
writer.writerow([entity.id, field, token])

def matches(
self,
) -> Generator[Tuple[Identifier, List[Tuple[Identifier, float]]], None, None]:
) -> Generator[Tuple[Identifier, MatchCandidates], None, None]:
if self.matching_dump is not None:
self.matching_dump.close()
self.matching_dump = None
Expand All @@ -215,18 +217,22 @@ def matches(
"""
results = self.con.execute(match_query)
previous_id = None
matches: List[Tuple[Identifier, float]] = []
matches: MatchCandidates = []
while batch := results.fetchmany(BATCH_SIZE):
for matching_id, match_id, score in batch:
# first row
if previous_id is None:
previous_id = matching_id
# Next pair of subject and candidates
if matching_id != previous_id:
if matches:
yield Identifier.get(previous_id), matches
matches = []
previous_id = matching_id
matches.append((Identifier.get(match_id), score))
yield Identifier.get(previous_id), matches[: self.max_candidates]
# Last pair or subject and candidates
if matches and previous_id is not None:
yield Identifier.get(previous_id), matches[: self.max_candidates]

def __repr__(self) -> str:
return "<DuckDBIndex(%r, %r)>" % (
Expand Down

0 comments on commit 4c682de

Please sign in to comment.