Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop qnn zh #43

Merged
merged 6 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ if(QNN)
set(CMAKE_LD_FLAGS "-shared -s -fPIC -pthread -fvisibility=hidden -flto")

add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/backends/QNN)
add_executable(qnn_test ${PROJECT_SOURCE_DIR}/demo/qnn/qnn_test.cpp ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_CPU} ${DIR_SRC_EXP} ${DIR_SRC} )
add_executable(qnn_test ${PROJECT_SOURCE_DIR}/demo/qnn/qnn_test.cpp ${PROJECT_SOURCE_DIR}/demo/qnn/qnn_wrapper.hpp ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_CPU} ${DIR_SRC_EXP} ${DIR_SRC} )
add_executable(silu_test ${PROJECT_SOURCE_DIR}/demo/qnn/silu_test.cpp ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_CPU} ${DIR_SRC_EXP} ${DIR_SRC} )
target_link_libraries(qnn_test MLLM_CPU MLLM_QNN ${CMAKE_DL_LIBS})
target_link_libraries(silu_test MLLM_CPU MLLM_QNN ${CMAKE_DL_LIBS})
Expand Down
60 changes: 2 additions & 58 deletions demo/qnn/qnn_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "tokenizers/BPE/Bpe.hpp"
#include "backends/QNN/QNNBackend.hpp"
#include "memory/SystemMemoryManager.hpp"
#include "backends/QNN/op/QNNAdd.hpp"
#include "qnn_wrapper.hpp"

using namespace mllm;

