From fdfb2497191ad0377846bfb322a34f9f9f9bc32c Mon Sep 17 00:00:00 2001 From: Balyshev Artem <43214667+BalyshevArtem@users.noreply.github.com> Date: Fri, 14 Jun 2024 19:40:33 +0300 Subject: [PATCH] [onert-micro] Introduce OMCheckpointLoader and OMCheckpointSaver (#13147) This pr introduces OMCheckpointLoader and OMCheckpointSaver entities. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- .../include/core/train/OMCheckpointLoader.h | 65 +++ .../include/core/train/OMCheckpointSaver.h | 66 +++ .../src/core/train/OMCheckpointLoader.cpp | 287 ++++++++++++ .../src/core/train/OMCheckpointSaver.cpp | 433 ++++++++++++++++++ 4 files changed, 851 insertions(+) create mode 100644 onert-micro/onert-micro/include/core/train/OMCheckpointLoader.h create mode 100644 onert-micro/onert-micro/include/core/train/OMCheckpointSaver.h create mode 100644 onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp create mode 100644 onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp diff --git a/onert-micro/onert-micro/include/core/train/OMCheckpointLoader.h b/onert-micro/onert-micro/include/core/train/OMCheckpointLoader.h new file mode 100644 index 00000000000..626c61e06e6 --- /dev/null +++ b/onert-micro/onert-micro/include/core/train/OMCheckpointLoader.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_CORE_TRAIN_CHECKPOINT_LOADER_H +#define ONERT_MICRO_CORE_TRAIN_CHECKPOINT_LOADER_H + +#include "OMStatus.h" +#include "OMConfig.h" +#include "core/OMRuntimeContext.h" +#include "core/OMRuntimeStorage.h" +#include "core/train/OMTrainingStorage.h" + +namespace onert_micro +{ +namespace core +{ +namespace train +{ + +/* + * Class to load checkpoints files + * Note: class is stateless + */ +class OMCheckpointLoader +{ +public: + OMCheckpointLoader() = default; + OMCheckpointLoader(const OMCheckpointLoader &) = delete; + OMCheckpointLoader(OMCheckpointLoader &&) = delete; + OMCheckpointLoader &operator=(const OMCheckpointLoader &) = delete; + OMCheckpointLoader &&operator=(const OMCheckpointLoader &&) = delete; + ~OMCheckpointLoader() = default; + + // Load and save states from checkpoint data in model and in config + // To check checkpoint file format please see https://github.com/Samsung/ONE/discussions/13037 + static OMStatus loadCheckpointData(core::OMRuntimeContext &context, + OMTrainingStorage &train_storage, const char *data, + OMConfig &config); + +private: + static OMStatus validateCheckpointData(core::OMRuntimeContext &context, const char *data, + OMConfig &config); + + static OMStatus loadBuffers(core::OMRuntimeContext &context, OMTrainingStorage &train_storage, + const char *data, OMConfig &config); +}; + +} // namespace train +} // namespace core +} // namespace onert_micro + +#endif // ONERT_MICRO_CORE_TRAIN_CHECKPOINT_LOADER_H diff --git a/onert-micro/onert-micro/include/core/train/OMCheckpointSaver.h b/onert-micro/onert-micro/include/core/train/OMCheckpointSaver.h new file mode 100644 index 00000000000..fbce18eaea4 --- /dev/null +++ b/onert-micro/onert-micro/include/core/train/OMCheckpointSaver.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_CORE_TRAIN_CHECKPOINT_SAVER_H +#define ONERT_MICRO_CORE_TRAIN_CHECKPOINT_SAVER_H + +#include "OMStatus.h" +#include "OMConfig.h" +#include "core/OMRuntimeContext.h" +#include "core/OMRuntimeStorage.h" +#include "core/train/OMTrainingStorage.h" + +namespace onert_micro +{ +namespace core +{ +namespace train +{ + +/* + * Class to save checkpoints + * Note: class is stateless + */ +class OMCheckpointSaver +{ +public: + OMCheckpointSaver() = default; + OMCheckpointSaver(const OMCheckpointSaver &) = delete; + OMCheckpointSaver(OMCheckpointSaver &&) = delete; + OMCheckpointSaver &operator=(const OMCheckpointSaver &) = delete; + OMCheckpointSaver &&operator=(const OMCheckpointSaver &&) = delete; + ~OMCheckpointSaver() = default; + + // Create checkpoint data for current state + // To check checkpoint file format please see https://github.com/Samsung/ONE/discussions/13037 + static OMStatus createCheckpointData(core::OMRuntimeContext &context, + OMTrainingStorage &train_storage, std::vector &data, + const OMConfig &config); + +private: + static size_t calculateFileSize(core::OMRuntimeContext &context, OMTrainingStorage &train_storage, + const OMConfig &config); + + static OMStatus writeOffsetsAndBuffers(core::OMRuntimeContext &context, + OMTrainingStorage &train_storage, const OMConfig &config, + std::vector &data); +}; + +} // namespace train +} // namespace core +} // namespace onert_micro + +#endif // ONERT_MICRO_CORE_TRAIN_CHECKPOINT_SAVER_H diff --git a/onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp b/onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp new file mode 100644 index 00000000000..db61df5a2b1 --- /dev/null +++ b/onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp @@ -0,0 +1,287 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/OMDataType.h" +#include "core/memory/OMMemoryManager.h" +#include "core/train/OMCheckpointLoader.h" + +using namespace onert_micro::core::train; +using namespace onert_micro::core; +using namespace onert_micro::train; +using namespace onert_micro; + +namespace +{ + +constexpr uint16_t MAGIC_NUMBER = 429; +constexpr uint8_t SCHEMA_VERSION = 1; + +enum WOFFieldsOffsets +{ + MAGIC_NUMBER_OFFSET = 0, + SCHEMA_VERSION_OFFSET = 2, + M_OFFSET = 4, + V_OFFSET = 8, + OTHER_PARAMS_OFFSET = 12, + NUM_BUFFERS_OFFSET = 16, + WEIGHT_OFFSET = 20, +}; + +// Layers with trainable weights +// Note: needed not to store some layers with const intputs but it is not trainable (for example +// Reshape) +bool isTrainableWeights(const circle::OperatorCode *opcode) +{ + switch (opcode->builtin_code()) + { + case circle::BuiltinOperator_FULLY_CONNECTED: + case circle::BuiltinOperator_CONV_2D: + return true; + default: + return false; + } +} + +} // namespace + +OMStatus OMCheckpointLoader::validateCheckpointData(core::OMRuntimeContext &context, + const char *data, OMConfig &config) +{ + OMStatus status = Ok; + + // Validate magic number + uint16_t mag_num = 0; + std::memcpy(&mag_num, &data[MAGIC_NUMBER_OFFSET], sizeof(mag_num)); + assert(mag_num == MAGIC_NUMBER && "False MAGIC NUMBER, check correctness of checkpoint file"); + if (mag_num != MAGIC_NUMBER) + return FailReadCheckpointFile; + + // Validate schema version + uint8_t version = 0; + std::memcpy(&version, &data[SCHEMA_VERSION_OFFSET], sizeof(version)); + assert(version == SCHEMA_VERSION && + "False SCHEMA_VERSION NUMBER, check correctness of checkpoint file"); + if (version != SCHEMA_VERSION) + return FailReadCheckpointFile; + + // Validate count of tensors + uint32_t num_tensors = context.getCircleTensors()->size(); + uint32_t num_tensors_in_file = 0; + std::memcpy(&num_tensors_in_file, &data[NUM_BUFFERS_OFFSET], sizeof(num_tensors_in_file)); + assert(num_tensors == num_tensors_in_file && + "Number of tensors in circle and in checkpoint file should be the same"); + if (num_tensors != num_tensors_in_file) + return FailReadCheckpointFile; + + // Validate m, v and other parameters offset + uint32_t m_offset; + uint32_t v_offset; + uint32_t other_params_offset; + std::memcpy(&m_offset, &data[M_OFFSET], sizeof(m_offset)); + std::memcpy(&v_offset, &data[V_OFFSET], sizeof(v_offset)); + std::memcpy(&other_params_offset, &data[OTHER_PARAMS_OFFSET], sizeof(other_params_offset)); + + assert(other_params_offset > 0); + if (other_params_offset == 0) + return FailReadCheckpointFile; + + if (config.training_context.optimizer == SGD) + { + assert(m_offset == 0 and v_offset == 0); + if (m_offset != 0 or v_offset != 0) + return FailReadCheckpointFile; + } + + return Ok; +} + +OMStatus OMCheckpointLoader::loadBuffers(core::OMRuntimeContext &context, + OMTrainingStorage &train_storage, const char *data, + OMConfig &config) +{ + OMStatus status = Ok; + + // Read v, m and other params offsets + uint32_t m_offset; + uint32_t v_offset; + uint32_t other_params_offset; + std::memcpy(&m_offset, &data[M_OFFSET], sizeof(m_offset)); + std::memcpy(&v_offset, &data[V_OFFSET], sizeof(v_offset)); + std::memcpy(&other_params_offset, &data[OTHER_PARAMS_OFFSET], sizeof(other_params_offset)); + + uint32_t weight_offset_pos = WEIGHT_OFFSET; + + // If optimizer is Adam then reset its state + optimizers::Adam *adam_opt = nullptr; + if (config.training_context.optimizer == ADAM) + { + adam_opt = train_storage.getAdam(); + assert(adam_opt != nullptr); + + adam_opt->fullReset(); + } + + auto tensors = context.getCircleTensors(); + auto tensors_size = tensors->size(); + const auto *operators = context.getCircleOperators(); + const auto num_kernels = operators->size(); + uint32_t num_train_layers = config.training_context.num_of_train_layers == 0 + ? num_kernels + : config.training_context.num_of_train_layers; + uint32_t last_node_pos = std::min(num_kernels, num_train_layers); + // Goes among trainable ops + const auto *op_codes = context.getCircleOpcodes(); + for (uint32_t index = 0; index < last_node_pos; ++index) + { + uint32_t cur_op_index = num_kernels - index - 1; + auto *cur_op = operators->operator[](cur_op_index); + + auto input_tensors = cur_op->inputs(); + + for (uint32_t i = 0; i < input_tensors->size(); ++i) + { + const auto input_tensor_index = input_tensors->operator[](i); + // Check is const + if (not context.isConstTensor(input_tensor_index)) + continue; + + uint32_t cur_opcode_index = cur_op->opcode_index(); + + assert(cur_opcode_index < op_codes->size()); + + const auto opcode = op_codes->operator[](cur_opcode_index); + + // Check it is trainable const + if (not isTrainableWeights(opcode)) + continue; + + // Get current weight file pos + uint32_t cur_weight_offset_pos = weight_offset_pos + input_tensor_index * 4; + + // Read current tensor offset + uint32_t cur_tensor_offset; + std::memcpy(&cur_tensor_offset, &data[cur_weight_offset_pos], sizeof(cur_tensor_offset)); + + // Check is it save data or not + // Note: zero means there are no data - it is error + if (cur_tensor_offset == 0) + return FailReadCheckpointFile; + + // Read weight data and save it + const auto tensor = tensors->operator[](input_tensor_index); + assert(tensor != nullptr); + + OMRuntimeShape shape(tensor); + auto type_size = sizeof(OMDataType(tensor->type())); + + size_t buffer_size = shape.flatSize() * type_size; + // Get pointer to the data in model + uint8_t *weight_data_in_model_ptr; + + status = context.getConstDataByTensorIndex(&weight_data_in_model_ptr, input_tensor_index); + assert(status == Ok); + assert(weight_data_in_model_ptr != nullptr); + if (status != Ok or weight_data_in_model_ptr == nullptr) + return status; + + std::memcpy(weight_data_in_model_ptr, &data[cur_tensor_offset], buffer_size); + + if (config.training_context.optimizer == SGD) + continue; + + // For Adam read m and v buffer + assert(config.training_context.optimizer == ADAM); + + // If no saved Adam state then continue + if (m_offset == 0 or v_offset == 0) + { + assert(v_offset == 0); + assert(m_offset == 0); + continue; + } + + // Get current v and m file pos + uint32_t cur_m_offset_pos = m_offset + input_tensor_index * 4; + uint32_t cur_v_offset_pos = v_offset + input_tensor_index * 4; + + // Read current v and m offset + uint32_t cur_m_offset; + uint32_t cur_v_offset; + std::memcpy(&cur_m_offset, &data[cur_m_offset_pos], sizeof(cur_m_offset)); + std::memcpy(&cur_v_offset, &data[cur_v_offset_pos], sizeof(cur_v_offset)); + + // Cannot be zero due to weight already not zero + assert(cur_m_offset != 0 and cur_v_offset != 0); + if (cur_m_offset == 0 or cur_v_offset == 0) + return FailReadCheckpointFile; + + assert(adam_opt != nullptr); + + // Allocate memory for m and v current buffer + uint8_t *m_buffer = nullptr; + uint8_t *v_buffer = nullptr; + status = memory::OMMemoryManager::allocateMemory(buffer_size, &m_buffer); + assert(status == Ok); + assert(m_buffer != nullptr); + status = memory::OMMemoryManager::allocateMemory(buffer_size, &v_buffer); + assert(status == Ok); + assert(v_buffer != nullptr); + + // Read m and v buffer + std::memcpy(m_buffer, &data[cur_m_offset], buffer_size); + std::memcpy(v_buffer, &data[cur_v_offset], buffer_size); + + // Save m and v buffer in adam optimizer state + adam_opt->setExponentAvgDataByTensorIndex(input_tensor_index, m_buffer); + adam_opt->setExponentAvgSquaresDataByTensorIndex(input_tensor_index, v_buffer); + } + } + + // Read other parameters: cur step num and cur epoch num + uint32_t cur_step_offset = other_params_offset; + uint32_t cur_epoch_offset = other_params_offset + 4; + + uint32_t cur_step; + std::memcpy(&cur_step, &data[cur_step_offset], sizeof(cur_step)); + + uint32_t cur_epoch; + std::memcpy(&cur_epoch, &data[cur_epoch_offset], sizeof(cur_epoch)); + + // Save it in config + config.training_context.num_step = cur_step; + config.training_context.num_epoch = cur_epoch; + + return status; +} + +// To check checkpoint file format please see https://github.com/Samsung/ONE/discussions/13037 +OMStatus OMCheckpointLoader::loadCheckpointData(core::OMRuntimeContext &context, + OMTrainingStorage &train_storage, const char *data, + OMConfig &config) +{ + + // Validate current checkpoint file data + OMStatus status = validateCheckpointData(context, data, config); + assert(status == Ok); + if (status != Ok) + return status; + + // Read and save buffer + status = loadBuffers(context, train_storage, data, config); + assert(status == Ok); + + return status; +} diff --git a/onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp b/onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp new file mode 100644 index 00000000000..ca32c5eeec6 --- /dev/null +++ b/onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp @@ -0,0 +1,433 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/OMDataType.h" +#include "core/train/OMCheckpointSaver.h" + +using namespace onert_micro::core::train; +using namespace onert_micro::core; +using namespace onert_micro::train; +using namespace onert_micro; + +namespace +{ + +// Layers with trainable weights +// Note: needed not to store some layers with const intputs but it is not trainable (for example +// Reshape) +bool isTrainableWeights(const circle::OperatorCode *opcode) +{ + switch (opcode->builtin_code()) + { + case circle::BuiltinOperator_FULLY_CONNECTED: + case circle::BuiltinOperator_CONV_2D: + return true; + default: + return false; + } +} + +constexpr uint16_t MAGIC_NUMBER = 429; +constexpr uint8_t SCHEMA_VERSION = 1; + +} // namespace + +/** + * Calculate result buffer size + **/ +size_t OMCheckpointSaver::calculateFileSize(core::OMRuntimeContext &context, + OMTrainingStorage &train_storage, + const OMConfig &config) +{ + size_t result = 0; + + // 2 bytes for Magic Number + result += 2; + + // 1 byte for Schema version + result += 1; + + // 1 byte for Reserved field + result += 1; + + // 4 bytes for Adam's state m buffers offset value + result += 4; + + // 4 bytes for Adam's state v buffers offset value + result += 4; + + // 4 bytes for others parameters offset value + result += 4; + + // 4 bytes for number of tensors + result += 4; + + auto tensors = context.getCircleTensors(); + + auto tensors_size = tensors->size(); + + // tensors_size * 4 bytes for buffers offsets + result += tensors_size * 4; + + const auto *operators = context.getCircleOperators(); + const auto num_kernels = operators->size(); + uint32_t num_train_layers = config.training_context.num_of_train_layers == 0 + ? num_kernels + : config.training_context.num_of_train_layers; + uint32_t last_node_pos = std::min(num_kernels, num_train_layers); + // Goes among trainable ops + size_t buffer_size = 0; + const auto *op_codes = context.getCircleOpcodes(); + for (uint32_t index = 0; index < last_node_pos; ++index) + { + uint32_t cur_op_index = num_kernels - index - 1; + auto *cur_op = operators->operator[](cur_op_index); + + auto input_tensors = cur_op->inputs(); + + for (uint32_t i = 0; i < input_tensors->size(); ++i) + { + const auto input_tensor_index = input_tensors->operator[](i); + // Check is const + if (not context.isConstTensor(input_tensor_index)) + continue; + + uint32_t cur_opcode_index = cur_op->opcode_index(); + + assert(cur_opcode_index < op_codes->size()); + + const auto opcode = op_codes->operator[](cur_opcode_index); + + // Check it is trainable const + if (not isTrainableWeights(opcode)) + continue; + + const auto tensor = context.getTensorByIndex(input_tensor_index); + OMRuntimeShape shape(tensor); + + auto type_size = sizeof(OMDataType(tensor->type())); + + buffer_size += type_size * shape.flatSize(); + } + } + + // If we use Adam optimizer then need to add Adam specific buffers + assert(config.training_context.optimizer == SGD or + config.training_context.optimizer == ADAM && "Unsupported type"); + if (config.training_context.optimizer == ADAM) + { + // Check is it save any state + if (not train_storage.getAdam()->isReset()) + { + // If yes - then just buffer_size = buffer_size * 3 (original weights and two buffers from + // Adam state) + buffer_size *= 3; + + // Add offsets for m + result += tensors_size * 4; + // Add offsets for v + result += tensors_size * 4; + } + } + + // Add buffer size + result += buffer_size; + + // 4 bytes to save information about current step + result += 4; + + // 4 bytes to save information about current epoch + result += 4; + + return result; +} + +OMStatus OMCheckpointSaver::writeOffsetsAndBuffers(core::OMRuntimeContext &context, + OMTrainingStorage &train_storage, + const OMConfig &config, std::vector &data) +{ + // Point to start of the buffer + char *cur_ptr = data.data(); + + // Set to the n_buffers field + cur_ptr += 16; + + auto tensors = context.getCircleTensors(); + auto tensors_size = tensors->size(); + + // Write number of buffers + std::memcpy(cur_ptr, &tensors_size, sizeof(tensors_size)); + + // Calculate buffers offsets, set all to zeros + // Zero value means that there is no buffer for this tensor + std::vector offsets(tensors_size, 0); + // Start offset for buffers + uint32_t cur_offset = 20 + tensors_size * 4; + + // To calculate sum of saved buffers sizes + uint32_t acc_buffer_size = 0; + + const auto *operators = context.getCircleOperators(); + const auto num_kernels = operators->size(); + uint32_t num_train_layers = config.training_context.num_of_train_layers == 0 + ? num_kernels + : config.training_context.num_of_train_layers; + uint32_t last_node_pos = std::min(num_kernels, num_train_layers); + // Goes among trainable ops + + const auto *op_codes = context.getCircleOpcodes(); + + for (uint32_t index = 0; index < last_node_pos; ++index) + { + uint32_t cur_op_index = num_kernels - index - 1; + auto *cur_op = operators->operator[](cur_op_index); + + auto input_tensors = cur_op->inputs(); + + for (uint32_t i = 0; i < input_tensors->size(); ++i) + { + const auto input_tensor_index = input_tensors->operator[](i); + // Check is const + if (not context.isConstTensor(input_tensor_index)) + continue; + + uint32_t cur_opcode_index = cur_op->opcode_index(); + + assert(cur_opcode_index < op_codes->size()); + + const auto opcode = op_codes->operator[](cur_opcode_index); + + // Check it is trainable const + if (not isTrainableWeights(opcode)) + continue; + + // Found trainable weight tensor, lets calculate its size and save offset for current buffer + const auto tensor = context.getTensorByIndex(input_tensor_index); + OMRuntimeShape shape(tensor); + + auto type_size = sizeof(OMDataType(tensor->type())); + + size_t buffer_size = type_size * shape.flatSize(); + // Save for current tensor index its offset + offsets[input_tensor_index] = cur_offset; + // Get buffer data + uint8_t *tensor_data; + OMStatus status = context.getConstDataByTensorIndex(&tensor_data, input_tensor_index); + assert(status == Ok); + if (status != Ok) + return status; + assert(tensor_data != nullptr); + + // Write buffer data into vector + cur_ptr = data.data() + cur_offset; + std::memcpy(cur_ptr, tensor_data, buffer_size); + // Move offset + cur_offset += buffer_size; + // Save buffers size + acc_buffer_size += buffer_size; + } + } + + // Now cur_offset points to last position after adding all buffers + // Let's handle with Adam buffers offsets + assert(config.training_context.optimizer == ADAM or + config.training_context.optimizer == SGD && "Unsupported type"); + if (config.training_context.optimizer == ADAM and not train_storage.getAdam()->isReset()) + { + // Move pointer to m offset field + cur_ptr = data.data() + 4; + // Save offset for the m offset field + uint32_t m_offset = cur_offset; + std::memcpy(cur_ptr, &m_offset, sizeof(m_offset)); + // Move pointer to the v offset field + cur_ptr += 4; + // Calculate offset for the v offset field + uint32_t v_offset = m_offset + acc_buffer_size + tensors_size * 4; + // Save it + std::memcpy(cur_ptr, &v_offset, sizeof(v_offset)); + + // Let's write offsets and buffers + const auto adam_opt = train_storage.getAdam(); + assert(adam_opt != nullptr); + + // Move m and v to buffers offsets + uint32_t m_buffer_offset = m_offset + tensors_size * 4; + uint32_t v_buffer_offset = v_offset + tensors_size * 4; + + // Adam buffers + std::vector m_offsets(tensors_size, 0); + std::vector v_offsets(tensors_size, 0); + + // Goes among trainable ops + for (uint32_t index = 0; index < last_node_pos; ++index) + { + uint32_t cur_op_index = num_kernels - index - 1; + auto *cur_op = operators->operator[](cur_op_index); + + auto input_tensors = cur_op->inputs(); + + for (uint32_t i = 0; i < input_tensors->size(); ++i) + { + const auto input_tensor_index = input_tensors->operator[](i); + // Check is const + if (not context.isConstTensor(input_tensor_index)) + continue; + + uint32_t cur_opcode_index = cur_op->opcode_index(); + + assert(cur_opcode_index < op_codes->size()); + + const auto opcode = op_codes->operator[](cur_opcode_index); + + // Check it is trainable const + if (not isTrainableWeights(opcode)) + continue; + + // Found trainable weight tensor, lets calculate its size and save offset for current buffer + const auto tensor = context.getTensorByIndex(input_tensor_index); + OMRuntimeShape shape(tensor); + + auto type_size = sizeof(OMDataType(tensor->type())); + + size_t buffer_size = type_size * shape.flatSize(); + // Save for current tensor index its offset + m_offsets[input_tensor_index] = m_buffer_offset; + v_offsets[input_tensor_index] = v_buffer_offset; + + // Obtain m and v data from train storage + uint8_t *m_data = adam_opt->getExponentAvgDataByTensorIndex(input_tensor_index); + assert(m_data != nullptr); + uint8_t *v_data = adam_opt->getExponentAvgSquaresDataByTensorIndex(input_tensor_index); + assert(v_data != nullptr); + + // Write m data + cur_ptr = data.data() + m_buffer_offset; + std::memcpy(cur_ptr, m_data, buffer_size); + // Write v data + cur_ptr = data.data() + v_buffer_offset; + std::memcpy(cur_ptr, v_data, buffer_size); + + // Move m and v buffers offsets + m_buffer_offset += buffer_size; + v_buffer_offset += buffer_size; + } + } + + // Set pointer to the m offset field start + char *m_offset_start = data.data() + m_offset; + // Set pointer to the v offset field start + char *v_offset_start = data.data() + v_offset; + // Write m and v offsets + for (uint32_t i = 0; i < tensors_size; ++i) + { + // Save m buffer offset + uint32_t cur_m_offset = m_offsets[i]; + std::memcpy(m_offset_start, &cur_m_offset, sizeof(cur_m_offset)); + m_offset_start += 4; + + // Save v buffer offset + uint32_t cur_v_offset = v_offsets[i]; + std::memcpy(v_offset_start, &cur_v_offset, sizeof(cur_v_offset)); + v_offset_start += 4; + } + } + else + { + // Note: offset = 0 - means there are no such buffers + // Move pointer to m offset field + cur_ptr = data.data() + 4; + // Save offset for the m offset field + uint32_t m_offset = 0; + std::memcpy(cur_ptr, &m_offset, sizeof(m_offset)); + + // Move pointer to the v offset field + cur_ptr += 4; + // Save offset for the v offset field + uint32_t v_offset = 0; + std::memcpy(cur_ptr, &v_offset, sizeof(v_offset)); + } + + // Move cur_ptr to the start of the offsets field + cur_ptr = data.data() + 20; + // Write offsets + for (uint32_t i = 0; i < tensors_size; ++i) + { + uint32_t offset = offsets[i]; + std::memcpy(cur_ptr, &offset, sizeof(offset)); + cur_ptr += 4; + } + + // Save other parameters offset: 20 initial bytes + tensors_size * 4 bytes for buffer offsets + + // buffer size + uint32_t other_parameters_offset = 20 + tensors_size * 4 + acc_buffer_size; + // Adam case need add two more acc_buffer_size + if (config.training_context.optimizer == ADAM and not train_storage.getAdam()->isReset()) + { + other_parameters_offset += acc_buffer_size * 2; + other_parameters_offset += tensors_size * 4 * 2; + } + + // Write this offset + cur_ptr = data.data() + 12; + std::memcpy(cur_ptr, &other_parameters_offset, sizeof(other_parameters_offset)); + + // Move pointer to other parameters offset + cur_ptr = data.data() + other_parameters_offset; + + // Write current step + std::memcpy(cur_ptr, &config.training_context.num_step, sizeof(config.training_context.num_step)); + + cur_ptr += 4; + // Write current epoch + std::memcpy(cur_ptr, &config.training_context.num_epoch, + sizeof(config.training_context.num_epoch)); + + return Ok; +} + +// To check checkpoint file format please see https://github.com/Samsung/ONE/discussions/13037 +OMStatus OMCheckpointSaver::createCheckpointData(core::OMRuntimeContext &context, + OMTrainingStorage &train_storage, + std::vector &data, const OMConfig &config) +{ + // Clear data + data.clear(); + + // Obtain file size and resize vector + const size_t data_size = calculateFileSize(context, train_storage, config); + data.resize(data_size); + + // Point to start of the buffer + char *cur_ptr = data.data(); + + // Write MAGIC_NUMBER + std::memcpy(cur_ptr, &MAGIC_NUMBER, sizeof(MAGIC_NUMBER)); + cur_ptr += 2; + + // Write SCHEMA_VERSION + std::memcpy(cur_ptr, &SCHEMA_VERSION, sizeof(SCHEMA_VERSION)); + cur_ptr += 1; + + // Miss RESERVED field + cur_ptr += 1; + + // Writes buffers and offsets + OMStatus status = writeOffsetsAndBuffers(context, train_storage, config, data); + + assert(status == Ok); + + return Ok; +}