Skip to content

Commit

Permalink
[c++] Add Bagging by Query for Lambdarank (#6623)
Browse files Browse the repository at this point in the history
* add bagging by query for lambdarank

* fix pre-commit

* fix bagging by query with cuda

* fix bagging by query test case

* fix bagging by query test case

* fix bagging by query test case

* add #include <vector>

* Update include/LightGBM/objective_function.h

Co-authored-by: Nikita Titov <[email protected]>

* Update tests/python_package_test/test_engine.py

Co-authored-by: Nikita Titov <[email protected]>

* Update tests/python_package_test/test_engine.py

Co-authored-by: Nikita Titov <[email protected]>

---------

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
shiyu1994 and StrikerRUS authored Oct 2, 2024
1 parent 59a3432 commit d1d218c
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 13 deletions.
4 changes: 4 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,10 @@ Learning Control Parameters

- random seed for bagging

- ``bagging_by_query`` :raw-html:`<a id="bagging_by_query" title="Permalink to this parameter" href="#bagging_by_query">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool

- whether to do bagging sample by query

- ``feature_fraction`` :raw-html:`<a id="feature_fraction" title="Permalink to this parameter" href="#feature_fraction">&#x1F517;&#xFE0E;</a>`, default = ``1.0``, type = double, aliases: ``sub_feature``, ``colsample_bytree``, constraints: ``0.0 < feature_fraction <= 1.0``

- LightGBM will randomly select a subset of features on each iteration (tree) if ``feature_fraction`` is smaller than ``1.0``. For example, if you set it to ``0.8``, LightGBM will select 80% of features before training each tree
Expand Down
3 changes: 3 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ struct Config {
// desc = random seed for bagging
int bagging_seed = 3;

// desc = whether to do bagging sample by query
bool bagging_by_query = false;

// alias = sub_feature, colsample_bytree
// check = >0.0
// check = <=1.0
Expand Down
5 changes: 5 additions & 0 deletions include/LightGBM/cuda/cuda_objective_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class CUDAObjectiveInterface: public HOST_OBJECTIVE {
SynchronizeCUDADevice(__FILE__, __LINE__);
}

void GetGradients(const double* scores, const data_size_t /*num_sampled_queries*/, const data_size_t* /*sampled_query_indices*/, score_t* gradients, score_t* hessians) const override {
LaunchGetGradientsKernel(scores, gradients, hessians);
SynchronizeCUDADevice(__FILE__, __LINE__);
}

void RenewTreeOutputCUDA(const double* score, const data_size_t* data_indices_in_leaf, const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf, const int num_leaves, double* leaf_value) const override {
global_timer.Start("CUDAObjectiveInterface::LaunchRenewTreeOutputCUDAKernel");
Expand Down
11 changes: 11 additions & 0 deletions include/LightGBM/objective_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ class ObjectiveFunction {
virtual void GetGradients(const double* score,
score_t* gradients, score_t* hessians) const = 0;

/*!
* \brief calculating first order derivative of loss function, used only for bagging by query in lambdarank
* \param score prediction score in this round
* \param num_sampled_queries number of in-bag queries
* \param sampled_query_indices indices of in-bag queries
* \gradients Output gradients
* \hessians Output hessians
*/
virtual void GetGradients(const double* score, const data_size_t /*num_sampled_queries*/, const data_size_t* /*sampled_query_indices*/,
score_t* gradients, score_t* hessians) const { GetGradients(score, gradients, hessians); }

virtual const char* GetName() const = 0;

virtual bool IsConstantHessian() const { return false; }
Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/sample_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class SampleStrategy {

bool NeedResizeGradients() const { return need_resize_gradients_; }

virtual data_size_t num_sampled_queries() const { return 0; }

virtual const data_size_t* sampled_query_indices() const { return nullptr; }

protected:
const Config* config_;
const Dataset* train_data_;
Expand Down
95 changes: 91 additions & 4 deletions src/boosting/bagging.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define LIGHTGBM_BOOSTING_BAGGING_HPP_

#include <string>
#include <vector>

namespace LightGBM {

Expand All @@ -17,8 +18,11 @@ class BaggingSampleStrategy : public SampleStrategy {
config_ = config;
train_data_ = train_data;
num_data_ = train_data->num_data();
num_queries_ = train_data->metadata().num_queries();
query_boundaries_ = train_data->metadata().query_boundaries();
objective_function_ = objective_function;
num_tree_per_iteration_ = num_tree_per_iteration;
num_threads_ = OMP_NUM_THREADS();
}

~BaggingSampleStrategy() {}
Expand All @@ -27,9 +31,10 @@ class BaggingSampleStrategy : public SampleStrategy {
Common::FunctionTimer fun_timer("GBDT::Bagging", global_timer);
// if need bagging
if ((bag_data_cnt_ < num_data_ && iter % config_->bagging_freq == 0) ||
need_re_bagging_) {
need_re_bagging_) {
need_re_bagging_ = false;
auto left_cnt = bagging_runner_.Run<true>(
if (!config_->bagging_by_query) {
auto left_cnt = bagging_runner_.Run<true>(
num_data_,
[=](int, data_size_t cur_start, data_size_t cur_cnt, data_size_t* left,
data_size_t*) {
Expand All @@ -43,7 +48,60 @@ class BaggingSampleStrategy : public SampleStrategy {
return cur_left_count;
},
bag_data_indices_.data());
bag_data_cnt_ = left_cnt;
bag_data_cnt_ = left_cnt;
} else {
num_sampled_queries_ = bagging_runner_.Run<true>(
num_queries_,
[=](int, data_size_t cur_start, data_size_t cur_cnt, data_size_t* left,
data_size_t*) {
data_size_t cur_left_count = 0;
cur_left_count = BaggingHelper(cur_start, cur_cnt, left);
return cur_left_count;
}, bag_query_indices_.data());

sampled_query_boundaries_[0] = 0;
OMP_INIT_EX();
#pragma omp parallel for schedule(static) num_threads(num_threads_)
for (data_size_t i = 0; i < num_sampled_queries_; ++i) {
OMP_LOOP_EX_BEGIN();
sampled_query_boundaries_[i + 1] = query_boundaries_[bag_query_indices_[i] + 1] - query_boundaries_[bag_query_indices_[i]];
OMP_LOOP_EX_END();
}
OMP_THROW_EX();

const int num_blocks = Threading::For<data_size_t>(0, num_sampled_queries_ + 1, 128, [this](int thread_index, data_size_t start_index, data_size_t end_index) {
for (data_size_t i = start_index + 1; i < end_index; ++i) {
sampled_query_boundaries_[i] += sampled_query_boundaries_[i - 1];
}
sampled_query_boundaires_thread_buffer_[thread_index] = sampled_query_boundaries_[end_index - 1];
});

for (int thread_index = 1; thread_index < num_blocks; ++thread_index) {
sampled_query_boundaires_thread_buffer_[thread_index] += sampled_query_boundaires_thread_buffer_[thread_index - 1];
}

Threading::For<data_size_t>(0, num_sampled_queries_ + 1, 128, [this](int thread_index, data_size_t start_index, data_size_t end_index) {
if (thread_index > 0) {
for (data_size_t i = start_index; i < end_index; ++i) {
sampled_query_boundaries_[i] += sampled_query_boundaires_thread_buffer_[thread_index - 1];
}
}
});

bag_data_cnt_ = sampled_query_boundaries_[num_sampled_queries_];

Threading::For<data_size_t>(0, num_sampled_queries_, 1, [this](int /*thread_index*/, data_size_t start_index, data_size_t end_index) {
for (data_size_t sampled_query_id = start_index; sampled_query_id < end_index; ++sampled_query_id) {
const data_size_t query_index = bag_query_indices_[sampled_query_id];
const data_size_t data_index_start = query_boundaries_[query_index];
const data_size_t data_index_end = query_boundaries_[query_index + 1];
const data_size_t sampled_query_start = sampled_query_boundaries_[sampled_query_id];
for (data_size_t i = data_index_start; i < data_index_end; ++i) {
bag_data_indices_[sampled_query_start + i - data_index_start] = i;
}
}
});
}
Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
// set bagging data to tree learner
if (!is_use_subset_) {
Expand Down Expand Up @@ -108,7 +166,14 @@ class BaggingSampleStrategy : public SampleStrategy {
cuda_bag_data_indices_.Resize(num_data_);
}
#endif // USE_CUDA
bagging_runner_.ReSize(num_data_);
if (!config_->bagging_by_query) {
bagging_runner_.ReSize(num_data_);
} else {
bagging_runner_.ReSize(num_queries_);
sampled_query_boundaries_.resize(num_queries_ + 1, 0);
sampled_query_boundaires_thread_buffer_.resize(num_threads_, 0);
bag_query_indices_.resize(num_data_);
}
bagging_rands_.clear();
for (int i = 0;
i < (num_data_ + bagging_rand_block_ - 1) / bagging_rand_block_; ++i) {
Expand Down Expand Up @@ -153,6 +218,14 @@ class BaggingSampleStrategy : public SampleStrategy {
return false;
}

data_size_t num_sampled_queries() const override {
return num_sampled_queries_;
}

const data_size_t* sampled_query_indices() const override {
return bag_query_indices_.data();
}

private:
data_size_t BaggingHelper(data_size_t start, data_size_t cnt, data_size_t* buffer) {
if (cnt <= 0) {
Expand Down Expand Up @@ -202,6 +275,20 @@ class BaggingSampleStrategy : public SampleStrategy {

/*! \brief whether need restart bagging in continued training */
bool need_re_bagging_;
/*! \brief number of threads */
int num_threads_;
/*! \brief query boundaries of the in-bag queries */
std::vector<data_size_t> sampled_query_boundaries_;
/*! \brief buffer for calculating sampled_query_boundaries_ */
std::vector<data_size_t> sampled_query_boundaires_thread_buffer_;
/*! \brief in-bag query indices */
std::vector<data_size_t, Common::AlignmentAllocator<data_size_t, kAlignedSize>> bag_query_indices_;
/*! \brief number of queries in the training dataset */
data_size_t num_queries_;
/*! \brief number of in-bag queries */
data_size_t num_sampled_queries_;
/*! \brief query boundaries of the whole training dataset */
const data_size_t* query_boundaries_;
};

} // namespace LightGBM
Expand Down
14 changes: 11 additions & 3 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,14 @@ void GBDT::Boosting() {
}
// objective function will calculate gradients and hessians
int64_t num_score = 0;
objective_function_->
GetGradients(GetTrainingScore(&num_score), gradients_pointer_, hessians_pointer_);
if (config_->bagging_by_query) {
data_sample_strategy_->Bagging(iter_, tree_learner_.get(), gradients_.data(), hessians_.data());
objective_function_->
GetGradients(GetTrainingScore(&num_score), data_sample_strategy_->num_sampled_queries(), data_sample_strategy_->sampled_query_indices(), gradients_pointer_, hessians_pointer_);
} else {
objective_function_->
GetGradients(GetTrainingScore(&num_score), gradients_pointer_, hessians_pointer_);
}
}

void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
Expand Down Expand Up @@ -366,7 +372,9 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
}

// bagging logic
data_sample_strategy_->Bagging(iter_, tree_learner_.get(), gradients_.data(), hessians_.data());
if (!config_->bagging_by_query) {
data_sample_strategy_->Bagging(iter_, tree_learner_.get(), gradients_.data(), hessians_.data());
}
const bool is_use_subset = data_sample_strategy_->is_use_subset();
const data_size_t bag_data_cnt = data_sample_strategy_->bag_data_cnt();
const std::vector<data_size_t, Common::AlignmentAllocator<data_size_t, kAlignedSize>>& bag_data_indices = data_sample_strategy_->bag_data_indices();
Expand Down
5 changes: 5 additions & 0 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,11 @@ void Config::CheckParamConflict(const std::unordered_map<std::string, std::strin
Log::Warning("Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss."
"To suppress this warning, set data_sample_strategy=goss instead.");
}

if (bagging_by_query && data_sample_strategy != std::string("bagging")) {
Log::Warning("bagging_by_query=true is only compatible with data_sample_strategy=bagging. Setting bagging_by_query=false.");
bagging_by_query = false;
}
}

std::string Config::ToString() const {
Expand Down
6 changes: 6 additions & 0 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"neg_bagging_fraction",
"bagging_freq",
"bagging_seed",
"bagging_by_query",
"feature_fraction",
"feature_fraction_bynode",
"feature_fraction_seed",
Expand Down Expand Up @@ -377,6 +378,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str

GetInt(params, "bagging_seed", &bagging_seed);

GetBool(params, "bagging_by_query", &bagging_by_query);

GetDouble(params, "feature_fraction", &feature_fraction);
CHECK_GT(feature_fraction, 0.0);
CHECK_LE(feature_fraction, 1.0);
Expand Down Expand Up @@ -688,6 +691,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[neg_bagging_fraction: " << neg_bagging_fraction << "]\n";
str_buf << "[bagging_freq: " << bagging_freq << "]\n";
str_buf << "[bagging_seed: " << bagging_seed << "]\n";
str_buf << "[bagging_by_query: " << bagging_by_query << "]\n";
str_buf << "[feature_fraction: " << feature_fraction << "]\n";
str_buf << "[feature_fraction_bynode: " << feature_fraction_bynode << "]\n";
str_buf << "[feature_fraction_seed: " << feature_fraction_seed << "]\n";
Expand Down Expand Up @@ -813,6 +817,7 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
{"neg_bagging_fraction", {"neg_sub_row", "neg_subsample", "neg_bagging"}},
{"bagging_freq", {"subsample_freq"}},
{"bagging_seed", {"bagging_fraction_seed"}},
{"bagging_by_query", {}},
{"feature_fraction", {"sub_feature", "colsample_bytree"}},
{"feature_fraction_bynode", {"sub_feature_bynode", "colsample_bynode"}},
{"feature_fraction_seed", {}},
Expand Down Expand Up @@ -957,6 +962,7 @@ const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
{"neg_bagging_fraction", "double"},
{"bagging_freq", "int"},
{"bagging_seed", "int"},
{"bagging_by_query", "bool"},
{"feature_fraction", "double"},
{"feature_fraction_bynode", "double"},
{"feature_fraction_seed", "int"},
Expand Down
18 changes: 12 additions & 6 deletions src/objective/rank_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,21 @@ class RankingObjective : public ObjectiveFunction {
pos_biases_.resize(num_position_ids_, 0.0);
}

void GetGradients(const double* score, score_t* gradients,
score_t* hessians) const override {
void GetGradients(const double* score, const data_size_t num_sampled_queries, const data_size_t* sampled_query_indices,
score_t* gradients, score_t* hessians) const override {
const data_size_t num_queries = (sampled_query_indices == nullptr ? num_queries_ : num_sampled_queries);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(guided)
for (data_size_t i = 0; i < num_queries_; ++i) {
const data_size_t start = query_boundaries_[i];
const data_size_t cnt = query_boundaries_[i + 1] - query_boundaries_[i];
for (data_size_t i = 0; i < num_queries; ++i) {
const data_size_t query_index = (sampled_query_indices == nullptr ? i : sampled_query_indices[i]);
const data_size_t start = query_boundaries_[query_index];
const data_size_t cnt = query_boundaries_[query_index + 1] - query_boundaries_[query_index];
std::vector<double> score_adjusted;
if (num_position_ids_ > 0) {
for (data_size_t j = 0; j < cnt; ++j) {
score_adjusted.push_back(score[start + j] + pos_biases_[positions_[start + j]]);
}
}
GetGradientsForOneQuery(i, cnt, label_ + start, num_position_ids_ > 0 ? score_adjusted.data() : score + start,
GetGradientsForOneQuery(query_index, cnt, label_ + start, num_position_ids_ > 0 ? score_adjusted.data() : score + start,
gradients + start, hessians + start);
if (weights_ != nullptr) {
for (data_size_t j = 0; j < cnt; ++j) {
Expand All @@ -84,6 +86,10 @@ class RankingObjective : public ObjectiveFunction {
}
}

void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
GetGradients(score, num_queries_, nullptr, gradients, hessians);
}

virtual void GetGradientsForOneQuery(data_size_t query_id, data_size_t cnt,
const label_t* label,
const double* score, score_t* lambdas,
Expand Down
23 changes: 23 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4509,3 +4509,26 @@ def test_quantized_training():
quant_bst = lgb.train(bst_params, ds, num_boost_round=10)
quant_rmse = np.sqrt(np.mean((quant_bst.predict(X) - y) ** 2))
assert quant_rmse < rmse + 6.0


def test_bagging_by_query_in_lambdarank():
rank_example_dir = Path(__file__).absolute().parents[2] / "examples" / "lambdarank"
X_train, y_train = load_svmlight_file(str(rank_example_dir / "rank.train"))
q_train = np.loadtxt(str(rank_example_dir / "rank.train.query"))
X_test, y_test = load_svmlight_file(str(rank_example_dir / "rank.test"))
q_test = np.loadtxt(str(rank_example_dir / "rank.test.query"))
params = {"objective": "lambdarank", "verbose": -1, "metric": "ndcg", "ndcg_eval_at": [5]}
lgb_train = lgb.Dataset(X_train, y_train, group=q_train, params=params)
lgb_test = lgb.Dataset(X_test, y_test, group=q_test, params=params)
gbm = lgb.train(params, lgb_train, num_boost_round=50, valid_sets=[lgb_test])
ndcg_score = gbm.best_score["valid_0"]["ndcg@5"]

params.update({"bagging_by_query": True, "bagging_fraction": 0.1, "bagging_freq": 1})
gbm_bagging_by_query = lgb.train(params, lgb_train, num_boost_round=50, valid_sets=[lgb_test])
ndcg_score_bagging_by_query = gbm_bagging_by_query.best_score["valid_0"]["ndcg@5"]

params.update({"bagging_by_query": False, "bagging_fraction": 0.1, "bagging_freq": 1})
gbm_no_bagging_by_query = lgb.train(params, lgb_train, num_boost_round=50, valid_sets=[lgb_test])
ndcg_score_no_bagging_by_query = gbm_no_bagging_by_query.best_score["valid_0"]["ndcg@5"]
assert ndcg_score_bagging_by_query >= ndcg_score - 0.1
assert ndcg_score_no_bagging_by_query >= ndcg_score - 0.1

0 comments on commit d1d218c

Please sign in to comment.