Skip to content

Commit

Permalink
lots of docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Jun 11, 2024
1 parent f5d26eb commit 5210dd7
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 38 deletions.
4 changes: 2 additions & 2 deletions src/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ using namespace Eval;
using namespace Var;
using namespace nlohmann;


template <ProgramType T> class Engine{
template <ProgramType T>
class Engine{
public:
Engine(const Parameters& p=Parameters())
: params(p)
Expand Down
39 changes: 35 additions & 4 deletions src/eval/evaluation.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,20 @@ using namespace Pop;
namespace Eval {

template<ProgramType T>
/**
* @class Evaluation
* @brief Class for evaluating the fitness of individuals in a population.
*/
class Evaluation {
public:
Scorer<T> S;

// TODO: make eval update loss_v accordingly, and set to th same as train loss if there is no batch or no validation
/**
* @brief Constructor for Evaluation class.
* @details Initializes the scorer based on the program type.
*/
Evaluation(){
// TODO: make eval update loss_v accordingly, and set to th same as train loss if there is no batch or no validation

string scorer;
if ( (T == Brush::ProgramType::MulticlassClassifier)
|| (T == Brush::ProgramType::Representer) )
Expand All @@ -38,10 +46,27 @@ class Evaluation {
};
~Evaluation(){};

/**
* @brief Set the scorer for evaluation.
* @param scorer The scorer to be set.
*/
void set_scorer(string scorer){this->S.set_scorer(scorer);};

/**
* @brief Get the current scorer.
* @return The current scorer.
*/
string get_scorer(){return this->S.get_scorer();};

/// fitness of population.
/**
* @brief Update the fitness of individuals in a population.
* @param pop The population to update.
* @param island The island index.
* @param data The dataset for evaluation.
* @param params The parameters for evaluation.
* @param fit Flag indicating whether to update fitness.
* @param validation Flag indicating whether to perform validation.
*/
void update_fitness(Population<T>& pop,
int island,
const Dataset& data,
Expand All @@ -50,7 +75,13 @@ class Evaluation {
bool validation=false
);

/// assign fitness to an individual.
/**
* @brief Assign fitness to an individual.
* @param ind The individual to assign fitness to.
* @param data The dataset for evaluation.
* @param params The parameters for evaluation.
* @param val Flag indicating whether it is validation fitness.
*/
void assign_fit(Individual<T>& ind, const Dataset& data,
const Parameters& params, bool val=false);

Expand Down
57 changes: 53 additions & 4 deletions src/eval/metrics.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,86 @@
#include "../data/data.h"

namespace Brush {
/**
* @namespace Eval
* @brief Namespace containing scoring functions for evaluation metrics.
*/
namespace Eval {

/* Scoring functions */

// regression ------------------------------------------------------------------
/// mean squared error

/**
* @brief Calculates the mean squared error between the predicted values and the true values.
* @param y The true values.
* @param yhat The predicted values.
* @param loss Reference to store the calculated losses for each sample.
* @param class_weights The optional class weights (not used for MSE).
* @return The mean squared error.
*/
float mse(const VectorXf& y, const VectorXf& yhat, VectorXf& loss,
const vector<float>& class_weights=vector<float>() );

// binary classification -------------------------------------------------------
/// log loss (2 methods below)

/**
* @brief Calculates the log loss between the predicted probabilities and the true labels.
* @param y The true labels.
* @param predict_proba The predicted probabilities.
* @param class_weights The optional class weights.
* @return The log loss.
*/
VectorXf log_loss(const VectorXf& y, const VectorXf& predict_proba,
const vector<float>& class_weights=vector<float>());

/**
* @brief Calculates the mean log loss between the predicted probabilities and the true labels.
* @param y The true labels.
* @param predict_proba The predicted probabilities.
* @param loss Reference to store the calculated losses for each sample.
* @param class_weights The optional class weights.
* @return The mean log loss.
*/
float mean_log_loss(const VectorXf& y, const VectorXf& predict_proba, VectorXf& loss,
const vector<float>& class_weights = vector<float>());

/**
* @brief Calculates the average precision score between the predicted probabilities and the true labels.
* @param y The true labels.
* @param predict_proba The predicted probabilities.
* @param loss Reference to store the calculated losses for each sample.
* @param class_weights The optional class weights.
* @return The average precision score.
*/
float average_precision_score(const VectorXf& y, const VectorXf& predict_proba,
VectorXf& loss,
const vector<float>& class_weights=vector<float>());

// multiclass classification ---------------------------------------------------
/// multinomial log loss (2 methods below)

/**
* @brief Calculates the multinomial log loss between the predicted probabilities and the true labels.
* @param y The true labels.
* @param predict_proba The predicted probabilities.
* @param class_weights The optional class weights.
* @return The multinomial log loss.
*/
VectorXf multi_log_loss(const VectorXf& y, const ArrayXXf& predict_proba,
const vector<float>& class_weights=vector<float>());

/**
* @brief Calculates the mean multinomial log loss between the predicted probabilities and the true labels.
* @param y The true labels.
* @param predict_proba The predicted probabilities.
* @param loss Reference to store the calculated losses for each sample.
* @param class_weights The optional class weights.
* @return The mean multinomial log loss.
*/
float mean_multi_log_loss(const VectorXf& y, const ArrayXXf& predict_proba,
VectorXf& loss,
const vector<float>& class_weights=vector<float>());


} // metrics
} // Brush

Expand Down
28 changes: 20 additions & 8 deletions src/ind/fitness.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@
using namespace nlohmann;

namespace Brush{


/**
* @brief Represents the fitness of an individual in the Brush namespace.
*
* The `Fitness` struct stores various attributes related to the fitness of an individual in the Brush namespace.
* It includes the aggregate loss score, aggregate validation loss score, complexity, size, depth, dominance counter,
* dominated individuals, Pareto front rank, crowding distance on the Pareto front, weighted values, and weights.
*
* The struct provides getter and setter methods for accessing and modifying these attributes.
* It also includes methods for calculating the hash value, setting values, clearing values, checking validity,
* and performing comparison operations.
*
* Additionally, there are methods for converting the `Fitness` object to JSON format and vice versa.
*/
struct Fitness {
// the loss is used in evolutionary functions

Expand All @@ -25,6 +38,12 @@ struct Fitness {
unsigned int rank; ///< pareto front rank
float crowding_dist; ///< crowding distance on the Pareto front

vector<float> values;
vector<float> weights;

// weighted values
vector<float> wvalues;

void set_dominated(vector<unsigned int>& dom){ dominated=dom; };
vector<unsigned int> get_dominated() const { return dominated; };

Expand Down Expand Up @@ -52,12 +71,6 @@ struct Fitness {
void set_crowding_dist(float cd){ crowding_dist=cd; };
float get_crowding_dist() const { return crowding_dist; };

vector<float> values;
vector<float> weights;

// weighted values
vector<float> wvalues;

// Constructor with initializer list for weights
Fitness(const vector<float>& w={}) : values(), wvalues(), weights(w) {
dcounter = 0;
Expand Down Expand Up @@ -91,7 +104,6 @@ struct Fitness {
if (v.size() != weights.size()) {
throw std::length_error("Assigned values have not the same length than current values");
}
// fmt::print("updated values\n");

values.resize(0);
for (const auto& element : v) {
Expand Down
89 changes: 72 additions & 17 deletions src/pop/archive.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,93 @@ using namespace Sel;

namespace Pop{

/**
* @brief The Archive struct represents a collection of individual programs.
*
* The Archive struct is used to store individual programs in a collection. It provides
* functionality for initializing, updating, and sorting the archive based on complexity
* or objectives. The archive can be operated on by a single thread.
*
* @tparam T The program type.
*/
template<ProgramType T>
struct Archive
{
// I dont need shared pointers here (this is not suposed to be operated
// by several threads)
vector<Individual<T>> individuals; ///< individual programs in the archive
bool sort_complexity; ///< whether to sort archive by complexity
NSGA2<T> selector; ///< using NSGA2 in survival mode (nsga2 does not implement selection)

// using NSGA2 in survival mode (nsga2 does not implement selection)
NSGA2<T> selector;

/**
* @brief Default constructor for the Archive struct.
*/
Archive();

/**
* @brief Initializes the archive with individuals from a population.
* @param pop The population from which to initialize the archive.
*/
void init(Population<T>& pop);

/**
* @brief Updates the archive with individuals from a population.
* @param pop The population from which to update the archive.
* @param params The parameters for the update.
*/
void update(Population<T>& pop, const Parameters& params);


/**
* @brief Sets the objectives for the archive.
*
* This function sets the objectives for the archive. The objectives are used for
* sorting the archive.
*
* @param objectives The objectives to set for the archive.
*/
void set_objectives(vector<string> objectives);

/// Sort population in increasing complexity.
static bool sortComplexity(const Individual<T>& lhs,
const Individual<T>& rhs);
/**
* @brief Sorts the population in increasing complexity.
*
* This static function is used to sort the population in increasing complexity.
* It is used as a comparison function for sorting algorithms.
*
* @param lhs The left-hand side individual to compare.
* @param rhs The right-hand side individual to compare.
*/
static bool sortComplexity(const Individual<T>& lhs, const Individual<T>& rhs);

/**
* @brief Sorts the population by the first objective.
*
* This static function is used to sort the population by the first objective.
* It is used as a comparison function for sorting algorithms.
*
* @param lhs The left-hand side individual to compare.
* @param rhs The right-hand side individual to compare.
*/
static bool sortObj1(const Individual<T>& lhs, const Individual<T>& rhs);

/// Sort population by first objective.
static bool sortObj1(const Individual<T>& lhs,
const Individual<T>& rhs);
/**
* @brief Checks if two individuals have the same fitness complexity.
*
* This static function is used to check if two individuals have the same fitness complexity.
* It is used as a comparison function for finding duplicates in the population.
*
* @param lhs The left-hand side individual to compare.
* @param rhs The right-hand side individual to compare.
*/
static bool sameFitComplexity(const Individual<T>& lhs, const Individual<T>& rhs);

/// check for repeats
static bool sameFitComplexity(const Individual<T>& lhs,
const Individual<T>& rhs);
static bool sameObjectives(const Individual<T>& lhs,
const Individual<T>& rhs);
/**
* @brief Checks if two individuals have the same objectives.
*
* This static function is used to check if two individuals have the same objectives.
* It is used as a comparison function for finding duplicates in the population.
*
* @param lhs The left-hand side individual to compare.
* @param rhs The right-hand side individual to compare.
*/
static bool sameObjectives(const Individual<T>& lhs, const Individual<T>& rhs);
};

//serialization
Expand Down
3 changes: 0 additions & 3 deletions src/selection/lexicase.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@
#include "selection_operator.h"
#include "../util/utils.h"


namespace Brush {
namespace Sel {


using namespace Brush;
using namespace Pop;
using namespace Sel;


////////////////////////////////////////////////////////////// Declarations
/*!
* @class Lexicase
* @brief Lexicase selection operator.
Expand Down

0 comments on commit 5210dd7

Please sign in to comment.