Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

duckdb blocking #179

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions nomenklatura/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from nomenklatura.entity import CE

log = logging.getLogger(__name__)
INDEX_TYPES = ["tantivy", Index.name]
INDEX_TYPES = ["tantivy", "duckdb", Index.name]


def get_index(
Expand All @@ -24,10 +24,17 @@ def get_index(
clazz = TantivyIndex[DS, CE]
except ImportError:
log.warning("`tantivy` is not available, falling back to in-memory index.")
if type_ == "duckdb":
try:
from nomenklatura.index.duckdb_index import DuckDBIndex

clazz = DuckDBIndex[DS, CE]
except ImportError:
log.warning("`duckdb` is not available, falling back to in-memory index.")

index = clazz(view, path)
index.build()
return index


__all__ = ["BaseIndex", "Index", "TantivyIndex", "get_index"]
__all__ = ["BaseIndex", "Index", "TantivyIndex", "DuckDBIndex", "get_index"]
8 changes: 3 additions & 5 deletions nomenklatura/index/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path
from typing import Generic, List, Tuple
from nomenklatura.resolver import Identifier
from typing import Generic, Iterable, List, Tuple
from nomenklatura.resolver import Pair, Identifier
from nomenklatura.dataset import DS
from nomenklatura.entity import CE
from nomenklatura.store import View
Expand All @@ -16,9 +16,7 @@ def __init__(self, view: View[DS, CE], data_dir: Path) -> None:
def build(self) -> None:
raise NotImplementedError

def pairs(
self, max_pairs: int = MAX_PAIRS
) -> List[Tuple[Tuple[Identifier, Identifier], float]]:
def pairs(self, max_pairs: int = MAX_PAIRS) -> Iterable[Tuple[Pair, float]]:
raise NotImplementedError

def match(self, entity: CE) -> List[Tuple[Identifier, float]]:
Expand Down
162 changes: 162 additions & 0 deletions nomenklatura/index/duckdb_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from duckdb import DuckDBPyRelation
from followthemoney.types import registry
from pathlib import Path
from shutil import rmtree
from typing import Iterable, Tuple
import csv
import duckdb
import logging

from nomenklatura.dataset import DS
from nomenklatura.entity import CE
from nomenklatura.index.common import BaseIndex
from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer
from nomenklatura.resolver import Pair, Identifier
from nomenklatura.store import View

log = logging.getLogger(__name__)

BATCH_SIZE = 1000


class DuckDBIndex(BaseIndex[DS, CE]):
jbothma marked this conversation as resolved.
Show resolved Hide resolved
"""
An in-memory search index to match entities against a given dataset.

For each field in the dataset, the index stores the IDs which contains each
token, along with the absolute frequency of each token in the document.
"""
jbothma marked this conversation as resolved.
Show resolved Hide resolved

name = "duckdb"

BOOSTS = {
NAME_PART_FIELD: 2.0,
WORD_FIELD: 0.5,
registry.name.name: 10.0,
# registry.country.name: 1.5,
# registry.date.name: 1.5,
# registry.language: 0.7,
# registry.iban.name: 3.0,
registry.phone.name: 3.0,
registry.email.name: 3.0,
# registry.entity: 0.0,
# registry.topic: 2.1,
registry.address.name: 2.5,
registry.identifier.name: 3.0,
}

__slots__ = "view", "fields", "tokenizer", "entities"

def __init__(self, view: View[DS, CE], data_dir: Path):
self.view = view
self.tokenizer = Tokenizer[DS, CE]()
self.data_dir = data_dir
if self.data_dir.exists():
rmtree(self.data_dir.as_posix())
self.data_dir.mkdir(parents=True)
self.con = duckdb.connect((self.data_dir / "duckdb_index.db").as_posix())

# https://duckdb.org/docs/guides/performance/environment
# > For ideal performance, aggregation-heavy workloads require approx.
# > 5 GB memory per thread and join-heavy workloads require approximately
# > 10 GB memory per thread.
# > Aim for 5-10 GB memory per thread.
self.con.execute("SET memory_limit = '2GB';")
self.con.execute("SET max_memory = '2GB';")
# > If you have a limited amount of memory, try to limit the number of threads
self.con.execute("SET threads = 1;")
jbothma marked this conversation as resolved.
Show resolved Hide resolved

def build(self) -> None:
"""Index all entities in the dataset."""
log.info("Building index from: %r...", self.view)
self.con.execute("CREATE TABLE boosts (field TEXT, boost FLOAT)")
for field, boost in self.BOOSTS.items():
self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost])

