From 80d591e4abdf7a2e5b673a04079f000e18cd432b Mon Sep 17 00:00:00 2001 From: xJoskiy Date: Sun, 5 Jan 2025 04:32:20 +0300 Subject: [PATCH] Add data cleaning scenario --- examples/datasets/taxes_2.csv | 3 +- examples/expert/data_cleaning.py | 87 ++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 examples/expert/data_cleaning.py diff --git a/examples/datasets/taxes_2.csv b/examples/datasets/taxes_2.csv index 9f80920fca..831c8c42ba 100644 --- a/examples/datasets/taxes_2.csv +++ b/examples/datasets/taxes_2.csv @@ -5,7 +5,8 @@ NewYork,5000,0.3 Wisconsin,5000,0.15 Wisconsin,6000,0.2 Wisconsin,4000,0.1 +Wisconsin,3000,0.3 Texas,1000,0.15 Texas,2000,0.25 Texas,3000,0.3 -Texas,5000,0.05 +Texas,4000,0.1 diff --git a/examples/expert/data_cleaning.py b/examples/expert/data_cleaning.py new file mode 100644 index 0000000000..f35aa273e0 --- /dev/null +++ b/examples/expert/data_cleaning.py @@ -0,0 +1,87 @@ +from typing import Dict, List, Tuple +from collections import defaultdict +import matplotlib.pyplot as plt +import desbordante as db +import networkx as nx +import time + + + +class DataCleaner: + def __init__(self, violations: List[Tuple[int, int]]) -> None: + self.graph: Dict[int, List[int]] = defaultdict(list) + for v1, v2 in violations: + if v1 != v2: + self.graph[v1].append(v2) + self.graph[v2].append(v1) + else: + self.graph[v1] = [v1] + self.nodes: List[int] = list(self.graph.keys()) + self.removed_nodes: List[int] = [] + + def __remove_highest_degree_node(self) -> None: + max_key = max(self.graph, key=lambda x: len(self.graph[x])) + for neighbor in self.graph[max_key]: + self.graph[neighbor].remove(max_key) + + del self.graph[max_key] + self.nodes.remove(max_key) + self.removed_nodes.append(max_key) + + # Check if the graph contains any edges + def __has_edges(self) -> bool: + return any(self.graph[node] for node in self.graph) + + # Remove highest degree node while graph has edges + def clean(self) -> None: + print("Cleaning algorithm started") + while self.__has_edges(): + self.__remove_highest_degree_node() + print("Cleaning algorithm finished") + + def draw(self, is_blocked: bool = True) -> None: + plt.figure() + G = nx.Graph() + G.add_nodes_from(self.nodes) + for node, neighbours in self.graph.items(): + [G.add_edge(node, neighbour) for neighbour in neighbours] + nx.draw(G, with_labels=True) + plt.show(block=is_blocked) + + +def main(): + TABLE_1 = '/home/joskiy/Projects/Desbordante/examples/datasets/taxes_2.csv' + DC = "!(s.State == t.State and s.Salary < t.Salary and s.FedTaxRate > t.FedTaxRate)" + SEPARATOR = ',' + HAS_HEADER = True + + print("Data loading started") + verificator = db.dc_verification.algorithms.Default() + verificator.load_data(table=(TABLE_1, SEPARATOR, HAS_HEADER)) + print("Data loading finished") + + DO_COLLECT_VIOLATIONS = True + + print("Algo execution started") + + verificator.execute(denial_constraint=DC, do_collect_violations=DO_COLLECT_VIOLATIONS) + + print("Algo execution finished") + + dc_holds = verificator.dc_holds() + + print("DC " + DC + " holds: " + str(dc_holds)) + + violations = verificator.get_violations() + cleaner = DataCleaner(violations) + + cleaner.draw(False) + cleaner.clean() + cleaner.draw() + + nodes = sorted(cleaner.removed_nodes) + print(f"Records to be removed: {", ".join(map(str, nodes))}") + + +if __name__ == "__main__": + main()