From 004ff7a810a6dfb43e912a5806d66fc8cbda597c Mon Sep 17 00:00:00 2001 From: gAldeia Date: Wed, 26 Jun 2024 12:12:17 -0300 Subject: [PATCH] linear complexity objective (alternative to the recursive version) --- docs/guide/bandits.ipynb | 2 +- docs/guide/saving_loading_populations.ipynb | 283 ++++++++++++-------- docs/guide/search_space.ipynb | 2 +- pybrush/EstimatorInterface.py | 5 +- src/bindings/bind_fitness.cpp | 2 + src/bindings/bind_programs.h | 1 + src/engine.cpp | 2 +- src/eval/evaluation.cpp | 9 +- src/ind/fitness.cpp | 2 + src/ind/fitness.h | 11 +- src/ind/individual.h | 2 + src/program/program.h | 10 +- src/program/tree_node.cpp | 24 ++ src/program/tree_node.h | 1 + src/vary/variation.cpp | 3 + 15 files changed, 239 insertions(+), 120 deletions(-) diff --git a/docs/guide/bandits.ipynb b/docs/guide/bandits.ipynb index ec264b55..2859a007 100644 --- a/docs/guide/bandits.ipynb +++ b/docs/guide/bandits.ipynb @@ -388,7 +388,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.1.undefined" } }, "nbformat": 4, diff --git a/docs/guide/saving_loading_populations.ipynb b/docs/guide/saving_loading_populations.ipynb index e8187730..3ed2376a 100644 --- a/docs/guide/saving_loading_populations.ipynb +++ b/docs/guide/saving_loading_populations.ipynb @@ -48,78 +48,78 @@ "output_type": "stream", "text": [ "Generation 1/10 [////// ]\n", - "Train Loss (Med): 14.24093 (60.79966)\n", - "Val Loss (Med): 14.24093 (60.79966)\n", - "Median Size (Max): 3 (18)\n", - "Median complexity (Max): 9 (351)\n", - "Time (s): 0.12213\n", + "Train Loss (Med): 16.41696 (74.37033)\n", + "Val Loss (Med): 16.41696 (74.37033)\n", + "Median Size (Max): 3 (12)\n", + "Median complexity (Max): 9 (156)\n", + "Time (s): 0.07226\n", "\n", "Generation 2/10 [/////////// ]\n", - "Train Loss (Med): 13.86337 (20.54475)\n", - "Val Loss (Med): 13.86337 (20.54475)\n", - "Median Size (Max): 3 (20)\n", - "Median complexity (Max): 9 (651)\n", - "Time (s): 0.20873\n", + "Train Loss (Med): 12.66635 (49.96683)\n", + "Val Loss (Med): 12.66635 (49.96683)\n", + "Median Size (Max): 3 (12)\n", + "Median complexity (Max): 9 (165)\n", + "Time (s): 0.12100\n", "\n", "Generation 3/10 [//////////////// ]\n", - "Train Loss (Med): 13.58168 (17.94969)\n", - "Val Loss (Med): 13.58168 (17.94969)\n", - "Median Size (Max): 3 (18)\n", - "Median complexity (Max): 9 (216)\n", - "Time (s): 0.30500\n", + "Train Loss (Med): 12.66635 (16.41696)\n", + "Val Loss (Med): 12.66635 (16.41696)\n", + "Median Size (Max): 5 (14)\n", + "Median complexity (Max): 33 (408)\n", + "Time (s): 0.16357\n", "\n", "Generation 4/10 [///////////////////// ]\n", - "Train Loss (Med): 13.58167 (17.94969)\n", - "Val Loss (Med): 13.58167 (17.94969)\n", - "Median Size (Max): 3 (18)\n", - "Median complexity (Max): 9 (478)\n", - "Time (s): 0.44846\n", + "Train Loss (Med): 10.97588 (17.85729)\n", + "Val Loss (Med): 10.97588 (17.85729)\n", + "Median Size (Max): 5 (14)\n", + "Median complexity (Max): 20 (360)\n", + "Time (s): 0.21556\n", "\n", "Generation 5/10 [////////////////////////// ]\n", - "Train Loss (Med): 13.58167 (17.94969)\n", - "Val Loss (Med): 13.58167 (17.94969)\n", - "Median Size (Max): 3 (18)\n", - "Median complexity (Max): 8 (297)\n", - "Time (s): 0.58464\n", + "Train Loss (Med): 10.97588 (16.95482)\n", + "Val Loss (Med): 10.97588 (16.95482)\n", + "Median Size (Max): 5 (15)\n", + "Median complexity (Max): 33 (399)\n", + "Time (s): 0.26767\n", "\n", "Generation 6/10 [/////////////////////////////// ]\n", - "Train Loss (Med): 13.25836 (63.64049)\n", - "Val Loss (Med): 13.25836 (63.64049)\n", - "Median Size (Max): 3 (18)\n", - "Median complexity (Max): 7 (297)\n", - "Time (s): 0.71193\n", + "Train Loss (Med): 10.97588 (16.41696)\n", + "Val Loss (Med): 10.97588 (16.41696)\n", + "Median Size (Max): 5 (15)\n", + "Median complexity (Max): 33 (315)\n", + "Time (s): 0.33006\n", "\n", "Generation 7/10 [//////////////////////////////////// ]\n", - "Train Loss (Med): 10.63986 (63.64049)\n", - "Val Loss (Med): 10.63986 (63.64049)\n", - "Median Size (Max): 3 (18)\n", - "Median complexity (Max): 6 (567)\n", - "Time (s): 0.84199\n", + "Train Loss (Med): 10.97588 (16.41696)\n", + "Val Loss (Med): 10.97588 (16.41696)\n", + "Median Size (Max): 5 (15)\n", + "Median complexity (Max): 33 (273)\n", + "Time (s): 0.43463\n", "\n", "Generation 8/10 [///////////////////////////////////////// ]\n", - "Train Loss (Med): 10.28156 (63.64049)\n", - "Val Loss (Med): 10.28156 (63.64049)\n", - "Median Size (Max): 3 (18)\n", - "Median complexity (Max): 6 (270)\n", - "Time (s): 0.96073\n", + "Train Loss (Med): 10.97588 (15.46647)\n", + "Val Loss (Med): 10.97588 (15.46647)\n", + "Median Size (Max): 7 (15)\n", + "Median complexity (Max): 43 (273)\n", + "Time (s): 0.51012\n", "\n", "Generation 9/10 [////////////////////////////////////////////// ]\n", - "Train Loss (Med): 10.09177 (63.64049)\n", - "Val Loss (Med): 10.09177 (63.64049)\n", - "Median Size (Max): 3 (18)\n", - "Median complexity (Max): 6 (216)\n", - "Time (s): 1.06539\n", + "Train Loss (Med): 10.97588 (15.94172)\n", + "Val Loss (Med): 10.97588 (15.94172)\n", + "Median Size (Max): 6 (15)\n", + "Median complexity (Max): 34 (273)\n", + "Time (s): 0.58572\n", "\n", "Generation 10/10 [//////////////////////////////////////////////////]\n", - "Train Loss (Med): 10.09177 (63.64049)\n", - "Val Loss (Med): 10.09177 (63.64049)\n", - "Median Size (Max): 3 (20)\n", - "Median complexity (Max): 6 (510)\n", - "Time (s): 1.15215\n", + "Train Loss (Med): 10.97588 (15.94172)\n", + "Val Loss (Med): 10.97588 (15.94172)\n", + "Median Size (Max): 6 (15)\n", + "Median complexity (Max): 34 (273)\n", + "Time (s): 0.64205\n", "\n", - "Saved population to file /tmp/tmp91lebnyr/population.json\n", + "Saved population to file /tmp/tmpfphckt_3/population.json\n", "saving final population as archive...\n", - "score: 0.8883469950530682\n" + "score: 0.8785654367436365\n" ] } ], @@ -161,10 +161,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Loaded population from /tmp/tmp91lebnyr/population.json of size = 200\n", + "Loaded population from /tmp/tmpfphckt_3/population.json of size = 200\n", "Completed 100% [====================]\n", "saving final population as archive...\n", - "score: 0.8893989188100573\n" + "score: 0.888055116477749\n" ] } ], @@ -200,7 +200,7 @@ "\n", "First, we run the evolution and save the population to a file; then, we load it and keep evolving the individuals.\n", "\n", - "What is different though is that the first run is optimizing `error` and `complexity`, and the second run is optimizing `average_precision_score` and `size`." + "What is different though is that the first run is optimizing `error` and `complexity`, and the second run is optimizing `average_precision_score` and `linear_complexity`." ] }, { @@ -213,79 +213,79 @@ "output_type": "stream", "text": [ "Generation 1/10 [////// ]\n", - "Train Loss (Med): 0.54930 (0.69315)\n", - "Val Loss (Med): 0.54930 (0.69315)\n", - "Median Size (Max): 5 (10)\n", + "Train Loss (Med): 0.54851 (0.69315)\n", + "Val Loss (Med): 0.54851 (0.69315)\n", + "Median Size (Max): 5 (12)\n", "Median complexity (Max): 6 (270)\n", - "Time (s): 0.05091\n", + "Time (s): 0.03284\n", "\n", "Generation 2/10 [/////////// ]\n", - "Train Loss (Med): 0.54848 (0.69315)\n", - "Val Loss (Med): 0.54848 (0.69315)\n", + "Train Loss (Med): 0.54850 (0.69315)\n", + "Val Loss (Med): 0.54850 (0.69315)\n", "Median Size (Max): 5 (10)\n", - "Median complexity (Max): 6 (54)\n", - "Time (s): 0.08855\n", + "Median complexity (Max): 6 (165)\n", + "Time (s): 0.05459\n", "\n", "Generation 3/10 [//////////////// ]\n", - "Train Loss (Med): 0.54847 (0.69315)\n", - "Val Loss (Med): 0.54847 (0.69315)\n", - "Median Size (Max): 5 (10)\n", - "Median complexity (Max): 6 (54)\n", - "Time (s): 0.12188\n", + "Train Loss (Med): 0.54851 (0.69315)\n", + "Val Loss (Med): 0.54851 (0.69315)\n", + "Median Size (Max): 3 (10)\n", + "Median complexity (Max): 3 (165)\n", + "Time (s): 0.07147\n", "\n", "Generation 4/10 [///////////////////// ]\n", - "Train Loss (Med): 0.54847 (0.69315)\n", - "Val Loss (Med): 0.54847 (0.69315)\n", + "Train Loss (Med): 0.54851 (0.69315)\n", + "Val Loss (Med): 0.54851 (0.69315)\n", "Median Size (Max): 1 (10)\n", - "Median complexity (Max): 2 (52)\n", - "Time (s): 0.15164\n", + "Median complexity (Max): 2 (54)\n", + "Time (s): 0.08754\n", "\n", "Generation 5/10 [////////////////////////// ]\n", "Train Loss (Med): 0.54846 (0.69315)\n", "Val Loss (Med): 0.54846 (0.69315)\n", - "Median Size (Max): 1 (9)\n", - "Median complexity (Max): 1 (52)\n", - "Time (s): 0.18853\n", + "Median Size (Max): 1 (10)\n", + "Median complexity (Max): 2 (54)\n", + "Time (s): 0.11204\n", "\n", "Generation 6/10 [/////////////////////////////// ]\n", - "Train Loss (Med): 0.50911 (0.69315)\n", - "Val Loss (Med): 0.50911 (0.69315)\n", - "Median Size (Max): 1 (12)\n", - "Median complexity (Max): 1 (52)\n", - "Time (s): 0.23546\n", + "Train Loss (Med): 0.54846 (0.69315)\n", + "Val Loss (Med): 0.54846 (0.69315)\n", + "Median Size (Max): 1 (10)\n", + "Median complexity (Max): 1 (90)\n", + "Time (s): 0.13126\n", "\n", "Generation 7/10 [//////////////////////////////////// ]\n", - "Train Loss (Med): 0.50911 (0.69315)\n", - "Val Loss (Med): 0.50911 (0.69315)\n", - "Median Size (Max): 1 (12)\n", - "Median complexity (Max): 1 (44)\n", - "Time (s): 0.28747\n", + "Train Loss (Med): 0.54846 (0.69315)\n", + "Val Loss (Med): 0.54846 (0.69315)\n", + "Median Size (Max): 1 (10)\n", + "Median complexity (Max): 1 (54)\n", + "Time (s): 0.14825\n", "\n", "Generation 8/10 [///////////////////////////////////////// ]\n", - "Train Loss (Med): 0.50911 (0.69315)\n", - "Val Loss (Med): 0.50911 (0.69315)\n", - "Median Size (Max): 1 (14)\n", - "Median complexity (Max): 1 (120)\n", - "Time (s): 0.34002\n", + "Train Loss (Med): 0.54846 (0.69315)\n", + "Val Loss (Med): 0.54846 (0.69315)\n", + "Median Size (Max): 1 (10)\n", + "Median complexity (Max): 1 (52)\n", + "Time (s): 0.16841\n", "\n", "Generation 9/10 [////////////////////////////////////////////// ]\n", - "Train Loss (Med): 0.50911 (0.69315)\n", - "Val Loss (Med): 0.50911 (0.69315)\n", - "Median Size (Max): 1 (14)\n", - "Median complexity (Max): 1 (80)\n", - "Time (s): 0.39083\n", + "Train Loss (Med): 0.54846 (0.69315)\n", + "Val Loss (Med): 0.54846 (0.69315)\n", + "Median Size (Max): 1 (10)\n", + "Median complexity (Max): 1 (48)\n", + "Time (s): 0.19657\n", "\n", "Generation 10/10 [//////////////////////////////////////////////////]\n", - "Train Loss (Med): 0.50911 (0.69315)\n", - "Val Loss (Med): 0.50911 (0.69315)\n", - "Median Size (Max): 1 (12)\n", - "Median complexity (Max): 1 (44)\n", - "Time (s): 0.45610\n", + "Train Loss (Med): 0.54846 (0.69315)\n", + "Val Loss (Med): 0.54846 (0.69315)\n", + "Median Size (Max): 1 (10)\n", + "Median complexity (Max): 1 (48)\n", + "Time (s): 0.22023\n", "\n", - "Saved population to file /tmp/tmpqcnwu35t/population.json\n", + "Saved population to file /tmp/tmpe7n_mbgz/population.json\n", "saving final population as archive...\n", - "If(AIDS>15890.50,Logistic(13.52),If(Total>1572255.50,0.22,0.52))\n", - "score: 0.7\n" + "If(AIDS>15890.50,1.18*Logistic(1.69*MeanLabel),0.39*MeanLabel)\n", + "score: 0.68\n" ] } ], @@ -317,23 +317,86 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Generation 1/1 [//////////////////////////////////////////////////]\n", + "Generation 1/10 [////// ]\n", "Train Loss (Med): 0.46115 (0.31675)\n", "Val Loss (Med): 0.46115 (0.31675)\n", "Median Size (Max): 5 (9)\n", + "Median complexity (Max): 6 (180)\n", + "Time (s): 0.03686\n", + "\n", + "Generation 2/10 [/////////// ]\n", + "Train Loss (Med): 0.75212 (0.31675)\n", + "Val Loss (Med): 0.75212 (0.31675)\n", + "Median Size (Max): 5 (9)\n", "Median complexity (Max): 6 (120)\n", - "Time (s): 0.04622\n", + "Time (s): 0.06046\n", + "\n", + "Generation 3/10 [//////////////// ]\n", + "Train Loss (Med): 0.75212 (0.31675)\n", + "Val Loss (Med): 0.75212 (0.31675)\n", + "Median Size (Max): 4 (9)\n", + "Median complexity (Max): 2 (90)\n", + "Time (s): 0.07728\n", + "\n", + "Generation 4/10 [///////////////////// ]\n", + "Train Loss (Med): 0.75212 (0.00000)\n", + "Val Loss (Med): 0.75212 (0.00000)\n", + "Median Size (Max): 1 (9)\n", + "Median complexity (Max): 1 (90)\n", + "Time (s): 0.09751\n", + "\n", + "Generation 5/10 [////////////////////////// ]\n", + "Train Loss (Med): 0.75212 (0.00000)\n", + "Val Loss (Med): 0.75212 (0.00000)\n", + "Median Size (Max): 1 (6)\n", + "Median complexity (Max): 1 (90)\n", + "Time (s): 0.11214\n", + "\n", + "Generation 6/10 [/////////////////////////////// ]\n", + "Train Loss (Med): 0.75212 (0.00000)\n", + "Val Loss (Med): 0.75212 (0.00000)\n", + "Median Size (Max): 1 (6)\n", + "Median complexity (Max): 1 (90)\n", + "Time (s): 0.12957\n", + "\n", + "Generation 7/10 [//////////////////////////////////// ]\n", + "Train Loss (Med): 0.75258 (0.00000)\n", + "Val Loss (Med): 0.75258 (0.00000)\n", + "Median Size (Max): 1 (8)\n", + "Median complexity (Max): 1 (360)\n", + "Time (s): 0.14327\n", + "\n", + "Generation 8/10 [///////////////////////////////////////// ]\n", + "Train Loss (Med): 0.75212 (0.31675)\n", + "Val Loss (Med): 0.75212 (0.31675)\n", + "Median Size (Max): 1 (6)\n", + "Median complexity (Max): 1 (30)\n", + "Time (s): 0.15835\n", + "\n", + "Generation 9/10 [////////////////////////////////////////////// ]\n", + "Train Loss (Med): 0.75212 (0.31675)\n", + "Val Loss (Med): 0.75212 (0.31675)\n", + "Median Size (Max): 1 (6)\n", + "Median complexity (Max): 1 (30)\n", + "Time (s): 0.17249\n", + "\n", + "Generation 10/10 [//////////////////////////////////////////////////]\n", + "Train Loss (Med): 0.75212 (0.31675)\n", + "Val Loss (Med): 0.75212 (0.31675)\n", + "Median Size (Max): 1 (6)\n", + "Median complexity (Max): 1 (30)\n", + "Time (s): 0.18584\n", "\n", "saving final population as archive...\n", - "Logistic(Sum(-1.2485952,Add(-1.25,0.08*AIDS)))\n", - "score: 0.54\n" + "Logistic(-0.07*Mul(0.00*AIDS,-0.03*MeanLabel))\n", + "score: 0.52\n" ] } ], @@ -341,9 +404,9 @@ "est = BrushClassifier(\n", " functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],\n", " #load_population=pop_file,\n", - " objectives=[\"error\", \"complexity\"],\n", + " objectives=[\"error\", \"linear_complexity\"],\n", " scorer=\"average_precision_score\",\n", - " max_gens=1,\n", + " max_gens=10,\n", " verbosity=2\n", ")\n", "\n", diff --git a/docs/guide/search_space.ipynb b/docs/guide/search_space.ipynb index 69faab2c..2f01433e 100644 --- a/docs/guide/search_space.ipynb +++ b/docs/guide/search_space.ipynb @@ -156,7 +156,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.1.undefined" } }, "nbformat": 4, diff --git a/pybrush/EstimatorInterface.py b/pybrush/EstimatorInterface.py index 57529f92..063e8f58 100644 --- a/pybrush/EstimatorInterface.py +++ b/pybrush/EstimatorInterface.py @@ -67,8 +67,9 @@ class EstimatorInterface(): uniformly distributed between 1 and `max_size`. objectives : list[str], default ["error", "size"] list with one or more objectives to use. The first objective is the main. - If `"error"` is used, then the metric in `scorer` will be used. - Possible values are "error", "size", "complexity", and "depth". + If `"error"` is used, then the metric in `scorer` will be used as objective. + Possible values are "error", "size", "complexity", "linear_complexity", + and "depth". scorer : str, default None The metric to use for the "error" objective. If None, it will be set to "mse" for regression and "log" for binary classification. diff --git a/src/bindings/bind_fitness.cpp b/src/bindings/bind_fitness.cpp index da8ef61d..549b4bbe 100644 --- a/src/bindings/bind_fitness.cpp +++ b/src/bindings/bind_fitness.cpp @@ -24,12 +24,14 @@ void bind_fitness(py::module& m) .def_property("loss_v", &br::Fitness::get_loss_v, &br::Fitness::set_loss_v) .def_property("size", &br::Fitness::get_size, &br::Fitness::set_size) .def_property("complexity", &br::Fitness::get_complexity, &br::Fitness::set_complexity) + .def_property("linear_complexity", &br::Fitness::get_linear_complexity, &br::Fitness::set_linear_complexity) .def_property("depth", &br::Fitness::get_depth, &br::Fitness::set_depth) .def_property_readonly("prev_loss", &br::Fitness::get_loss) .def_property_readonly("prev_loss_v", &br::Fitness::get_loss_v) .def_property_readonly("prev_size", &br::Fitness::get_size) .def_property_readonly("prev_complexity", &br::Fitness::get_complexity) + .def_property_readonly("prev_linear_complexity", &br::Fitness::get_linear_complexity) .def_property_readonly("prev_depth", &br::Fitness::get_depth) .def("valid", &br::Fitness::valid, "Check if the fitness is valid") diff --git a/src/bindings/bind_programs.h b/src/bindings/bind_programs.h index 49ca8ff7..c01bb6ab 100644 --- a/src/bindings/bind_programs.h +++ b/src/bindings/bind_programs.h @@ -46,6 +46,7 @@ void bind_program(py::module& m, string name) .def("get_weights", &T::get_weights) .def("size", &T::size, py::arg("include_weight")=true) .def("complexity", &T::complexity) + .def("linear_complexity", &T::linear_complexity) .def("depth", &T::depth) // .def("cross", &T::cross, py::return_value_policy::automatic, // "Performs one attempt to stochastically swap subtrees between two programs and generate a child") diff --git a/src/engine.cpp b/src/engine.cpp index 1e6ea33c..c33a465f 100644 --- a/src/engine.cpp +++ b/src/engine.cpp @@ -32,7 +32,7 @@ void Engine::init() float error_weight = Individual::weightsMap[params.scorer]; this->best_score = -error_weight*MAX_FLT; - this->best_complexity = -error_weight*MAX_FLT; + this->best_complexity = -error_weight*MAX_FLT; // untie by complexity this->archive.set_objectives(params.get_objectives()); diff --git a/src/eval/evaluation.cpp b/src/eval/evaluation.cpp index 15092efd..ac8d8057 100644 --- a/src/eval/evaluation.cpp +++ b/src/eval/evaluation.cpp @@ -74,6 +74,7 @@ void Evaluation::assign_fit(Individual& ind, const Dataset& data, ind.fitness.set_loss_v(f_v); ind.fitness.set_size(ind.get_size()); ind.fitness.set_complexity(ind.get_complexity()); + ind.fitness.set_linear_complexity(ind.get_linear_complexity()); ind.fitness.set_depth(ind.get_depth()); vector values; @@ -84,11 +85,13 @@ void Evaluation::assign_fit(Individual& ind, const Dataset& data, if (n.compare(params.scorer)==0) values.push_back(val ? f_v : f); else if (n.compare("complexity")==0) - values.push_back(ind.program.complexity()); + values.push_back(ind.get_complexity()); + else if (n.compare("linear_complexity")==0) + values.push_back(ind.get_linear_complexity()); else if (n.compare("size")==0) - values.push_back(ind.program.size()); + values.push_back(ind.get_size()); else if (n.compare("depth")==0) - values.push_back(ind.program.depth()); + values.push_back(ind.get_depth()); else HANDLE_ERROR_THROW(n+" is not a known objective"); } diff --git a/src/ind/fitness.cpp b/src/ind/fitness.cpp index 4fdffb18..13a09a28 100644 --- a/src/ind/fitness.cpp +++ b/src/ind/fitness.cpp @@ -12,6 +12,7 @@ void to_json(json &j, const Fitness &f) {"loss", f.loss}, {"loss_v", f.loss_v}, {"complexity", f.complexity}, + {"linear_complexity", f.linear_complexity}, {"size", f.size}, {"depth", f.depth}, {"dcounter", f.dcounter}, @@ -29,6 +30,7 @@ void from_json(const json &j, Fitness& f) j.at("loss").get_to( f.loss ); j.at("loss_v").get_to( f.loss_v ); j.at("complexity").get_to( f.complexity ); + j.at("linear_complexity").get_to( f.linear_complexity ); j.at("size").get_to( f.size ); j.at("depth").get_to( f.depth ); j.at("dcounter").get_to( f.dcounter ); diff --git a/src/ind/fitness.h b/src/ind/fitness.h index 3c82fedf..e6f4c312 100644 --- a/src/ind/fitness.h +++ b/src/ind/fitness.h @@ -29,6 +29,7 @@ struct Fitness { float loss_v; ///< aggregate validation loss score unsigned int complexity; + unsigned int linear_complexity; unsigned int size; unsigned int depth; @@ -38,6 +39,7 @@ struct Fitness { float prev_loss_v; unsigned int prev_complexity; + unsigned int prev_linear_complexity; unsigned int prev_size; unsigned int prev_depth; @@ -73,11 +75,18 @@ struct Fitness { unsigned int get_size() const { return size; }; unsigned int get_prev_size() const {return prev_size; }; - void set_complexity(unsigned int new_c){ prev_complexity=complexity; complexity=new_c; }; + void set_complexity(unsigned int new_c){ + prev_complexity=complexity; complexity=new_c; }; unsigned int get_complexity() const { return complexity; }; unsigned int get_prev_complexity() const {return prev_complexity; }; + void set_linear_complexity(unsigned int new_lc){ + prev_linear_complexity=linear_complexity; linear_complexity=new_lc; }; + + unsigned int get_linear_complexity() const { return linear_complexity; }; + unsigned int get_prev_linear_complexity() const {return prev_linear_complexity; }; + void set_depth(unsigned int new_d){ prev_depth=depth; depth=new_d; }; unsigned int get_depth() const { return depth; }; diff --git a/src/ind/individual.h b/src/ind/individual.h index 8009ad3a..e4ce6842 100644 --- a/src/ind/individual.h +++ b/src/ind/individual.h @@ -96,6 +96,7 @@ class Individual{ unsigned int get_size() const { return program.size(); }; unsigned int get_depth() const { return program.depth(); }; unsigned int get_complexity() const { return program.complexity(); }; + unsigned int get_linear_complexity() const { return program.linear_complexity(); }; Program& get_program() { return program; }; string get_model(string fmt="compact", bool pretty=false) { @@ -134,6 +135,7 @@ class Individual{ // a minimization by default, thus "error" has weight -1.0) inline static std::map weightsMap = { {"complexity", -1.0}, + {"linear_complexity", -1.0}, {"size", -1.0}, {"mse", -1.0}, {"log", -1.0}, diff --git a/src/program/program.h b/src/program/program.h index a311603d..396d3d58 100644 --- a/src/program/program.h +++ b/src/program/program.h @@ -88,7 +88,7 @@ template struct Program SSref = std::optional>{s}; } - /// @brief count the complexity of the program. + /// @brief count the (recursive) complexity of the program. /// @return int complexity. int complexity() const{ auto head = Tree.begin(); @@ -96,6 +96,14 @@ template struct Program return head.node->get_complexity(); } + /// @brief count the linear complexity of the program. + /// @return int complexity. + int linear_complexity() const{ + auto head = Tree.begin(); + + return head.node->get_linear_complexity(); + } + /// @brief count the tree size of the program, including the weights in weighted nodes. /// @param include_weight whether to include the node's weight in the count. /// @return int number of nodes. diff --git a/src/program/tree_node.cpp b/src/program/tree_node.cpp index 2a186418..f0e57ce1 100644 --- a/src/program/tree_node.cpp +++ b/src/program/tree_node.cpp @@ -144,6 +144,30 @@ unordered_map operator_complexities = { {NodeType::CustomSplit , 5} }; +int TreeNode::get_linear_complexity() const +{ + int tree_complexity = operator_complexities.at(data.node_type); + + auto child = first_child; + for(int i = 0; i < data.get_arg_count(); ++i) + { + tree_complexity += child->get_linear_complexity(); + child = child->next_sibling; + } + + // include the `w` and `*` if the node is weighted (and it is not a constant or mean label) + if (data.get_is_weighted() + && !(Is(data.node_type) + || ( Is(data.node_type) + || Is(data.node_type)) ) + ) + return operator_complexities.at(NodeType::Mul) + + operator_complexities.at(NodeType::Constant) + + tree_complexity; + + return tree_complexity; +}; + int TreeNode::get_complexity() const { int node_complexity = operator_complexities.at(data.node_type); diff --git a/src/program/tree_node.h b/src/program/tree_node.h index dc50f00a..2b21942d 100644 --- a/src/program/tree_node.h +++ b/src/program/tree_node.h @@ -51,6 +51,7 @@ class tree_node_ { // size: 5*4=20 bytes (on 32 bit arch), can be reduced string get_tree_model(bool pretty=false, string offset="") const; int get_complexity() const; + int get_linear_complexity() const; int get_size(bool include_weight=true) const; }; using TreeNode = class tree_node_; diff --git a/src/vary/variation.cpp b/src/vary/variation.cpp index 639bee92..1a0925cb 100644 --- a/src/vary/variation.cpp +++ b/src/vary/variation.cpp @@ -649,6 +649,7 @@ void Variation::vary(Population& pop, int island, ind.fitness.set_loss_v(mom.fitness.get_loss_v()); ind.fitness.set_size(mom.fitness.get_size()); ind.fitness.set_complexity(mom.fitness.get_complexity()); + ind.fitness.set_linear_complexity(mom.fitness.get_linear_complexity()); ind.fitness.set_depth(mom.fitness.get_depth()); assert(ind.program.size()>0); @@ -697,6 +698,8 @@ vector Variation::calculate_rewards(Population& pop, int island) delta = ind.fitness.get_loss_v()-ind.fitness.get_loss(); else if (obj.compare("complexity")==0) delta = ind.fitness.get_complexity()-ind.fitness.get_prev_complexity(); + else if (obj.compare("linear_complexity")==0) + delta = ind.fitness.get_linear_complexity()-ind.fitness.get_prev_linear_complexity(); else if (obj.compare("size")==0) delta = ind.fitness.get_size()-ind.fitness.get_prev_size(); else if (obj.compare("depth")==0)