self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)")
csv_path = self.data_dir / "mentions.csv"
log.info("Dumping entity tokens to CSV for bulk load into the database...")
with open(csv_path, "w") as fh:
writer = csv.writer(fh)

# csv.writer type gymnastics
def dump_entity(entity: CE) -> None:
if not entity.schema.matchable or entity.id is None:
return
for field, token in self.tokenizer.entity(entity):
writer.writerow([entity.id, field, token])

writer.writerow(["id", "field", "token"])
for idx, entity in enumerate(self.view.entities()):
dump_entity(entity)
if idx % 10000 == 0:
log.info("Dumped %s entities" % idx)

log.info("Loading data...")
self.con.execute(f"COPY entries from '{csv_path}'")
log.info("Index built.")

def field_len_rel(self) -> DuckDBPyRelation:
field_len_query = """
SELECT field, id, count(*) as field_len from entries
GROUP BY field, id
"""
return self.con.sql(field_len_query)

def mentions_rel(self) -> DuckDBPyRelation:
mentions_query = """
SELECT field, id, token, count(*) as mentions
FROM entries
GROUP BY field, id, token
"""
return self.con.sql(mentions_query)

def token_freq_rel(self) -> DuckDBPyRelation:
token_freq_query = """
SELECT field, token, count(*) as token_freq
FROM entries
GROUP BY field, token
"""
return self.con.sql(token_freq_query)

def frequencies_rel(self) -> DuckDBPyRelation:
field_len = self.field_len_rel() # noqa
mentions = self.mentions_rel() # noqa
token_freq = self.token_freq_rel() # noqa
term_frequencies_query = """
SELECT mentions.field, mentions.token, mentions.id, mentions/field_len as tf
FROM field_len
JOIN mentions
ON field_len.field = mentions.field AND field_len.id = mentions.id
JOIN token_freq
ON token_freq.field = mentions.field AND token_freq.token = mentions.token
where token_freq < 100
"""
return self.con.sql(term_frequencies_query)

def pairs(
self, max_pairs: int = BaseIndex.MAX_PAIRS
) -> Iterable[Tuple[Pair, float]]:
term_frequencies = self.frequencies_rel() # noqa
pairs_query = """
SELECT "left".id, "right".id, sum(("left".tf + "right".tf) * boost) as score
FROM term_frequencies as "left"
JOIN term_frequencies as "right"
ON "left".field = "right".field AND "left".token = "right".token
JOIN boosts
jbothma marked this conversation as resolved.
Show resolved Hide resolved
ON "left".field = boosts.field
WHERE "left".id > "right".id
GROUP BY "left".id, "right".id
ORDER BY score DESC
LIMIT ?
"""
results = self.con.execute(pairs_query, [max_pairs])
while batch := results.fetchmany(BATCH_SIZE):
for left, right, score in batch:
yield (Identifier.get(left), Identifier.get(right)), score

def __repr__(self) -> str:
return "<DuckDBIndex(%r, %r)>" % (
self.view.scope.name,
self.con,
)
14 changes: 6 additions & 8 deletions nomenklatura/index/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class Entry(object):
__slots__ = "idf", "entities"

def __init__(self) -> None:
self.idf: float = 0.0
self.entities: Dict[Identifier, int] = dict()