Expand All @@ -30,63 +30,7 @@ int main() {

// build graph
std::cout << "build graph" << std::endl;
// graph add node
uint32_t dimensions[] = {1, 2, 2, 2};
qbn->modelAddTensor("x", // Node Name
(Qnn_Tensor_t){
.version = QNN_TENSOR_VERSION_1,
{.v1 = {
.id = 0,
.name = "x",
.type = QNN_TENSOR_TYPE_APP_WRITE,
.dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER,
.dataType = QNN_DATATYPE_FLOAT_32,
.quantizeParams = {QNN_DEFINITION_UNDEFINED,
QNN_QUANTIZATION_ENCODING_UNDEFINED,
{.scaleOffsetEncoding = {.scale = 0.0000000000000000f, .offset = 0}}},
.rank = 4,
.dimensions = dimensions,
.memType = QNN_TENSORMEMTYPE_RAW,
{.clientBuf = {.data = nullptr,
.dataSize = 0}}}}});

float data[] = {1, 2, 3, 4, 5, 6, 7, 8};
qbn->modelAddTensor("y", // Node Name
(Qnn_Tensor_t){
.version = QNN_TENSOR_VERSION_1,
{.v1 = {
.id = 0,
.name = "y",
.type = QNN_TENSOR_TYPE_STATIC,
.dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER,
.dataType = QNN_DATATYPE_FLOAT_32,
.quantizeParams = {QNN_DEFINITION_UNDEFINED,
QNN_QUANTIZATION_ENCODING_UNDEFINED,
{.scaleOffsetEncoding = {.scale = 0.0000000000000000f, .offset = 0}}},
.rank = 4,
.dimensions = dimensions,
.memType = QNN_TENSORMEMTYPE_RAW,
{.clientBuf = {.data = data,
.dataSize = 32}}}}});

vector<Qnn_Tensor_t> outputs = {
(Qnn_Tensor_t){
.version = QNN_TENSOR_VERSION_1,
{.v1 = {
.id = 0,
.name = "add-output",
.type = QNN_TENSOR_TYPE_APP_READ,
.dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER,
.dataType = QNN_DATATYPE_FLOAT_32,
.quantizeParams = {QNN_DEFINITION_UNDEFINED,
QNN_QUANTIZATION_ENCODING_UNDEFINED,
{.scaleOffsetEncoding = {.scale = 0.0000000000000000f, .offset = 0}}},
.rank = 4,
.dimensions = dimensions,
.memType = QNN_TENSORMEMTYPE_RAW,
{.clientBuf = {.data = nullptr,
.dataSize = 0}}}}}};
qbn->graphAddNode("qnn-add", "ElementWiseAdd", {"x", "y"}, outputs, "qti.aisw");
testMatMul(qbn);
// graph compile
std::cout << "graph compile" << std::endl;
qbn->graphFinilize();
Expand Down
76 changes: 76 additions & 0 deletions demo/qnn/qnn_wrapper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include <iostream>

#include "backends/QNN/QNNBackend.hpp"
#include "memory/SystemMemoryManager.hpp"
#include "backends/QNN/op/QNNAdd.hpp"

using namespace mllm;

void testMatMul(QNNBackend *qbn) {
// graph add node
uint32_t dimensions0[] = {1, 2, 2, 2};
uint32_t dimensions1[] = {1, 1, 4, 2};
qbn->modelAddTensor("x", // Node Name
(Qnn_Tensor_t){
.version = QNN_TENSOR_VERSION_1,
{.v1 = {
.id = 0,
.name = "x",
.type = QNN_TENSOR_TYPE_APP_WRITE,
.dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER,
.dataType = QNN_DATATYPE_FLOAT_32,
.quantizeParams = {QNN_DEFINITION_UNDEFINED,
QNN_QUANTIZATION_ENCODING_UNDEFINED,
{.scaleOffsetEncoding = {.scale = 0.0000000000000000f, .offset = 0}}},
.rank = 4,
.dimensions = dimensions0,
.memType = QNN_TENSORMEMTYPE_RAW,
{.clientBuf = {.data = nullptr,
.dataSize = 0}}}}});

float data[] = {1, 2, 3, 4, 5, 6, 7, 8};
qbn->modelAddTensor("y", // Node Name
(Qnn_Tensor_t){
.version = QNN_TENSOR_VERSION_1,
{.v1 = {
.id = 0,
.name = "y",
.type = QNN_TENSOR_TYPE_STATIC,
.dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER,
.dataType = QNN_DATATYPE_FLOAT_32,
.quantizeParams = {QNN_DEFINITION_UNDEFINED,
QNN_QUANTIZATION_ENCODING_UNDEFINED,
{.scaleOffsetEncoding = {.scale = 0.0000000000000000f, .offset = 0}}},
.rank = 4,
.dimensions = dimensions1,
.memType = QNN_TENSORMEMTYPE_RAW,
{.clientBuf = {.data = data,
.dataSize = 32}}}}});

uint32_t dimensionsOut[] = {1, 2, 2, 4};
vector<Qnn_Tensor_t> outputs = {
(Qnn_Tensor_t){
.version = QNN_TENSOR_VERSION_1,
{.v1 = {
.id = 0,
.name = "add-output",
.type = QNN_TENSOR_TYPE_APP_READ,
.dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER,
.dataType = QNN_DATATYPE_FLOAT_32,
.quantizeParams = {QNN_DEFINITION_UNDEFINED,
QNN_QUANTIZATION_ENCODING_UNDEFINED,
{.scaleOffsetEncoding = {.scale = 0.0000000000000000f, .offset = 0}}},
.rank = 4,
.dimensions = dimensionsOut,
.memType = QNN_TENSORMEMTYPE_RAW,
{.clientBuf = {.data = nullptr,
.dataSize = 0}}}}}};
vector<Qnn_Param_t> paramsMatmul = {
{.paramType = QNN_PARAMTYPE_SCALAR,
.name = "transpose_in0",
{.scalarParam = (Qnn_Scalar_t){QNN_DATATYPE_BOOL_8, {.bool8Value = 0}}}},
{.paramType = QNN_PARAMTYPE_SCALAR,
.name = "transpose_in1",
{.scalarParam = (Qnn_Scalar_t){QNN_DATATYPE_BOOL_8, {.bool8Value = 1}}}}};
qbn->graphAddNode("qnn-add", "MatMul", {"x", "y"}, outputs, paramsMatmul, "qti.aisw");
}
12 changes: 9 additions & 3 deletions src/backends/QNN/QNNBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "PAL/DynamicLoading.hpp"
#include "PAL/GetOpt.hpp"
#include "QnnSampleAppUtils.hpp"
#include "QnnTypes.h"
#include "QnnWrapperUtils.hpp"
#include "DynamicLoadUtil.hpp"
#include "Types.hpp"
Expand Down Expand Up @@ -46,7 +47,7 @@ QNNBackend::QNNBackend(shared_ptr<MemoryManager> mm) : Backend(mm) {

std::string modelPath = "/qnn-projects/QNN-test-libs/example_libs/x86_64-linux-clang/libqnn_model_float.so";
std::string backEndPath = "/qnn-projects/QNN-test-libs/libQnnCpu.so";
std::string inputListPaths = "/qnn-projects/QNN-test-libs/input_list_float.txt";
std::string inputListPaths = "/qnn-projects/mllm/bin/input-list.txt";
std::string opPackagePaths = "/qnn-projects/QNN-test-libs/libQnnCpuOpPackageExample.so:QnnOpPackage_interfaceProvider";
// TODO: make these configuable
m_debug = true;
Expand Down Expand Up @@ -182,15 +183,20 @@ qnn_wrapper_api::ModelError_t QNNBackend::graphAddNode(string name,
string nodeType,
std::vector<const char *> inputTensorNames,
std::vector<Qnn_Tensor_t> outputTensors,
std::vector<Qnn_Param_t> params,
string packageName) {
qnn_wrapper_api::ModelError_t err = qnn_wrapper_api::ModelError_t::MODEL_NO_ERROR;
Qnn_Param_t* paramsPtr = nullptr;
if (params.size() > 0) {
paramsPtr = params.data();
}
VALIDATE(qnnModel.addNode(
QNN_OPCONFIG_VERSION_1, // Op_Config_t Version
name.c_str(), // Node Name
packageName.c_str(), // Package Name
nodeType.c_str(), // Qnn Node Type
nullptr, // Node Params
0, // Num Node Params
paramsPtr, // Node Params
params.size(), // Num Node Params
inputTensorNames.data(), // Input Tensor Names
inputTensorNames.size(), // Num Input Tensor Names
outputTensors.data(), // Output Tensors
Expand Down
28 changes: 13 additions & 15 deletions src/backends/QNN/QNNBackend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "Backend.hpp"
#include "Op.hpp"
#include "OpDefined.hpp"
#include "QnnTypes.h"
#include "Types.hpp"
#include "MemoryManager.hpp"
#include "NetParameter.hpp"
Expand All @@ -20,19 +21,17 @@ using std::shared_ptr;
using namespace qnn;
using namespace qnn::tools;


namespace mllm {

enum class StatusCode {
SUCCESS,
FAILURE,
FAILURE_INPUT_LIST_EXHAUSTED,
FAILURE_SYSTEM_ERROR,
FAILURE_SYSTEM_COMMUNICATION_ERROR,
QNN_FEATURE_UNSUPPORTED
SUCCESS,
FAILURE,
FAILURE_INPUT_LIST_EXHAUSTED,
FAILURE_SYSTEM_ERROR,
FAILURE_SYSTEM_COMMUNICATION_ERROR,
QNN_FEATURE_UNSUPPORTED
};


class Op;

class Tensor;
Expand Down Expand Up @@ -83,6 +82,7 @@ class QNNBackend : public Backend {

qnn_wrapper_api::ModelError_t graphAddNode(string name, string nodeType,
std::vector<const char *> inputTensorNames, std::vector<Qnn_Tensor_t> outputTensors,
std::vector<Qnn_Param_t> params,
string packageName);
qnn_wrapper_api::ModelError_t graphFinilize();
qnn_wrapper_api::ModelError_t modelAddTensor(const char *nodeName, Qnn_Tensor_t tensor);
Expand All @@ -97,7 +97,7 @@ class QNNBackend : public Backend {

// @brief Print a message to STDERR then exit with a non-zero
void reportError(const std::string &err);

StatusCode initialize();

StatusCode initializeBackend();
Expand Down Expand Up @@ -132,7 +132,6 @@ class QNNBackend : public Backend {

StatusCode verifyFailReturnStatus(Qnn_ErrorHandle_t errCode);


StatusCode extractBackendProfilingInfo(Qnn_ProfileHandle_t profileHandle);

StatusCode extractProfilingSubEvents(QnnProfile_EventId_t profileEventId);
Expand All @@ -155,7 +154,7 @@ class QNNBackend : public Backend {
std::vector<std::string> m_opPackagePaths;
std::string m_outputPath;
QnnBackend_Config_t **m_backendConfig = nullptr;
Qnn_ContextHandle_t m_context = nullptr;
Qnn_ContextHandle_t m_context = nullptr;
QnnContext_Config_t **m_contextConfig = nullptr;
bool m_debug;
iotensor::OutputDataType m_outputDataType;
Expand All @@ -174,15 +173,14 @@ class QNNBackend : public Backend {
iotensor::IOTensor m_ioTensor;
bool m_isBackendInitialized;
bool m_isContextCreated;
Qnn_ProfileHandle_t m_profileBackendHandle = nullptr;
Qnn_ProfileHandle_t m_profileBackendHandle = nullptr;
qnn_wrapper_api::GraphConfigInfo_t **m_graphConfigsInfo = nullptr;
uint32_t m_graphConfigsInfoCount;
Qnn_LogHandle_t m_logHandle = nullptr;
Qnn_LogHandle_t m_logHandle = nullptr;
Qnn_BackendHandle_t m_backendHandle = nullptr;
Qnn_DeviceHandle_t m_deviceHandle = nullptr;
Qnn_DeviceHandle_t m_deviceHandle = nullptr;
};


} // namespace mllm

#endif // MLLM_QNNBACKEND_H
2 changes: 0 additions & 2 deletions src/backends/QNN/op/QNNAdd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ ErrorCode QNNAdd::reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<T

ErrorCode QNNAdd::setUp(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
// graph add node
// TODO: check if name_ is set in Op
return graphAddNode(name(), "Add", inputs, outputs);
return NO_ERROR;
}
} // namespace mllm
19 changes: 19 additions & 0 deletions src/backends/QNN/op/QNNCausalMask.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

#include "QNNCausalMask.hpp"
#include "Types.hpp"
#include "QNNCommonOp.hpp"

namespace mllm {
QNNCausalMask::QNNCausalMask(Backend *bn, string opName) :
QNNCommonOp(bn, opName) {
}

ErrorCode QNNCausalMask::reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
return NO_ERROR;
}

ErrorCode QNNCausalMask::setUp(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
return graphAddNode(name(), "Add", inputs, outputs);
}
} // namespace mllm

24 changes: 24 additions & 0 deletions src/backends/QNN/op/QNNCausalMask.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

#ifndef MLLM_QNNCAUSALMASK_H
#define MLLM_QNNCAUSALMASK_H

#include "QNNCommonOp.hpp"
namespace mllm {
class QNNCausalMask : public QNNCommonOp {
public:
QNNCausalMask(Backend *bn, string opName);
virtual ~QNNCausalMask() = default;
virtual ErrorCode reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;
virtual ErrorCode setUp(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;
};

class QNNCausalMaskCreator : public QNNBackend::Creator {
public:
virtual Op *create(OpParam op_param, Backend *bn, string name) const {
return new QNNCausalMask(bn, name);
}
};

} // namespace mllm

#endif
Loading
Loading