Skip to content

Commit

Permalink
🐛 Numeric aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
simonwoerpel committed Mar 14, 2024
1 parent 916c855 commit 548d28b
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 53 deletions.
20 changes: 5 additions & 15 deletions ftmq/aggregations.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
import statistics
from collections import defaultdict
from functools import cache
from typing import Any, Generator, Iterable, TypeAlias

from anystore.util import clean_dict
from banal import ensure_list
from followthemoney.schema import Schema
from followthemoney.types import registry
from pydantic import BaseModel

from ftmq.enums import Aggregations, Fields, Properties
from ftmq.types import CE, CEGenerator
from ftmq.util import to_numeric
from ftmq.util import prop_is_numeric, to_numeric

Value: TypeAlias = int | float | str
Values: TypeAlias = list[Value]


@cache
def get_is_numeric(schema: Schema, prop: str) -> bool:
prop = schema.get(prop)
if prop is not None:
return prop.type == registry.number
return False


class Aggregation(BaseModel):
prop: Properties | Fields
func: Aggregations
Expand Down Expand Up @@ -71,7 +61,7 @@ def get_proxy_values(
yield from proxy.get(prop, quiet=True)

def collect(self, proxy: CE) -> CE:
is_numeric = get_is_numeric(proxy.schema, self.prop)
is_numeric = prop_is_numeric(proxy.schema, self.prop)
for value in self.get_proxy_values(proxy):
if is_numeric:
value = to_numeric(value)
Expand Down Expand Up @@ -118,9 +108,9 @@ def __exit__(self, *args, **kwargs) -> None:
for agg in self.aggregations:
self.result[str(agg.func)][str(agg.prop)] = agg.value
for group in agg.group_props:
self.result["groups"][str(group)][str(agg.func)][
str(agg.prop)
] = agg.groups[group]
self.result["groups"][str(group)][str(agg.func)][str(agg.prop)] = (
agg.groups[group]
)
self.result = clean_dict(self.result)

def apply(self, proxies: CEGenerator) -> CEGenerator:
Expand Down
19 changes: 10 additions & 9 deletions ftmq/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
)
from ftmq.sql import Sql
from ftmq.types import CEGenerator
from ftmq.util import parse_comparator, parse_unknown_filters
from ftmq.util import (
parse_comparator,
parse_unknown_filters,
prop_is_numeric,
to_numeric,
)

Q = TypeVar("Q", bound="Query")
Slice = TypeVar("Slice", bound=slice)
Expand All @@ -33,9 +38,10 @@ def __init__(self, values: Iterable[str], ascending: bool | None = True) -> None
def apply(self, proxy: CE) -> tuple[str]:
values = tuple()
for v in self.values:
p_values = proxy.get(v, quiet=True)
if p_values is not None:
values = values + (tuple(p_values))
p_values = proxy.get(v, quiet=True) or []
if prop_is_numeric(proxy.schema, v):
p_values = map(to_numeric, p_values)
values = values + (tuple(p_values))
return values

def apply_iter(self, proxies: CEGenerator) -> CEGenerator:
Expand Down Expand Up @@ -301,8 +307,3 @@ def apply_iter(self, proxies: CEGenerator) -> CEGenerator:
self.aggregator = self.get_aggregator()
proxies = self.aggregator.apply(proxies)
yield from proxies

