Skip to content

Commit

Permalink
Fixed meanlabel. Updated how we initialize weights
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Oct 24, 2023
1 parent aee7b30 commit 8c0979d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ template<PT PType> struct Program
out += fmt::format("{}\n", extras);

auto get_id = [](const auto& n){
if (Is<NodeType::Terminal, NodeType::MeanLabel>(n->data.node_type))
if (Is<NodeType::Terminal>(n->data.node_type))
return n->data.get_name(false);

return fmt::format("{}",fmt::ptr(n)).substr(2);
Expand Down
8 changes: 5 additions & 3 deletions src/search_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ float calc_initial_weight(const ArrayXf& value, const ArrayXf& y)
float prob_change = std::abs(slope(data.col(0).array() , // x=variable
data.col(1).array() )); // y=target

// having a minimum feature weight if it was not set to zero
if (std::abs(prob_change)<1e-4)
prob_change = 1e-1;

// prob_change will evaluate to nan if variance(x)==0. Features with
// zero variance should not be used (as they behave just like a constant).
if (std::isnan(prob_change))
prob_change = 0.0;
else
// having a minimum feature weight if it was not set to zero
prob_change += 1e-1;

return prob_change;
}
Expand Down Expand Up @@ -131,6 +132,7 @@ vector<Node> generate_terminals(const Dataset& d, const bool weights_init)
return sum / count;
};

// constants for each type
auto cXf = Node(NodeType::Constant, Signature<ArrayXf()>{}, true, "Cf");
float floats_avg_weights = signature_avg(cXf.ret_type);
cXf.set_prob_change(floats_avg_weights);
Expand Down
15 changes: 4 additions & 11 deletions tests/cpp/test_search_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ TEST(SearchSpace, Initialization)

// different weights to check if searchspace is initialized correctnly
unordered_map<string, float> user_ops = {
{"Add", 1},
{"Sub", 1},
{"Div", .5},
{"Mul", 0.5}
{"Add", 1},
{"Sub", 1},
{"Div", .5},
{"Mul", 0.5}
};

SearchSpace SS;
Expand All @@ -45,28 +45,21 @@ TEST(SearchSpace, Initialization)
ArrayXf expected_weights_Xf(4); // 5 elements (x3, x4, x5, c, meanLabel)
expected_weights_Xf << 0.80240685, 0.19270448, 0.5994426, 0.531518, 0.531518;

// terminals that arent constant will have a minimum value
expected_weights_Xf = expected_weights_Xf + minimum_prob;

auto actual_weights_f = SS.terminal_weights.at(DataType::ArrayF);
Eigen::Map<ArrayXf> actual_weights_Xf(actual_weights_f.data(), actual_weights_f.size());

ASSERT_TRUE(expected_weights_Xf.isApprox(actual_weights_Xf));


ArrayXf expected_weights_Xi(2); // 2 elements (x2 and c)
expected_weights_Xi << 0.2736814, 0.2736814;
expected_weights_Xi = expected_weights_Xi + minimum_prob;

auto actual_weights_i = SS.terminal_weights.at(DataType::ArrayI);
Eigen::Map<ArrayXf> actual_weights_Xi(actual_weights_i.data(), actual_weights_i.size());

ASSERT_TRUE(expected_weights_Xi.isApprox(actual_weights_Xi));


ArrayXf expected_weights_Xb(2); // 2 elements (x0 and c)
expected_weights_Xb << 0.8117065, 0.8117065;
expected_weights_Xb = expected_weights_Xb + minimum_prob;

auto actual_weights_b = SS.terminal_weights.at(DataType::ArrayB);
Eigen::Map<ArrayXf> actual_weights_Xb(actual_weights_b.data(), actual_weights_b.size());
Expand Down

0 comments on commit 8c0979d

Please sign in to comment.