diff --git a/source/adapters/level_zero/memory.cpp b/source/adapters/level_zero/memory.cpp index e21470eee2..271411c4b3 100644 --- a/source/adapters/level_zero/memory.cpp +++ b/source/adapters/level_zero/memory.cpp @@ -2081,10 +2081,11 @@ static ur_result_t ZeDeviceMemAllocHelper(void **ResultPtr, return UR_RESULT_SUCCESS; } -ur_result_t _ur_buffer::getZeHandle(char *&ZeHandle, access_mode_t AccessMode, - ur_device_handle_t Device, - const ur_event_handle_t *phWaitEvents, - uint32_t numWaitEvents) { +ur_result_t _ur_buffer::getBufferZeHandle(char *&ZeHandle, + access_mode_t AccessMode, + ur_device_handle_t Device, + const ur_event_handle_t *phWaitEvents, + uint32_t numWaitEvents) { // NOTE: There might be no valid allocation at all yet and we get // here from piEnqueueKernelLaunch that would be doing the buffer @@ -2393,7 +2394,7 @@ ur_result_t _ur_buffer::free() { // Buffer constructor _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, char *HostPtr, bool ImportedHostPtr = false) - : ur_mem_handle_t_(Context), Size(Size) { + : ur_mem_handle_t_(mem_type_t::buffer, Context), Size(Size) { // We treat integrated devices (physical memory shared with the CPU) // differently from discrete devices (those with distinct memories). @@ -2422,13 +2423,13 @@ _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, char *HostPtr, _ur_buffer::_ur_buffer(ur_context_handle_t Context, ur_device_handle_t Device, size_t Size) - : ur_mem_handle_t_(Context, Device), Size(Size) {} + : ur_mem_handle_t_(mem_type_t::buffer, Context, Device), Size(Size) {} // Interop-buffer constructor _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, ur_device_handle_t Device, char *ZeMemHandle, bool OwnZeMemHandle) - : ur_mem_handle_t_(Context, Device), Size(Size) { + : ur_mem_handle_t_(mem_type_t::buffer, Context, Device), Size(Size) { // Device == nullptr means host allocation Allocations[Device].ZeHandle = ZeMemHandle; @@ -2449,11 +2450,38 @@ _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, LastDeviceWithValidAllocation = Device; } -ur_result_t _ur_buffer::getZeHandlePtr(char **&ZeHandlePtr, - access_mode_t AccessMode, - ur_device_handle_t Device, - const ur_event_handle_t *phWaitEvents, - uint32_t numWaitEvents) { +ur_result_t ur_mem_handle_t_::getZeHandle(char *&ZeHandle, access_mode_t mode, + ur_device_handle_t Device, + const ur_event_handle_t *phWaitEvents, + uint32_t numWaitEvents) { + switch (mem_type) { + case ur_mem_handle_t_::image: + return reinterpret_cast<_ur_image *>(this)->getImageZeHandle( + ZeHandle, mode, Device, phWaitEvents, numWaitEvents); + case ur_mem_handle_t_::buffer: + return reinterpret_cast<_ur_buffer *>(this)->getBufferZeHandle( + ZeHandle, mode, Device, phWaitEvents, numWaitEvents); + } + abort(); +} + +ur_result_t ur_mem_handle_t_::getZeHandlePtr( + char **&ZeHandlePtr, access_mode_t mode, ur_device_handle_t Device, + const ur_event_handle_t *phWaitEvents, uint32_t numWaitEvents) { + switch (mem_type) { + case ur_mem_handle_t_::image: + return reinterpret_cast<_ur_image *>(this)->getImageZeHandlePtr( + ZeHandlePtr, mode, Device, phWaitEvents, numWaitEvents); + case ur_mem_handle_t_::buffer: + return reinterpret_cast<_ur_buffer *>(this)->getBufferZeHandlePtr( + ZeHandlePtr, mode, Device, phWaitEvents, numWaitEvents); + } + abort(); +} + +ur_result_t _ur_buffer::getBufferZeHandlePtr( + char **&ZeHandlePtr, access_mode_t AccessMode, ur_device_handle_t Device, + const ur_event_handle_t *phWaitEvents, uint32_t numWaitEvents) { char *ZeHandle; UR_CALL( getZeHandle(ZeHandle, AccessMode, Device, phWaitEvents, numWaitEvents)); diff --git a/source/adapters/level_zero/memory.hpp b/source/adapters/level_zero/memory.hpp index 8c3820f1a1..315b816055 100644 --- a/source/adapters/level_zero/memory.hpp +++ b/source/adapters/level_zero/memory.hpp @@ -70,34 +70,37 @@ struct ur_mem_handle_t_ : _ur_object { // Keeps device of this memory handle ur_device_handle_t UrDevice; + // Whether this is an image or buffer + enum mem_type_t { image, buffer }; + mem_type_t mem_type; + // Enumerates all possible types of accesses. enum access_mode_t { unknown, read_write, read_only, write_only }; // Interface of the _ur_mem object // Get the Level Zero handle of the current memory object - virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t, - ur_device_handle_t Device, - const ur_event_handle_t *phWaitEvents, - uint32_t numWaitEvents) = 0; + ur_result_t getZeHandle(char *&ZeHandle, access_mode_t, + ur_device_handle_t Device, + const ur_event_handle_t *phWaitEvents, + uint32_t numWaitEvents); // Get a pointer to the Level Zero handle of the current memory object - virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t, - ur_device_handle_t Device, - const ur_event_handle_t *phWaitEvents, - uint32_t numWaitEvents) = 0; + ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t, + ur_device_handle_t Device, + const ur_event_handle_t *phWaitEvents, + uint32_t numWaitEvents); // Method to get type of the derived object (image or buffer) - virtual bool isImage() const = 0; - - virtual ~ur_mem_handle_t_() = default; + bool isImage() const { return mem_type == mem_type_t::image; } protected: - ur_mem_handle_t_(ur_context_handle_t Context) - : UrContext{Context}, UrDevice{nullptr} {} + ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context) + : UrContext{Context}, UrDevice{nullptr}, mem_type(type) {} - ur_mem_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device) - : UrContext{Context}, UrDevice(Device) {} + ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context, + ur_device_handle_t Device) + : UrContext{Context}, UrDevice(Device), mem_type(type) {} }; struct _ur_buffer final : ur_mem_handle_t_ { @@ -110,8 +113,8 @@ struct _ur_buffer final : ur_mem_handle_t_ { // Sub-buffer constructor _ur_buffer(_ur_buffer *Parent, size_t Origin, size_t Size) - : ur_mem_handle_t_(Parent->UrContext), Size(Size), - SubBuffer{{Parent, Origin}} { + : ur_mem_handle_t_(mem_type_t::buffer, Parent->UrContext), + Size(Size), SubBuffer{{Parent, Origin}} { // Retain the Parent Buffer due to the Creation of the SubBuffer. Parent->RefCount.increment(); } @@ -127,16 +130,15 @@ struct _ur_buffer final : ur_mem_handle_t_ { // up-to-date and any data copies needed for that are performed under // the hood. // - virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t, - ur_device_handle_t Device, - const ur_event_handle_t *phWaitEvents, - uint32_t numWaitEvents) override; - virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t, - ur_device_handle_t Device, - const ur_event_handle_t *phWaitEvents, - uint32_t numWaitEvents) override; + ur_result_t getBufferZeHandle(char *&ZeHandle, access_mode_t, + ur_device_handle_t Device, + const ur_event_handle_t *phWaitEvents, + uint32_t numWaitEvents); + ur_result_t getBufferZeHandlePtr(char **&ZeHandlePtr, access_mode_t, + ur_device_handle_t Device, + const ur_event_handle_t *phWaitEvents, + uint32_t numWaitEvents); - bool isImage() const override { return false; } bool isSubBuffer() const { return SubBuffer != std::nullopt; } // Frees all allocations made for the buffer. @@ -206,35 +208,33 @@ struct _ur_buffer final : ur_mem_handle_t_ { struct _ur_image final : ur_mem_handle_t_ { // Image constructor _ur_image(ur_context_handle_t UrContext, ze_image_handle_t ZeImage) - : ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {} + : ur_mem_handle_t_(mem_type_t::image, UrContext), ZeImage{ZeImage} {} _ur_image(ur_context_handle_t UrContext, ze_image_handle_t ZeImage, bool OwnZeMemHandle) - : ur_mem_handle_t_(UrContext), ZeImage{ZeImage} { + : ur_mem_handle_t_(mem_type_t::image, UrContext), ZeImage{ZeImage} { OwnNativeHandle = OwnZeMemHandle; } - virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t, - ur_device_handle_t, - const ur_event_handle_t *phWaitEvents, - uint32_t numWaitEvents) override { + ur_result_t getImageZeHandle(char *&ZeHandle, access_mode_t, + ur_device_handle_t, + const ur_event_handle_t *phWaitEvents, + uint32_t numWaitEvents) { std::ignore = phWaitEvents; std::ignore = numWaitEvents; ZeHandle = reinterpret_cast(ZeImage); return UR_RESULT_SUCCESS; } - virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t, - ur_device_handle_t, - const ur_event_handle_t *phWaitEvents, - uint32_t numWaitEvents) override { + ur_result_t getImageZeHandlePtr(char **&ZeHandlePtr, access_mode_t, + ur_device_handle_t, + const ur_event_handle_t *phWaitEvents, + uint32_t numWaitEvents) { std::ignore = phWaitEvents; std::ignore = numWaitEvents; ZeHandlePtr = reinterpret_cast(&ZeImage); return UR_RESULT_SUCCESS; } - bool isImage() const override { return true; } - // Keep the descriptor of the image ZeStruct ZeImageDesc;