Skip to content

Commit

Permalink
Merge pull request dmlc#260 from dmlc/colopt
Browse files Browse the repository at this point in the history
Colopt
  • Loading branch information
tqchen committed Apr 25, 2015
2 parents f28a7a0 + 5870b47 commit 4275434
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 84 deletions.
2 changes: 1 addition & 1 deletion src/io/libsvm_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class LibSVMPageFactory {
int maxthread;
#pragma omp parallel
{
maxthread = omp_get_num_threads();
maxthread = omp_get_num_procs();
}
maxthread = std::max(maxthread / 2, 1);
nthread_ = std::min(maxthread, nthread);
Expand Down
17 changes: 14 additions & 3 deletions src/io/page_dmatrix-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace io {
class ThreadRowPageIterator: public utils::IIterator<RowBatch> {
public:
ThreadRowPageIterator(void) {
itr.SetParam("buffer_size", "2");
itr.SetParam("buffer_size", "4");
page_ = NULL;
base_rowid_ = 0;
}
Expand Down Expand Up @@ -109,7 +109,7 @@ class DMatrixPageBase : public DataMatrix {
std::string fname = fname_;
int tmagic;
utils::Check(fi.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format");
utils::Check(tmagic == magic, "invalid format,magic number mismatch");
this->CheckMagic(tmagic);
this->info.LoadBinary(fi);
// load in the row data file
fname += ".row.blob";
Expand Down Expand Up @@ -203,6 +203,7 @@ class DMatrixPageBase : public DataMatrix {

protected:
virtual void set_cache_file(const std::string &cache_file) = 0;
virtual void CheckMagic(int tmagic) = 0;
/*! \brief row iterator */
ThreadRowPageIterator *iter_;
};
Expand All @@ -221,6 +222,11 @@ class DMatrixPage : public DMatrixPageBase<0xffffab02> {
virtual void set_cache_file(const std::string &cache_file) {
fmat_->set_cache_file(cache_file);
}
virtual void CheckMagic(int tmagic) {
utils::Check(tmagic == DMatrixPageBase<0xffffab02>::kMagic ||
tmagic == DMatrixPageBase<0xffffab03>::kMagic,
"invalid format,magic number mismatch");
}
/*! \brief the real fmatrix */
FMatrixPage *fmat_;
};
Expand All @@ -238,7 +244,12 @@ class DMatrixHalfRAM : public DMatrixPageBase<0xffffab03> {
virtual IFMatrix *fmat(void) const {
return fmat_;
}
virtual void set_cache_file(const std::string &cache_file) {
virtual void set_cache_file(const std::string &cache_file) {
}
virtual void CheckMagic(int tmagic) {
utils::Check(tmagic == DMatrixPageBase<0xffffab02>::kMagic ||
tmagic == DMatrixPageBase<0xffffab03>::kMagic,
"invalid format,magic number mismatch");
}
/*! \brief the real fmatrix */
IFMatrix *fmat_;
Expand Down
218 changes: 138 additions & 80 deletions src/io/page_fmatrix-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,124 @@ class ThreadColPageIterator: public utils::IIterator<ColBatch> {
std::vector<SparseBatch::Inst> col_data_;
utils::ThreadBuffer<SparsePage*, SparsePageFactory> itr;
};

struct ColConvertFactory {
inline bool Init(void) {
return true;
}
inline void Setup(float pkeep,
size_t num_col,
utils::IIterator<RowBatch> *iter,
std::vector<bst_uint> *buffered_rowset,
const std::vector<bool> *enabled) {
pkeep_ = pkeep;
num_col_ = num_col;
iter_ = iter;
buffered_rowset_ = buffered_rowset;
enabled_ = enabled;
}
inline SparsePage *Create(void) {
return new SparsePage();
}
inline void FreeSpace(SparsePage *a) {
delete a;
}
inline void SetParam(const char *name, const char *val) {}
inline bool LoadNext(SparsePage *val) {
tmp_.Clear();
size_t btop = buffered_rowset_->size();
while (iter_->Next()) {
const RowBatch &batch = iter_->Value();
for (size_t i = 0; i < batch.size; ++i) {
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (pkeep_ == 1.0f || random::SampleBinary(pkeep_)) {
buffered_rowset_->push_back(ridx);
tmp_.Push(batch[i]);
}
}
if (tmp_.MemCostBytes() >= kPageSize) {
this->MakeColPage(tmp_, BeginPtr(*buffered_rowset_) + btop,
*enabled_, val);
return true;
}
}
if (tmp_.Size() != 0){
this->MakeColPage(tmp_, BeginPtr(*buffered_rowset_) + btop,
*enabled_, val);
return true;
} else {
return false;
}
}
inline void Destroy(void) {}
inline void BeforeFirst(void) {}
inline void MakeColPage(const SparsePage &prow,
const bst_uint *ridx,
const std::vector<bool> &enabled,
SparsePage *pcol) {
pcol->Clear();
int nthread;
#pragma omp parallel
{
nthread = omp_get_num_threads();
int max_nthread = std::max(omp_get_num_procs() / 2 - 4, 1);
if (nthread > max_nthread) {
nthread = max_nthread;
}
}
pcol->Clear();
utils::ParallelGroupBuilder<SparseBatch::Entry>
builder(&pcol->offset, &pcol->data);
builder.InitBudget(num_col_, nthread);
bst_omp_uint ndata = static_cast<bst_uint>(prow.Size());
#pragma omp parallel for schedule(static) num_threads(nthread)
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
const SparseBatch::Entry &e = prow.data[j];
if (enabled[e.index]) {
builder.AddBudget(e.index, tid);
}
}
}
builder.InitStorage();
#pragma omp parallel for schedule(static) num_threads(nthread)
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
const SparseBatch::Entry &e = prow.data[j];
builder.Push(e.index,
SparseBatch::Entry(ridx[i], e.fvalue),
tid);
}
}
utils::Assert(pcol->Size() == num_col_, "inconsistent col data");
// sort columns
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
for (bst_omp_uint i = 0; i < ncol; ++i) {
if (pcol->offset[i] < pcol->offset[i + 1]) {
std::sort(BeginPtr(pcol->data) + pcol->offset[i],
BeginPtr(pcol->data) + pcol->offset[i + 1],
SparseBatch::Entry::CmpValue);
}
}
}
// probability of keep
float pkeep_;
// number of columns
size_t num_col_;
// row batch iterator
utils::IIterator<RowBatch> *iter_;
// buffered rowset
std::vector<bst_uint> *buffered_rowset_;
// enabled marks
const std::vector<bool> *enabled_;
// internal temp cache
SparsePage tmp_;
/*! \brief page size 256 M */
static const size_t kPageSize = 256 << 20UL;
};
/*!
* \brief sparse matrix that support column access, CSC
*/
Expand Down Expand Up @@ -165,101 +283,41 @@ class FMatrixPage : public IFMatrix {
* \param pkeep probability to keep a row
*/
inline void InitColData(const std::vector<bool> &enabled, float pkeep) {
SparsePage prow, pcol;
size_t btop = 0;
// clear rowset
buffered_rowset_.clear();
col_size_.resize(info.num_col());
std::fill(col_size_.begin(), col_size_.end(), 0);
utils::FileStream fo;
fo = utils::FileStream(utils::FopenCheck(col_data_name_.c_str(), "wb"));
size_t bytes_write = 0;
double tstart = rabit::utils::GetTime();
// start working
iter_->BeforeFirst();
while (iter_->Next()) {
const RowBatch &batch = iter_->Value();
for (size_t i = 0; i < batch.size; ++i) {
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (pkeep == 1.0f || random::SampleBinary(pkeep)) {
buffered_rowset_.push_back(ridx);
prow.Push(batch[i]);
if (prow.MemCostBytes() >= kPageSize) {
bytes_write += prow.MemCostBytes();
this->PushColPage(prow, BeginPtr(buffered_rowset_) + btop,
enabled, &pcol, &fo);
btop += prow.Size();
prow.Clear();

double tdiff = rabit::utils::GetTime() - tstart;
utils::Printf("Writting to %s in %g MB/s, %lu MB written\n",
col_data_name_.c_str(),
(bytes_write >> 20UL) / tdiff,
(bytes_write >> 20UL));
}
}
double tstart = rabit::utils::GetTime();
size_t bytes_write = 0;
utils::ThreadBuffer<SparsePage*, ColConvertFactory> citer;
citer.SetParam("buffer_size", "2");
citer.get_factory().Setup(pkeep, info.num_col(),
iter_, &buffered_rowset_, &enabled);
citer.Init();
SparsePage *pcol;
while (citer.Next(pcol)) {
for (size_t i = 0; i < pcol->Size(); ++i) {
col_size_[i] += pcol->offset[i + 1] - pcol->offset[i];
}
}
if (prow.Size() != 0) {
this->PushColPage(prow, BeginPtr(buffered_rowset_) + btop,
enabled, &pcol, &fo);
pcol->Save(&fo);
size_t spage = pcol->MemCostBytes();
bytes_write += spage;
double tnow = rabit::utils::GetTime();
double tdiff = tnow - tstart;
utils::Printf("Writting to %s in %g MB/s, %lu MB written current speed:%g MB/s\n",
col_data_name_.c_str(),
(bytes_write >> 20UL) / tdiff,
(bytes_write >> 20UL));
}
fo.Close();
num_buffered_row_ = buffered_rowset_.size();
fo = utils::FileStream(utils::FopenCheck(col_meta_name_.c_str(), "wb"));
this->SaveMeta(&fo);
fo.Close();
}
inline void PushColPage(const SparsePage &prow,
const bst_uint *ridx,
const std::vector<bool> &enabled,
SparsePage *pcol,
utils::IStream *fo) {
pcol->Clear();
int nthread;
#pragma omp parallel
{
nthread = omp_get_num_threads();
}
pcol->Clear();
utils::ParallelGroupBuilder<SparseBatch::Entry>
builder(&pcol->offset, &pcol->data);
builder.InitBudget(info.num_col(), nthread);
bst_omp_uint ndata = static_cast<bst_uint>(prow.Size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
const SparseBatch::Entry &e = prow.data[j];
if (enabled[e.index]) {
builder.AddBudget(e.index, tid);
}
}
}
builder.InitStorage();
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
const SparseBatch::Entry &e = prow.data[j];
builder.Push(e.index,
SparseBatch::Entry(ridx[i], e.fvalue),
tid);
}
}
utils::Assert(pcol->Size() == info.num_col(), "inconsistent col data");
// sort columns
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < ncol; ++i) {
if (pcol->offset[i] < pcol->offset[i + 1]) {
std::sort(BeginPtr(pcol->data) + pcol->offset[i],
BeginPtr(pcol->data) + pcol->offset[i + 1], Entry::CmpValue);
}
col_size_[i] += pcol->offset[i + 1] - pcol->offset[i];
}
pcol->Save(fo);
}

private:
/*! \brief page size 256 M */
Expand Down

0 comments on commit 4275434

Please sign in to comment.