Skip to content

Commit

Permalink
[onert] Load tensor and optimizer data from the checkpoint file (#13918)
Browse files Browse the repository at this point in the history
This commit loads tensor and optimizer data from the checkpoint file.

ONE-DCO-1.0-Signed-off-by: Jiyoung Yun <[email protected]>
  • Loading branch information
jyoungyun authored Sep 9, 2024
1 parent 12941da commit d1b25ee
Showing 1 changed file with 157 additions and 5 deletions.
162 changes: 157 additions & 5 deletions runtime/onert/core/src/loader/train/CheckpointLoader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,48 @@ using namespace train;
using namespace checkpoint;
using namespace exec;

struct DataBufferPair
{
uint32_t offset;
uint32_t size;
};

class DataBuffer
{
public:
DataBuffer() = default;
DataBuffer(uint32_t size) { setSize(size); }

void setSize(uint32_t size)
{
_offset.resize(size);
_size.resize(size);
}

char *getOffsetBuf() { return reinterpret_cast<char *>(_offset.data()); }

// This function should be called after loading the _offset buffer.
void calculateSize(uint32_t next_start_offset)
{
assert(_offset.size() == _size.size());
for (size_t i = 0; i < _offset.size() - 1; ++i)
_size[i] = _offset[i + 1] - _offset[i];
_size.back() = next_start_offset - _offset.back();
}

// offset, size
DataBufferPair operator[](uint32_t i) const
{
assert(_offset.size() == _size.size());
assert(i < _offset.size());
return DataBufferPair{_offset[i], _size[i]};
}

private:
std::vector<uint32_t> _offset;
std::vector<uint32_t> _size;
};

class CheckpointLoader final
{
public:
Expand Down Expand Up @@ -62,6 +104,33 @@ class CheckpointLoader final
if (_header.schema != checkpoint::SCHEMA_VERSION)
throw std::runtime_error{"Invalid SCHEMA VERSION"};

_tensor_data.setSize(_header.length);
_file.read(_tensor_data.getOffsetBuf(),
static_cast<std::streamsize>(_header.length * sizeof(uint32_t)));
if (_file.fail())
throw std::runtime_error{"Failed to load tensor data"};
_tensor_data.calculateSize(_header.opt1_offset);

if (_header.opt1_offset)
{
DataBuffer opt1_data(_header.length);
_file.seekg(static_cast<std::streamoff>(_header.opt1_offset), std::ios::beg);
_file.read(opt1_data.getOffsetBuf(),
static_cast<std::streamsize>(_header.length * sizeof(uint32_t)));
opt1_data.calculateSize(_header.opt2_offset);
_optimizer_data.emplace_back(std::move(opt1_data));
}

if (_header.opt2_offset)
{
DataBuffer opt2_data(_header.length);
_file.seekg(static_cast<std::streamoff>(_header.opt2_offset), std::ios::beg);
_file.read(opt2_data.getOffsetBuf(),
static_cast<std::streamsize>(_header.length * sizeof(uint32_t)));
opt2_data.calculateSize(_header.other_offset);
_optimizer_data.emplace_back(std::move(opt2_data));
}

if ((filesize - _header.other_offset) != sizeof(_footer))
throw std::runtime_error{"Invalid checkpoint file footer data"};

Expand All @@ -76,10 +145,95 @@ class CheckpointLoader final
_file.close();
}

void updateTensor(const std::unique_ptr<exec::Execution> &exec)
{
uint32_t vindex = 0;
exec->iterateTrainableTensors(
[&](const ir::OperandIndex &, const backend::train::ITrainableTensor *) { vindex++; });

if (_header.length != vindex)
throw std::runtime_error{
"Invalid number of tensors between TrainingInfo and checkpoint file"};

// Reset EOF bit
_file.clear();

vindex = 0;
exec->iterateTrainableTensors(
[&](const ir::OperandIndex &, const backend::train::ITrainableTensor *tensor) {
assert(tensor);
assert(_tensor_data[vindex].size == tensor->total_size());
_file.seekg(static_cast<std::streamoff>(_tensor_data[vindex].offset), std::ios::beg);
_file.read(reinterpret_cast<char *>(tensor->buffer()),
static_cast<std::streamsize>(tensor->total_size()));
vindex++;
});
}

void updateOptimizer(const std::unique_ptr<ir::train::TrainingInfo> &train_info,
const std::unique_ptr<onert::exec::Execution> &exec)
{
ir::train::OptimizerCode ckpt_opt_code = ir::train::OptimizerCode::SGD;
// TODO Support other optimizers
if (_optimizer_data.size() == 2)
ckpt_opt_code = ir::train::OptimizerCode::Adam;

if (ckpt_opt_code != train_info->optimizerInfo().optim_code)
throw std::runtime_error{
"Not compatible optimizer type between TrainingInfo and checkpoint file"};

switch (train_info->optimizerInfo().optim_code)
{
case ir::train::OptimizerCode::Adam:
updateAdamOptimizer(exec);
break;
default:
break;
}
}

void updateTrainingInfo(const std::unique_ptr<ir::train::TrainingInfo> &train_info)
{
// TODO Verify cur_step value
train_info->trainingStep() = _footer.cur_step;
}

private:
void updateAdamOptimizer(const std::unique_ptr<onert::exec::Execution> &exec)
{
// Adam optimizer has two optimizer variables. (mean, variance)
[[maybe_unused]] const std::size_t ADAM_VARIABLE_COUNT = 2;

// Reset EOF bit
_file.clear();

uint32_t vindex = 0;
exec->iterateTrainableTensors([&](const ir::OperandIndex &,
const backend::train::ITrainableTensor *tensor) {
assert(tensor);
auto trainable_tensor = const_cast<backend::train::ITrainableTensor *>(tensor);
const auto opt_vars = trainable_tensor->optVars();

// Untrainable tensor should not have any optimizer variables.
assert(opt_vars.size() == ADAM_VARIABLE_COUNT || opt_vars.size() == 0);

for (size_t i = 0; i < opt_vars.size(); ++i)
{
assert(opt_vars[i]->total_size() == _optimizer_data[i][vindex].size);
_file.seekg(static_cast<std::streamoff>(_optimizer_data[i][vindex].offset), std::ios::beg);
_file.read(reinterpret_cast<char *>(opt_vars[i]->buffer()),
static_cast<std::streamsize>(opt_vars[i]->total_size()));
}
vindex++;
});
}

private:
std::ifstream _file;
checkpoint::Header _header;
checkpoint::Footer _footer;
DataBuffer _tensor_data;
std::vector<DataBuffer> _optimizer_data;
};

} // namespace
Expand All @@ -96,11 +250,9 @@ void loadCheckpoint(const std::string &filename,
const std::unique_ptr<onert::exec::Execution> &exec)
{
CheckpointLoader loader(filename);

// TODO Load tensor data
UNUSED_RELEASE(exec);
// TODO Update step in train_info
UNUSED_RELEASE(train_info);
loader.updateTensor(exec);
loader.updateOptimizer(train_info, exec);
loader.updateTrainingInfo(train_info);
}

} // namespace train
Expand Down

0 comments on commit d1b25ee

Please sign in to comment.