Skip to content

Commit

Permalink
[fix] move x86 specific to engines/x86
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Aug 15, 2024
1 parent 5eab117 commit 5daf820
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 25 deletions.
33 changes: 12 additions & 21 deletions src/ppl/nn/engines/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ RetCode CopyBuffer(const BufferDesc& src_buf, const TensorShape& src_shape, Devi
/* -------------------------------------------------------------------------- */

RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device,
RuntimeConstantInfo* info, bool omit_data) {
info->Reshape(shape);
RuntimeConstantInfo* info) {
info->SetDevice(device);

if (size > 0) {
info->Reshape(shape);

if (!omit_data) {
info->SetDevice(device);
auto status = info->ReallocBuffer();
if (status != RC_SUCCESS) {
LOG(ERROR) << "alloc buffer for constant failed: " << GetRetCodeStr(status);
Expand All @@ -108,6 +109,7 @@ RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape&
RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device,
BufferInfo* info) {
info->SetDevice(device);

auto status = info->ReallocBuffer(shape);
if (status != RC_SUCCESS) {
LOG(ERROR) << "alloc buffer for constant failed: " << GetRetCodeStr(status);
Expand All @@ -123,8 +125,7 @@ RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape&
return RC_SUCCESS;
}

RetCode LoadConstants(const ir::Graph& graph, Device* device, map<edgeid_t, RuntimeConstantInfo>* constants,
const std::set<edgeid_t>* data_omitted_constants) {
RetCode LoadConstants(const ir::Graph& graph, Device* device, map<edgeid_t, RuntimeConstantInfo>* constants) {
auto topo = graph.topo.get();
auto graph_data = graph.data.get();

Expand Down Expand Up @@ -156,25 +157,15 @@ RetCode LoadConstants(const ir::Graph& graph, Device* device, map<edgeid_t, Runt
return RC_NOT_FOUND;
}

bool omit_data = false;
if (data_omitted_constants != nullptr) {
omit_data = (data_omitted_constants->find(eid) != data_omitted_constants->end());
}

if (!omit_data) {
if (constant_ref->second.data.GetSize() == 0) {
if (tensor_shape.CalcBytesIncludingPadding() == 0) {
omit_data = true;
} else {
LOG(ERROR) << "constant [" << edge->GetName() << "] data size is 0 but shape size is not 0.";
return RC_INVALID_VALUE;
}
}
if (constant_ref->second.data.GetSize() != tensor_shape.CalcBytesIncludingPadding()) {
LOG(ERROR) << "constant [" << edge->GetName() << "] data size [" << constant_ref->second.data.GetSize() <<
"] != shape size [" << tensor_shape.CalcBytesIncludingPadding() << "]";
return RC_INVALID_VALUE;
}

RuntimeConstantInfo& constant_info = ret_pair.first->second;
auto status = GenericLoadConstant(constant_ref->second.data.GetData(), constant_ref->second.data.GetSize(),
tensor_shape, device, &constant_info, omit_data);
tensor_shape, device, &constant_info);
if (status != RC_SUCCESS) {
LOG(ERROR) << "load constant[" << edge->GetName() << "] failed: " << GetRetCodeStr(status);
return status;
Expand Down
5 changes: 2 additions & 3 deletions src/ppl/nn/engines/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@ static inline ppl::common::RetCode CopyTensorBuffer(const TensorImpl& src, Tenso
return CopyBuffer(src.GetBufferDesc(), *src.GetShape(), src.GetDevice(), dst, tmp_cpu_device);
}

ppl::common::RetCode LoadConstants(const ir::Graph&, Device*, std::map<edgeid_t, RuntimeConstantInfo>*,
const std::set<edgeid_t>* = nullptr);
ppl::common::RetCode LoadConstants(const ir::Graph&, Device*, std::map<edgeid_t, RuntimeConstantInfo>*);

ppl::common::RetCode LoadConstants(const ConstantVisitor&, Device*, std::map<edgeid_t, BufferInfo>*);

ppl::common::RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device,
RuntimeConstantInfo* info, bool omit_data = false);
RuntimeConstantInfo* info);

ppl::common::RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device,
BufferInfo* info);
Expand Down
85 changes: 84 additions & 1 deletion src/ppl/nn/engines/x86/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,89 @@ ppl::common::RetCode X86Engine::CalDataOmittedConstants(const ir::Graph& graph,
return ppl::common::RC_SUCCESS;
}

static RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device,
RuntimeConstantInfo* info, bool omit_data) {
info->Reshape(shape);

if (!omit_data) {
info->SetDevice(device);
auto status = info->ReallocBuffer();
if (status != RC_SUCCESS) {
LOG(ERROR) << "alloc buffer for constant failed: " << GetRetCodeStr(status);
return status;
}

status = device->CopyFromHost(&info->GetBufferDesc(), data, shape);
if (status != RC_SUCCESS) {
LOG(ERROR) << "copy constant failed: " << GetRetCodeStr(status);
return status;
}
}

return RC_SUCCESS;
}

static RetCode LoadConstants(const ir::Graph& graph, Device* device, map<edgeid_t, RuntimeConstantInfo>* constants,
const std::set<edgeid_t>* data_omitted_constants) {
auto topo = graph.topo.get();
auto graph_data = graph.data.get();

for (uint32_t i = 0; i < topo->GetConstantCount(); ++i) {
auto eid = topo->GetConstant(i);
auto edge = topo->GetEdge(eid);
if (edge == nullptr) {
LOG(ERROR) << "cannot get edge of constant[edgeid=" << eid << "]";
return RC_NOT_FOUND;
}

auto ret_pair = constants->insert(make_pair(eid, RuntimeConstantInfo()));
if (!ret_pair.second) {
continue;
}

auto shape_ref = graph_data->shapes.find(eid);
if (shape_ref == graph_data->shapes.end()) {
LOG(ERROR) << "cannot find shape of constant[" << edge->GetName() << "]";
return RC_NOT_FOUND;
}

TensorShape tensor_shape;
utils::IrShape2TensorShape(shape_ref->second, &tensor_shape);

auto constant_ref = graph_data->constants.find(eid);
if (constant_ref == graph_data->constants.end()) {
LOG(ERROR) << "cannot find data of constant[" << edge->GetName() << "]";
return RC_NOT_FOUND;
}

bool omit_data = false;
if (data_omitted_constants != nullptr) {
omit_data = (data_omitted_constants->find(eid) != data_omitted_constants->end());
}

if (!omit_data) {
if (constant_ref->second.data.GetSize() == 0) {
if (tensor_shape.CalcBytesIncludingPadding() == 0) {
omit_data = true;
} else {
LOG(ERROR) << "constant [" << edge->GetName() << "] data size is 0 but shape size is not 0.";
return RC_INVALID_VALUE;
}
}
}

RuntimeConstantInfo& constant_info = ret_pair.first->second;
auto status = GenericLoadConstant(constant_ref->second.data.GetData(), constant_ref->second.data.GetSize(),
tensor_shape, device, &constant_info, omit_data);
if (status != RC_SUCCESS) {
LOG(ERROR) << "load constant[" << edge->GetName() << "] failed: " << GetRetCodeStr(status);
return status;
}
}

return RC_SUCCESS;
}

RetCode X86Engine::ProcessGraph(const utils::SharedResource& resource, ir::Graph* graph, RuntimePartitionInfo* info) {
auto status = NNOptimizerManager::GetInstance()->Process(graph);
if (status != RC_SUCCESS) {
Expand All @@ -162,7 +245,7 @@ RetCode X86Engine::ProcessGraph(const utils::SharedResource& resource, ir::Graph
return status;
}

status = utils::LoadConstants(*graph, &device_, &info->constants, &data_omitted_constants);
status = LoadConstants(*graph, &device_, &info->constants, &data_omitted_constants);
if (status != RC_SUCCESS) {
LOG(ERROR) << "LoadConstants failed: " << GetRetCodeStr(status);
return status;
Expand Down

0 comments on commit 5daf820

Please sign in to comment.