Skip to content

Commit

Permalink
Fix dot print (#24)
Browse files Browse the repository at this point in the history
* removed renumbering (still need to correct for splits) and fixed printing dot files

* Add smoke test for dot output

* Add a method to state_merger to print dot output to a stream rather than a file to aid in testing.

---------

Co-authored-by: sverwer <[email protected]>
Co-authored-by: Tom Catshoek <[email protected]>
  • Loading branch information
3 people authored Jan 18, 2024
1 parent 0b6669c commit ad2a819
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 21 deletions.
31 changes: 10 additions & 21 deletions source/apta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ apta::apta(){
}

void apta::print_dot(iostream& output){
int ncounter = 0;
/*for(APTA_iterator Ait = APTA_iterator(root); *Ait != 0; ++Ait){
(*Ait)->number = ncounter++;
}*/

output << "digraph DFA {\n";
output << "\t" << root->find()->number << " [label=\"root\" shape=box];\n";
output << "\t\tI -> " << root->find()->number << ";\n";
Expand All @@ -67,13 +62,14 @@ void apta::print_dot(iostream& output){

output << "\t" << n->number << " [ label=\"";
if(DEBUGGING){
output << n << ":#" << "\n";
output << "rep#" << n->representative << "\n";
output << n << ":#" << "\\n";
output << "rep#" << n->representative << "\\n";
}
output << n->number << " #" << n->size << "\" ";
output << n->number << " #" << n->size << " ";
n->data->print_state_label(output);
//output << "\" ";
n->data->print_state_style(output);
output << "\" ";

if (n->is_red()) output << ", style=filled, fillcolor=\"firebrick1\"";
else if (n->is_blue()) output << ", style=filled, fillcolor=\"dodgerblue1\"";
else if (n->is_white()) output << ", style=filled, fillcolor=\"ghostwhite\"";
Expand All @@ -99,18 +95,14 @@ void apta::print_dot(iostream& output){
}

output << "\t\t" << n->number << " -> " << child->number << " [label=\"";

output << inputdata::get_symbol(it->first);

output << inputdata::get_symbol(it->first) << " ";
n->data->print_transition_label(output, it->first);

for(auto & min_attribute_value : g->min_attribute_values){
output << "\n" << inputdata::get_attribute(min_attribute_value.first) << " >= " << min_attribute_value.second;
output << "\\n" << inputdata::get_attribute(min_attribute_value.first) << " >= " << min_attribute_value.second;
}
for(auto & max_attribute_value : g->max_attribute_values){
output << "\n" << inputdata::get_attribute(max_attribute_value.first) << " < " << max_attribute_value.second;
output << "\\n" << inputdata::get_attribute(max_attribute_value.first) << " < " << max_attribute_value.second;
}

output << "\" ";
output << ", penwidth=" << log(1 + n->size);
output << " ];\n";
Expand Down Expand Up @@ -169,10 +161,6 @@ void apta::print_json(iostream& output){
set_json_depths();
int count = 0;
root->depth = 0;
for(merged_APTA_iterator Ait = merged_APTA_iterator(root); *Ait != nullptr; ++Ait){
apta_node* n = *Ait;
n->number = count++;
}

output << "{\n";
output << "\t\"types\" : [\n";
Expand Down Expand Up @@ -434,7 +422,8 @@ void apta_node::initialize(apta_node* n){
representative_of = nullptr;
tails_head = nullptr;
access_trace = nullptr;
number = -1;
// keep the old node number
// number = -1;
size = 0;
final = 0;
depth = 0;
Expand Down
6 changes: 6 additions & 0 deletions source/state_merger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,12 @@ void state_merger::print_dot(const string& file_name)
output1.close();
}

void state_merger::print_dot(ostream& output)
{
todot();
output << dot_output;
}

void state_merger::print_json(const string& file_name)
{
tojson();
Expand Down
1 change: 1 addition & 0 deletions source/state_merger.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class state_merger{
bool split_init(apta_node *red, tail *t, int attr, int depth, bool evaluate, bool perform, bool test);

void print_dot(const string& file_name);
void print_dot(ostream& output);

void print_json(const string& file_name);

Expand Down
69 changes: 69 additions & 0 deletions tests/smoketest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "evaluation_factory.h"
#include "parameters.h"

using Catch::Matchers::Equals;

//TODO: refactor: These should probably be taken out of main.cpp
evaluation_function* get_evaluation();
void print_current_automaton(state_merger*, const string&, const string&);
Expand Down Expand Up @@ -73,3 +75,70 @@ TEST_CASE( "Smoke test: greedy edsm on stamina 1_training", "[smoke]" ) {

delete merger;
}

// This tests verifies that the dot file output works as expected
// It will need updating whenever the dot output changes
// TODO: Figure out a way to check if the dot output is valid without running graphviz
TEST_CASE( "Smoke test: dot output", "[smoke]" ) {
HEURISTIC_NAME = "evidence_driven";
DATA_NAME = "edsm_data";

evaluation_function *eval = get_evaluation();
REQUIRE(eval != nullptr);

std::string input = "12 4\n"
"1 3 a b c\n"
"1 3 a b d\n"
"0 2 a b\n"
"0 2 a a\n"
"0 2 b b\n"
"0 1 c\n"
"0 2 c c\n"
"0 1 d\n"
"0 2 d d\n"
"0 1 a\n"
"0 2 b c\n"
"0 2 b d\n";
std::istringstream input_stream(input);

auto* id = new inputdata();
id->read_abbadingo_header(input_stream);

apta* the_apta = new apta();
auto* merger = new state_merger(id, eval, the_apta);
the_apta->set_context(merger);
eval->set_context(merger);

id->read_abbadingo_file(input_stream);
eval->initialize_before_adding_traces();
id->add_traces_to_apta(the_apta);
eval->initialize_after_adding_traces(merger);

greedy_run(merger);

std::stringstream dot_stream;
merger->print_dot(dot_stream);
std::string actual_output = dot_stream.str();

std::string expected_output = "// produced with flexfringe // \n"
"digraph DFA {\n"
"\t-1 [label=\"root\" shape=box];\n"
"\t\tI -> -1;\n"
"\t-1 [ label=\"-1 #25 fin: 0:8 , \n"
" path: 1:2 , 0:15 , \" , style=filled, fillcolor=\"firebrick1\", width=1.44882, height=1.44882, penwidth=3.2581];\n"
"\t\t-1 -> 1 [label=\"a \" , penwidth=3.2581 ];\n"
"\t\t-1 -> -1 [label=\"b \" , penwidth=3.2581 ];\n"
"\t\t-1 -> -1 [label=\"c \" , penwidth=3.2581 ];\n"
"\t\t-1 -> -1 [label=\"d \" , penwidth=3.2581 ];\n"
"\t1 [ label=\"1 #8 fin: 0:2 , \n"
" path: 1:4 , 0:2 , \" , style=filled, fillcolor=\"firebrick1\", width=1.16228, height=1.16228, penwidth=2.19722];\n"
"\t\t1 -> -1 [label=\"a \" , penwidth=2.19722 ];\n"
"\t\t1 -> 1 [label=\"b \" , penwidth=2.19722 ];\n"
"\t\t1 -> 3 [label=\"c \" , penwidth=2.19722 ];\n"
"\t\t1 -> 3 [label=\"d \" , penwidth=2.19722 ];\n"
"\t3 [ label=\"3 #2 fin: 1:2 , \n"
" path: \" , style=filled, fillcolor=\"firebrick1\", width=0.741276, height=0.741276, penwidth=1.09861];\n"
"}\n";

REQUIRE_THAT(actual_output, Equals(expected_output));
}

0 comments on commit ad2a819

Please sign in to comment.