diff --git a/runtime/onert/core/src/loader/train/CheckpointLoader.cc b/runtime/onert/core/src/loader/train/CheckpointLoader.cc index 951b0180f27..36c95d35d9d 100644 --- a/runtime/onert/core/src/loader/train/CheckpointLoader.cc +++ b/runtime/onert/core/src/loader/train/CheckpointLoader.cc @@ -32,6 +32,48 @@ using namespace train; using namespace checkpoint; using namespace exec; +struct DataBufferPair +{ + uint32_t offset; + uint32_t size; +}; + +class DataBuffer +{ +public: + DataBuffer() = default; + DataBuffer(uint32_t size) { setSize(size); } + + void setSize(uint32_t size) + { + _offset.resize(size); + _size.resize(size); + } + + char *getOffsetBuf() { return reinterpret_cast(_offset.data()); } + + // This function should be called after loading the _offset buffer. + void calculateSize(uint32_t next_start_offset) + { + assert(_offset.size() == _size.size()); + for (size_t i = 0; i < _offset.size() - 1; ++i) + _size[i] = _offset[i + 1] - _offset[i]; + _size.back() = next_start_offset - _offset.back(); + } + + // offset, size + DataBufferPair operator[](uint32_t i) const + { + assert(_offset.size() == _size.size()); + assert(i <= _offset.size()); + return DataBufferPair{_offset[i], _size[i]}; + } + +private: + std::vector _offset; + std::vector _size; +}; + class CheckpointLoader final { public: @@ -62,6 +104,30 @@ class CheckpointLoader final if (_header.schema != checkpoint::SCHEMA_VERSION) throw std::runtime_error{"Invalid SCHEMA VERSION"}; + _tensor_data.setSize(_header.length); + _file.read(_tensor_data.getOffsetBuf(), _header.length * sizeof(uint32_t)); + if (_file.fail()) + throw std::runtime_error{"Failed to load tensor data"}; + _tensor_data.calculateSize(_header.opt1_offset); + + if (_header.opt1_offset) + { + DataBuffer opt1_data(_header.length); + _file.seekg(_header.opt1_offset, std::ios::beg); + _file.read(opt1_data.getOffsetBuf(), _header.length * sizeof(uint32_t)); + opt1_data.calculateSize(_header.opt2_offset); + _optimizer_data.emplace_back(std::move(opt1_data)); + } + + if (_header.opt2_offset) + { + DataBuffer opt2_data(_header.length); + _file.seekg(_header.opt2_offset, std::ios::beg); + _file.read(opt2_data.getOffsetBuf(), _header.length * sizeof(uint32_t)); + opt2_data.calculateSize(_header.other_offset); + _optimizer_data.emplace_back(std::move(opt2_data)); + } + if ((filesize - _header.other_offset) != sizeof(_footer)) throw std::runtime_error{"Invalid checkpoint file footer data"}; @@ -76,10 +142,93 @@ class CheckpointLoader final _file.close(); } + void updateTensor(const std::unique_ptr &exec) + { + uint32_t vindex = 0; + exec->iterateTrainableTensors( + [&](const ir::OperandIndex &, const backend::train::ITrainableTensor *) { vindex++; }); + + if (_header.length != vindex) + throw std::runtime_error{ + "Invalid number of tensors between TrainingInfo and checkpoint file"}; + + // Reset EOF bit + _file.clear(); + + vindex = 0; + exec->iterateTrainableTensors( + [&](const ir::OperandIndex &, const backend::train::ITrainableTensor *tensor) { + assert(tensor); + assert(_tensor_data[vindex].size == tensor->total_size()); + _file.seekg(_tensor_data[vindex].offset, std::ios::beg); + _file.read(reinterpret_cast(tensor->buffer()), tensor->total_size()); + vindex++; + }); + } + + void updateOptimizer(const std::unique_ptr &train_info, + const std::unique_ptr &exec) + { + ir::train::OptimizerCode ckpt_opt_code = ir::train::OptimizerCode::SGD; + // TODO Support other optimizers + if (_optimizer_data.size() == 2) + ckpt_opt_code = ir::train::OptimizerCode::Adam; + + if (ckpt_opt_code != train_info->optimizerInfo().optim_code) + throw std::runtime_error{ + "Not compatible optimizer type between TrainingInfo and checkpoint file"}; + + switch (train_info->optimizerInfo().optim_code) + { + case ir::train::OptimizerCode::Adam: + updateAdamOptimizer(exec); + break; + default: + break; + } + } + + void updateTrainingInfo(const std::unique_ptr &train_info) + { + // TODO Verify cur_step value + train_info->trainingStep() = _footer.cur_step; + } + +private: + void updateAdamOptimizer(const std::unique_ptr &exec) + { + // Adam optimizer has two optimizer variables. (mean, variance) + [[maybe_unused]] constexpr auto ADAM_VARIABLE_COUNT = 2; + + // Reset EOF bit + _file.clear(); + + auto vindex = 0; + exec->iterateTrainableTensors( + [&](const ir::OperandIndex &, const backend::train::ITrainableTensor *tensor) { + assert(tensor); + auto trainable_tensor = const_cast(tensor); + const auto opt_vars = trainable_tensor->optVars(); + + // Untrainable tensor should not have any optimizer variables. + assert(opt_vars.size() == ADAM_VARIABLE_COUNT || opt_vars.size() == 0); + + for (size_t i = 0; i < opt_vars.size(); ++i) + { + assert(opt_vars[i]->total_size() == _optimizer_data[i][vindex].size); + _file.seekg(_optimizer_data[i][vindex].offset, std::ios::beg); + _file.read(reinterpret_cast(opt_vars[i]->buffer()), opt_vars[i]->total_size()); + } + vindex++; + }); + } + private: std::ifstream _file; checkpoint::Header _header; checkpoint::Footer _footer; + DataBuffer _tensor_data; + std::vector _optimizer_data; }; } // namespace @@ -96,11 +245,9 @@ void loadCheckpoint(const std::string &filename, const std::unique_ptr &exec) { CheckpointLoader loader(filename); - - // TODO Load tensor data - UNUSED_RELEASE(exec); - // TODO Update step in train_info - UNUSED_RELEASE(train_info); + loader.updateTensor(exec); + loader.updateOptimizer(train_info, exec); + loader.updateTrainingInfo(train_info); } } // namespace train