diff --git a/runtime/onert/api/nnfw/src/nnfw_api_internal.cc b/runtime/onert/api/nnfw/src/nnfw_api_internal.cc index 6a2ab2e4ae3..2489fcd3286 100644 --- a/runtime/onert/api/nnfw/src/nnfw_api_internal.cc +++ b/runtime/onert/api/nnfw/src/nnfw_api_internal.cc @@ -1719,7 +1719,7 @@ NNFW_STATUS nnfw_session::train_import_checkpoint(const char *path) try { - onert::loader::loadCheckpoint(_execution, _train_info, path); + onert::loader::loadCheckpoint(path, _train_info, _execution); } catch (const std::exception &e) { diff --git a/runtime/onert/core/include/loader/CheckpointLoader.h b/runtime/onert/core/include/loader/CheckpointLoader.h index dd433895dfe..7ff3220c992 100644 --- a/runtime/onert/core/include/loader/CheckpointLoader.h +++ b/runtime/onert/core/include/loader/CheckpointLoader.h @@ -40,9 +40,9 @@ namespace onert namespace loader { -void loadCheckpoint(const std::unique_ptr &exec, const -std::unique_ptr &train_info, - const std::string &filename); +void loadCheckpoint(const std::string &filename, + const std::unique_ptr &train_info, + const std::unique_ptr &exec); } // namespace loader } // namespace onert diff --git a/runtime/onert/core/src/loader/CheckpointLoader.cc b/runtime/onert/core/src/loader/CheckpointLoader.cc index 42b502753e4..b6005144edd 100644 --- a/runtime/onert/core/src/loader/CheckpointLoader.cc +++ b/runtime/onert/core/src/loader/CheckpointLoader.cc @@ -49,7 +49,7 @@ struct __attribute__((packed)) Header struct DataBufferPair { - DataBufferPair(uint32_t _offset, uint32_t _size): offset{_offset}, size{_size} + DataBufferPair(uint32_t _offset, uint32_t _size) : offset{_offset}, size{_size} { // DO NOTHING } @@ -63,19 +63,20 @@ struct DataBuffer std::vector offset; std::vector size; - void resize(uint32_t length) { + void resize(uint32_t length) + { offset.resize(length); size.resize(length); } - char *getOffsetBuf() { - return reinterpret_cast(offset.data()); - } + char *getOffsetBuf() { return reinterpret_cast(offset.data()); } - void calculateSize(uint32_t next_beg_offset) { + void calculateSize(uint32_t next_beg_offset) + { assert(offset.size() == size.size()); uint32_t cur = offset[0]; - for (size_t i = 1; i < offset.size(); ++i) { + for (size_t i = 1; i < offset.size(); ++i) + { size[i - 1] = offset[i] - cur; cur = offset[i]; } @@ -83,7 +84,8 @@ struct DataBuffer } // offset, size - DataBufferPair operator[](uint32_t i) { + DataBufferPair operator[](uint32_t i) + { assert(offset.size() == size.size()); assert(i <= offset.size()); return DataBufferPair{offset[i], size[i]}; @@ -105,7 +107,7 @@ class CheckpointLoader _file.seekg(0, std::ios::end); const auto filesize = _file.tellg(); _file.seekg(0, std::ios::beg); - + if (filesize < static_cast(sizeof(_header))) throw std::runtime_error{"Invalid checkpoint file data"}; @@ -123,7 +125,7 @@ class CheckpointLoader _tensor_data.resize(_header.length); _file.read(_tensor_data.getOffsetBuf(), _header.length * sizeof(uint32_t)); _tensor_data.calculateSize(_header.opt1_offset); - + if (_header.opt1_offset) { _opt1_data.resize(_header.length); @@ -171,9 +173,9 @@ class CheckpointLoader DataBuffer _opt2_data; }; - -void loadCheckpoint(const std::unique_ptr &exec, const std::unique_ptr &train_info, - const std::string &filename) +void loadCheckpoint(const std::string &filename, + const std::unique_ptr &train_info, + const std::unique_ptr &exec) { CheckpointLoader loader(filename); loader.updateTensor(exec); diff --git a/tests/tools/onert_train/src/onert_train.cc b/tests/tools/onert_train/src/onert_train.cc index 1b4a5a2b33f..a7b0dd3fbc6 100644 --- a/tests/tools/onert_train/src/onert_train.cc +++ b/tests/tools/onert_train/src/onert_train.cc @@ -161,7 +161,12 @@ int main(const int argc, char **argv) // prepare execution // TODO When nnfw_{prepare|run} are failed, can't catch the time - measure.run(PhaseType::PREPARE, [&]() { NNPR_ENSURE_STATUS(nnfw_train_prepare(session)); }); + measure.run(PhaseType::PREPARE, [&]() { + NNPR_ENSURE_STATUS(nnfw_train_prepare(session)); + + if (auto name = args.getCheckpointFilename(); name != "") + NNPR_ENSURE_STATUS(nnfw_train_import_checkpoint(session, name.c_str())); + }); // prepare input and expected tensor info lists std::vector input_infos;