def apply_aggregations(self, proxies: CEGenerator) -> Aggregator:
aggregator = self.get_aggregator()
[x for x in aggregator.apply(proxies)]
return aggregator
3 changes: 2 additions & 1 deletion ftmq/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,11 @@ def _sorted_statements(self) -> Select:
value = self.table.c.value
if PropertyTypesMap[prop].value == registry.number:
value = func.cast(self.table.c.value, NUMERIC)
group_func = func.min if self.q.sort.ascending else func.max
inner = (
select(
self.table.c.canonical_id,
func.group_concat(value).label("sortable_value"),
group_func(value).label("sortable_value"),
)
.where(
and_(
Expand Down
4 changes: 2 additions & 2 deletions ftmq/store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def aggregations(self, query: Q) -> AggregatorResult | None:
key = f"agg-{hash(query)}"
if key in self._cache:
return self._cache[key]
aggregator = query.apply_aggregations(self.entities(query))
res = dict(aggregator.result)
_ = [x for x in self.entities(query)]
res = dict(query.aggregator.result)
self._cache[key] = res
return res
9 changes: 9 additions & 0 deletions ftmq/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pycountry
from banal import ensure_list
from followthemoney.schema import Schema
from followthemoney.types import registry
from followthemoney.util import make_entity_id, sanitize_text
from nomenklatura.dataset import Dataset
Expand Down Expand Up @@ -196,3 +197,11 @@ def make_string_id(value: Any) -> str | None:
@lru_cache(1024)
def make_fingerprint_id(value: Any) -> str | None:
return make_entity_id(make_fingerprint(value))


@cache
def prop_is_numeric(schema: Schema, prop: str) -> bool:
prop = schema.get(prop)
if prop is not None:
return prop.type == registry.number
return False
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,6 @@ build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
env = [
"DEBUG=1",
"NOMENKLATURA_STATEMENT_TABLE=test_table"
"NOMENKLATURA_STATEMENT_TABLE=test_table",
"MAX_SQL_AGG_GROUPS=11",
]
14 changes: 14 additions & 0 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@ def test_proxy_sort(proxies):
break
assert tested

# numeric sort
tested = False
q = Query().where(schema="Payment").order_by("amountEur")
for proxy in q.apply_iter(proxies):
assert proxy.get("amountEur") == ["50000"]
tested = True
break
tested = False
q = Query().where(schema="Payment").order_by("amountEur", ascending=False)
for proxy in q.apply_iter(proxies):
assert proxy.get("amountEur") == ["2334526"]
tested = True
break


def test_proxy_slice(proxies):
q = Query()[:10]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_sql():
q.sql.statements,
f"""
SELECT {fields}, anon_1.canonical_id AS canonical_id_1, anon_1.sortable_value
FROM test_table JOIN (SELECT test_table.canonical_id AS canonical_id, group_concat(test_table.value) AS sortable_value
FROM test_table JOIN (SELECT test_table.canonical_id AS canonical_id, max(test_table.value) AS sortable_value
FROM test_table
WHERE test_table.prop = :prop_1 AND test_table.canonical_id IN (SELECT DISTINCT test_table.canonical_id
FROM test_table WHERE (test_table.dataset = :dataset_1 OR test_table.dataset = :dataset_2)
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_sql():
q.sql.statements,
f"""
SELECT {fields}, anon_1.canonical_id AS canonical_id_1, anon_1.sortable_value
FROM test_table JOIN (SELECT test_table.canonical_id AS canonical_id, group_concat(test_table.value) AS sortable_value
FROM test_table JOIN (SELECT test_table.canonical_id AS canonical_id, min(test_table.value) AS sortable_value
FROM test_table
WHERE test_table.prop = :prop_1 AND test_table.canonical_id IN (SELECT DISTINCT test_table.canonical_id
FROM test_table WHERE (test_table.dataset = :dataset_1 OR test_table.dataset = :dataset_2)
Expand Down
41 changes: 22 additions & 19 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,28 @@ def _run_store_test(cls: Store, proxies, **kwargs):
]

# ordering
# q = Query().where(schema="Payment", prop="date", value=2011, comparator="gte")
# q = q.order_by("amountEur")
# res = [e for e in view.entities(q)]
# assert len(res) == 21
# assert res[0].get("amountEur") == ["50001"]
# q = q.order_by("amountEur", ascending=False)
# res = [e for e in view.entities(q)]
# assert len(res) == 21
# assert res[0].get("amountEur") == ["320000"]
q = Query().where(schema="Payment", prop="date", value=2011, comparator="gte")
q = q.order_by("amountEur")
res = [e for e in view.entities(q)]
assert len(res) == 21
assert res[0].get("amountEur") == ["50001"]
q = q.order_by("amountEur", ascending=False)
res = [e for e in view.entities(q)]
assert len(res) == 21
assert res[0].get("amountEur") == ["320000"]

# slice
# q = Query().where(schema="Payment", prop="date", value=2011, comparator="gte")
# q = q.order_by("amountEur")
# q = q[:10]
# res = [e for e in view.entities(q)]
# assert len(res) == 10
# assert res[0].get("payer") == ["62ad0fe6f56dbbf6fee57ce3da76e88c437024d5"]
q = Query().where(schema="Payment", prop="date", value=2011, comparator="gte")
q = q.order_by("amountEur")
q = q[:10]
res = [e for e in view.entities(q)]
assert len(res) == 10
assert res[0].get("payer") == ["efccc434cdf141c7ba6f6e539bb6b42ecd97c368"]

q = Query().where(schema="Person").order_by("name")[0]
res = [e for e in view.entities(q)]
assert len(res) == 1
assert res[0].caption == "Dr.-Ing. E. h. Martin Herrenknecht"

# aggregation
q = Query().aggregate("max", "date").aggregate("min", "date")
Expand All @@ -119,9 +124,7 @@ def _run_store_test(cls: Store, proxies, **kwargs):
]
== 10
)
# assert (
# sum(res["groups"]["beneficiary"]["count"]["id"].values()) == res["count"]["id"]
# )
assert len(proxies) == res["count"]["id"]

q = (
Query()
Expand All @@ -144,7 +147,7 @@ def _run_store_test(cls: Store, proxies, **kwargs):
"9fbaa5733790781e56eec4998aeacf5093dccbf5": 290725,
"9e292c150c617eec85e5479c5f039f8441569441": 175000,
"49d46f7e70e19bc497a17734af53ea1a00c831d6": 1221256,
"4b308dc2b128377e63a4bf2e4c1b9fcd59614eee": 52000,
"4b308dc2b128377e63a4bf2e4c1b9fcd59614eee": 52000, # pytest: MAX_SQL_AGG_GROUPS=11
}
}
}
Expand Down
13 changes: 9 additions & 4 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

import cloudpickle
import pytest

if sys.version_info >= (3, 11):
from enum import EnumType

from followthemoney import model
from nomenklatura.dataset import Dataset

from ftmq import util
from ftmq.enums import Comparators, StrEnum

if sys.version_info >= (3, 11):
from enum import EnumType


def test_util_make_dataset():
ds = util.make_dataset("Test")
Expand Down Expand Up @@ -101,3 +101,8 @@ def test_util_get_year():
assert util.make_fingerprint(" ") is None
assert util.make_fingerprint("") is None
assert util.make_fingerprint(None) is None


def test_util_prop_is_numeric():
assert not util.prop_is_numeric(model.get("Person"), "name")
assert util.prop_is_numeric(model.get("Payment"), "amountEur")

0 comments on commit 548d28b

Please sign in to comment.