From 899dba7609ad1d781cab71baa31ac976f7c3ea31 Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Mon, 15 Jul 2024 22:03:45 +0100 Subject: [PATCH 1/8] Update to latest nightly ov release --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9196b56..66a1cc9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,9 +38,9 @@ function(get_linux_lsb_release_information) endfunction() set(OV_VERSION_SHORT "nightly") -set(OV_VERSION "2024.3.0.dev20240524_x86_64") +set(OV_VERSION "2024.4.0.dev20240715_x86_64") set(OV_STORAGE_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages") -set(OV_NIGHTLY_COMMIT "2024.3.0-15502-66093834e38") +set(OV_NIGHTLY_COMMIT "2024.4.0-16039-620d2a20c8c") if (WIN32) if(NOT OV_LIBRARY_URL) From b5c03c35050813265a406257649a14077c22e194 Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Tue, 16 Jul 2024 06:35:11 +0100 Subject: [PATCH 2/8] Add remote tensor example --- examples/cpp/CMakeLists.txt | 3 +- examples/cpp/main.cpp | 33 ++++++++++--------- .../intel_npu_acceleration_library/common.h | 1 + .../inference.h | 28 ++++++++++++++++ 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 1db5687..40c28dc 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -19,8 +19,7 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) FetchContent_Declare( intel_npu_acceleration_library - GIT_REPOSITORY "https://github.com/intel/intel-npu-acceleration-library" - GIT_TAG "main" + SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../" ) FetchContent_MakeAvailable(intel_npu_acceleration_library) diff --git a/examples/cpp/main.cpp b/examples/cpp/main.cpp index 2946105..6fbbd25 100644 --- a/examples/cpp/main.cpp +++ b/examples/cpp/main.cpp @@ -9,11 +9,13 @@ using namespace intel_npu_acceleration_library; #include int main() { - const size_t batch = 128, inC = 256, outC = 512, N = 100000; + const size_t batch = 128, inC = 256, outC = 512, N = 1; std::cout << "Create a ModelFactory" << std::endl; auto factory = std::make_shared("NPU"); + auto context = factory->get_context(); + // create parameter auto input = factory->parameter({batch, inC}, ov::element::f16); auto weights = factory->parameter({outC, inC}, ov::element::f16); @@ -31,16 +33,18 @@ int main() { std::cout << "Saving model to matmul.xml" << std::endl; factory->saveModel("matmul.xml"); - // Here you can create float16 buffers and run inference by using - half_ptr input_buffer = new uint16_t[batch * inC]; - half_ptr weights_buffer = new uint16_t[outC * inC]; - half_ptr bias_buffer = new uint16_t[outC]; - half_ptr output_buffer = new uint16_t[batch * outC]; + std::cout << "Creating a remote tensor" << std::endl; + auto input_buffer = context.create_l0_host_tensor(ov::element::f16, {batch, inC}, ov::intel_npu::TensorType::INPUT); + auto weights_buffer = + context.create_l0_host_tensor(ov::element::f16, {outC, inC}, ov::intel_npu::TensorType::INPUT); + auto bias_buffer = context.create_l0_host_tensor(ov::element::f16, {1, outC}, ov::intel_npu::TensorType::INPUT); + auto output_buffer = + context.create_l0_host_tensor(ov::element::f16, {batch, outC}, ov::intel_npu::TensorType::OUTPUT); - memset(input_buffer, 0, batch * inC * sizeof(uint16_t)); - memset(weights_buffer, 0, outC * inC * sizeof(uint16_t)); - memset(output_buffer, 0, batch * outC * sizeof(uint16_t)); - memset(bias_buffer, 0, outC * sizeof(uint16_t)); + std::memset(input_buffer.get(), 0, input_buffer.get_byte_size()); + std::memset(weights_buffer.get(), 0, weights_buffer.get_byte_size()); + std::memset(bias_buffer.get(), 0, bias_buffer.get_byte_size()); + std::memset(output_buffer.get(), 0, output_buffer.get_byte_size()); factory->setInputTensor(input_buffer, 0); factory->setInputTensor(weights_buffer, 1); @@ -49,13 +53,10 @@ int main() { // Run inference std::cout << "Run inference on " << N << " workloads" << std::endl; - for (auto idx = 0; idx < N; idx++) + for (auto idx = 0; idx < N; idx++) { factory->run(); - std::cout << "Inference done" << std::endl; + } - delete[] input_buffer; - delete[] weights_buffer; - delete[] bias_buffer; - delete[] output_buffer; + std::cout << "Inference done" << std::endl; return 0; } \ No newline at end of file diff --git a/include/intel_npu_acceleration_library/common.h b/include/intel_npu_acceleration_library/common.h index cba90b8..298fd60 100644 --- a/include/intel_npu_acceleration_library/common.h +++ b/include/intel_npu_acceleration_library/common.h @@ -13,6 +13,7 @@ #include "openvino/opsets/opset7.hpp" #include "openvino/opsets/opset8.hpp" #include "openvino/opsets/opset9.hpp" +#include "openvino/runtime/intel_npu/level_zero/level_zero.hpp" #include "openvino/runtime/intel_npu/properties.hpp" #if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) diff --git a/include/intel_npu_acceleration_library/inference.h b/include/intel_npu_acceleration_library/inference.h index 9232165..df2e27a 100644 --- a/include/intel_npu_acceleration_library/inference.h +++ b/include/intel_npu_acceleration_library/inference.h @@ -126,6 +126,14 @@ class OVInferenceModel { wt_thread.join(); } + /** + * @brief Get the remote context + * + */ + auto get_context() { + return core.get_default_context(device).as(); + } + /** * @brief Save the model to a local path * @@ -201,6 +209,16 @@ class OVInferenceModel { infer_request.set_input_tensor(idx, X); } + /** + * @brief Set the input activations + * + * @param _X reference to a zero buffer tensor + * @param idx input tensor index + */ + void setInputTensor(ov::intel_npu::level_zero::ZeroBufferTensor& _X, size_t idx) { + infer_request.set_input_tensor(idx, _X); + } + /** * @brief Set the output activations * @@ -213,6 +231,16 @@ class OVInferenceModel { infer_request.set_output_tensor(idx, X); } + /** + * @brief Set the output activations + * + * @param _X reference to a zero buffer tensor + * @param idx output tensor index + */ + void setOutputTensor(ov::intel_npu::level_zero::ZeroBufferTensor& _X, size_t idx) { + infer_request.set_output_tensor(idx, _X); + } + /** * @brief Set the input and output activations * From 193124677f51ae8647b8aa87fad1968e817ece80 Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Tue, 16 Jul 2024 07:16:34 +0100 Subject: [PATCH 3/8] Fix multiple inferences --- examples/cpp/main.cpp | 6 +++--- include/intel_npu_acceleration_library/inference.h | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/cpp/main.cpp b/examples/cpp/main.cpp index 6fbbd25..a6f5636 100644 --- a/examples/cpp/main.cpp +++ b/examples/cpp/main.cpp @@ -9,7 +9,7 @@ using namespace intel_npu_acceleration_library; #include int main() { - const size_t batch = 128, inC = 256, outC = 512, N = 1; + const size_t batch = 128, inC = 256, outC = 512, N = 10000; std::cout << "Create a ModelFactory" << std::endl; auto factory = std::make_shared("NPU"); @@ -30,8 +30,8 @@ int main() { factory->compile(); // Save OV model - std::cout << "Saving model to matmul.xml" << std::endl; - factory->saveModel("matmul.xml"); + // std::cout << "Saving model to matmul.xml" << std::endl; + // factory->saveModel("matmul.xml"); std::cout << "Creating a remote tensor" << std::endl; auto input_buffer = context.create_l0_host_tensor(ov::element::f16, {batch, inC}, ov::intel_npu::TensorType::INPUT); diff --git a/include/intel_npu_acceleration_library/inference.h b/include/intel_npu_acceleration_library/inference.h index df2e27a..e798f77 100644 --- a/include/intel_npu_acceleration_library/inference.h +++ b/include/intel_npu_acceleration_library/inference.h @@ -95,8 +95,6 @@ class OVInferenceModel { compiled_model = core.compile_model(model, device); // Create inference request infer_request = compiled_model.create_infer_request(); - // First inference - infer_request.infer(); } /** From f2ff0cda0068437af21b6b9231a010ad6f6c9555 Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Tue, 16 Jul 2024 07:33:01 +0100 Subject: [PATCH 4/8] Create factory API --- examples/cpp/main.cpp | 12 +++---- .../inference.h | 36 +++++++++++++++++++ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/examples/cpp/main.cpp b/examples/cpp/main.cpp index a6f5636..ab13f62 100644 --- a/examples/cpp/main.cpp +++ b/examples/cpp/main.cpp @@ -14,8 +14,6 @@ int main() { std::cout << "Create a ModelFactory" << std::endl; auto factory = std::make_shared("NPU"); - auto context = factory->get_context(); - // create parameter auto input = factory->parameter({batch, inC}, ov::element::f16); auto weights = factory->parameter({outC, inC}, ov::element::f16); @@ -34,12 +32,10 @@ int main() { // factory->saveModel("matmul.xml"); std::cout << "Creating a remote tensor" << std::endl; - auto input_buffer = context.create_l0_host_tensor(ov::element::f16, {batch, inC}, ov::intel_npu::TensorType::INPUT); - auto weights_buffer = - context.create_l0_host_tensor(ov::element::f16, {outC, inC}, ov::intel_npu::TensorType::INPUT); - auto bias_buffer = context.create_l0_host_tensor(ov::element::f16, {1, outC}, ov::intel_npu::TensorType::INPUT); - auto output_buffer = - context.create_l0_host_tensor(ov::element::f16, {batch, outC}, ov::intel_npu::TensorType::OUTPUT); + auto input_buffer = factory->createRemoteInputTensor(0); + auto weights_buffer = factory->createRemoteInputTensor(1); + auto bias_buffer = factory->createRemoteInputTensor(2); + auto output_buffer = factory->createRemoteOutputTensor(0); std::memset(input_buffer.get(), 0, input_buffer.get_byte_size()); std::memset(weights_buffer.get(), 0, weights_buffer.get_byte_size()); diff --git a/include/intel_npu_acceleration_library/inference.h b/include/intel_npu_acceleration_library/inference.h index e798f77..3f1d1f7 100644 --- a/include/intel_npu_acceleration_library/inference.h +++ b/include/intel_npu_acceleration_library/inference.h @@ -173,6 +173,42 @@ class OVInferenceModel { } } + /** + * @brief Create a Remote Tensor object + * + * @param type element type + * @param shape element shape + * @param tensor_type element tensor type: INPUT, OUTPUT, BIND + * @return auto + */ + auto createRemoteTensor(const ov::element::Type type, const ov::Shape& shape, + const ov::intel_npu::TensorType tensor_type) { + ov::intel_npu::level_zero::ZeroContext context = get_context(); + return context.create_l0_host_tensor(type, shape, tensor_type); + } + + /** + * @brief Create a Remote Tensor object + * + * @param idx index of the input tensor + * @return auto + */ + auto createRemoteInputTensor(size_t idx) { + auto tensor = infer_request.get_input_tensor(idx); + return createRemoteTensor(tensor.get_element_type(), tensor.get_shape(), ov::intel_npu::TensorType::INPUT); + } + + /** + * @brief Create a Remote Tensor object + * + * @param idx index of the output tensor + * @return auto + */ + auto createRemoteOutputTensor(size_t idx) { + auto tensor = infer_request.get_output_tensor(idx); + return createRemoteTensor(tensor.get_element_type(), tensor.get_shape(), ov::intel_npu::TensorType::OUTPUT); + } + /** * @brief Get model input tensor * From 9760d57f8147fac700ce36091a225a04fd2f7d04 Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Tue, 16 Jul 2024 13:14:02 +0100 Subject: [PATCH 5/8] Create python bindings --- .../inference.h | 3 +- .../intel_npu_acceleration_library/tensor.h | 48 ++++++++++++ .../backend/bindings.py | 9 +++ .../backend/factory.py | 28 +------ .../backend/tensor.py | 76 +++++++++++++++++++ intel_npu_acceleration_library/dtypes.py | 37 +++++++++ src/bindings.cpp | 19 +++++ 7 files changed, 193 insertions(+), 27 deletions(-) create mode 100644 include/intel_npu_acceleration_library/tensor.h diff --git a/include/intel_npu_acceleration_library/inference.h b/include/intel_npu_acceleration_library/inference.h index 3f1d1f7..dc97794 100644 --- a/include/intel_npu_acceleration_library/inference.h +++ b/include/intel_npu_acceleration_library/inference.h @@ -19,6 +19,7 @@ #include #include "intel_npu_acceleration_library/common.h" #include "intel_npu_acceleration_library/parameters.h" +#include "intel_npu_acceleration_library/tensor.h" namespace intel_npu_acceleration_library { @@ -81,7 +82,7 @@ class OVInferenceModel { // set letency hint core.set_property(ov::cache_dir("cache")); core.set_property(device, ov::hint::performance_mode(ov::hint::PerformanceMode::THROUGHPUT)); - // core.set_property("NPU", ov::log::level(ov::log::Level::DEBUG)); + core.set_property("NPU", ov::log::level(ov::log::Level::DEBUG)); if (device == "NPU") { core.set_property(device, intel_npu_acceleration_library::npu_compiler_type("DRIVER")); if (profile) { diff --git a/include/intel_npu_acceleration_library/tensor.h b/include/intel_npu_acceleration_library/tensor.h new file mode 100644 index 0000000..1f21ef1 --- /dev/null +++ b/include/intel_npu_acceleration_library/tensor.h @@ -0,0 +1,48 @@ +// +// Copyright © 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 +// + +#include "intel_npu_acceleration_library/common.h" + +namespace intel_npu_acceleration_library { + +/** + * @brief Class representing a NPU tensor + * + */ +class Tensor { +private: + ov::intel_npu::level_zero::ZeroBufferTensor _remote_tensor; + void* data_ptr; + +public: + /** + * @brief Construct a new Tensor object + * + * @param dtype tensor datatype + * @param shape tensor shape + * @param data pointer to tensor data + * @param tensor_type tensor type. Choices between INPUT, OUTPUT, BINDED + * @param device target device for the tensor + */ + Tensor(ov::element::Type_t dtype, ov::Shape shape, void* data, + ov::intel_npu::TensorType tensor_type = ov::intel_npu::TensorType::INPUT, std::string device = "NPU") { + ov::Core core; + auto context = core.get_default_context(device).as(); + _remote_tensor = context.create_l0_host_tensor(dtype, shape, tensor_type); + data_ptr = _remote_tensor.get(); + std::memcpy(data_ptr, data, _remote_tensor.get_byte_size()); + } + + /** + * @brief Get the data pointer + * + * @return void* + */ + void* data() { + return data_ptr; + } +}; + +} // namespace intel_npu_acceleration_library \ No newline at end of file diff --git a/intel_npu_acceleration_library/backend/bindings.py b/intel_npu_acceleration_library/backend/bindings.py index 9e17fa9..587700b 100644 --- a/intel_npu_acceleration_library/backend/bindings.py +++ b/intel_npu_acceleration_library/backend/bindings.py @@ -88,6 +88,15 @@ def init_common(lib: ctypes.CDLL): lib.compressToI4.argtypes = [c_i8_array, c_u8_array, ctypes.c_int] + # Remote tensors + lib.to_npu.argtypes = [ctypes.c_int, c_u32_array, ctypes.c_char_p, ctypes.c_void_p] + lib.to_npu.restype = handler + + lib.remote_tensor_data.argtypes = [handler] + lib.remote_tensor_data.restype = ctypes.c_void_p + + lib.del_remote_tensor.argtypes = [handler] + def init_network_factory(lib: ctypes.CDLL): """Initialize Netowrk factory bindings. diff --git a/intel_npu_acceleration_library/backend/factory.py b/intel_npu_acceleration_library/backend/factory.py index 48108df..db0e1d1 100644 --- a/intel_npu_acceleration_library/backend/factory.py +++ b/intel_npu_acceleration_library/backend/factory.py @@ -7,7 +7,7 @@ from intel_npu_acceleration_library.backend.ops import get_supported_ops from intel_npu_acceleration_library.backend.bindings import lib as backend_lib from intel_npu_acceleration_library.backend.tensor import Tensor -from intel_npu_acceleration_library.dtypes import int4, bfloat16 +from intel_npu_acceleration_library.dtypes import get_backend_dtype from typing import Optional, Tuple, Any, Union, Sequence, TypeVar, Callable, cast, List from functools import partial import numpy.typing as npt @@ -115,34 +115,10 @@ def get_backend_dtype(self, dtype) -> ctypes.c_char_p: Args: dtype: numpy dtype - Raises: - RuntimeError: Unsupported datatype - Returns: ctypes.c_char_p: string representation of the dtype """ - if dtype in [np.int8, torch.int8]: - str_dtype = "int8" - elif dtype == np.uint8 or dtype == int4: - # u8 represents packed i4 dtypes - str_dtype = "int4" - elif dtype in [np.int16, torch.int16]: - str_dtype = "int16" - elif dtype in [np.int32, torch.int32]: - str_dtype = "int32" - elif dtype in [np.int64, torch.int64]: - str_dtype = "int64" - elif dtype in [np.float16, torch.float16]: - str_dtype = "float16" - elif dtype in [np.float32, torch.float32]: - str_dtype = "float32" - elif dtype in [np.float64, torch.float64]: - str_dtype = "float64" - elif dtype in [bfloat16, torch.bfloat16]: - str_dtype = "bfloat16" - else: - raise RuntimeError(f"DType is not supported {dtype}") - return ctypes.c_char_p(str_dtype.encode()) + return get_backend_dtype(dtype) @return_tensor def parameter( diff --git a/intel_npu_acceleration_library/backend/tensor.py b/intel_npu_acceleration_library/backend/tensor.py index 2236eda..9ed5f18 100644 --- a/intel_npu_acceleration_library/backend/tensor.py +++ b/intel_npu_acceleration_library/backend/tensor.py @@ -16,14 +16,90 @@ int32, int64, NPUDtype, + get_backend_dtype, ) from dataclasses import dataclass import functools +from math import prod import numpy as np import ctypes import torch +class RemoteTensor(torch.Tensor): + """ + Represent a remote tensor object. + + Attrs: + _remote_tensor (ctypes._Pointer): The pointer to the underlying remote tensor. + + Methods: + from_torch(x: torch.Tensor): Create a remote tensor from a torch tensor. + """ + + _remote_tensor = None + + @staticmethod + def __new__(cls, x: Any, remote_tensor: ctypes._Pointer, *args: Any, **kwargs: Any): + """ + Create a new remote tensor object. + + Args: + x (Any): tensor input + remote_tensor (ctypes._Pointer): remote tensor pointer + args (Any): additional arguments + kwargs (Any): additional keyword arguments + + Returns: + RemoteTensor: a RemoteTensor object + """ + return super().__new__(cls, x, *args, **kwargs) + + def __init__(self, x: Any, remote_tensor: ctypes._Pointer): + """ + Initialize the remote tensor object. + + Args: + x (Any): tensor input + remote_tensor (ctypes._Pointer): remote tensor pointer + """ + self._remote_tensor = remote_tensor + + # def __del__(self): + # if self._remote_tensor and backend_lib: + # backend_lib.del_remote_tensor(self._remote_tensor) + + @staticmethod + def from_torch(x: torch.Tensor) -> "RemoteTensor": + """ + Create a remote tensor from a torch tensor. + + Args: + x (torch.Tensor): The torch tensor. + + Returns: + RemoteTensor: The remote tensor. + """ + shape_arr = np.array(x.shape, dtype=np.uint32) + dtype_str = get_backend_dtype(x.dtype) + p = ctypes.cast(x.data_ptr(), ctypes.c_void_p) + + rt = backend_lib.to_npu(shape_arr.size, shape_arr, dtype_str, p) + + pointer = ctypes.cast( + backend_lib.remote_tensor_data(rt), + ctypes.POINTER(ctypes.c_uint8), + ) + + arr = (pointer._type_ * prod(x.shape) * x.element_size()).from_address( + ctypes.addressof(pointer.contents) + ) + + pt_tensor = torch.frombuffer(arr, dtype=x.dtype).view(*x.shape) + + return RemoteTensor(pt_tensor, rt) + + @dataclass class Tensor: """ diff --git a/intel_npu_acceleration_library/dtypes.py b/intel_npu_acceleration_library/dtypes.py index 8754e2f..55be13d 100644 --- a/intel_npu_acceleration_library/dtypes.py +++ b/intel_npu_acceleration_library/dtypes.py @@ -7,6 +7,7 @@ from typing import Union import numpy as np import torch +import ctypes @dataclass(frozen=True) @@ -81,6 +82,42 @@ def __repr__(self) -> str: return self.name +def get_backend_dtype(dtype) -> ctypes.c_char_p: + """Get the string representation of the dtype. + + Args: + dtype: numpy dtype + + Raises: + RuntimeError: Unsupported datatype + + Returns: + ctypes.c_char_p: string representation of the dtype + """ + if dtype in [np.int8, torch.int8]: + str_dtype = "int8" + elif dtype == np.uint8 or dtype == int4: + # u8 represents packed i4 dtypes + str_dtype = "int4" + elif dtype in [np.int16, torch.int16]: + str_dtype = "int16" + elif dtype in [np.int32, torch.int32]: + str_dtype = "int32" + elif dtype in [np.int64, torch.int64]: + str_dtype = "int64" + elif dtype in [np.float16, torch.float16]: + str_dtype = "float16" + elif dtype in [np.float32, torch.float32]: + str_dtype = "float32" + elif dtype in [np.float64, torch.float64]: + str_dtype = "float64" + elif dtype in [bfloat16, torch.bfloat16]: + str_dtype = "bfloat16" + else: + raise RuntimeError(f"DType is not supported {dtype}") + return ctypes.c_char_p(str_dtype.encode()) + + float16 = NPUDtype( "fp16", 16, diff --git a/src/bindings.cpp b/src/bindings.cpp index a706b82..707f4c3 100644 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -17,6 +17,25 @@ intel_npu_acceleration_library_DLL_API uint32_t getNPUDriverVersion() { return intel_npu_acceleration_library::driver_version(core); } +// ######################## Remote Tensors ######################## + +intel_npu_acceleration_library_DLL_API intel_npu_acceleration_library::Tensor* to_npu(size_t size, + unsigned int* shape_data, + char* dtype, void* data) { + ov::element::Type_t ov_dtype = intel_npu_acceleration_library::dtype_from_string(std::string(dtype)); + std::vector shape(shape_data, shape_data + size); + + return new intel_npu_acceleration_library::Tensor(ov_dtype, shape, data); +} + +intel_npu_acceleration_library_DLL_API void* remote_tensor_data(intel_npu_acceleration_library::Tensor* rt) { + return rt->data(); +} + +intel_npu_acceleration_library_DLL_API void del_remote_tensor(intel_npu_acceleration_library::Tensor* rt) { + delete rt; +} + // ######################## Compression ######################## intel_npu_acceleration_library_DLL_API void compressToI4(const int8_t* src, uint8_t* dst, size_t size) { From f06e80608e396cc5f025f403d84de29fc691349f Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Tue, 16 Jul 2024 13:41:07 +0100 Subject: [PATCH 6/8] Remove useless code, implement to for tensors --- intel_npu_acceleration_library/device.py | 4 ++-- intel_npu_acceleration_library/nn/module.py | 21 --------------------- 2 files changed, 2 insertions(+), 23 deletions(-) diff --git a/intel_npu_acceleration_library/device.py b/intel_npu_acceleration_library/device.py index 988c315..28e8484 100644 --- a/intel_npu_acceleration_library/device.py +++ b/intel_npu_acceleration_library/device.py @@ -4,6 +4,7 @@ # from intel_npu_acceleration_library.nn.module import convert_to_npu_module +from intel_npu_acceleration_library.backend.tensor import RemoteTensor from torch.overrides import TorchFunctionMode from functools import lru_cache from typing import Any, MutableMapping @@ -165,8 +166,7 @@ def to(super_fn: Any, self: Any, *args: Any, **kwargs: Any): """ npu_device, args, kwargs = parse_to_arguments(*args, **kwargs) if npu_device: - # None for now, once the remote tensor feature lands, it can be converted to a remote tensor - pass + return super_fn(RemoteTensor.from_torch(self), *args, **kwargs) return super_fn(self, *args, **kwargs) diff --git a/intel_npu_acceleration_library/nn/module.py b/intel_npu_acceleration_library/nn/module.py index ef23c8e..9c5bf6d 100644 --- a/intel_npu_acceleration_library/nn/module.py +++ b/intel_npu_acceleration_library/nn/module.py @@ -67,25 +67,6 @@ def compute_input_signature( return "_".join(signature) -def patch_parameters(module: torch.nn.Module, model: NNFactory, recurse: bool = False): - """Patch the parameters of a PyTorch module with constants. - - Args: - module (torch.nn.Module): The PyTorch module. - model (NNFactory): The NNFactory instance. - recurse (bool, optional): Recurse over all submodules. Defaults to False. - """ - elements = list(module.named_parameters(recurse=recurse)) - for name, param in elements: - del module._parameters[name] - setattr(module, name, model.constant(param.data.detach().numpy())) - - buffers = list(module.named_buffers(recurse=recurse)) - for name, param in buffers: - del module._buffers[name] - setattr(module, name, model.constant(param.data.detach().numpy())) - - def patch_modules(module: torch.nn.Module, model: NNFactory): """Patch the modules of a PyTorch module with constants. @@ -97,7 +78,6 @@ def patch_modules(module: torch.nn.Module, model: NNFactory): for _, module in modules: if isinstance(module, Module): module.npu_top_level_module = False - # patch_parameters(module, model) patch_modules(module, model) @@ -224,7 +204,6 @@ def create_kwargs_from_list( npu_kwargs = create_kwargs_from_list(kwargs) patch_modules(self, model) - # patch_parameters(self, model) _ = self.forward(*npu_args, **npu_kwargs) model.compile() From 16c69bc8c8a6d68fec093c38e5ec8db1d8a5a050 Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Tue, 16 Jul 2024 14:39:37 +0100 Subject: [PATCH 7/8] Common OV core --- include/intel_npu_acceleration_library/common.h | 6 ++++++ include/intel_npu_acceleration_library/inference.h | 8 +------- include/intel_npu_acceleration_library/tensor.h | 14 +++++++++----- src/bindings.cpp | 6 ++---- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/include/intel_npu_acceleration_library/common.h b/include/intel_npu_acceleration_library/common.h index 298fd60..22ce837 100644 --- a/include/intel_npu_acceleration_library/common.h +++ b/include/intel_npu_acceleration_library/common.h @@ -24,6 +24,12 @@ namespace intel_npu_acceleration_library { +/** + * @brief OpenVINO core object + * + */ +ov::Core core; + static constexpr ov::Property npu_compiler_type{"NPU_COMPILER_TYPE"}; static constexpr ov::Property npu_parameters{"NPU_COMPILATION_MODE_PARAMS"}; diff --git a/include/intel_npu_acceleration_library/inference.h b/include/intel_npu_acceleration_library/inference.h index dc97794..15331b7 100644 --- a/include/intel_npu_acceleration_library/inference.h +++ b/include/intel_npu_acceleration_library/inference.h @@ -23,12 +23,6 @@ namespace intel_npu_acceleration_library { -/** - * @brief OpenVINO core object - * - */ -static ov::Core core; - /** * @brief Create a remote tensor * @@ -82,7 +76,7 @@ class OVInferenceModel { // set letency hint core.set_property(ov::cache_dir("cache")); core.set_property(device, ov::hint::performance_mode(ov::hint::PerformanceMode::THROUGHPUT)); - core.set_property("NPU", ov::log::level(ov::log::Level::DEBUG)); + // core.set_property("NPU", ov::log::level(ov::log::Level::DEBUG)); if (device == "NPU") { core.set_property(device, intel_npu_acceleration_library::npu_compiler_type("DRIVER")); if (profile) { diff --git a/include/intel_npu_acceleration_library/tensor.h b/include/intel_npu_acceleration_library/tensor.h index 1f21ef1..70f5523 100644 --- a/include/intel_npu_acceleration_library/tensor.h +++ b/include/intel_npu_acceleration_library/tensor.h @@ -28,11 +28,15 @@ class Tensor { */ Tensor(ov::element::Type_t dtype, ov::Shape shape, void* data, ov::intel_npu::TensorType tensor_type = ov::intel_npu::TensorType::INPUT, std::string device = "NPU") { - ov::Core core; - auto context = core.get_default_context(device).as(); - _remote_tensor = context.create_l0_host_tensor(dtype, shape, tensor_type); - data_ptr = _remote_tensor.get(); - std::memcpy(data_ptr, data, _remote_tensor.get_byte_size()); + if (!_isNPUAvailable(core)) { + // Cannot create NPU remote tensor... use the same pointer as before + data_ptr = data; + } else { + auto context = core.get_default_context(device).as(); + _remote_tensor = context.create_l0_host_tensor(dtype, shape, tensor_type); + data_ptr = _remote_tensor.get(); + std::memcpy(data_ptr, data, _remote_tensor.get_byte_size()); + } } /** diff --git a/src/bindings.cpp b/src/bindings.cpp index 707f4c3..c2d49c6 100644 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -8,13 +8,11 @@ extern "C" { intel_npu_acceleration_library_DLL_API bool isNPUAvailable() { - ov::Core core; - return intel_npu_acceleration_library::_isNPUAvailable(core); + return intel_npu_acceleration_library::_isNPUAvailable(intel_npu_acceleration_library::core); } intel_npu_acceleration_library_DLL_API uint32_t getNPUDriverVersion() { - ov::Core core; - return intel_npu_acceleration_library::driver_version(core); + return intel_npu_acceleration_library::driver_version(intel_npu_acceleration_library::core); } // ######################## Remote Tensors ######################## From 98dc1a26fc121bee5b8aad6381506919345f5ad6 Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Tue, 16 Jul 2024 15:16:59 +0100 Subject: [PATCH 8/8] fix dtype --- intel_npu_acceleration_library/dtypes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intel_npu_acceleration_library/dtypes.py b/intel_npu_acceleration_library/dtypes.py index 55be13d..e996809 100644 --- a/intel_npu_acceleration_library/dtypes.py +++ b/intel_npu_acceleration_library/dtypes.py @@ -96,7 +96,7 @@ def get_backend_dtype(dtype) -> ctypes.c_char_p: """ if dtype in [np.int8, torch.int8]: str_dtype = "int8" - elif dtype == np.uint8 or dtype == int4: + elif dtype in [np.uint8, int4, torch.uint8]: # u8 represents packed i4 dtypes str_dtype = "int4" elif dtype in [np.int16, torch.int16]: