From 20e540d2e373226b23357464d43c86de1f8ef81e Mon Sep 17 00:00:00 2001 From: Cheng Tang Date: Mon, 11 Mar 2024 16:27:02 +0000 Subject: [PATCH 1/7] a hacky prototype --- cmake/ext_tests.cmake | 2 +- includes/custom_op_lite.h | 243 ++++++++++++++---------- includes/onnxruntime_cpp_api_legacy.hpp | 13 ++ test/shared_test/test_ortops_math.cc | 14 ++ 4 files changed, 172 insertions(+), 100 deletions(-) diff --git a/cmake/ext_tests.cmake b/cmake/ext_tests.cmake index fe710401b..64a343827 100644 --- a/cmake/ext_tests.cmake +++ b/cmake/ext_tests.cmake @@ -165,7 +165,7 @@ else() LIBRARIES ${extensions_test_libraries} TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data) - target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS}) + target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) target_compile_definitions(extensions_test PUBLIC ${OCOS_COMPILE_DEFINITIONS}) if(use_extensions_shared_library) diff --git a/includes/custom_op_lite.h b/includes/custom_op_lite.h index 365b020a7..42a681f89 100644 --- a/includes/custom_op_lite.h +++ b/includes/custom_op_lite.h @@ -12,13 +12,7 @@ namespace Custom { class TensorBase { public: - TensorBase(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : api_(api), - ctx_(ctx), - indice_(indice), - is_input_(is_input) {} + TensorBase() {} virtual ~TensorBase() = default; operator bool() const { @@ -60,10 +54,6 @@ class TensorBase { virtual size_t SizeInBytes() const = 0; protected: - const OrtW::CustomOpApi& api_; - OrtKernelContext& ctx_; - size_t indice_; - bool is_input_; std::optional> shape_; ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; const char* mem_type_ = "Cpu"; @@ -118,40 +108,105 @@ struct Span { #endif +class TensorStorage { +public: + virtual const void* DataRaw() const = 0; + virtual void* Allocate(const std::vector& shape) = 0; +}; + +class EagerTensorStorage : public TensorStorage { +public: + EagerTensorStorage(void* buffer, + ONNXTensorElementDataType type) : buffer_(buffer), type_(type) {} + const void* DataRaw() const override { + return buffer_; + } + + void* Allocate(const std::vector& shape) override { + if (!buffer_) { + // TODO: allocated with ORT allocator + int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); + // TODO: get size of type + auto buffer_size = n_elem * 4; + buffer_holder_ = std::make_unique(buffer_size); + buffer_ = buffer_holder_.get(); + } + return buffer_; + } +private: + void* buffer_; + std::unique_ptr buffer_holder_; + ONNXTensorElementDataType type_; +}; + +class OrtTensorStorage : public TensorStorage { +public: + OrtTensorStorage(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : api_(api), ctx_(ctx), indice_(indice) { + if (is_input){ + auto input_count = api_.KernelContext_GetInputCount(&ctx_); + if (indice >= input_count) { + ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); + } + const_value_ = api_.KernelContext_GetInput(&ctx_, indice); + } + } + + const void* DataRaw() const override { + return api_.GetTensorRawData(const_value_); + } + + void* Allocate(const std::vector& shape) override { + if (!const_value_) { + const_value_ = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size()); + } + return api_.GetTensorMutableRawData(const_cast(const_value_)); + } + +private: + const OrtW::CustomOpApi& api_; + OrtKernelContext& ctx_; + size_t indice_; + const OrtValue* const_value_{}; // for input + +}; + template class Tensor : public TensorBase { public: using TT = typename std::remove_reference::type; + Tensor(ONNXTensorElementDataType type, + void* buffer_ptr, + std::optional> shape) : storage_(std::make_unique(buffer_ptr, type)){ + shape_ = shape; + } + Tensor(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { + bool is_input) : storage_(std::make_unique(api, ctx, indice, is_input)) { + // init metadata if (is_input) { - auto input_count = api_.KernelContext_GetInputCount(&ctx_); - if (indice >= input_count) { - ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); - } - const_value_ = api_.KernelContext_GetInput(&ctx_, indice); - auto* info = api_.GetTensorTypeAndShape(const_value_); - shape_ = api_.GetTensorShape(info); - type_ = api_.GetTensorElementType(info); - api_.ReleaseTensorTypeAndShapeInfo(info); + const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice); + auto* info = api.GetTensorTypeAndShape(const_value); + shape_ = api.GetTensorShape(info); + type_ = api.GetTensorElementType(info); + api.ReleaseTensorTypeAndShapeInfo(info); const OrtMemoryInfo* mem_info = {}; - api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info)); + api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info)); if (mem_info) { - api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); + api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); } } } const TT* Data() const { - return api_.GetTensorData(const_value_); + return static_cast(storage_->DataRaw()); } const void* DataRaw() const override { - return reinterpret_cast(Data()); + return storage_->DataRaw(); } size_t SizeInBytes() const override { @@ -159,13 +214,13 @@ class Tensor : public TensorBase { } TT* Allocate(const std::vector& shape) { - if (!data_) { - OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size()); + // it should be OK to allocate multiple times + void* buffer = storage_->Allocate(shape); + if (!shape_.has_value()) shape_ = shape; - data_ = api_.GetTensorMutableData(out); - } - return data_; + return static_cast(buffer); } + const Span& AsSpan() { if (!shape_.has_value() || shape_->size() != 1) { ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); @@ -173,6 +228,7 @@ class Tensor : public TensorBase { span_.Assign(Data(), (*shape_)[0]); return span_; } + const T& AsScalar() { if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) { ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); @@ -181,8 +237,7 @@ class Tensor : public TensorBase { } private: - const OrtValue* const_value_{}; // for input - TT* data_{}; // for output + std::unique_ptr storage_; Span span_; }; @@ -194,10 +249,9 @@ class Tensor : public TensorBase { Tensor(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { + bool is_input) : api_(api), + ctx_(ctx), + indice_(indice) { if (is_input) { auto input_count = api_.KernelContext_GetInputCount(&ctx_); if (indice >= input_count) { @@ -268,6 +322,9 @@ class Tensor : public TensorBase { } private: + const OrtW::CustomOpApi& api_; + OrtKernelContext& ctx_; + size_t indice_; std::vector input_strings_; // for input }; @@ -280,11 +337,10 @@ class Tensor : public TensorBase { Tensor(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { - if (is_input_) { + bool is_input) : api_(api), + ctx_(ctx), + indice_(indice) { + if (is_input) { auto input_count = api_.KernelContext_GetInputCount(&ctx_); if (indice >= input_count) { ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); @@ -347,6 +403,9 @@ class Tensor : public TensorBase { } private: + const OrtW::CustomOpApi& api_; + OrtKernelContext& ctx_; + size_t indice_; std::vector chars_; // for input std::vector input_string_views_; // for input }; @@ -358,40 +417,33 @@ struct Tensor : public TensorBase { Tensor(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { + bool is_input) : storage_(std::make_unique(api, ctx, indice, is_input)) { + // init metadata type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - if (is_input_) { - auto input_count = api_.KernelContext_GetInputCount(&ctx_); - if (indice >= input_count) { - ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); - } - const_value_ = api_.KernelContext_GetInput(&ctx_, indice); - auto* info = api_.GetTensorTypeAndShape(const_value_); - shape_ = api_.GetTensorShape(info); - type_ = api_.GetTensorElementType(info); - api_.ReleaseTensorTypeAndShapeInfo(info); + if (is_input) { + const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice); + auto* info = api.GetTensorTypeAndShape(const_value); + shape_ = api.GetTensorShape(info); + type_ = api.GetTensorElementType(info); + api.ReleaseTensorTypeAndShapeInfo(info); const OrtMemoryInfo* mem_info = {}; - api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info)); + api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info)); if (mem_info) { - api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); + api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); } } } const MFloat16* Data() const { - return reinterpret_cast(api_.GetTensorData(const_value_)); + return reinterpret_cast(storage_->DataRaw()); } MFloat16* Allocate(const std::vector& shape) { - if (!data_) { - OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size()); + // it should be OK to allocate multiple times + void* buffer = storage_->Allocate(shape); + if (!shape_.has_value()) shape_ = shape; - data_ = reinterpret_cast(api_.GetTensorMutableData(out)); - } - return data_; + return reinterpret_cast(buffer); } const Span& AsSpan() { @@ -403,7 +455,7 @@ struct Tensor : public TensorBase { } const void* DataRaw() const override { - return reinterpret_cast(Data()); + return storage_->DataRaw(); } virtual size_t SizeInBytes() const override { @@ -411,8 +463,7 @@ struct Tensor : public TensorBase { } private: - const OrtValue* const_value_{}; // for input - MFloat16* data_{}; // for output + std::unique_ptr storage_; }; template <> @@ -420,40 +471,33 @@ struct Tensor : public TensorBase { Tensor(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { - type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - if (is_input_) { - auto input_count = api_.KernelContext_GetInputCount(&ctx_); - if (indice >= input_count) { - ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); - } - const_value_ = api_.KernelContext_GetInput(&ctx_, indice); - auto* info = api_.GetTensorTypeAndShape(const_value_); - shape_ = api_.GetTensorShape(info); - type_ = api_.GetTensorElementType(info); - api_.ReleaseTensorTypeAndShapeInfo(info); + bool is_input) : storage_(std::make_unique(api, ctx, indice, is_input)) { + // init metadata + type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; + if (is_input) { + const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice); + auto* info = api.GetTensorTypeAndShape(const_value); + shape_ = api.GetTensorShape(info); + type_ = api.GetTensorElementType(info); + api.ReleaseTensorTypeAndShapeInfo(info); const OrtMemoryInfo* mem_info = {}; - api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info)); + api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info)); if (mem_info) { - api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); + api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); } } } const BFloat16* Data() const { - return reinterpret_cast(api_.GetTensorData(const_value_)); + return reinterpret_cast(storage_->DataRaw()); } BFloat16* Allocate(const std::vector& shape) { - if (!data_) { - OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size()); + // it should be OK to allocate multiple times + void* buffer = storage_->Allocate(shape); + if (!shape_.has_value()) shape_ = shape; - data_ = reinterpret_cast(api_.GetTensorMutableData(out)); - } - return data_; + return reinterpret_cast(buffer); } const Span& AsSpan() { @@ -465,7 +509,7 @@ struct Tensor : public TensorBase { } const void* DataRaw() const override { - return reinterpret_cast(Data()); + return storage_->DataRaw(); } virtual size_t SizeInBytes() const override { @@ -473,8 +517,7 @@ struct Tensor : public TensorBase { } private: - const OrtValue* const_value_{}; // for input - BFloat16* data_{}; // for output + std::unique_ptr storage_; }; #endif @@ -487,10 +530,9 @@ struct Variadic : public TensorBase { Variadic(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { + bool is_input) : api_(api), + ctx_(ctx), + indice_(indice) { #if ORT_API_VERSION < 14 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION); #endif @@ -578,6 +620,9 @@ struct Variadic : public TensorBase { } private: + const OrtW::CustomOpApi& api_; + OrtKernelContext& ctx_; + size_t indice_; TensorPtrs tensors_; }; diff --git a/includes/onnxruntime_cpp_api_legacy.hpp b/includes/onnxruntime_cpp_api_legacy.hpp index 99f452282..55bad7728 100644 --- a/includes/onnxruntime_cpp_api_legacy.hpp +++ b/includes/onnxruntime_cpp_api_legacy.hpp @@ -30,6 +30,9 @@ struct CustomOpApi { template const T* GetTensorData(_Inout_ const OrtValue* value) const; + void* GetTensorMutableRawData(_Inout_ OrtValue* value) const; + const void* GetTensorRawData(_Inout_ const OrtValue* value) const; + std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const; void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const; size_t KernelContext_GetInputCount(const OrtKernelContext* context) const; @@ -162,6 +165,16 @@ inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const return GetTensorMutableData(const_cast(value)); } +inline void* CustomOpApi::GetTensorMutableRawData(_Inout_ OrtValue* value) const { + void* data = nullptr; + ThrowOnError(api_.GetTensorMutableData(value, &data)); + return data; +} + +inline const void* CustomOpApi::GetTensorRawData(_Inout_ const OrtValue* value) const { + return GetTensorMutableRawData(const_cast(value)); +} + inline std::vector CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const { std::vector output(GetDimensionsCount(info)); GetDimensions(info, output.data(), output.size()); diff --git a/test/shared_test/test_ortops_math.cc b/test/shared_test/test_ortops_math.cc index 72da1f985..ad56bdc76 100644 --- a/test/shared_test/test_ortops_math.cc +++ b/test/shared_test/test_ortops_math.cc @@ -6,6 +6,20 @@ #include "ocos.h" #include "test_kernel.hpp" +#include "operators/math/negpos.hpp" + +TEST(math_operator, eager_poc){ + std::vector input_data = {0.0f, 0.2f, -1.3f, 1.5f}; + + ortc::Tensor input(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + input_data.data(), + std::vector{2, 2}); + ortc::Tensor output1(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, nullptr, std::vector{2, 2}); + ortc::Tensor output2(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, nullptr, std::vector{2, 2}); + + auto result = neg_pos(input, output1, output2); + assert(!result); +} TEST(math_operator, segment_extraction) { auto ort_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "Default"); From 0ee757e55a6681b7fa370e213faa1e596dfd1787 Mon Sep 17 00:00:00 2001 From: Cheng Tang Date: Sat, 16 Mar 2024 00:23:17 +0000 Subject: [PATCH 2/7] refactory the code --- includes/custom_op_lite.h | 515 ++++++++------------------- includes/tensor_api.h | 249 +++++++++++++ test/shared_test/test_ortops_math.cc | 11 +- 3 files changed, 412 insertions(+), 363 deletions(-) create mode 100644 includes/tensor_api.h diff --git a/includes/custom_op_lite.h b/includes/custom_op_lite.h index 42a681f89..a94ed2a7a 100644 --- a/includes/custom_op_lite.h +++ b/includes/custom_op_lite.h @@ -2,163 +2,74 @@ // Licensed under the MIT License. #pragma once -#include "onnxruntime_customop.hpp" -#include "onnxruntime_f16.h" + #include #include +#include "tensor_api.h" namespace Ort { namespace Custom { -class TensorBase { - public: - TensorBase() {} - - virtual ~TensorBase() = default; - operator bool() const { - return shape_.has_value(); - } - const std::vector& Shape() const { - if (shape_.has_value()) { - return *shape_; - } else { - ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); - } - } - ONNXTensorElementDataType Type() const { - return type_; - } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); - } else { - ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); - } - } - std::string Shape2Str() const { - if (shape_.has_value()) { - std::string shape_str; - for (const auto& dim : *shape_) { - shape_str.append(std::to_string(dim)); - shape_str.append(", "); +class OrtKernelArg { +public: + OrtKernelArg(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : api_(api), ctx_(ctx), indice_(indice) { + if (is_input) { + const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice); + const OrtMemoryInfo* mem_info = {}; + api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info)); + if (mem_info) { + api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); } - return shape_str; - } else { - return "empty"; } } + bool IsCpuTensor() const { return strcmp("Cpu", mem_type_) == 0; } - virtual const void* DataRaw() const = 0; - virtual size_t SizeInBytes() const = 0; - protected: - std::optional> shape_; - ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; +protected: + const OrtW::CustomOpApi& api_; + OrtKernelContext& ctx_; + size_t indice_; const char* mem_type_ = "Cpu"; }; -template -struct Span { - const T* data_ = {}; - size_t size_ = {}; - void Assign(const T* data, size_t size) { - data_ = data; - size_ = size; - } - size_t size() const { return size_; } - T operator[](size_t indice) const { - return data_[indice]; - } - const T* data() const { return data_; } -}; - -#if ORT_API_VERSION >= 16 - -template <> -struct Span { - const MFloat16* data_ = {}; - size_t size_ = {}; - void Assign(const MFloat16* data, size_t size) { - data_ = data; - size_ = size; - } - size_t size() const { return size_; } - MFloat16 operator[](size_t indice) const { - return data_[indice]; - } - const MFloat16* data() const { return data_; } -}; - -template <> -struct Span { - const BFloat16* data_ = {}; - size_t size_ = {}; - void Assign(const BFloat16* data, size_t size) { - data_ = data; - size_ = size; - } - size_t size() const { return size_; } - BFloat16 operator[](size_t indice) const { - return data_[indice]; - } - const BFloat16* data() const { return data_; } -}; - -#endif - -class TensorStorage { -public: - virtual const void* DataRaw() const = 0; - virtual void* Allocate(const std::vector& shape) = 0; -}; - -class EagerTensorStorage : public TensorStorage { +class OrtKernelContextStorage : public ITensorStorage { public: - EagerTensorStorage(void* buffer, - ONNXTensorElementDataType type) : buffer_(buffer), type_(type) {} - const void* DataRaw() const override { - return buffer_; - } - - void* Allocate(const std::vector& shape) override { - if (!buffer_) { - // TODO: allocated with ORT allocator - int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); - // TODO: get size of type - auto buffer_size = n_elem * 4; - buffer_holder_ = std::make_unique(buffer_size); - buffer_ = buffer_holder_.get(); - } - return buffer_; - } -private: - void* buffer_; - std::unique_ptr buffer_holder_; - ONNXTensorElementDataType type_; -}; - -class OrtTensorStorage : public TensorStorage { -public: - OrtTensorStorage(const OrtW::CustomOpApi& api, + OrtKernelContextStorage(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, bool is_input) : api_(api), ctx_(ctx), indice_(indice) { if (is_input){ - auto input_count = api_.KernelContext_GetInputCount(&ctx_); + auto input_count = api.KernelContext_GetInputCount(&ctx); if (indice >= input_count) { ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); } - const_value_ = api_.KernelContext_GetInput(&ctx_, indice); + const_value_ = api.KernelContext_GetInput(&ctx, indice); + auto* info = api.GetTensorTypeAndShape(const_value_); + shape_ = api.GetTensorShape(info); + api.ReleaseTensorTypeAndShapeInfo(info); } } + const std::vector& Shape() const override { + if (!IsInitialized()) + ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); + return *shape_; + } + + virtual bool IsInitialized() const override { + return shape_.has_value(); + } + const void* DataRaw() const override { return api_.GetTensorRawData(const_value_); } - void* Allocate(const std::vector& shape) override { + void* Initialize(const std::vector& shape, size_t element_size) override { if (!const_value_) { const_value_ = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size()); } @@ -170,88 +81,29 @@ class OrtTensorStorage : public TensorStorage { OrtKernelContext& ctx_; size_t indice_; const OrtValue* const_value_{}; // for input - + std::optional> shape_; }; template -class Tensor : public TensorBase { - public: - using TT = typename std::remove_reference::type; - Tensor(ONNXTensorElementDataType type, - void* buffer_ptr, - std::optional> shape) : storage_(std::make_unique(buffer_ptr, type)){ - shape_ = shape; - } - - Tensor(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : storage_(std::make_unique(api, ctx, indice, is_input)) { - // init metadata - if (is_input) { - const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice); - auto* info = api.GetTensorTypeAndShape(const_value); - shape_ = api.GetTensorShape(info); - type_ = api.GetTensorElementType(info); - api.ReleaseTensorTypeAndShapeInfo(info); - const OrtMemoryInfo* mem_info = {}; - api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info)); - if (mem_info) { - api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); - } - } - } - const TT* Data() const { - return static_cast(storage_->DataRaw()); - } - - const void* DataRaw() const override { - return storage_->DataRaw(); - } - - size_t SizeInBytes() const override { - return NumberOfElement() * sizeof(TT); - } - - TT* Allocate(const std::vector& shape) { - // it should be OK to allocate multiple times - void* buffer = storage_->Allocate(shape); - if (!shape_.has_value()) - shape_ = shape; - return static_cast(buffer); - } - - const Span& AsSpan() { - if (!shape_.has_value() || shape_->size() != 1) { - ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); - } - span_.Assign(Data(), (*shape_)[0]); - return span_; - } - - const T& AsScalar() { - if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) { - ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); - } - return *Data(); +class OrtTensor : public OrtKernelArg, public Tensor { +public: + OrtTensor(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : OrtKernelArg(api, ctx, indice, is_input), + Tensor(std::make_unique(api, ctx, indice, is_input)){ } - - private: - std::unique_ptr storage_; - Span span_; }; template <> -class Tensor : public TensorBase { +class Tensor : public OrtKernelArg, public Arg { public: using strings = std::vector; Tensor(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : api_(api), - ctx_(ctx), - indice_(indice) { + bool is_input) : OrtKernelArg(api, ctx, indice, is_input) { if (is_input) { auto input_count = api_.KernelContext_GetInputCount(&ctx_); if (indice >= input_count) { @@ -261,7 +113,6 @@ class Tensor : public TensorBase { auto* const_value = api_.KernelContext_GetInput(&ctx_, indice); auto* info = api_.GetTensorTypeAndShape(const_value); shape_ = api_.GetTensorShape(info); - type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; api_.ReleaseTensorTypeAndShapeInfo(info); size_t num_chars; @@ -287,13 +138,44 @@ class Tensor : public TensorBase { const strings& Data() const { return input_strings_; } - const void* DataRaw() const override { + + const std::vector& Shape() const { + if (shape_.has_value()) { + return *shape_; + } else { + ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); + } + } + + int64_t NumberOfElement() const { + if (shape_.has_value()) { + return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); + } else { + ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); + } + } + + std::string Shape2Str() const { + if (shape_.has_value()) { + std::string shape_str; + for (const auto& dim : *shape_) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + + const void* DataRaw() const { if (input_strings_.size() != 1) { ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } return reinterpret_cast(input_strings_[0].c_str()); } - size_t SizeInBytes() const override { + size_t SizeInBytes() const { if (input_strings_.size() != 1) { ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } @@ -322,14 +204,22 @@ class Tensor : public TensorBase { } private: - const OrtW::CustomOpApi& api_; - OrtKernelContext& ctx_; - size_t indice_; std::vector input_strings_; // for input + std::optional> shape_; +}; + +// to make the metaprogramming magic happy. +template <> +class OrtTensor : public Tensor{ +public: + OrtTensor(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : Tensor(api, ctx, indice, is_input) {}; }; template <> -class Tensor : public TensorBase { +class Tensor : public OrtKernelArg, public Arg { public: using strings = std::vector; using string_views = std::vector; @@ -337,9 +227,7 @@ class Tensor : public TensorBase { Tensor(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : api_(api), - ctx_(ctx), - indice_(indice) { + bool is_input) : OrtKernelArg(api, ctx, indice, is_input) { if (is_input) { auto input_count = api_.KernelContext_GetInputCount(&ctx_); if (indice >= input_count) { @@ -348,7 +236,6 @@ class Tensor : public TensorBase { auto* const_value = api_.KernelContext_GetInput(&ctx_, indice); auto* info = api_.GetTensorTypeAndShape(const_value); shape_ = api_.GetTensorShape(info); - type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; api_.ReleaseTensorTypeAndShapeInfo(info); size_t num_chars; @@ -370,6 +257,14 @@ class Tensor : public TensorBase { } } } + const std::vector& Shape() const { + if (shape_.has_value()) { + return *shape_; + } else { + ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); + } + } + int64_t NumberOfElement() const { if (shape_.has_value()) { return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); @@ -377,16 +272,29 @@ class Tensor : public TensorBase { return 0; } } + std::string Shape2Str() const { + if (shape_.has_value()) { + std::string shape_str; + for (const auto& dim : *shape_) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + const string_views& Data() const { return input_string_views_; } - const void* DataRaw() const override { + const void* DataRaw() const { if (input_string_views_.size() != 1) { ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } return reinterpret_cast(input_string_views_[0].data()); } - size_t SizeInBytes() const override { + size_t SizeInBytes() const { if (input_string_views_.size() != 1) { ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } @@ -403,141 +311,35 @@ class Tensor : public TensorBase { } private: - const OrtW::CustomOpApi& api_; - OrtKernelContext& ctx_; - size_t indice_; std::vector chars_; // for input std::vector input_string_views_; // for input + std::optional> shape_; }; -#if ORT_API_VERSION >= 16 - -template <> -struct Tensor : public TensorBase { - Tensor(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : storage_(std::make_unique(api, ctx, indice, is_input)) { - // init metadata - type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - if (is_input) { - const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice); - auto* info = api.GetTensorTypeAndShape(const_value); - shape_ = api.GetTensorShape(info); - type_ = api.GetTensorElementType(info); - api.ReleaseTensorTypeAndShapeInfo(info); - const OrtMemoryInfo* mem_info = {}; - api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info)); - if (mem_info) { - api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); - } - } - } - - const MFloat16* Data() const { - return reinterpret_cast(storage_->DataRaw()); - } - - MFloat16* Allocate(const std::vector& shape) { - // it should be OK to allocate multiple times - void* buffer = storage_->Allocate(shape); - if (!shape_.has_value()) - shape_ = shape; - return reinterpret_cast(buffer); - } - - const Span& AsSpan() { - ORTX_CXX_API_THROW("AsSpan for MFloat16 not implemented", ORT_RUNTIME_EXCEPTION); - } - - const MFloat16& AsScalar() { - ORTX_CXX_API_THROW("AsScalar for MFloat16 not implemented", ORT_RUNTIME_EXCEPTION); - } - - const void* DataRaw() const override { - return storage_->DataRaw(); - } - - virtual size_t SizeInBytes() const override { - return NumberOfElement() * sizeof(uint16_t); - } - - private: - std::unique_ptr storage_; -}; - +// to make the metaprogramming magic happy. template <> -struct Tensor : public TensorBase { - Tensor(const OrtW::CustomOpApi& api, +class OrtTensor : public Tensor{ +public: + OrtTensor(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : storage_(std::make_unique(api, ctx, indice, is_input)) { - // init metadata - type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; - if (is_input) { - const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice); - auto* info = api.GetTensorTypeAndShape(const_value); - shape_ = api.GetTensorShape(info); - type_ = api.GetTensorElementType(info); - api.ReleaseTensorTypeAndShapeInfo(info); - const OrtMemoryInfo* mem_info = {}; - api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info)); - if (mem_info) { - api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); - } - } - } - - const BFloat16* Data() const { - return reinterpret_cast(storage_->DataRaw()); - } - - BFloat16* Allocate(const std::vector& shape) { - // it should be OK to allocate multiple times - void* buffer = storage_->Allocate(shape); - if (!shape_.has_value()) - shape_ = shape; - return reinterpret_cast(buffer); - } - - const Span& AsSpan() { - ORTX_CXX_API_THROW("AsSpan for BFloat16 not implemented", ORT_RUNTIME_EXCEPTION); - } - - const BFloat16& AsScalar() { - ORTX_CXX_API_THROW("AsScalar for BFloat16 not implemented", ORT_RUNTIME_EXCEPTION); - } - - const void* DataRaw() const override { - return storage_->DataRaw(); - } - - virtual size_t SizeInBytes() const override { - return NumberOfElement() * sizeof(uint16_t); - } - - private: - std::unique_ptr storage_; + bool is_input) : Tensor(api, ctx, indice, is_input) {} }; -#endif - -using TensorPtr = std::unique_ptr; +using TensorPtr = std::unique_ptr; using TensorPtrs = std::vector; // Represent variadic input or output -struct Variadic : public TensorBase { +struct Variadic : public OrtKernelArg, public Arg { Variadic(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : api_(api), - ctx_(ctx), - indice_(indice) { + bool is_input) : OrtKernelArg(api, ctx, indice, is_input) { #if ORT_API_VERSION < 14 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION); #endif if (is_input) { - auto input_count = api_.KernelContext_GetInputCount(&ctx_); + auto input_count = api.KernelContext_GetInputCount(&ctx_); for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input); auto* info = api_.GetTensorTypeAndShape(const_value); @@ -546,40 +348,40 @@ struct Variadic : public TensorBase { TensorPtr tensor; switch (type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; default: ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); @@ -593,7 +395,7 @@ struct Variadic : public TensorBase { } template T* AllocateOutput(size_t ith_output, const std::vector& shape) { - auto tensor = std::make_unique>(api_, ctx_, ith_output, false); + auto tensor = std::make_unique>(api_, ctx_, ith_output, false); auto raw_output = tensor.get()->Allocate(shape); tensors_.emplace_back(tensor.release()); return raw_output; @@ -604,11 +406,11 @@ struct Variadic : public TensorBase { tensors_.emplace_back(tensor.release()); return output; } - const void* DataRaw() const override { + const void* DataRaw() const { ORTX_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION); return nullptr; } - size_t SizeInBytes() const override { + size_t SizeInBytes() const { ORTX_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION); return 0; } @@ -620,9 +422,6 @@ struct Variadic : public TensorBase { } private: - const OrtW::CustomOpApi& api_; - OrtKernelContext& ctx_; - size_t indice_; TensorPtrs tensors_; }; @@ -718,7 +517,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if*>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ @@ -726,7 +525,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if&>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ @@ -735,7 +534,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { static typename std::enable_if*>>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_input < num_input) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ @@ -748,8 +547,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if*>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ - if (!tensors.back()->IsCpuTensor()) { \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + if (!reinterpret_cast*>(tensors.back().get())->IsCpuTensor()) { \ ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \ } \ std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ @@ -759,8 +558,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if&>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ - if (!tensors.back()->IsCpuTensor()) { \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + if (!reinterpret_cast*>(tensors.back().get())->IsCpuTensor()) { \ ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \ } \ std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ @@ -771,8 +570,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { static typename std::enable_if*>>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_input < num_input) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ - if (!tensors.back()->IsCpuTensor()) { \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + if (!reinterpret_cast*>(tensors.back().get())->IsCpuTensor()) { \ ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \ } \ std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ @@ -787,8 +586,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ - if (!tensors.back()->IsCpuTensor()) { \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + if (!reinterpret_cast*>(tensors.back().get())->IsCpuTensor()) { \ ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \ } \ std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ @@ -799,8 +598,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { static typename std::enable_if>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_input < num_input) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ - if (!tensors.back()->IsCpuTensor()) { \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + if (!reinterpret_cast*>(tensors.back().get())->IsCpuTensor()) { \ ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \ } \ std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ @@ -816,7 +615,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if*>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ @@ -824,7 +623,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if&>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ @@ -833,7 +632,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { static typename std::enable_if*>>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_output < num_output) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ diff --git a/includes/tensor_api.h b/includes/tensor_api.h new file mode 100644 index 000000000..5fdbbbaa5 --- /dev/null +++ b/includes/tensor_api.h @@ -0,0 +1,249 @@ +#include +#include +#include +#include "onnxruntime_customop.hpp" +#include "onnxruntime_f16.h" + +namespace Ort { +namespace Custom { + +// this is for the ORT custom op template magic +class Arg { +}; + +template +struct Span { + const T* data_ = {}; + size_t size_ = {}; + void Assign(const T* data, size_t size) { + data_ = data; + size_ = size; + } + size_t size() const { return size_; } + T operator[](size_t indice) const { + return data_[indice]; + } + const T* data() const { return data_; } +}; + + +#if ORT_API_VERSION >= 16 + +template <> +struct Span { + const MFloat16* data_ = {}; + size_t size_ = {}; + void Assign(const MFloat16* data, size_t size) { + data_ = data; + size_ = size; + } + size_t size() const { return size_; } + MFloat16 operator[](size_t indice) const { + return data_[indice]; + } + const MFloat16* data() const { return data_; } +}; + +template <> +struct Span { + const BFloat16* data_ = {}; + size_t size_ = {}; + void Assign(const BFloat16* data, size_t size) { + data_ = data; + size_ = size; + } + size_t size() const { return size_; } + BFloat16 operator[](size_t indice) const { + return data_[indice]; + } + const BFloat16* data() const { return data_; } +}; + +#endif + +class ITensorStorage{ +public: + virtual const std::vector& Shape() const = 0; + virtual const void* DataRaw() const = 0; + virtual bool IsInitialized() const = 0; + virtual void* Initialize(const std::vector& shape, size_t element_size) = 0; +}; + +class IAllocator { +public: + virtual void* Alloc(size_t size) = 0; + virtual void Free(void* p) = 0; +}; + +// TODO: remove this + +class TestAllocator : public IAllocator { +public: + void* Alloc(size_t size) override { + return malloc(size); + } + + void Free(void* p) override { + if (p){ + free(p); + } + } +}; + +class OrtEagerTensorStorage : public ITensorStorage { +public: + OrtEagerTensorStorage(const std::vector& shape, + void* buffer) : buffer_(buffer), shape_(shape){ + + } + + OrtEagerTensorStorage(IAllocator* allocator) : allocator_(allocator){ + } + + virtual ~OrtEagerTensorStorage(){ + if (allocator_ && buffer_) + allocator_->Free(buffer_); + } + + const std::vector& Shape() const override { + if (!IsInitialized()) + ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); + return *shape_; + } + + virtual bool IsInitialized() const override { + return shape_.has_value(); + } + + const void* DataRaw() const override { + return buffer_; + } + + void* Initialize(const std::vector& shape, size_t element_size) override { + if (IsInitialized()) + return buffer_; + assert(allocator_); + shape_ = shape; + int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); + auto buffer_size = n_elem * element_size; + buffer_ = allocator_->Alloc(buffer_size); + return buffer_; + } + +private: + void* buffer_ {}; + std::optional> shape_; + // caller need to make sure the allocator is alive + IAllocator* allocator_; +}; + +template +class Tensor : public Arg { + public: + using TT = typename std::remove_reference::type; + Tensor(std::unique_ptr tensor_storage) : storage_(std::move(tensor_storage)){ + } + + Tensor(const std::vector& shape, void* buffer) : Tensor(std::make_unique(shape, buffer)) {} + + Tensor(IAllocator* allocator) : storage_(std::make_unique(allocator)){} + + virtual ~Tensor() = default; + + operator bool() const { + return storage_->IsInitialized(); + } + + const std::vector& Shape() const { + return storage_->Shape(); + } + + int64_t NumberOfElement() const { + auto& shape = storage_->Shape(); + return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); + } + + std::string Shape2Str() const { + if (storage_->IsInitialized()) { + std::string shape_str; + auto& shape = storage_->Shape(); + for (const auto& dim : shape) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + const TT* Data() const { +#if ORT_API_VERSION >= 16 + if constexpr (std::is_same::value || std::is_same::value) + return reinterpret_cast(storage_->DataRaw()); + else +#endif + return static_cast(storage_->DataRaw()); + } + + const void* DataRaw() const { + return storage_->DataRaw(); + } + + size_t SizeInBytes() const { + return NumberOfElement() * sizeof(TT); + } + + TT* Allocate(const std::vector& shape) { + // it should be OK to allocate multiple times + void* buffer = storage_->Initialize(shape, sizeof(TT)); +#if ORT_API_VERSION >= 16 + if constexpr (std::is_same::value || std::is_same::value) + return reinterpret_cast(buffer); + else +#endif + return static_cast(buffer); + } + + const Span& AsSpan() { +#if ORT_API_VERSION >= 16 + if constexpr (std::is_same::value || std::is_same::value) { + ORTX_CXX_API_THROW("AsSpan for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION); + } + else{ +#endif + auto& shape = storage_->Shape(); + if (shape.size() != 1) { + ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); + } + span_.Assign(Data(), shape[0]); + return span_; +#if ORT_API_VERSION >= 16 + } +#endif + } + + const T& AsScalar() { +#if ORT_API_VERSION >= 16 + if constexpr (std::is_same::value || std::is_same::value) { + ORTX_CXX_API_THROW("AsScalar for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION); + } + else{ +#endif + auto& shape = storage_->Shape(); + if ((shape.size() == 1 && shape[0] != 1) || shape.size() > 1) { + ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); + } + return *Data(); +#if ORT_API_VERSION >= 16 + } +#endif + } + + private: + std::unique_ptr storage_; + Span span_; +}; + +} +} diff --git a/test/shared_test/test_ortops_math.cc b/test/shared_test/test_ortops_math.cc index ad56bdc76..c8168083c 100644 --- a/test/shared_test/test_ortops_math.cc +++ b/test/shared_test/test_ortops_math.cc @@ -9,16 +9,17 @@ #include "operators/math/negpos.hpp" TEST(math_operator, eager_poc){ + auto test_allocator = std::make_unique(); std::vector input_data = {0.0f, 0.2f, -1.3f, 1.5f}; - ortc::Tensor input(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - input_data.data(), - std::vector{2, 2}); - ortc::Tensor output1(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, nullptr, std::vector{2, 2}); - ortc::Tensor output2(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, nullptr, std::vector{2, 2}); + ortc::Tensor input(std::vector{2, 2}, input_data.data()); + + ortc::Tensor output1(test_allocator.get()); + ortc::Tensor output2(test_allocator.get()); auto result = neg_pos(input, output1, output2); assert(!result); + assert(output1.Shape() == input.Shape() && output2.Shape() == input.Shape()); } TEST(math_operator, segment_extraction) { From 42af3a371d12e45432e8d8445d6fc7d0db550caf Mon Sep 17 00:00:00 2001 From: Cheng Tang Date: Tue, 19 Mar 2024 03:56:37 +0000 Subject: [PATCH 3/7] support named argument dict based struct; support string tensor --- cmake/ext_tests.cmake | 2 + includes/custom_op_lite.h | 222 +++++++--------- includes/tensor_api.h | 245 ++++++++++++++++++ operators/math/negpos.hpp | 1 + operators/tokenizer/basic_tokenizer.cc | 20 +- operators/tokenizer/basic_tokenizer.hpp | 16 +- operators/tokenizer/bert_tokenizer_decoder.cc | 34 +-- .../tokenizer/bert_tokenizer_decoder.hpp | 23 +- test/shared_test/test_ortops_tokenizer.cc | 17 ++ 9 files changed, 425 insertions(+), 155 deletions(-) diff --git a/cmake/ext_tests.cmake b/cmake/ext_tests.cmake index 64a343827..941fa0b38 100644 --- a/cmake/ext_tests.cmake +++ b/cmake/ext_tests.cmake @@ -167,6 +167,8 @@ else() target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(extensions_test PRIVATE ocos_operators) + target_compile_definitions(extensions_test PUBLIC ${OCOS_COMPILE_DEFINITIONS}) if(use_extensions_shared_library) target_compile_definitions(extensions_test PUBLIC ORT_EXTENSIONS_UNIT_TEST_USE_EXTENSIONS_SHARED_LIBRARY) diff --git a/includes/custom_op_lite.h b/includes/custom_op_lite.h index a94ed2a7a..b43e8ff9c 100644 --- a/includes/custom_op_lite.h +++ b/includes/custom_op_lite.h @@ -6,6 +6,7 @@ #include #include #include "tensor_api.h" +#include "onnxruntime_cpp_api_legacy.hpp" namespace Ort { namespace Custom { @@ -95,15 +96,13 @@ class OrtTensor : public OrtKernelArg, public Tensor { } }; -template <> -class Tensor : public OrtKernelArg, public Arg { - public: +class OrtStringTensorStorage : public IStringTensorStorage{ +public: using strings = std::vector; - - Tensor(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : OrtKernelArg(api, ctx, indice, is_input) { + OrtStringTensorStorage(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : api_(api), ctx_(ctx), indice_(indice){ if (is_input) { auto input_count = api_.KernelContext_GetInputCount(&ctx_); if (indice >= input_count) { @@ -118,8 +117,9 @@ class Tensor : public OrtKernelArg, public Arg { size_t num_chars; OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars)); std::vector chars(num_chars + 1, '\0'); - auto num_strings = NumberOfElement(); - std::vector offsets(NumberOfElement()); + assert((*shape_).size() == 1); + auto num_strings = (*shape_)[0]; + std::vector offsets((*shape_)[0]); OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value, (void*)chars.data(), num_chars, @@ -135,53 +135,25 @@ class Tensor : public OrtKernelArg, public Arg { } } } - const strings& Data() const { - return input_strings_; - } - - const std::vector& Shape() const { - if (shape_.has_value()) { - return *shape_; - } else { - ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); - } - } - - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); - } else { - ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); - } - } - std::string Shape2Str() const { - if (shape_.has_value()) { - std::string shape_str; - for (const auto& dim : *shape_) { - shape_str.append(std::to_string(dim)); - shape_str.append(", "); - } - return shape_str; - } else { - return "empty"; - } + const std::vector& Shape() const override { + if (!IsInitialized()) + ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); + return *shape_; } - - const void* DataRaw() const { + virtual const void* DataRaw() const override { if (input_strings_.size() != 1) { ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } return reinterpret_cast(input_strings_[0].c_str()); } - size_t SizeInBytes() const { - if (input_strings_.size() != 1) { - ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); - } - return input_strings_[0].size(); + + virtual bool IsInitialized() const override { + return shape_.has_value(); } - void SetStringOutput(const strings& ss, const std::vector& dims) { + + virtual void SetStringOutput(const strings& ss, const std::vector& dims) override { std::vector raw; for (const auto& s : ss) { raw.push_back(s.data()); @@ -189,45 +161,32 @@ class Tensor : public OrtKernelArg, public Arg { auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size()); OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size())); } - void SetStringOutput(const std::vector& ss, const std::vector& dims) { + + virtual void SetStringOutput(const std::vector& ss, const std::vector& dims) override { auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size()); OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size())); } - const Span& AsSpan() { - ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION); - } - const std::string& AsScalar() { - if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) { - ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); - } - return input_strings_[0]; + + const strings& Data() const override { + return input_strings_; } - private: - std::vector input_strings_; // for input +private: + const OrtW::CustomOpApi& api_; + OrtKernelContext& ctx_; + size_t indice_; + std::vector input_strings_; std::optional> shape_; }; -// to make the metaprogramming magic happy. -template <> -class OrtTensor : public Tensor{ -public: - OrtTensor(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : Tensor(api, ctx, indice, is_input) {}; -}; - -template <> -class Tensor : public OrtKernelArg, public Arg { - public: - using strings = std::vector; - using string_views = std::vector; - Tensor(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : OrtKernelArg(api, ctx, indice, is_input) { +class OrtStringViewTensorStorage : public IStringTensorStorage{ +public: + using strings = std::vector; + OrtStringViewTensorStorage(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : api_(api), ctx_(ctx), indice_(indice){ if (is_input) { auto input_count = api_.KernelContext_GetInputCount(&ctx_); if (indice >= input_count) { @@ -242,7 +201,7 @@ class Tensor : public OrtKernelArg, public Arg { OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars)); chars_.resize(num_chars + 1, '\0'); - auto num_strings = static_cast(NumberOfElement()); + auto num_strings = static_cast((*shape_)[0]); if (num_strings) { std::vector offsets(num_strings); OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value, @@ -257,60 +216,40 @@ class Tensor : public OrtKernelArg, public Arg { } } } - const std::vector& Shape() const { - if (shape_.has_value()) { - return *shape_; - } else { - ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); - } - } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); - } else { - return 0; - } - } - std::string Shape2Str() const { - if (shape_.has_value()) { - std::string shape_str; - for (const auto& dim : *shape_) { - shape_str.append(std::to_string(dim)); - shape_str.append(", "); - } - return shape_str; - } else { - return "empty"; - } + const std::vector& Shape() const override { + if (!IsInitialized()) + ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); + return *shape_; } - const string_views& Data() const { - return input_string_views_; - } - const void* DataRaw() const { + virtual const void* DataRaw() const override { if (input_string_views_.size() != 1) { ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } return reinterpret_cast(input_string_views_[0].data()); } - size_t SizeInBytes() const { - if (input_string_views_.size() != 1) { - ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); - } - return input_string_views_[0].size(); + + virtual bool IsInitialized() const override { + return shape_.has_value(); } - const Span& AsSpan() { - ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION); + + virtual void SetStringOutput(const strings& ss, const std::vector& dims) override { + ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION); } - std::string_view AsScalar() { - if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) { - ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); - } - return input_string_views_[0]; + + virtual void SetStringOutput(const std::vector& ss, const std::vector& dims) override { + ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION); } - private: + const strings& Data() const override { + return input_string_views_; + } + +private: + const OrtW::CustomOpApi& api_; + OrtKernelContext& ctx_; + size_t indice_; std::vector chars_; // for input std::vector input_string_views_; // for input std::optional> shape_; @@ -318,12 +257,25 @@ class Tensor : public OrtKernelArg, public Arg { // to make the metaprogramming magic happy. template <> -class OrtTensor : public Tensor{ +class OrtTensor : public OrtKernelArg, + public Tensor{ public: OrtTensor(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : Tensor(api, ctx, indice, is_input) {} + bool is_input) : OrtKernelArg(api, ctx, indice, is_input), + Tensor(std::make_unique(api, ctx, indice, is_input)) {} +}; + +template <> +class OrtTensor : public OrtKernelArg, + public Tensor{ +public: + OrtTensor(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : OrtKernelArg(api, ctx, indice, is_input), + Tensor(std::make_unique(api, ctx, indice, is_input)) {} }; using TensorPtr = std::unique_ptr; @@ -401,7 +353,7 @@ struct Variadic : public OrtKernelArg, public Arg { return raw_output; } Tensor& AllocateStringTensor(size_t ith_output) { - auto tensor = std::make_unique>(api_, ctx_, ith_output, false); + auto tensor = std::make_unique>(api_, ctx_, ith_output, false); Tensor& output = *tensor; tensors_.emplace_back(tensor.release()); return output; @@ -915,6 +867,20 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFn compute_fn_; }; +class OrtAttributeReader { +public: + OrtAttributeReader(const OrtApi& api, const OrtKernelInfo& info) : base_kernel_(api, info) { + } + + template + T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept { + return base_kernel_.TryToGetAttributeWithDefault(name, default_value); + } + +private: + BaseKernel base_kernel_; +}; + template struct OrtLiteCustomStruct : public OrtLiteCustomOp { template @@ -951,7 +917,13 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->custom_op_ = std::make_unique(*ort_api, *info); + + if constexpr (std::is_constructible::value){ + kernel->custom_op_ = std::make_unique(*ort_api, *info); + } + else { + kernel->custom_op_ = std::make_unique(OrtAttributeReader(*ort_api, *info)); + } auto self = static_cast(this_); kernel->ep_ = self->execution_provider_; kernel->api_ = std::make_unique(*ort_api); diff --git a/includes/tensor_api.h b/includes/tensor_api.h index 5fdbbbaa5..0923e4f5a 100644 --- a/includes/tensor_api.h +++ b/includes/tensor_api.h @@ -245,5 +245,250 @@ class Tensor : public Arg { Span span_; }; +template +class IStringTensorStorage{ +public: + using strings = std::vector; + virtual const std::vector& Shape() const = 0; + virtual const void* DataRaw() const = 0; + virtual const strings& Data() const = 0; + virtual bool IsInitialized() const = 0; + virtual void SetStringOutput(const strings& ss, const std::vector& dims) = 0; + virtual void SetStringOutput(const std::vector& ss, const std::vector& dims) = 0; +}; + +template +class EagerStringTensorStorage : public IStringTensorStorage{ +public: + using strings = std::vector; + EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector{ss.size()}){} + + EagerStringTensorStorage() {} + + const std::vector& Shape() const override { + if (!IsInitialized()) + ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); + return *shape_; + } + + virtual const void* DataRaw() const override { + if (input_strings_.size() != 1) { + ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + if constexpr (std::is_same::value) + return reinterpret_cast(input_strings_[0].data()); + else + return reinterpret_cast(input_strings_[0].c_str()); + } + + virtual bool IsInitialized() const override { + return shape_.has_value(); + } + + virtual void SetStringOutput(const strings& ss, const std::vector& dims) override { + if constexpr (std::is_same::value) + ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION); + input_strings_.assign(ss.begin(), ss.end()); + shape_ = dims; + } + + const strings& Data() const override { + return input_strings_; + } + + virtual void SetStringOutput(const std::vector& ss, const std::vector& dims) override { + if constexpr (std::is_same::value) + ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION); + + for (const char* s : ss){ + input_strings_.push_back(s); + } + shape_ = dims; + } + +private: + std::vector input_strings_; + std::optional> shape_; +}; + +template <> +class Tensor : public Arg { + public: + using strings = std::vector; + + Tensor(std::unique_ptr> storage) : storage_(std::move(storage)) {} + + Tensor(const strings& ss) : storage_(std::make_unique>(ss)) {} + + Tensor() : storage_(std::make_unique>()) {} + + const strings& Data() const { + return storage_->Data(); + } + + const std::vector& Shape() const { + return storage_->Shape(); + } + + int64_t NumberOfElement() const { + auto& shape = storage_->Shape(); + return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); + } + + std::string Shape2Str() const { + if (storage_->IsInitialized()) { + std::string shape_str; + auto& shape = storage_->Shape(); + for (const auto& dim : shape) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + const void* DataRaw() const { + return storage_->DataRaw(); + } + + size_t SizeInBytes() const { + auto& ss = storage_->Data(); + if (ss.size() != 1) { + ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return ss[0].size(); + } + + void SetStringOutput(const strings& ss, const std::vector& dims) { + storage_->SetStringOutput(ss, dims); + } + void SetStringOutput(const std::vector& ss, const std::vector& dims) { + storage_->SetStringOutput(ss, dims); + } + const Span& AsSpan() { + ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION); + } + const std::string& AsScalar() { + auto& ss = storage_->Data(); + if (ss.size() != 1) { + ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); + } + return ss[0]; + } + + private: + std::unique_ptr> storage_; +}; + + +template <> +class Tensor : public Arg { + public: + using strings = std::vector; + + Tensor(std::unique_ptr> storage) : storage_(std::move(storage)) {} + + Tensor(const strings& ss) : storage_(std::make_unique>(ss)) {} + + const strings& Data() const { + return storage_->Data(); + } + + const std::vector& Shape() const { + return storage_->Shape(); + } + + int64_t NumberOfElement() const { + auto& shape = storage_->Shape(); + return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); + } + + std::string Shape2Str() const { + if (storage_->IsInitialized()) { + std::string shape_str; + auto& shape = storage_->Shape(); + for (const auto& dim : shape) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + const void* DataRaw() const { + return storage_->DataRaw(); + } + + size_t SizeInBytes() const { + auto& ss = storage_->Data(); + if (ss.size() != 1) { + ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return ss[0].size(); + } + + void SetStringOutput(const strings& ss, const std::vector& dims) { + storage_->SetStringOutput(ss, dims); + } + void SetStringOutput(const std::vector& ss, const std::vector& dims) { + storage_->SetStringOutput(ss, dims); + } + const Span& AsSpan() { + ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION); + } + const std::string_view& AsScalar() { + auto& ss = storage_->Data(); + if (ss.size() != 1) { + ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); + } + return ss[0]; + } + + private: + std::unique_ptr> storage_; +}; + + +template +class NamedArgumentDict{ +public: + using ValueTuple = std::tuple; + + NamedArgumentDict(const std::vector& keys, const std::tuple& args) : names_(keys), entries_(args) { + } + + template + T TryToGetAttributeWithDefault(const char* name, const T& default_value) const { + return TryToGetAttributeWithDefaultInternal<0>(name, default_value); + } + +private: + template + typename std::enable_if::type + TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const { + return default_value; + } + + template + typename std::enable_if::type + TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const { + if (names_[I] == name){ + if constexpr (std::is_same, T>::value) + return std::get(entries_); + else + throw std::runtime_error("name matched but type is not"); + } + return TryToGetAttributeWithDefaultInternal(name, default_value); + } + + std::vector names_; + std::tuple entries_; + +}; + } } diff --git a/operators/math/negpos.hpp b/operators/math/negpos.hpp index 62dc4f34e..83b82534f 100644 --- a/operators/math/negpos.hpp +++ b/operators/math/negpos.hpp @@ -24,3 +24,4 @@ OrtStatusPtr neg_pos(const ortc::Tensor& input, return nullptr; } + diff --git a/operators/tokenizer/basic_tokenizer.cc b/operators/tokenizer/basic_tokenizer.cc index 8c9a11f8d..3a2a9e06c 100644 --- a/operators/tokenizer/basic_tokenizer.cc +++ b/operators/tokenizer/basic_tokenizer.cc @@ -81,16 +81,16 @@ std::vector BasicTokenizer::Tokenize(ustring text) { return result; } -KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) { - bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true); - bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true); - bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false); - bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false); - bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true); - - tokenizer_ = std::make_shared(do_lower_case, tokenize_chinese_chars, strip_accents, - tokenize_punctuation, remove_control_chars); -} +// KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) { +// bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true); +// bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true); +// bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false); +// bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false); +// bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true); + +// tokenizer_ = std::make_shared(do_lower_case, tokenize_chinese_chars, strip_accents, +// tokenize_punctuation, remove_control_chars); +// } void KernelBasicTokenizer::Compute(std::string_view input, ortc::Tensor& output) const { diff --git a/operators/tokenizer/basic_tokenizer.hpp b/operators/tokenizer/basic_tokenizer.hpp index 713bd956f..85c71fab8 100644 --- a/operators/tokenizer/basic_tokenizer.hpp +++ b/operators/tokenizer/basic_tokenizer.hpp @@ -21,8 +21,20 @@ class BasicTokenizer { bool remove_control_chars_; }; -struct KernelBasicTokenizer : BaseKernel { - KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info); +struct KernelBasicTokenizer { + + template + KernelBasicTokenizer(const T& dict) { + bool do_lower_case = dict.TryToGetAttributeWithDefault("do_lower_case", true); + bool tokenize_chinese_chars = dict.TryToGetAttributeWithDefault("tokenize_chinese_chars", true); + bool strip_accents = dict.TryToGetAttributeWithDefault("strip_accents", false); + bool tokenize_punctuation = dict.TryToGetAttributeWithDefault("tokenize_punctuation", false); + bool remove_control_chars = dict.TryToGetAttributeWithDefault("remove_control_chars", true); + + tokenizer_ = std::make_shared(do_lower_case, tokenize_chinese_chars, strip_accents, + tokenize_punctuation, remove_control_chars); + } + void Compute(std::string_view input, ortc::Tensor& output) const; diff --git a/operators/tokenizer/bert_tokenizer_decoder.cc b/operators/tokenizer/bert_tokenizer_decoder.cc index e8742c53c..d03131fd3 100644 --- a/operators/tokenizer/bert_tokenizer_decoder.cc +++ b/operators/tokenizer/bert_tokenizer_decoder.cc @@ -119,22 +119,24 @@ bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new return false; } -KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) { - std::string vocab = ort_.KernelInfoGetAttribute(&info, "vocab_file"); - std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]")); - std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]")); - std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]")); - std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]")); - std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]")); - std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##")); - - use_indices_ = TryToGetAttributeWithDefault("use_indices", false); - skip_special_tokens_ = TryToGetAttributeWithDefault("skip_special_tokens", false); - clean_up_tokenization_spaces_ = TryToGetAttributeWithDefault("clean_up_tokenization_spaces", true); - - decoder_ = std::make_shared(vocab, unk_token, sep_token, pad_token, - cls_token, mask_token, suffix_indicator); -} +// template +// KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const T& dict) { +// //std::string vocab = ort_.KernelInfoGetAttribute(&info, "vocab_file"); +// std::string vocab = dict.TryToGetAttributeWithDefault("vocab_file", std::string("")); +// std::string unk_token = dict.TryToGetAttributeWithDefault("unk_token", std::string("[UNK]")); +// std::string sep_token = dict.TryToGetAttributeWithDefault("sep_token", std::string("[SEP]")); +// std::string pad_token = dict.TryToGetAttributeWithDefault("pad_token", std::string("[PAD]")); +// std::string cls_token = dict.TryToGetAttributeWithDefault("cls_token", std::string("[CLS]")); +// std::string mask_token = dict.TryToGetAttributeWithDefault("mask_token", std::string("[MASK]")); +// std::string suffix_indicator = dict.TryToGetAttributeWithDefault("suffix_indicator", std::string("##")); + +// use_indices_ = dict.TryToGetAttributeWithDefault("use_indices", false); +// skip_special_tokens_ = dict.TryToGetAttributeWithDefault("skip_special_tokens", false); +// clean_up_tokenization_spaces_ = dict.TryToGetAttributeWithDefault("clean_up_tokenization_spaces", true); + +// decoder_ = std::make_shared(vocab, unk_token, sep_token, pad_token, +// cls_token, mask_token, suffix_indicator); +// } void KernelBertTokenizerDecoder::Compute(const ortc::Tensor& ids, const ortc::Tensor& positions, diff --git a/operators/tokenizer/bert_tokenizer_decoder.hpp b/operators/tokenizer/bert_tokenizer_decoder.hpp index 16441c484..e5aa34859 100644 --- a/operators/tokenizer/bert_tokenizer_decoder.hpp +++ b/operators/tokenizer/bert_tokenizer_decoder.hpp @@ -29,8 +29,27 @@ class BertTokenizerDecoder { bool RemoveTokenizeSpace(int64_t pre_token_id, int64_t new_token_id); }; -struct KernelBertTokenizerDecoder : BaseKernel { - KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info); +struct KernelBertTokenizerDecoder { + + template + KernelBertTokenizerDecoder(const T& dict) { + //std::string vocab = ort_.KernelInfoGetAttribute(&info, "vocab_file"); + std::string vocab = dict.TryToGetAttributeWithDefault("vocab_file", std::string("")); + std::string unk_token = dict.TryToGetAttributeWithDefault("unk_token", std::string("[UNK]")); + std::string sep_token = dict.TryToGetAttributeWithDefault("sep_token", std::string("[SEP]")); + std::string pad_token = dict.TryToGetAttributeWithDefault("pad_token", std::string("[PAD]")); + std::string cls_token = dict.TryToGetAttributeWithDefault("cls_token", std::string("[CLS]")); + std::string mask_token = dict.TryToGetAttributeWithDefault("mask_token", std::string("[MASK]")); + std::string suffix_indicator = dict.TryToGetAttributeWithDefault("suffix_indicator", std::string("##")); + + use_indices_ = dict.TryToGetAttributeWithDefault("use_indices", false); + skip_special_tokens_ = dict.TryToGetAttributeWithDefault("skip_special_tokens", false); + clean_up_tokenization_spaces_ = dict.TryToGetAttributeWithDefault("clean_up_tokenization_spaces", true); + + decoder_ = std::make_shared(vocab, unk_token, sep_token, pad_token, + cls_token, mask_token, suffix_indicator); + } + void Compute(const ortc::Tensor& ids, const ortc::Tensor& positions, ortc::Tensor& output) const; diff --git a/test/shared_test/test_ortops_tokenizer.cc b/test/shared_test/test_ortops_tokenizer.cc index 5933f358f..b87e3d557 100644 --- a/test/shared_test/test_ortops_tokenizer.cc +++ b/test/shared_test/test_ortops_tokenizer.cc @@ -7,6 +7,23 @@ #include "ocos.h" #include "test_kernel.hpp" +#include "operators/tokenizer/basic_tokenizer.hpp" + +TEST(basic_tokenizer, eager) { + std::string test_case = "I mean, you’ll need something to talk about next Sunday, right?"; + std::vector expect_result = {"I", "mean", ",", "you", "’", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"}; + + ortc::NamedArgumentDict dict({"do_lower_case", "tokenize_chinese_chars", "strip_accents", "tokenize_punctuation", "remove_control_chars"}, + std::make_tuple(false, true, true, true, true)); + + KernelBasicTokenizer tokenizer(dict); + + //ortc::Tensor input(std::vector{test_case}); + ortc::Tensor output; + tokenizer.Compute(test_case, output); + EXPECT_EQ(output.Data(), expect_result); +} + TEST(tokenizer_opertors, test_bert_tokenizer) { auto ort_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "Default"); From 69fcc2d96a0f0a301f2124863136609809357f46 Mon Sep 17 00:00:00 2001 From: Cheng Tang Date: Mon, 1 Apr 2024 14:46:08 -0700 Subject: [PATCH 4/7] cuda test --- CMakeLists.txt | 2 +- build.sh | 2 +- includes/custom_op_lite.h | 143 ++++++++++++++++++++++++ includes/onnxruntime_cpp_api_legacy.hpp | 2 +- includes/tensor_api.h | 8 +- 5 files changed, 149 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f67e31e49..09e70a8da 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -797,7 +797,7 @@ if(_BUILD_SHARED_LIBRARY) standardize_output_folder(extensions_shared) if(LINUX OR ANDROID) - set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver") + # set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver") # strip if not a debug build if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-s") diff --git a/build.sh b/build.sh index 3b7379c4c..4522e66a5 100755 --- a/build.sh +++ b/build.sh @@ -27,4 +27,4 @@ if [ -n "$cuda_arch" ]; then param="$@ -DCMAKE_CUDA_ARCHITECTURE=$cuda_arch ../../.." fi # it looks the parallel build on CI pipeline machine causes crashes. -cmake $param && cmake --build . --config $BUILD_FLAVOR --parallel "${CPU_NUMBER}" +cmake "$@" ../../.. "-DOCOS_USE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" && cmake --build . --config $BUILD_FLAVOR --parallel "${CPU_NUMBER}" diff --git a/includes/custom_op_lite.h b/includes/custom_op_lite.h index e1be3f473..b0143af38 100644 --- a/includes/custom_op_lite.h +++ b/includes/custom_op_lite.h @@ -377,6 +377,36 @@ struct Variadic : public OrtKernelArg, public Arg { TensorPtrs tensors_; }; +class OrtGraphKernelContext : public KernelContext { +public: + OrtGraphKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) { + OrtMemoryInfo* info; + OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info)); + OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &allocator_)); + api.ReleaseMemoryInfo(info); + } + + virtual ~OrtGraphKernelContext(){ + if (allocator_){ + api_.ReleaseAllocator(allocator_); + } + } + + void* AllocScratchBuffer(size_t size) override{ + return allocator_->Alloc(allocator_, size); + } + + void FreeScratchBuffer(void* p) override { + if (p){ + allocator_->Free(allocator_, p); + } + } + +private: + const OrtApi& api_; + OrtAllocator* allocator_; +}; + #ifdef USE_CUDA enum CudaResource { @@ -412,6 +442,89 @@ struct CudaContext { int device_id = 0; }; + +class OrtGraphCudaKernelContext : public CUDAKernelContext { +public: + static const int cuda_resource_ver = 1; + + OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) { + api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_); + if (!cuda_stream_) { + ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION); + } + api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_); + if (!cublas_) { + ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION); + } + void* resource = nullptr; + OrtStatusPtr result = api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource); + if (result) { + ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION); + } + memcpy(&device_id_, &resource, sizeof(int)); + + OrtMemoryInfo* info; + OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info)); + OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_)); + api.ReleaseMemoryInfo(info); + + OrtMemoryInfo* cuda_mem_info; + OrtW::ThrowOnError(api, api.CreateMemoryInfo("GPU", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info)); + OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_)); + api.ReleaseMemoryInfo(cuda_mem_info); + + } + + virtual ~OrtGraphCudaKernelContext(){ + if (cpu_allocator_){ + api_.ReleaseAllocator(cpu_allocator_); + } + if (cuda_allocator_){ + api_.ReleaseAllocator(cuda_allocator_); + } + } + + void* AllocScratchBuffer(size_t size) override{ + return cpu_allocator_->Alloc(cpu_allocator_, size); + } + + void FreeScratchBuffer(void* p) override { + if (p){ + cpu_allocator_->Free(cpu_allocator_, p); + } + } + + void* AllocCudaScratchBuffer(size_t size) override { + return cuda_allocator_->Alloc(cuda_allocator_, size); + } + + void FreeCudaScratchBuffer(void* p) override { + if (p){ + cuda_allocator_->Free(cuda_allocator_, p); + } + } + + void* GetCudaStream() const override { + return cuda_stream_; + } + + void* GetCublasHandle() const override { + return cublas_; + } + + int GetCudaDeviceId() const override { + return device_id_; + } + +private: + const OrtApi& api_; + OrtAllocator* cpu_allocator_; + OrtAllocator* cuda_allocator_; + void* cuda_stream_ = {}; + void* cublas_ = {}; + int device_id_ = 0; +}; + #endif // using mf16_t = uint16_t; @@ -444,6 +557,24 @@ struct OrtLiteCustomOp : public OrtCustomOp { } #endif + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + tensors.push_back(std::make_unique(api->GetOrtApi(), *context)); + std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; + auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + tensors.push_back(std::make_unique(api->GetOrtApi(), *context)); + std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; + auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + #if ORT_API_VERSION >= 14 template static typename std::enable_if::value, std::tuple>::type @@ -653,6 +784,18 @@ struct OrtLiteCustomOp : public OrtCustomOp { } #endif + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } + #if ORT_API_VERSION >= 14 template static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type diff --git a/includes/onnxruntime_cpp_api_legacy.hpp b/includes/onnxruntime_cpp_api_legacy.hpp index 16bd8131e..ddacb70d1 100644 --- a/includes/onnxruntime_cpp_api_legacy.hpp +++ b/includes/onnxruntime_cpp_api_legacy.hpp @@ -40,7 +40,7 @@ struct CustomOpApi { size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const; OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) const; - + void ThrowOnError(OrtStatus* status) const { OrtW::ThrowOnError(api_, status); } diff --git a/includes/tensor_api.h b/includes/tensor_api.h index 0923e4f5a..4a1e8c86d 100644 --- a/includes/tensor_api.h +++ b/includes/tensor_api.h @@ -1,16 +1,14 @@ +#pragma once #include #include #include #include "onnxruntime_customop.hpp" #include "onnxruntime_f16.h" +#include "kernel_context.h" namespace Ort { namespace Custom { -// this is for the ORT custom op template magic -class Arg { -}; - template struct Span { const T* data_ = {}; @@ -69,6 +67,7 @@ class ITensorStorage{ virtual void* Initialize(const std::vector& shape, size_t element_size) = 0; }; + class IAllocator { public: virtual void* Alloc(size_t size) = 0; @@ -76,7 +75,6 @@ class IAllocator { }; // TODO: remove this - class TestAllocator : public IAllocator { public: void* Alloc(size_t size) override { From cf7d14bc9c2a1a51de5c02a657d00a1943fbef55 Mon Sep 17 00:00:00 2001 From: Cheng Tang Date: Tue, 2 Apr 2024 10:41:17 -0700 Subject: [PATCH 5/7] add missing file --- includes/kernel_context.h | 34 +++++++++++++++++++++++++++++++ operators/math/cuda/negpos_def.cc | 4 ++-- operators/math/cuda/negpos_def.h | 2 +- 3 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 includes/kernel_context.h diff --git a/includes/kernel_context.h b/includes/kernel_context.h new file mode 100644 index 000000000..520056503 --- /dev/null +++ b/includes/kernel_context.h @@ -0,0 +1,34 @@ +#pragma once +#include +#include +#include + +namespace Ort { +namespace Custom { + +// this is for the ORT custom op template magic +class Arg { +}; + +class KernelContext : public Arg{ +public: + virtual void* AllocScratchBuffer(size_t size) = 0; + virtual void FreeScratchBuffer(void* p) = 0; + // TODO: threadpool? +}; + +#ifdef USE_CUDA +class CUDAKernelContext : public KernelContext { +public: + virtual void* AllocCudaScratchBuffer(size_t size) = 0; + virtual void FreeCudaScratchBuffer(void* p) = 0; + virtual void* GetCudaStream() const = 0; + virtual void* GetCublasHandle() const = 0; + virtual int GetCudaDeviceId() const = 0; +}; +#endif + +// TODO: helper func to create context from global ORT env. + +} +} \ No newline at end of file diff --git a/operators/math/cuda/negpos_def.cc b/operators/math/cuda/negpos_def.cc index b1a78b8be..9d9c6e16c 100644 --- a/operators/math/cuda/negpos_def.cc +++ b/operators/math/cuda/negpos_def.cc @@ -4,7 +4,7 @@ #include #include -OrtStatusPtr neg_pos_cuda(const Ort::Custom::CudaContext& ctx, +OrtStatusPtr neg_pos_cuda(Ort::Custom::CUDAKernelContext& ctx, const ortc::Tensor& input, ortc::Tensor& out0_tensor, ortc::Tensor& out1_tensor) { @@ -13,6 +13,6 @@ OrtStatusPtr neg_pos_cuda(const Ort::Custom::CudaContext& ctx, float* out1 = out1_tensor.Allocate(input.Shape()); const float* X = input.Data(); - neg_pos_impl(reinterpret_cast(ctx.cuda_stream), X, out0, out1, size); + neg_pos_impl(reinterpret_cast(ctx.GetCudaStream()), X, out0, out1, size); return nullptr; } diff --git a/operators/math/cuda/negpos_def.h b/operators/math/cuda/negpos_def.h index 3ae0f4ef9..5479c7ada 100644 --- a/operators/math/cuda/negpos_def.h +++ b/operators/math/cuda/negpos_def.h @@ -4,7 +4,7 @@ #pragma once #include "ocos.h" -OrtStatusPtr neg_pos_cuda(const Ort::Custom::CudaContext& ctx, +OrtStatusPtr neg_pos_cuda(Ort::Custom::CUDAKernelContext& ctx, const ortc::Tensor& input, ortc::Tensor& out0_tensor, ortc::Tensor& out1_tensor); From 31150c15374e679136475aad0d0584876655dbf7 Mon Sep 17 00:00:00 2001 From: jslhcl Date: Wed, 3 Apr 2024 15:20:44 -0700 Subject: [PATCH 6/7] add UT for neg_pos_cuda in eager mode and fix build break in Windows --- includes/tensor_api.h | 2 +- test/shared_test/test_ortops_cuda.cc | 55 ++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/includes/tensor_api.h b/includes/tensor_api.h index 4a1e8c86d..17f456189 100644 --- a/includes/tensor_api.h +++ b/includes/tensor_api.h @@ -259,7 +259,7 @@ template class EagerStringTensorStorage : public IStringTensorStorage{ public: using strings = std::vector; - EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector{ss.size()}){} + EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector{static_cast(ss.size())}){} EagerStringTensorStorage() {} diff --git a/test/shared_test/test_ortops_cuda.cc b/test/shared_test/test_ortops_cuda.cc index dc9ae35b2..cdc753eb5 100644 --- a/test/shared_test/test_ortops_cuda.cc +++ b/test/shared_test/test_ortops_cuda.cc @@ -6,8 +6,12 @@ #include "gtest/gtest.h" #include "ocos.h" #include "test_kernel.hpp" +#include "operators\math\cuda\negpos_def.h" +#include "kernel_context.h" #ifdef USE_CUDA +#include +#include TEST(CudaOp, test_fastgelu) { auto ort_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "Default"); @@ -35,4 +39,55 @@ TEST(CudaOp, test_fastgelu) { TestInference(*ort_env, model_path.c_str(), inputs, outputs); } +class MockCudaKernelContext : public Ort::Custom::CUDAKernelContext { +public: + MockCudaKernelContext() { cudaStreamCreate(&stream); } + ~MockCudaKernelContext() { cudaStreamDestroy(stream); } + void* AllocScratchBuffer(size_t size) override { return nullptr; } + void FreeScratchBuffer(void* p) override {} + void* AllocCudaScratchBuffer(size_t size) override { return nullptr; } + void FreeCudaScratchBuffer(void* p) override {} + void* GetCudaStream() const override { return static_cast(stream); } + void* GetCublasHandle() const override { return nullptr; } + int GetCudaDeviceId() const override { return 0; } + +private: + cudaStream_t stream; +}; + +class CudaAllocator : public Ort::Custom::IAllocator { +public: + void* Alloc(size_t size) override { + void* p = nullptr; + cudaMalloc((void**)&p, size); + return p; + } + void Free(void* p) override { cudaFree(p); } +}; + +TEST(CudaOp, test_eager_negpos) { + MockCudaKernelContext mock_cuda_kc; + std::vector input_data = {0.0f, 0.2f, -1.3f, 1.5f}; + std::unique_ptr cuda_alloc = std::make_unique(); + void* device_input = cuda_alloc->Alloc(sizeof(float) * input_data.size()); + cudaMemcpyAsync(device_input, input_data.data(), sizeof(float)*input_data.size(), cudaMemcpyHostToDevice, static_cast(mock_cuda_kc.GetCudaStream())); + + ortc::Tensor input(std::vector{2, 2}, device_input); + ortc::Tensor output1(cuda_alloc.get()); + ortc::Tensor output2(cuda_alloc.get()); + neg_pos_cuda(mock_cuda_kc, input, output1, output2); + + float* host_output1 = (float*)malloc(sizeof(float) * input_data.size()); + float* host_output2 = (float*)malloc(sizeof(float) * input_data.size()); + cudaMemcpyAsync(host_output1, output1.DataRaw(), sizeof(float)*input_data.size(), cudaMemcpyDeviceToHost, static_cast(mock_cuda_kc.GetCudaStream())); + cudaMemcpyAsync(host_output2, output2.DataRaw(), sizeof(float)*input_data.size(), cudaMemcpyDeviceToHost, static_cast(mock_cuda_kc.GetCudaStream())); + ASSERT_NEAR(host_output1[1], input_data[1], 0.01f); + ASSERT_NEAR(host_output2[2], input_data[2], 0.01f); + ASSERT_NEAR(host_output1[3], input_data[3], 0.01f); + + cuda_alloc->Free(device_input); + free(host_output1); + free(host_output2); +} + #endif \ No newline at end of file From d9be6e43935ce0729c4d284a360554d54ae2b7fb Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 3 Apr 2024 15:50:12 -0700 Subject: [PATCH 7/7] fix Linux build break --- test/shared_test/test_ortops_cuda.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/shared_test/test_ortops_cuda.cc b/test/shared_test/test_ortops_cuda.cc index cdc753eb5..c62977990 100644 --- a/test/shared_test/test_ortops_cuda.cc +++ b/test/shared_test/test_ortops_cuda.cc @@ -6,10 +6,10 @@ #include "gtest/gtest.h" #include "ocos.h" #include "test_kernel.hpp" -#include "operators\math\cuda\negpos_def.h" #include "kernel_context.h" #ifdef USE_CUDA +#include "operators/math/cuda/negpos_def.h" #include #include