Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] move x86 specific to engines/x86 #966

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 12 additions & 21 deletions src/ppl/nn/engines/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ RetCode CopyBuffer(const BufferDesc& src_buf, const TensorShape& src_shape, Devi
/* -------------------------------------------------------------------------- */

RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device,
RuntimeConstantInfo* info, bool omit_data) {
info->Reshape(shape);
RuntimeConstantInfo* info) {
info->SetDevice(device);

if (size > 0) {
info->Reshape(shape);

if (!omit_data) {
info->SetDevice(device);
auto status = info->ReallocBuffer();
if (status != RC_SUCCESS) {
LOG(ERROR) << "alloc buffer for constant failed: " << GetRetCodeStr(status);
Expand All @@ -108,6 +109,7 @@ RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape&
RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device,
BufferInfo* info) {
info->SetDevice(device);

auto status = info->ReallocBuffer(shape);
if (status != RC_SUCCESS) {
LOG(ERROR) << "alloc buffer for constant failed: " << GetRetCodeStr(status);
Expand All @@ -123,8 +125,7 @@ RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape&
return RC_SUCCESS;
}

RetCode LoadConstants(const ir::Graph& graph, Device* device, map<edgeid_t, RuntimeConstantInfo>* constants,
const std::set<edgeid_t>* data_omitted_constants) {
RetCode LoadConstants(const ir::Graph& graph, Device* device, map<edgeid_t, RuntimeConstantInfo>* constants) {
auto topo = graph.topo.get();
auto graph_data = graph.data.get();

Expand Down Expand Up @@ -156,25 +157,15 @@ RetCode LoadConstants(const ir::Graph& graph, Device* device, map<edgeid_t, Runt
return RC_NOT_FOUND;
}

bool omit_data = false;
if (data_omitted_constants != nullptr) {
omit_data = (data_omitted_constants->find(eid) != data_omitted_constants->end());
}

if (!omit_data) {
if (constant_ref->second.data.GetSize() == 0) {
if (tensor_shape.CalcBytesIncludingPadding() == 0) {
omit_data = true;
} else {
LOG(ERROR) << "constant [" << edge->GetName() << "] data size is 0 but shape size is not 0.";
return RC_INVALID_VALUE;
}
}
if (constant_ref->second.data.GetSize() != tensor_shape.CalcBytesIncludingPadding()) {
LOG(ERROR) << "constant [" << edge->GetName() << "] data size [" << constant_ref->second.data.GetSize() <<
"] != shape size [" << tensor_shape.CalcBytesIncludingPadding() << "]";
return RC_INVALID_VALUE;
}

RuntimeConstantInfo& constant_info = ret_pair.first->second;
auto status = GenericLoadConstant(constant_ref->second.data.GetData(), constant_ref->second.data.GetSize(),
tensor_shape, device, &constant_info, omit_data);
tensor_shape, device, &constant_info);
if (status != RC_SUCCESS) {
LOG(ERROR) << "load constant[" << edge->GetName() << "] failed: " << GetRetCodeStr(status);
return status;
Expand Down
5 changes: 2 additions & 3 deletions src/ppl/nn/engines/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@ static inline ppl::common::RetCode CopyTensorBuffer(const TensorImpl& src, Tenso
return CopyBuffer(src.GetBufferDesc(), *src.GetShape(), src.GetDevice(), dst, tmp_cpu_device);
}

ppl::common::RetCode LoadConstants(const ir::Graph&, Device*, std::map<edgeid_t, RuntimeConstantInfo>*,
const std::set<edgeid_t>* = nullptr);
ppl::common::RetCode LoadConstants(const ir::Graph&, Device*, std::map<edgeid_t, RuntimeConstantInfo>*);

ppl::common::RetCode LoadConstants(const ConstantVisitor&, Device*, std::map<edgeid_t, BufferInfo>*);

ppl::common::RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device,
RuntimeConstantInfo* info, bool omit_data = false);
RuntimeConstantInfo* info);

ppl::common::RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device,
BufferInfo* info);
Expand Down
85 changes: 84 additions & 1 deletion src/ppl/nn/engines/x86/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,89 @@ ppl::common::RetCode X86Engine::CalDataOmittedConstants(const ir::Graph& graph,
return ppl::common::RC_SUCCESS;
}

static RetCode GenericLoadConstant(const void* data, uint64_t size, const TensorShape& shape, Device* device,
RuntimeConstantInfo* info, bool omit_data) {
info->Reshape(shape);

if (!omit_data) {
info->SetDevice(device);
auto status = info->ReallocBuffer();
if (status != RC_SUCCESS) {
LOG(ERROR) << "alloc buffer for constant failed: " << GetRetCodeStr(status);
return status;
}

status = device->CopyFromHost(&info->GetBufferDesc(), data, shape);
if (status != RC_SUCCESS) {
LOG(ERROR) << "copy constant failed: " << GetRetCodeStr(status);
return status;
}
}

return RC_SUCCESS;
}

static RetCode LoadConstants(const ir::Graph& graph, Device* device, map<edgeid_t, RuntimeConstantInfo>* constants,
const std::set<edgeid_t>* data_omitted_constants) {
auto topo = graph.topo.get();
auto graph_data = graph.data.get();

for (uint32_t i = 0; i < topo->GetConstantCount(); ++i) {
auto eid = topo->GetConstant(i);
auto edge = topo->GetEdge(eid);
if (edge == nullptr) {
LOG(ERROR) << "cannot get edge of constant[edgeid=" << eid << "]";
return RC_NOT_FOUND;
}

auto ret_pair = constants->insert(make_pair(eid, RuntimeConstantInfo()));
if (!ret_pair.second) {
continue;
}

auto shape_ref = graph_data->shapes.find(eid);
if (shape_ref == graph_data->shapes.end()) {
LOG(ERROR) << "cannot find shape of constant[" << edge->GetName() << "]";
return RC_NOT_FOUND;
}

TensorShape tensor_shape;
utils::IrShape2TensorShape(shape_ref->second, &tensor_shape);

auto constant_ref = graph_data->constants.find(eid);
if (constant_ref == graph_data->constants.end()) {
LOG(ERROR) << "cannot find data of constant[" << edge->GetName() << "]";
return RC_NOT_FOUND;
}

bool omit_data = false;
if (data_omitted_constants != nullptr) {
omit_data = (data_omitted_constants->find(eid) != data_omitted_constants->end());
}

if (!omit_data) {
if (constant_ref->second.data.GetSize() == 0) {
if (tensor_shape.CalcBytesIncludingPadding() == 0) {
omit_data = true;
} else {
LOG(ERROR) << "constant [" << edge->GetName() << "] data size is 0 but shape size is not 0.";
return RC_INVALID_VALUE;
}
}
}

RuntimeConstantInfo& constant_info = ret_pair.first->second;
auto status = GenericLoadConstant(constant_ref->second.data.GetData(), constant_ref->second.data.GetSize(),
tensor_shape, device, &constant_info, omit_data);
if (status != RC_SUCCESS) {
LOG(ERROR) << "load constant[" << edge->GetName() << "] failed: " << GetRetCodeStr(status);
return status;
}
}

return RC_SUCCESS;
}

RetCode X86Engine::ProcessGraph(const utils::SharedResource& resource, ir::Graph* graph, RuntimePartitionInfo* info) {
auto status = NNOptimizerManager::GetInstance()->Process(graph);
if (status != RC_SUCCESS) {
Expand All @@ -162,7 +245,7 @@ RetCode X86Engine::ProcessGraph(const utils::SharedResource& resource, ir::Graph
return status;
}

status = utils::LoadConstants(*graph, &device_, &info->constants, &data_omitted_constants);
status = LoadConstants(*graph, &device_, &info->constants, &data_omitted_constants);
if (status != RC_SUCCESS) {
LOG(ERROR) << "LoadConstants failed: " << GetRetCodeStr(status);
return status;
Expand Down
Loading