diff --git a/src/gbm/gbtree-inl.hpp b/src/gbm/gbtree-inl.hpp index 0a1ee4f9807d..c868c302a267 100644 --- a/src/gbm/gbtree-inl.hpp +++ b/src/gbm/gbtree-inl.hpp @@ -64,7 +64,13 @@ class GBTree : public IGradBooster { } virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const { utils::Assert(mparam.num_trees == static_cast(trees.size()), "GBTree"); - fo.Write(&mparam, sizeof(ModelParam)); + if (with_pbuffer) { + fo.Write(&mparam, sizeof(ModelParam)); + } else { + ModelParam p = mparam; + p.num_pbuffer = 0; + fo.Write(&p, sizeof(ModelParam)); + } for (size_t i = 0; i < trees.size(); ++i) { trees[i]->SaveModel(fo); } diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index 9ceec969e1d0..5a080d5b1b2c 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -157,11 +157,9 @@ class BoostLearner : public rabit::Serializable { /*! * \brief load model from stream * \param fi input stream - * \param with_pbuffer whether to load with predict buffer * \param calc_num_feature whether call InitTrainer with calc_num_feature */ inline void LoadModel(utils::IStream &fi, - bool with_pbuffer = true, bool calc_num_feature = true) { utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0, "BoostLearner: wrong model format"); @@ -189,15 +187,15 @@ class BoostLearner : public rabit::Serializable { char tmp[32]; utils::SPrintf(tmp, sizeof(tmp), "%u", mparam.num_class); obj_->SetParam("num_class", tmp); - gbm_->LoadModel(fi, with_pbuffer); - if (!with_pbuffer || distributed_mode == 2) { + gbm_->LoadModel(fi, mparam.saved_with_pbuffer != 0); + if (mparam.saved_with_pbuffer == 0) { gbm_->ResetPredBuffer(pred_buffer_size); } } // rabit load model from rabit checkpoint virtual void Load(rabit::Stream *fi) { // for row split, we should not keep pbuffer - this->LoadModel(*fi, distributed_mode != 2, false); + this->LoadModel(*fi, false); } // rabit save model to rabit checkpoint virtual void Save(rabit::Stream *fo) const { @@ -218,18 +216,20 @@ class BoostLearner : public rabit::Serializable { if (header == "bs64") { utils::Base64InStream bsin(fi); bsin.InitPosition(); - this->LoadModel(bsin); + this->LoadModel(bsin, true); } else if (header == "binf") { - this->LoadModel(*fi); + this->LoadModel(*fi, true); } else { delete fi; fi = utils::IStream::Create(fname, "r"); - this->LoadModel(*fi); + this->LoadModel(*fi, true); } delete fi; } - inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const { - fo.Write(&mparam, sizeof(ModelParam)); + inline void SaveModel(utils::IStream &fo, bool with_pbuffer) const { + ModelParam p = mparam; + p.saved_with_pbuffer = static_cast(with_pbuffer); + fo.Write(&p, sizeof(ModelParam)); fo.Write(name_obj_); fo.Write(name_gbm_); gbm_->SaveModel(fo, with_pbuffer); @@ -237,17 +237,18 @@ class BoostLearner : public rabit::Serializable { /*! * \brief save model into file * \param fname file name + * \param with_pbuffer whether save pbuffer together */ - inline void SaveModel(const char *fname) const { + inline void SaveModel(const char *fname, bool with_pbuffer) const { utils::IStream *fo = utils::IStream::Create(fname, "w"); if (save_base64 != 0 || !strcmp(fname, "stdout")) { fo->Write("bs64\t", 5); utils::Base64OutStream bout(fo); - this->SaveModel(bout); + this->SaveModel(bout, with_pbuffer); bout.Finish('\n'); } else { fo->Write("binf", 4); - this->SaveModel(*fo); + this->SaveModel(*fo, with_pbuffer); } delete fo; } @@ -442,14 +443,17 @@ class BoostLearner : public rabit::Serializable { unsigned num_feature; /* \brief number of class, if it is multi-class classification */ int num_class; + /*! \brief whether the model itself is saved with pbuffer */ + int saved_with_pbuffer; /*! \brief reserved field */ - int reserved[31]; + int reserved[30]; /*! \brief constructor */ ModelParam(void) { + std::memset(this, 0, sizeof(ModelParam)); base_score = 0.5f; num_feature = 0; num_class = 0; - std::memset(reserved, 0, sizeof(reserved)); + saved_with_pbuffer = 0; } /*! * \brief set parameters from outside diff --git a/src/xgboost_main.cpp b/src/xgboost_main.cpp index ad87f8879c2a..769e3be3b1ed 100644 --- a/src/xgboost_main.cpp +++ b/src/xgboost_main.cpp @@ -87,6 +87,7 @@ class BoostLearnTask { if (!strcmp("name_pred", name)) name_pred = val; if (!strcmp("dsplit", name)) data_split = val; if (!strcmp("dump_stats", name)) dump_model_stats = atoi(val); + if (!strcmp("save_pbuffer", name)) save_with_pbuffer = atoi(val); if (!strncmp("eval[", name, 5)) { char evname[256]; utils::Assert(sscanf(name, "eval[%[^]]", evname) == 1, "must specify evaluation name for display"); @@ -115,6 +116,7 @@ class BoostLearnTask { model_dir_path = "./"; data_split = "NONE"; load_part = 0; + save_with_pbuffer = 0; data = NULL; } ~BoostLearnTask(void){ @@ -241,7 +243,7 @@ class BoostLearnTask { } inline void SaveModel(const char *fname) const { if (rabit::GetRank() != 0) return; - learner.SaveModel(fname); + learner.SaveModel(fname, save_with_pbuffer != 0); } inline void SaveModel(int i) const { char fname[256]; @@ -297,6 +299,8 @@ class BoostLearnTask { int pred_margin; /*! \brief whether dump statistics along with model */ int dump_model_stats; + /*! \brief whether save prediction buffer */ + int save_with_pbuffer; /*! \brief name of feature map */ std::string name_fmap; /*! \brief name of dump file */ diff --git a/wrapper/xgboost_wrapper.cpp b/wrapper/xgboost_wrapper.cpp index 8ec3aa3f4b25..be2a2001cdb8 100644 --- a/wrapper/xgboost_wrapper.cpp +++ b/wrapper/xgboost_wrapper.cpp @@ -58,13 +58,13 @@ class Booster: public learner::BoostLearner { } inline void LoadModelFromBuffer(const void *buf, size_t size) { utils::MemoryFixSizeBuffer fs((void*)buf, size); - learner::BoostLearner::LoadModel(fs); + learner::BoostLearner::LoadModel(fs, true); this->init_model = true; } inline const char *GetModelRaw(bst_ulong *out_len) { model_str.resize(0); utils::MemoryBufferStream fs(&model_str); - learner::BoostLearner::SaveModel(fs); + learner::BoostLearner::SaveModel(fs, false); *out_len = static_cast(model_str.length()); if (*out_len == 0) { return NULL; @@ -323,7 +323,7 @@ extern "C"{ static_cast(handle)->LoadModel(fname); } void XGBoosterSaveModel(const void *handle, const char *fname) { - static_cast(handle)->SaveModel(fname); + static_cast(handle)->SaveModel(fname, false); } void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) { static_cast(handle)->LoadModelFromBuffer(buf, len);