Skip to content

Commit

Permalink
Implement python bindings with basic example
Browse files Browse the repository at this point in the history
  • Loading branch information
nemakin committed Dec 1, 2024
1 parent 763ceb6 commit ac6093a
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 49 deletions.
59 changes: 59 additions & 0 deletions examples/basic/verifying_pfd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import desbordante
import pandas as pd

ERROR = 0.3
PER_TUPLE = 'per_tuple'
PER_VALUE = 'per_value'
TABLE = 'examples/datasets/glitchy_sensor_2.csv'


def print_results(verifier):
error = verifier.get_error()
if error <= ERROR:
print('PFD holds')
else:
print(f'PFD with error {ERROR} does not hold')
print(f'But it holds with error {error}')
print()
print('Additional info:')
print(f'Number of rows violating PFD: {verifier.get_num_violating_rows()}')
print(f'Number of clusters violating PFD: {verifier.get_num_violating_clusters()}')
print()

table = pd.read_csv(TABLE)
violating_clusters = verifier.get_violating_clusters()
number_names = ['First', 'Second', 'Third']
cluster_number = 0
for violating_cluster in violating_clusters:
print(f'{number_names[cluster_number]} violating cluster:')
cluster_number += 1
violating_series = []
for i, row in table.iterrows():
if i not in violating_cluster:
continue
violating_series.append(row)
print(pd.DataFrame(violating_series))
print()

# Loading input data
algo = desbordante.pfd_verification.algorithms.Default()
algo.load_data(table=(TABLE, ',', True))

# Print dataset
print(f'Dataset: {TABLE}')
print(pd.read_csv(TABLE))
print()

# Checking whether PFD (DeviceId) -> (Data) holds for PerValue measure
algo.execute(lhs_indices=[1], rhs_indices=[2], error=ERROR, pfd_error_measure=PER_VALUE)
print('-' * 80)
print(f'Checking whether PFD (DeviceId) -> (Data) holds for {PER_VALUE} error measure')
print('-' * 80)
print_results(algo)

# Checking whether the same PFD holds for PerTuple measure
algo.execute(lhs_indices=[1], rhs_indices=[2], error=ERROR, pfd_error_measure=PER_TUPLE)
print('-' * 80)
print(f'Checking whether the same PFD holds for {PER_TUPLE} error measure:')
print('-' * 80)
print_results(algo)
17 changes: 17 additions & 0 deletions examples/datasets/glitchy_sensor_2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Id,DeviceId,Data
1,D-1,1001
2,D-1,1002
3,D-1,1003
4,D-1,1004
5,D-1,1005
6,D-1,1006
7,D-2,1000
8,D-2,1001
9,D-2,1000
10,D-3,1010
11,D-4,1011
12,D-4,1011
13,D-5,1015
14,D-5,1014
15,D-5,1015
16,D-5,1015
19 changes: 7 additions & 12 deletions src/core/algorithms/fd/pfd_verifier/pfd_stats_calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,23 @@ namespace algos {
class PFDStatsCalculator {
private:
std::shared_ptr<ColumnLayoutRelationData> relation_;
config::ErrorType max_fd_error_;
config::ErrorMeasureType error_measure_;
config::PfdErrorMeasureType error_measure_;

std::vector<model::PLI::Cluster> clusters_violating_pfd_;
size_t num_rows_violating_pfd_ = 0;
config::ErrorType error_ = 0.0;

public:
explicit PFDStatsCalculator(std::shared_ptr<ColumnLayoutRelationData> relation,
config::ErrorMeasureType measure, config::ErrorType max_fd_error)
: relation_(std::move(relation)), max_fd_error_(max_fd_error), error_measure_(measure) {}
config::PfdErrorMeasureType measure)
: relation_(std::move(relation)), error_measure_(measure) {}

void ResetState() {
clusters_violating_pfd_.clear();
num_rows_violating_pfd_ = 0;
error_ = 0;
}

bool PFDHolds() const {
return error_ <= max_fd_error_;
}

size_t GetNumViolatingClusters() const {
return clusters_violating_pfd_.size();
}
Expand Down Expand Up @@ -79,17 +74,17 @@ class PFDStatsCalculator {
clusters_violating_pfd_.push_back(x_cluster);
}
num_rows_violating_pfd_ += x_cluster_size - max;
sum += error_measure_ == +ErrorMeasure::per_tuple
sum += error_measure_ == +PfdErrorMeasure::per_tuple
? static_cast<double>(max)
: static_cast<double>(max) / x_cluster_size;
cluster_rows_count += x_cluster.size();
}
unsigned int unique_rows =
static_cast<unsigned int>(x_pli->GetRelationSize() - cluster_rows_count);
double probability =
static_cast<double>(sum + unique_rows) / (error_measure_ == +ErrorMeasure::per_tuple
? x_pli->GetRelationSize()
: x_index.size() + unique_rows);
static_cast<double>(sum + unique_rows) /
(error_measure_ == +PfdErrorMeasure::per_tuple ? x_pli->GetRelationSize()
: x_index.size() + unique_rows);
error_ = 1.0 - probability;
}
};
Expand Down
9 changes: 3 additions & 6 deletions src/core/algorithms/fd/pfd_verifier/pfd_verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "config/names.h"
#include "config/tabular_data/input_table/option.h"
#include "equal_nulls/option.h"
#include "error/option.h"
#include "error_measure/option.h"
#include "indices/option.h"

