From 3c438181fae9af58977a31b469742ff37a1615a7 Mon Sep 17 00:00:00 2001 From: Edon Gashi Date: Thu, 2 Nov 2023 19:37:48 +0100 Subject: [PATCH] Forest snapshot tests --- tests/clustering/test_stitching.py | 5 +- tests/conftest.py | 16 ++++- tests/data/tree.0.json | 82 ++++++++++++++++++++++ tests/data/tree.0_1.json | 109 +++++++++++++++++++++++++++++ tests/data/tree.0_1_2.json | 82 ++++++++++++++++++++++ tests/data/tree.1.json | 38 ++++++++++ tests/data/tree.2.json | 16 +++++ tests/data/tree.csv | 33 +++++++++ tests/test_forest.py | 32 +++++++++ 9 files changed, 410 insertions(+), 3 deletions(-) create mode 100644 tests/data/tree.0.json create mode 100644 tests/data/tree.0_1.json create mode 100644 tests/data/tree.0_1_2.json create mode 100644 tests/data/tree.1.json create mode 100644 tests/data/tree.2.json create mode 100644 tests/data/tree.csv diff --git a/tests/clustering/test_stitching.py b/tests/clustering/test_stitching.py index a9a2223..94a1ab9 100644 --- a/tests/clustering/test_stitching.py +++ b/tests/clustering/test_stitching.py @@ -27,7 +27,10 @@ def build_rows(*cols: list[int]) -> list[MicrodataRow]: (ColumnId(2), ColumnId(3)): build_rows(col_c_right, col_d), } - metadata = StitchingMetadata(dimension_is_integral=[True, True, True], entropy_1dim=np.array([1.0, 2.0, 3.0])) + metadata = StitchingMetadata( + dimension_is_integral=[True, True, True, True], + entropy_1dim=np.array([1.0, 1.0, 1.0, 1.0]), + ) def materialize_tree(_forest: Forest, columns: list[ColumnId]) -> tuple[list[MicrodataRow], Combination]: combination = tuple(sorted(columns)) diff --git a/tests/conftest.py b/tests/conftest.py index c923d0c..0f93438 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,12 @@ +import json import os +from typing import Any from syndiffix.common import * from syndiffix.counters import UniqueAidCountersFactory from syndiffix.forest import * from syndiffix.microdata import apply_convertors, get_convertor +from syndiffix.tree import Branch, Leaf, Node SALT = bytes([]) NOISELESS_SUPPRESSION = SuppressionParams(layer_sd=0.0) @@ -32,6 +35,11 @@ def create_forest( ) +def _test_file_dir(filename: str) -> str: + current_dir = os.path.dirname(os.path.abspath(__file__)) + return os.path.join(current_dir, "data", filename) + + def _load_csv(path: str, columns: list[str] | None) -> pd.DataFrame: df = pd.read_csv(path, keep_default_na=False, na_values=[""], low_memory=False) if columns is not None: @@ -45,6 +53,10 @@ def load_forest( anon_params: AnonymizationParams | None = None, bucketization_params: BucketizationParams | None = None, ) -> Forest: - current_dir = os.path.dirname(os.path.abspath(__file__)) - data_df = _load_csv(os.path.join(current_dir, "data", filename), columns) + data_df = _load_csv(_test_file_dir(filename), columns) return create_forest(data_df, anon_params=anon_params, bucketization_params=bucketization_params) + + +def load_json(filename: str) -> Any: + with open(_test_file_dir(filename), "r", encoding="utf-8") as file: + return json.load(file) diff --git a/tests/data/tree.0.json b/tests/data/tree.0.json new file mode 100644 index 0000000..bf96625 --- /dev/null +++ b/tests/data/tree.0.json @@ -0,0 +1,82 @@ +{ + "ranges": [[0.0, 8.0]], + "count": 32, + "children": { + "0": { + "ranges": [[0.0, 4.0]], + "count": 16, + "children": { + "0": { + "ranges": [[0.0, 2.0]], + "count": 8, + "children": { + "0": { + "ranges": [[0.0, 1.0]], + "count": 4, + "children": null + }, + "1": { + "ranges": [[1.0, 2.0]], + "count": 4, + "children": null + } + } + }, + "1": { + "ranges": [[2.0, 4.0]], + "count": 8, + "children": { + "0": { + "ranges": [[2.0, 3.0]], + "count": 4, + "children": null + }, + "1": { + "ranges": [[3.0, 4.0]], + "count": 4, + "children": null + } + } + } + } + }, + "1": { + "ranges": [[4.0, 8.0]], + "count": 16, + "children": { + "0": { + "ranges": [[4.0, 6.0]], + "count": 8, + "children": { + "0": { + "ranges": [[4.0, 5.0]], + "count": 4, + "children": null + }, + "1": { + "ranges": [[5.0, 6.0]], + "count": 4, + "children": null + } + } + }, + "1": { + "ranges": [[6.0, 8.0]], + "count": 8, + "children": { + "0": { + "ranges": [[6.0, 7.0]], + "count": 4, + "children": null + }, + "1": { + "ranges": [[7.0, 8.0]], + "count": 4, + "children": null + } + } + } + } + } + } +} diff --git a/tests/data/tree.0_1.json b/tests/data/tree.0_1.json new file mode 100644 index 0000000..11864a8 --- /dev/null +++ b/tests/data/tree.0_1.json @@ -0,0 +1,109 @@ +{ + "ranges": [ + [0.0, 8.0], + [0.0, 4.0] + ], + "count": 32, + "children": { + "0": { + "ranges": [ + [0.0, 4.0], + [0.0, 2.0] + ], + "count": 8, + "children": { + "0": { + "ranges": [ + [0.0, 2.0], + [0.0, 1.0] + ], + "count": 4, + "children": null + }, + "1": { + "ranges": [ + [0.0, 2.0], + [1.0, 2.0] + ], + "count": 4, + "children": null + } + } + }, + "1": { + "ranges": [ + [0.0, 4.0], + [2.0, 4.0] + ], + "count": 8, + "children": { + "2": { + "ranges": [ + [2.0, 4.0], + [2.0, 3.0] + ], + "count": 4, + "children": null + }, + "3": { + "ranges": [ + [2.0, 4.0], + [3.0, 4.0] + ], + "count": 4, + "children": null + } + } + }, + "2": { + "ranges": [ + [4.0, 8.0], + [0.0, 2.0] + ], + "count": 8, + "children": { + "0": { + "ranges": [ + [4.0, 6.0], + [0.0, 1.0] + ], + "count": 4, + "children": null + }, + "1": { + "ranges": [ + [4.0, 6.0], + [1.0, 2.0] + ], + "count": 4, + "children": null + } + } + }, + "3": { + "ranges": [ + [4.0, 8.0], + [2.0, 4.0] + ], + "count": 8, + "children": { + "2": { + "ranges": [ + [6.0, 8.0], + [2.0, 3.0] + ], + "count": 4, + "children": null + }, + "3": { + "ranges": [ + [6.0, 8.0], + [3.0, 4.0] + ], + "count": 4, + "children": null + } + } + } + } +} diff --git a/tests/data/tree.0_1_2.json b/tests/data/tree.0_1_2.json new file mode 100644 index 0000000..f4af163 --- /dev/null +++ b/tests/data/tree.0_1_2.json @@ -0,0 +1,82 @@ +{ + "ranges": [ + [0.0, 8.0], + [0.0, 4.0], + [0.0, 2.0] + ], + "count": 32, + "children": { + "0": { + "ranges": [ + [0.0, 4.0], + [0.0, 2.0], + [0.0, 1.0] + ], + "count": 4, + "children": null + }, + "1": { + "ranges": [ + [0.0, 4.0], + [0.0, 2.0], + [1.0, 2.0] + ], + "count": 4, + "children": null + }, + "2": { + "ranges": [ + [0.0, 4.0], + [2.0, 4.0], + [0.0, 1.0] + ], + "count": 4, + "children": null + }, + "3": { + "ranges": [ + [0.0, 4.0], + [2.0, 4.0], + [1.0, 2.0] + ], + "count": 4, + "children": null + }, + "4": { + "ranges": [ + [4.0, 8.0], + [0.0, 2.0], + [0.0, 1.0] + ], + "count": 4, + "children": null + }, + "5": { + "ranges": [ + [4.0, 8.0], + [0.0, 2.0], + [1.0, 2.0] + ], + "count": 4, + "children": null + }, + "6": { + "ranges": [ + [4.0, 8.0], + [2.0, 4.0], + [0.0, 1.0] + ], + "count": 4, + "children": null + }, + "7": { + "ranges": [ + [4.0, 8.0], + [2.0, 4.0], + [1.0, 2.0] + ], + "count": 4, + "children": null + } + } +} diff --git a/tests/data/tree.1.json b/tests/data/tree.1.json new file mode 100644 index 0000000..483c153 --- /dev/null +++ b/tests/data/tree.1.json @@ -0,0 +1,38 @@ +{ + "ranges": [[0.0, 4.0]], + "count": 32, + "children": { + "0": { + "ranges": [[0.0, 2.0]], + "count": 16, + "children": { + "0": { + "ranges": [[0.0, 1.0]], + "count": 8, + "children": null + }, + "1": { + "ranges": [[1.0, 2.0]], + "count": 8, + "children": null + } + } + }, + "1": { + "ranges": [[2.0, 4.0]], + "count": 16, + "children": { + "0": { + "ranges": [[2.0, 3.0]], + "count": 8, + "children": null + }, + "1": { + "ranges": [[3.0, 4.0]], + "count": 8, + "children": null + } + } + } + } +} diff --git a/tests/data/tree.2.json b/tests/data/tree.2.json new file mode 100644 index 0000000..fbb48b2 --- /dev/null +++ b/tests/data/tree.2.json @@ -0,0 +1,16 @@ +{ + "ranges": [[0.0, 2.0]], + "count": 32, + "children": { + "0": { + "ranges": [[0.0, 1.0]], + "count": 16, + "children": null + }, + "1": { + "ranges": [[1.0, 2.0]], + "count": 16, + "children": null + } + } +} diff --git a/tests/data/tree.csv b/tests/data/tree.csv new file mode 100644 index 0000000..7a3dfab --- /dev/null +++ b/tests/data/tree.csv @@ -0,0 +1,33 @@ +a,b,c +0,0.0,cat_1 +1,1.0,cat_1 +2,2.0,cat_1 +3,3.0,cat_1 +4,0.0,cat_1 +5,1.0,cat_1 +6,2.0,cat_1 +7,3.0,cat_1 +0,0.0,cat_1 +1,1.0,cat_1 +2,2.0,cat_1 +3,3.0,cat_1 +4,0.0,cat_1 +5,1.0,cat_1 +6,2.0,cat_1 +7,3.0,cat_1 +0,0.0,cat_2 +1,1.0,cat_2 +2,2.0,cat_2 +3,3.0,cat_2 +4,0.0,cat_2 +5,1.0,cat_2 +6,2.0,cat_2 +7,3.0,cat_2 +0,0.0,cat_2 +1,1.0,cat_2 +2,2.0,cat_2 +3,3.0,cat_2 +4,0.0,cat_2 +5,1.0,cat_2 +6,2.0,cat_2 +7,3.0,cat_2 diff --git a/tests/test_forest.py b/tests/test_forest.py index af55439..e56977e 100644 --- a/tests/test_forest.py +++ b/tests/test_forest.py @@ -1,3 +1,5 @@ +from typing import Any, Optional + import pytest from syndiffix.common import * @@ -141,3 +143,33 @@ def test_depth_limiting() -> None: forest = create_forest(DataFrame(data, columns=["col"]), bucketization_params=zero_depth_params) tree = forest.get_tree((ColumnId(0),)) assert isinstance(tree, Leaf) + + +def _node_to_dict(node: Node) -> dict[str, Any]: + count = node.noisy_count() + children: Optional[dict[str, Any]] = None + + if isinstance(node, Branch): + children = {str(key): _node_to_dict(child) for key, child in sorted(node.children.items())} + + return { + "ranges": [[interval.min, interval.max] for interval in node.snapped_intervals], + "count": count, + "children": children, + } + + +def test_tree_snapshot() -> None: + forest = load_forest( + "tree.csv", + anon_params=NOISELESS_PARAMS, + ) + + def get_tree(*comb: int) -> dict[str, Any]: + return _node_to_dict(forest.get_tree(tuple(comb))) + + assert get_tree(0) == load_json("tree.0.json") + assert get_tree(1) == load_json("tree.1.json") + assert get_tree(2) == load_json("tree.2.json") + assert get_tree(0, 1) == load_json("tree.0_1.json") + assert get_tree(0, 1, 2) == load_json("tree.0_1_2.json")