Skip to content

Commit

Permalink
add with pbuffer info to model, allow xgb model to be saved in a more…
Browse files Browse the repository at this point in the history
… memory compact way
  • Loading branch information
tqchen committed May 6, 2015
1 parent 3b46977 commit 7f7947f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 20 deletions.
8 changes: 7 additions & 1 deletion src/gbm/gbtree-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ class GBTree : public IGradBooster {
}
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree");
fo.Write(&mparam, sizeof(ModelParam));
if (with_pbuffer) {
fo.Write(&mparam, sizeof(ModelParam));
} else {
ModelParam p = mparam;
p.num_pbuffer = 0;
fo.Write(&p, sizeof(ModelParam));
}
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->SaveModel(fo);
}
Expand Down
34 changes: 19 additions & 15 deletions src/learner/learner-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,9 @@ class BoostLearner : public rabit::Serializable {
/*!
* \brief load model from stream
* \param fi input stream
* \param with_pbuffer whether to load with predict buffer
* \param calc_num_feature whether call InitTrainer with calc_num_feature
*/
inline void LoadModel(utils::IStream &fi,
bool with_pbuffer = true,
bool calc_num_feature = true) {
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"BoostLearner: wrong model format");
Expand Down Expand Up @@ -189,15 +187,15 @@ class BoostLearner : public rabit::Serializable {
char tmp[32];
utils::SPrintf(tmp, sizeof(tmp), "%u", mparam.num_class);
obj_->SetParam("num_class", tmp);
gbm_->LoadModel(fi, with_pbuffer);
if (!with_pbuffer || distributed_mode == 2) {
gbm_->LoadModel(fi, mparam.saved_with_pbuffer != 0);
if (mparam.saved_with_pbuffer == 0) {
gbm_->ResetPredBuffer(pred_buffer_size);
}
}
// rabit load model from rabit checkpoint
virtual void Load(rabit::Stream *fi) {
// for row split, we should not keep pbuffer
this->LoadModel(*fi, distributed_mode != 2, false);
this->LoadModel(*fi, false);
}
// rabit save model to rabit checkpoint
virtual void Save(rabit::Stream *fo) const {
Expand All @@ -218,36 +216,39 @@ class BoostLearner : public rabit::Serializable {
if (header == "bs64") {
utils::Base64InStream bsin(fi);
bsin.InitPosition();
this->LoadModel(bsin);
this->LoadModel(bsin, true);
} else if (header == "binf") {
this->LoadModel(*fi);
this->LoadModel(*fi, true);
} else {
delete fi;
fi = utils::IStream::Create(fname, "r");
this->LoadModel(*fi);
this->LoadModel(*fi, true);
}
delete fi;
}
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
fo.Write(&mparam, sizeof(ModelParam));
inline void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
ModelParam p = mparam;
p.saved_with_pbuffer = static_cast<int>(with_pbuffer);
fo.Write(&p, sizeof(ModelParam));
fo.Write(name_obj_);
fo.Write(name_gbm_);
gbm_->SaveModel(fo, with_pbuffer);
}
/*!
* \brief save model into file
* \param fname file name
* \param with_pbuffer whether save pbuffer together
*/
inline void SaveModel(const char *fname) const {
inline void SaveModel(const char *fname, bool with_pbuffer) const {
utils::IStream *fo = utils::IStream::Create(fname, "w");
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
fo->Write("bs64\t", 5);
utils::Base64OutStream bout(fo);
this->SaveModel(bout);
this->SaveModel(bout, with_pbuffer);
bout.Finish('\n');
} else {
fo->Write("binf", 4);
this->SaveModel(*fo);
this->SaveModel(*fo, with_pbuffer);
}
delete fo;
}
Expand Down Expand Up @@ -442,14 +443,17 @@ class BoostLearner : public rabit::Serializable {
unsigned num_feature;
/* \brief number of class, if it is multi-class classification */
int num_class;
/*! \brief whether the model itself is saved with pbuffer */
int saved_with_pbuffer;
/*! \brief reserved field */
int reserved[31];
int reserved[30];
/*! \brief constructor */
ModelParam(void) {
std::memset(this, 0, sizeof(ModelParam));
base_score = 0.5f;
num_feature = 0;
num_class = 0;
std::memset(reserved, 0, sizeof(reserved));
saved_with_pbuffer = 0;
}
/*!
* \brief set parameters from outside
Expand Down
6 changes: 5 additions & 1 deletion src/xgboost_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class BoostLearnTask {
if (!strcmp("name_pred", name)) name_pred = val;
if (!strcmp("dsplit", name)) data_split = val;
if (!strcmp("dump_stats", name)) dump_model_stats = atoi(val);
if (!strcmp("save_pbuffer", name)) save_with_pbuffer = atoi(val);
if (!strncmp("eval[", name, 5)) {
char evname[256];
utils::Assert(sscanf(name, "eval[%[^]]", evname) == 1, "must specify evaluation name for display");
Expand Down Expand Up @@ -115,6 +116,7 @@ class BoostLearnTask {
model_dir_path = "./";
data_split = "NONE";
load_part = 0;
save_with_pbuffer = 0;
data = NULL;
}
~BoostLearnTask(void){
Expand Down Expand Up @@ -241,7 +243,7 @@ class BoostLearnTask {
}
inline void SaveModel(const char *fname) const {
if (rabit::GetRank() != 0) return;
learner.SaveModel(fname);
learner.SaveModel(fname, save_with_pbuffer != 0);
}
inline void SaveModel(int i) const {
char fname[256];
Expand Down Expand Up @@ -297,6 +299,8 @@ class BoostLearnTask {
int pred_margin;
/*! \brief whether dump statistics along with model */
int dump_model_stats;
/*! \brief whether save prediction buffer */
int save_with_pbuffer;
/*! \brief name of feature map */
std::string name_fmap;
/*! \brief name of dump file */
Expand Down
6 changes: 3 additions & 3 deletions wrapper/xgboost_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ class Booster: public learner::BoostLearner {
}
inline void LoadModelFromBuffer(const void *buf, size_t size) {
utils::MemoryFixSizeBuffer fs((void*)buf, size);
learner::BoostLearner::LoadModel(fs);
learner::BoostLearner::LoadModel(fs, true);
this->init_model = true;
}
inline const char *GetModelRaw(bst_ulong *out_len) {
model_str.resize(0);
utils::MemoryBufferStream fs(&model_str);
learner::BoostLearner::SaveModel(fs);
learner::BoostLearner::SaveModel(fs, false);
*out_len = static_cast<bst_ulong>(model_str.length());
if (*out_len == 0) {
return NULL;
Expand Down Expand Up @@ -323,7 +323,7 @@ extern "C"{
static_cast<Booster*>(handle)->LoadModel(fname);
}
void XGBoosterSaveModel(const void *handle, const char *fname) {
static_cast<const Booster*>(handle)->SaveModel(fname);
static_cast<const Booster*>(handle)->SaveModel(fname, false);
}
void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) {
static_cast<Booster*>(handle)->LoadModelFromBuffer(buf, len);
Expand Down

0 comments on commit 7f7947f

Please sign in to comment.