def add(self, entity_id: Identifier) -> None:
Expand All @@ -21,13 +20,15 @@ def add(self, entity_id: Identifier) -> None:
except KeyError:
self.entities[entity_id] = 1

def compute(self, field: "Field") -> None:
"""Compute weighted term frequency for scoring."""
self.idf = math.log(field.len / len(self.entities))

def frequencies(
self, field: "Field"
) -> Generator[Tuple[Identifier, float], None, None]:
"""
Term Frequency (TF) for each entity in this entry.

TF being the number of occurrences of this token in the entity divided
by the total number of tokens in the entity (scoped to this field).
"""
for entity_id, mentions in self.entities.items():
field_len = max(1, field.entities[entity_id])
yield entity_id, (mentions / field_len)
Expand Down Expand Up @@ -69,9 +70,6 @@ def compute(self) -> None:
self.len = max(1, len(self.entities))
self.avg_len = sum(self.entities.values()) / self.len

for entry in self.tokens.values():
entry.compute(self)

def to_dict(self) -> Dict[str, Any]:
return {
"tokens": {t: e.to_dict() for t, e in self.tokens.items()},
Expand Down
27 changes: 18 additions & 9 deletions nomenklatura/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@


class Index(BaseIndex[DS, CE]):
"""An in-memory search index to match entities against a given dataset."""
"""
An in-memory search index to match entities against a given dataset.

For each field in the dataset, the index stores the IDs which contains each
token, along with the absolute frequency of each token in the document.
"""

name = "memory"

Expand Down Expand Up @@ -73,23 +78,27 @@ def commit(self) -> None:
field.compute()

def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float]]:
"""A second method of doing xref: summing up the pairwise match value
for all entities lineraly. This uses a lot of memory but is really
fast."""
"""
A second method of doing xref: summing up the pairwise match value
for all entities linearly. This uses a lot of memory but is really
fast.

The score of each pair is the the sum of the product of term frequencies for
each co-occurring token in each field of the pair.

We skip any tokens with more than 100 entities.
"""
pairs: Dict[Pair, float] = {}
log.info("Building index blocking pairs...")
for field_name, field in self.fields.items():
boost = self.BOOSTS.get(field_name, 1.0)
for idx, entry in enumerate(field.tokens.values()):
for idx, (token, entry) in enumerate(field.tokens.items()):
if idx % 10000 == 0:
log.info("Pairwise xref [%s]: %d" % (field_name, idx))

if len(entry.entities) == 1 or len(entry.entities) > 100:
continue
entities = sorted(
entry.frequencies(field), key=lambda f: f[1], reverse=True
)
for (left, lw), (right, rw) in combinations(entities, 2):
for (left, lw), (right, rw) in combinations(entry.frequencies(field), 2):
if lw == 0.0 or rw == 0.0:
continue
pair = (max(left, right), min(left, right))
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"plyvel < 2.0.0",
"redis > 5.0.0, < 6.0.0",
"tantivy < 1.0.0",
"duckdb < 2.0.0",
],
"leveldb": [
"plyvel < 2.0.0",
Expand All @@ -69,5 +70,8 @@
"tantivy": [
"tantivy < 1.0.0",
],
"duckdb": [
"duckdb < 2.0.0",
],
},
)
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tempfile import mkdtemp

from nomenklatura import settings
from nomenklatura.index.duckdb_index import DuckDBIndex
from nomenklatura.index.tantivy_index import TantivyIndex
from nomenklatura.store import load_entity_file_store, SimpleMemoryStore
from nomenklatura.kv import get_redis
Expand Down Expand Up @@ -81,6 +82,13 @@ def tantivy_index(index_path: Path, dstore: SimpleMemoryStore):
yield index


@pytest.fixture(scope="function")
def duckdb_index(index_path: Path, dstore: SimpleMemoryStore):
index = DuckDBIndex(dstore.default_view(), index_path)
index.build()
yield index


@pytest.fixture(scope="function")
def index_path():
index_path = Path(mkdtemp()) / "index-dir"
Expand Down
Loading
Loading