diff --git a/src/ppl/nn/engines/utils.cc b/src/ppl/nn/engines/utils.cc index 237e7d8ac..e4751c26b 100644 --- a/src/ppl/nn/engines/utils.cc +++ b/src/ppl/nn/engines/utils.cc @@ -84,11 +84,12 @@ RetCode CopyBuffer(const BufferDesc& src_buf, const TensorShape& src_shape, Devi /* -------------------------------------------------------------------------- */ RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device, - RuntimeConstantInfo* info, bool omit_data) { - info->Reshape(shape); + RuntimeConstantInfo* info) { + info->SetDevice(device); + + if (size > 0) { + info->Reshape(shape); - if (!omit_data) { - info->SetDevice(device); auto status = info->ReallocBuffer(); if (status != RC_SUCCESS) { LOG(ERROR) << "alloc buffer for constant failed: " << GetRetCodeStr(status); @@ -108,6 +109,7 @@ RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device, BufferInfo* info) { info->SetDevice(device); + auto status = info->ReallocBuffer(shape); if (status != RC_SUCCESS) { LOG(ERROR) << "alloc buffer for constant failed: " << GetRetCodeStr(status); @@ -123,8 +125,7 @@ RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& return RC_SUCCESS; } -RetCode LoadConstants(const ir::Graph& graph, Device* device, map* constants, - const std::set* data_omitted_constants) { +RetCode LoadConstants(const ir::Graph& graph, Device* device, map* constants) { auto topo = graph.topo.get(); auto graph_data = graph.data.get(); @@ -156,25 +157,15 @@ RetCode LoadConstants(const ir::Graph& graph, Device* device, mapfind(eid) != data_omitted_constants->end()); - } - - if (!omit_data) { - if (constant_ref->second.data.GetSize() == 0) { - if (tensor_shape.CalcBytesIncludingPadding() == 0) { - omit_data = true; - } else { - LOG(ERROR) << "constant [" << edge->GetName() << "] data size is 0 but shape size is not 0."; - return RC_INVALID_VALUE; - } - } + if (constant_ref->second.data.GetSize() != tensor_shape.CalcBytesIncludingPadding()) { + LOG(ERROR) << "constant [" << edge->GetName() << "] data size [" << constant_ref->second.data.GetSize() << + "] != shape size [" << tensor_shape.CalcBytesIncludingPadding() << "]"; + return RC_INVALID_VALUE; } RuntimeConstantInfo& constant_info = ret_pair.first->second; auto status = GenericLoadConstant(constant_ref->second.data.GetData(), constant_ref->second.data.GetSize(), - tensor_shape, device, &constant_info, omit_data); + tensor_shape, device, &constant_info); if (status != RC_SUCCESS) { LOG(ERROR) << "load constant[" << edge->GetName() << "] failed: " << GetRetCodeStr(status); return status; diff --git a/src/ppl/nn/engines/utils.h b/src/ppl/nn/engines/utils.h index c0c1fea88..ed92a6558 100644 --- a/src/ppl/nn/engines/utils.h +++ b/src/ppl/nn/engines/utils.h @@ -37,13 +37,12 @@ static inline ppl::common::RetCode CopyTensorBuffer(const TensorImpl& src, Tenso return CopyBuffer(src.GetBufferDesc(), *src.GetShape(), src.GetDevice(), dst, tmp_cpu_device); } -ppl::common::RetCode LoadConstants(const ir::Graph&, Device*, std::map*, - const std::set* = nullptr); +ppl::common::RetCode LoadConstants(const ir::Graph&, Device*, std::map*); ppl::common::RetCode LoadConstants(const ConstantVisitor&, Device*, std::map*); ppl::common::RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device, - RuntimeConstantInfo* info, bool omit_data = false); + RuntimeConstantInfo* info); ppl::common::RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device, BufferInfo* info); diff --git a/src/ppl/nn/engines/x86/engine.cc b/src/ppl/nn/engines/x86/engine.cc index 788035d9f..8e98d2f5c 100644 --- a/src/ppl/nn/engines/x86/engine.cc +++ b/src/ppl/nn/engines/x86/engine.cc @@ -142,6 +142,89 @@ ppl::common::RetCode X86Engine::CalDataOmittedConstants(const ir::Graph& graph, return ppl::common::RC_SUCCESS; } +static RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device, + RuntimeConstantInfo* info, bool omit_data) { + info->Reshape(shape); + + if (!omit_data) { + info->SetDevice(device); + auto status = info->ReallocBuffer(); + if (status != RC_SUCCESS) { + LOG(ERROR) << "alloc buffer for constant failed: " << GetRetCodeStr(status); + return status; + } + + status = device->CopyFromHost(&info->GetBufferDesc(), data, shape); + if (status != RC_SUCCESS) { + LOG(ERROR) << "copy constant failed: " << GetRetCodeStr(status); + return status; + } + } + + return RC_SUCCESS; +} + +static RetCode LoadConstants(const ir::Graph& graph, Device* device, map* constants, + const std::set* data_omitted_constants) { + auto topo = graph.topo.get(); + auto graph_data = graph.data.get(); + + for (uint32_t i = 0; i < topo->GetConstantCount(); ++i) { + auto eid = topo->GetConstant(i); + auto edge = topo->GetEdge(eid); + if (edge == nullptr) { + LOG(ERROR) << "cannot get edge of constant[edgeid=" << eid << "]"; + return RC_NOT_FOUND; + } + + auto ret_pair = constants->insert(make_pair(eid, RuntimeConstantInfo())); + if (!ret_pair.second) { + continue; + } + + auto shape_ref = graph_data->shapes.find(eid); + if (shape_ref == graph_data->shapes.end()) { + LOG(ERROR) << "cannot find shape of constant[" << edge->GetName() << "]"; + return RC_NOT_FOUND; + } + + TensorShape tensor_shape; + utils::IrShape2TensorShape(shape_ref->second, &tensor_shape); + + auto constant_ref = graph_data->constants.find(eid); + if (constant_ref == graph_data->constants.end()) { + LOG(ERROR) << "cannot find data of constant[" << edge->GetName() << "]"; + return RC_NOT_FOUND; + } + + bool omit_data = false; + if (data_omitted_constants != nullptr) { + omit_data = (data_omitted_constants->find(eid) != data_omitted_constants->end()); + } + + if (!omit_data) { + if (constant_ref->second.data.GetSize() == 0) { + if (tensor_shape.CalcBytesIncludingPadding() == 0) { + omit_data = true; + } else { + LOG(ERROR) << "constant [" << edge->GetName() << "] data size is 0 but shape size is not 0."; + return RC_INVALID_VALUE; + } + } + } + + RuntimeConstantInfo& constant_info = ret_pair.first->second; + auto status = GenericLoadConstant(constant_ref->second.data.GetData(), constant_ref->second.data.GetSize(), + tensor_shape, device, &constant_info, omit_data); + if (status != RC_SUCCESS) { + LOG(ERROR) << "load constant[" << edge->GetName() << "] failed: " << GetRetCodeStr(status); + return status; + } + } + + return RC_SUCCESS; +} + RetCode X86Engine::ProcessGraph(const utils::SharedResource& resource, ir::Graph* graph, RuntimePartitionInfo* info) { auto status = NNOptimizerManager::GetInstance()->Process(graph); if (status != RC_SUCCESS) { @@ -162,7 +245,7 @@ RetCode X86Engine::ProcessGraph(const utils::SharedResource& resource, ir::Graph return status; } - status = utils::LoadConstants(*graph, &device_, &info->constants, &data_omitted_constants); + status = LoadConstants(*graph, &device_, &info->constants, &data_omitted_constants); if (status != RC_SUCCESS) { LOG(ERROR) << "LoadConstants failed: " << GetRetCodeStr(status); return status;