Skip to content

Commit

Permalink
Forest snapshot tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edongashi committed Nov 2, 2023
1 parent ad0703a commit 765c18b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import hashlib
import os

from syndiffix.common import *
from syndiffix.counters import UniqueAidCountersFactory
from syndiffix.forest import *
from syndiffix.microdata import apply_convertors, get_convertor
from syndiffix.tree import Branch, Leaf, Node

SALT = bytes([])
NOISELESS_SUPPRESSION = SuppressionParams(layer_sd=0.0)
Expand Down Expand Up @@ -48,3 +50,33 @@ def load_forest(
current_dir = os.path.dirname(os.path.abspath(__file__))
data_df = _load_csv(os.path.join(current_dir, "data", filename), columns)
return create_forest(data_df, anon_params=anon_params, bucketization_params=bucketization_params)


# Forest hashing


def _fstr(num: float) -> str:
return str(int(num)) if num.is_integer() else str(num)


def _ranges_str(node: Node) -> str:
intervals = [
f"[{_fstr(round(interval.min, 5))},{_fstr(round(interval.max, 5))}]" for interval in node.snapped_intervals
]
return ", ".join(intervals)


def hash_tree(node: Node) -> str:
outstr = ""

outstr += "Leaf;" if isinstance(node, Leaf) else "Branch;"
outstr += f"Count:{node.noisy_count()};"
outstr += f"Ranges:{_ranges_str(node)};"

if isinstance(node, Branch):
outstr += "Children:"
for key, child in sorted(node.children.items()):
child_hash = hash_tree(child)
outstr += f"Id:{key};Hash:{child_hash};"

return hashlib.sha1(outstr.encode("utf-8")).hexdigest()
24 changes: 24 additions & 0 deletions tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,27 @@ def test_depth_limiting() -> None:
forest = create_forest(DataFrame(data, columns=["col"]), bucketization_params=zero_depth_params)
tree = forest.get_tree((ColumnId(0),))
assert isinstance(tree, Leaf)


def test_taxi_snapshot() -> None:
forest = load_forest(
"taxi-1000.csv",
columns=["pickup_longitude", "pickup_latitude", "fare_amount", "rate_code", "passenger_count"],
anon_params=NOISELESS_PARAMS,
)

def get_hash(*cols: list[int]) -> str:
return hash_tree(forest.get_tree(tuple(ColumnId(col) for col in cols)))

# Single dimensions
assert get_hash(0) == "c54e1fcd3aeda4e2efaa8ac178f5dfe7d12d8cc2"
assert get_hash(1) == "8ab7f27328f922965b89fc0f894eab0ca474c986"
assert get_hash(2) == "0460917ff0942a43f9e28cbdc132dd3861995c82"
assert get_hash(3) == "cfa83a32d3eea908b4bb33d858e7a8786952b739"
assert get_hash(4) == "f41bc89f0e683e4ac65421b5681aa48fb1ed4456"

# Higher dimensions
assert get_hash(0, 1) == "daa2f4af52460728af6276fc5664d186f6a4df9b"
assert get_hash(0, 1, 2) == "55b1cf65c256545b3c1da970cc9c2083a87c2e97"
assert get_hash(0, 1, 2, 3) == "12a5ac07bf3d98501378360deb8505d90a791c32"
assert get_hash(0, 1, 2, 3, 4) == "bd75e1cdfd8957dd45a1f61ff6cc5a1e7254ee99"

0 comments on commit 765c18b

Please sign in to comment.