Expand All @@ -19,13 +18,12 @@ void PFDVerifier::RegisterOptions() {
RegisterOption(config::kEqualNullsOpt(&is_null_equal_null_));
RegisterOption(config::kLhsIndicesOpt(&lhs_indices_, get_schema_cols));
RegisterOption(config::kRhsIndicesOpt(&rhs_indices_, get_schema_cols));
RegisterOption(config::kErrorMeasureOpt(&error_measure_));
RegisterOption(config::kErrorOpt(&max_fd_error_));
RegisterOption(config::kPfdErrorMeasureOpt(&error_measure_));
}

void PFDVerifier::MakeExecuteOptsAvailable() {
using namespace config::names;
MakeOptionsAvailable({kLhsIndices, kRhsIndices, kErrorMeasure, kError});
MakeOptionsAvailable({kLhsIndices, kRhsIndices, kPfdErrorMeasure});
}

void PFDVerifier::LoadDataInternal() {
Expand All @@ -37,8 +35,7 @@ void PFDVerifier::LoadDataInternal() {

unsigned long long PFDVerifier::ExecuteInternal() {
auto start_time = std::chrono::system_clock::now();
stats_calculator_ =
std::make_unique<PFDStatsCalculator>(relation_, error_measure_, max_fd_error_);
stats_calculator_ = std::make_unique<PFDStatsCalculator>(relation_, error_measure_);
VerifyPFD();
auto elapsed_milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now() - start_time);
Expand Down
8 changes: 1 addition & 7 deletions src/core/algorithms/fd/pfd_verifier/pfd_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ class PFDVerifier : public Algorithm {
config::IndicesType lhs_indices_;
config::IndicesType rhs_indices_;
config::EqNullsType is_null_equal_null_;
config::ErrorType max_fd_error_;
config::ErrorMeasureType error_measure_ = +ErrorMeasure::per_tuple;
config::PfdErrorMeasureType error_measure_ = +PfdErrorMeasure::per_tuple;

std::shared_ptr<ColumnLayoutRelationData> relation_;
std::unique_ptr<PFDStatsCalculator> stats_calculator_;
Expand All @@ -41,11 +40,6 @@ class PFDVerifier : public Algorithm {
std::shared_ptr<model::PLI const> CalculatePLI(config::IndicesType const& indices) const;

public:
bool PFDHolds() const {
assert(stats_calculator_);
return stats_calculator_->PFDHolds();
}

size_t GetNumViolatingClusters() const {
assert(stats_calculator_);
return stats_calculator_->GetNumViolatingClusters();
Expand Down
4 changes: 3 additions & 1 deletion src/python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "nd/bind_nd.h"
#include "nd/bind_nd_verification.h"
#include "od/bind_od.h"
#include "pfd/bind_pfd_verification.h"
#include "sfd/bind_sfd.h"
#include "statistics/bind_statistics.h"
#include "ucc/bind_ucc.h"
Expand Down Expand Up @@ -60,7 +61,8 @@ PYBIND11_MODULE(desbordante, module, pybind11::mod_gil_not_used()) {
BindNdVerification,
BindSFD,
BindMd,
BindDCVerification}) {
BindDCVerification,
BindPfdVerification}) {
bind_func(module);
}
}
Expand Down
25 changes: 25 additions & 0 deletions src/python_bindings/pfd/bind_pfd_verification.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "bind_pfd_verification.h"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "algorithms/fd/pfd_verifier/pfd_verifier.h"
#include "py_util/bind_primitive.h"

namespace {
namespace py = pybind11;
} // namespace

namespace python_bindings {
void BindPfdVerification(py::module_& main_module) {
using namespace algos;
auto pfd_verification_module = main_module.def_submodule("pfd_verification");

BindPrimitiveNoBase<PFDVerifier>(pfd_verification_module, "PFDVerifier")
.def("get_num_violating_clusters", &PFDVerifier::GetNumViolatingClusters)
.def("get_num_violating_rows", &PFDVerifier::GetNumViolatingRows)
.def("get_violating_clusters", &PFDVerifier::GetViolatingClusters)
.def("get_error", &PFDVerifier::GetError);
main_module.attr("pfd_verification") = pfd_verification_module;
}
} // namespace python_bindings
7 changes: 7 additions & 0 deletions src/python_bindings/pfd/bind_pfd_verification.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

#include <pybind11/pybind11.h>

namespace python_bindings {
void BindPfdVerification(pybind11::module_& main_module);
} // namespace python_bindings
48 changes: 25 additions & 23 deletions src/tests/test_pfd_verifier.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <gtest/gtest.h>

#include "algorithms/algo_factory.h"
#include "algorithms/fd/pfdtane/pfd_verifier/pfd_verifier.h"
#include "algorithms/fd/pfd_verifier/pfd_verifier.h"
#include "all_csv_configs.h"
#include "config/indices/type.h"
#include "config/names.h"
Expand All @@ -19,16 +19,17 @@ struct PFDVerifyingParams {
std::vector<model::PLI::Cluster> const clusters_violating_pfd;

PFDVerifyingParams(config::IndicesType lhs_indices, config::IndicesType rhs_indices,
config::ErrorMeasureType error_measure, config::ErrorType error,
config::PfdErrorMeasureType error_measure, config::ErrorType error,
size_t num_violating_clusters, size_t num_violating_rows,
std::vector<model::PLI::Cluster> clusters_violating_pfd,
CSVConfig const& csv_config)
: params({{onam::kCsvConfig, csv_config},
: params({
{onam::kCsvConfig, csv_config},
{onam::kEqualNulls, true},
{onam::kLhsIndices, std::move(lhs_indices)},
{onam::kRhsIndices, std::move(rhs_indices)},
{onam::kErrorMeasure, error_measure},
{onam::kError, error}}),
{onam::kPfdErrorMeasure, error_measure},
}),
expected_error(error),
num_violating_clusters(num_violating_clusters),
num_violating_rows(num_violating_rows),
Expand All @@ -44,7 +45,6 @@ TEST_P(TestPFDVerifying, DefaultTest) {
auto verifier = algos::CreateAndLoadAlgorithm<algos::PFDVerifier>(p.params);
double const eps = 0.0001;
verifier->Execute();
EXPECT_TRUE(verifier->PFDHolds());
EXPECT_NEAR(p.expected_error, verifier->GetError(), eps);
EXPECT_EQ(p.num_violating_clusters, verifier->GetNumViolatingClusters());
EXPECT_EQ(p.num_violating_rows, verifier->GetNumViolatingRows());
Expand All @@ -53,21 +53,23 @@ TEST_P(TestPFDVerifying, DefaultTest) {

INSTANTIATE_TEST_SUITE_P(
PFDVerifierTestSuite, TestPFDVerifying,
::testing::Values(PFDVerifyingParams({2}, {3}, +algos::ErrorMeasure::per_value, 0.0625, 1,
1, {{0, 1}}, kTestFD),
PFDVerifyingParams({0, 1}, {4}, +algos::ErrorMeasure::per_value, 0.166667,
2, 2, {{0, 1, 2}, {6, 7, 8}}, kTestFD),
PFDVerifyingParams({4}, {5}, +algos::ErrorMeasure::per_value, 0.3334, 4,
4, {{0, 8}, {1, 2}, {3, 4, 5}, {9, 10, 11}}, kTestFD),
PFDVerifyingParams({5}, {1}, +algos::ErrorMeasure::per_value, 0.0, 0, 0,
{}, kTestFD),
PFDVerifyingParams({2}, {3}, +algos::ErrorMeasure::per_tuple, 0.0834, 1,
1, {{0, 1}}, kTestFD),
PFDVerifyingParams({0, 1}, {4}, +algos::ErrorMeasure::per_tuple, 0.1667,
2, 2, {{0, 1, 2}, {6, 7, 8}}, kTestFD),
PFDVerifyingParams({4}, {5}, +algos::ErrorMeasure::per_tuple, 0.3334, 4,
4, {{0, 8}, {1, 2}, {3, 4, 5}, {9, 10, 11}}, kTestFD),
PFDVerifyingParams({5}, {1}, +algos::ErrorMeasure::per_tuple, 0.0, 0, 0,
{}, kTestFD)));
::testing::Values(PFDVerifyingParams({2}, {3}, +algos::PfdErrorMeasure::per_value, 0.0625,
1, 1, {{0, 1}}, kTestFD),
PFDVerifyingParams({0, 1}, {4}, +algos::PfdErrorMeasure::per_value,
0.166667, 2, 2, {{0, 1, 2}, {6, 7, 8}}, kTestFD),
PFDVerifyingParams({4}, {5}, +algos::PfdErrorMeasure::per_value, 0.3334,
4, 4, {{0, 8}, {1, 2}, {3, 4, 5}, {9, 10, 11}},
kTestFD),
PFDVerifyingParams({5}, {1}, +algos::PfdErrorMeasure::per_value, 0.0, 0,
0, {}, kTestFD),
PFDVerifyingParams({2}, {3}, +algos::PfdErrorMeasure::per_tuple, 0.0834,
1, 1, {{0, 1}}, kTestFD),
PFDVerifyingParams({0, 1}, {4}, +algos::PfdErrorMeasure::per_tuple,
0.1667, 2, 2, {{0, 1, 2}, {6, 7, 8}}, kTestFD),
PFDVerifyingParams({4}, {5}, +algos::PfdErrorMeasure::per_tuple, 0.3334,
4, 4, {{0, 8}, {1, 2}, {3, 4, 5}, {9, 10, 11}},
kTestFD),
PFDVerifyingParams({5}, {1}, +algos::PfdErrorMeasure::per_tuple, 0.0, 0,
0, {}, kTestFD)));

} // namespace tests
} // namespace tests

0 comments on commit ac6093a

Please sign in to comment.