diff --git a/scripts/templates/queue_api.cpp.mako b/scripts/templates/queue_api.cpp.mako index 14def952ac..6d7d2e6f38 100644 --- a/scripts/templates/queue_api.cpp.mako +++ b/scripts/templates/queue_api.cpp.mako @@ -23,9 +23,10 @@ from templates import helper as th // Do not edit. This file is auto generated from a template: scripts/templates/queue_api.cpp.mako #include "queue_api.hpp" +#include "queue_handle.hpp" #include "ur_util.hpp" -ur_queue_handle_t_::~ur_queue_handle_t_() {} +ur_queue_t_::~ur_queue_t_() {} ## FUNCTION ################################################################### namespace ${x}::level_zero { @@ -37,7 +38,7 @@ ${th.make_func_name(n, tags, obj)}( %endfor ) try { - return ${obj['params'][0]['name']}->${th.transform_queue_related_function_name(n, tags, obj, format=["name"])}; + return ${obj['params'][0]['name']}->get().${th.transform_queue_related_function_name(n, tags, obj, format=["name"])}; } catch(...) { return exceptionToResult(std::current_exception()); } diff --git a/scripts/templates/queue_api.hpp.mako b/scripts/templates/queue_api.hpp.mako index b39226e798..4e9dd28913 100644 --- a/scripts/templates/queue_api.hpp.mako +++ b/scripts/templates/queue_api.hpp.mako @@ -27,8 +27,8 @@ from templates import helper as th #include #include -struct ur_queue_handle_t_ { - virtual ~ur_queue_handle_t_(); +struct ur_queue_t_ { + virtual ~ur_queue_t_(); virtual void deferEventFree(ur_event_handle_t hEvent) = 0; diff --git a/source/adapters/level_zero/v2/command_buffer.cpp b/source/adapters/level_zero/v2/command_buffer.cpp index eace40918b..6bf3fdf040 100644 --- a/source/adapters/level_zero/v2/command_buffer.cpp +++ b/source/adapters/level_zero/v2/command_buffer.cpp @@ -12,6 +12,7 @@ #include "../helpers/kernel_helpers.hpp" #include "../ur_interface_loader.hpp" #include "logger/ur_logger.hpp" +#include "queue_handle.hpp" namespace { @@ -141,7 +142,7 @@ ur_result_t urCommandBufferEnqueueExp( ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueCommandBuffer( + return hQueue->get().enqueueCommandBuffer( hCommandBuffer->commandListManager.getZeCommandList(), phEvent, numEventsInWaitList, phEventWaitList); } catch (...) { diff --git a/source/adapters/level_zero/v2/command_list_manager.cpp b/source/adapters/level_zero/v2/command_list_manager.cpp index 3592a1227e..5371fe076d 100644 --- a/source/adapters/level_zero/v2/command_list_manager.cpp +++ b/source/adapters/level_zero/v2/command_list_manager.cpp @@ -17,7 +17,7 @@ ur_command_list_manager::ur_command_list_manager( ur_context_handle_t context, ur_device_handle_t device, v2::raii::command_list_unique_handle &&commandList, v2::event_flags_t flags, - ur_queue_handle_t queue) + ur_queue_t_ *queue) : context(context), device(device), eventPool(context->eventPoolCache.borrow(device->Id.value(), flags)), zeCommandList(std::move(commandList)), queue(queue) { diff --git a/source/adapters/level_zero/v2/command_list_manager.hpp b/source/adapters/level_zero/v2/command_list_manager.hpp index b24433044a..1967d706d2 100644 --- a/source/adapters/level_zero/v2/command_list_manager.hpp +++ b/source/adapters/level_zero/v2/command_list_manager.hpp @@ -21,7 +21,7 @@ struct ur_command_list_manager : public _ur_object { ur_device_handle_t device, v2::raii::command_list_unique_handle &&commandList, v2::event_flags_t flags = v2::EVENT_FLAGS_COUNTER, - ur_queue_handle_t_ *queue = nullptr); + ur_queue_t_ *queue = nullptr); ~ur_command_list_manager(); ur_result_t appendKernelLaunch(ur_kernel_handle_t hKernel, uint32_t workDim, @@ -47,6 +47,6 @@ struct ur_command_list_manager : public _ur_object { ur_device_handle_t device; v2::raii::cache_borrowed_event_pool eventPool; v2::raii::command_list_unique_handle zeCommandList; - ur_queue_handle_t queue; + ur_queue_t_ *queue; std::vector waitList; }; diff --git a/source/adapters/level_zero/v2/event.cpp b/source/adapters/level_zero/v2/event.cpp index d2332ddafb..82f08dfd6b 100644 --- a/source/adapters/level_zero/v2/event.cpp +++ b/source/adapters/level_zero/v2/event.cpp @@ -15,6 +15,7 @@ #include "event_pool.hpp" #include "event_provider.hpp" #include "queue_api.hpp" +#include "queue_handle.hpp" #include "../ur_interface_loader.hpp" @@ -93,7 +94,7 @@ ur_event_handle_t_::ur_event_handle_t_( : hContext(hContext), event_pool(pool), hZeEvent(std::move(hZeEvent)), flags(flags), profilingData(getZeEvent()) {} -void ur_event_handle_t_::resetQueueAndCommand(ur_queue_handle_t hQueue, +void ur_event_handle_t_::resetQueueAndCommand(ur_queue_t_ *hQueue, ur_command_t commandType) { this->hQueue = hQueue; this->commandType = commandType; @@ -182,7 +183,7 @@ ur_event_handle_t_::getEventEndTimestampAndHandle() { return {profilingData.eventEndTimestampAddr(), getZeEvent()}; } -ur_queue_handle_t ur_event_handle_t_::getQueue() const { return hQueue; } +ur_queue_t_ *ur_event_handle_t_::getQueue() const { return hQueue; } ur_context_handle_t ur_event_handle_t_::getContext() const { return hContext; } diff --git a/source/adapters/level_zero/v2/event.hpp b/source/adapters/level_zero/v2/event.hpp index f4a2bb8c11..8365832224 100644 --- a/source/adapters/level_zero/v2/event.hpp +++ b/source/adapters/level_zero/v2/event.hpp @@ -15,6 +15,7 @@ #include #include +#include "adapters/level_zero/v2/queue_api.hpp" #include "common.hpp" #include "event_provider.hpp" @@ -61,7 +62,7 @@ struct ur_event_handle_t_ : _ur_object { const ur_event_native_properties_t *pProperties); // Set the queue and command that this event is associated with - void resetQueueAndCommand(ur_queue_handle_t hQueue, ur_command_t commandType); + void resetQueueAndCommand(ur_queue_t_ *hQueue, ur_command_t commandType); // releases event immediately ur_result_t forceRelease(); @@ -86,7 +87,7 @@ struct ur_event_handle_t_ : _ur_object { bool isProfilingEnabled() const; // Queue associated with this event. Can be nullptr (for native events) - ur_queue_handle_t getQueue() const; + ur_queue_t_ *getQueue() const; // Context associated with this event ur_context_handle_t getContext() const; @@ -119,7 +120,7 @@ struct ur_event_handle_t_ : _ur_object { // queue and commandType that this event is associated with, set by enqueue // commands - ur_queue_handle_t hQueue = nullptr; + ur_queue_t_ *hQueue = nullptr; ur_command_t commandType = UR_COMMAND_FORCE_UINT32; v2::event_flags_t flags; diff --git a/source/adapters/level_zero/v2/kernel.cpp b/source/adapters/level_zero/v2/kernel.cpp index db10ff03ff..6da466ff4d 100644 --- a/source/adapters/level_zero/v2/kernel.cpp +++ b/source/adapters/level_zero/v2/kernel.cpp @@ -14,6 +14,7 @@ #include "kernel.hpp" #include "memory.hpp" #include "queue_api.hpp" +#include "queue_handle.hpp" #include "../device.hpp" #include "../helpers/kernel_helpers.hpp" @@ -656,8 +657,9 @@ ur_result_t urKernelGetSuggestedLocalWorkSize( std::copy(pGlobalWorkSize, pGlobalWorkSize + workDim, globalWorkSize3D); ur_device_handle_t hDevice; - UR_CALL(hQueue->queueGetInfo(UR_QUEUE_INFO_DEVICE, sizeof(hDevice), - reinterpret_cast(&hDevice), nullptr)); + UR_CALL(hQueue->get().queueGetInfo(UR_QUEUE_INFO_DEVICE, sizeof(hDevice), + reinterpret_cast(&hDevice), + nullptr)); UR_CALL(getSuggestedLocalWorkSize(hDevice, hKernel->getZeHandle(hDevice), globalWorkSize3D, localWorkSize)); diff --git a/source/adapters/level_zero/v2/queue_api.cpp b/source/adapters/level_zero/v2/queue_api.cpp index 28ff527413..2805f50d92 100644 --- a/source/adapters/level_zero/v2/queue_api.cpp +++ b/source/adapters/level_zero/v2/queue_api.cpp @@ -15,42 +15,44 @@ // scripts/templates/queue_api.cpp.mako #include "queue_api.hpp" +#include "queue_handle.hpp" #include "ur_util.hpp" -ur_queue_handle_t_::~ur_queue_handle_t_() {} +ur_queue_t_::~ur_queue_t_() {} namespace ur::level_zero { ur_result_t urQueueGetInfo(ur_queue_handle_t hQueue, ur_queue_info_t propName, size_t propSize, void *pPropValue, size_t *pPropSizeRet) try { - return hQueue->queueGetInfo(propName, propSize, pPropValue, pPropSizeRet); + return hQueue->get().queueGetInfo(propName, propSize, pPropValue, + pPropSizeRet); } catch (...) { return exceptionToResult(std::current_exception()); } ur_result_t urQueueRetain(ur_queue_handle_t hQueue) try { - return hQueue->queueRetain(); + return hQueue->get().queueRetain(); } catch (...) { return exceptionToResult(std::current_exception()); } ur_result_t urQueueRelease(ur_queue_handle_t hQueue) try { - return hQueue->queueRelease(); + return hQueue->get().queueRelease(); } catch (...) { return exceptionToResult(std::current_exception()); } ur_result_t urQueueGetNativeHandle(ur_queue_handle_t hQueue, ur_queue_native_desc_t *pDesc, ur_native_handle_t *phNativeQueue) try { - return hQueue->queueGetNativeHandle(pDesc, phNativeQueue); + return hQueue->get().queueGetNativeHandle(pDesc, phNativeQueue); } catch (...) { return exceptionToResult(std::current_exception()); } ur_result_t urQueueFinish(ur_queue_handle_t hQueue) try { - return hQueue->queueFinish(); + return hQueue->get().queueFinish(); } catch (...) { return exceptionToResult(std::current_exception()); } ur_result_t urQueueFlush(ur_queue_handle_t hQueue) try { - return hQueue->queueFlush(); + return hQueue->get().queueFlush(); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -59,7 +61,7 @@ ur_result_t urEnqueueKernelLaunch( const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueKernelLaunch( + return hQueue->get().enqueueKernelLaunch( hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { @@ -69,16 +71,16 @@ ur_result_t urEnqueueEventsWait(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueEventsWait(numEventsInWaitList, phEventWaitList, - phEvent); + return hQueue->get().enqueueEventsWait(numEventsInWaitList, phEventWaitList, + phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } ur_result_t urEnqueueEventsWaitWithBarrier( ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueEventsWaitWithBarrier(numEventsInWaitList, - phEventWaitList, phEvent); + return hQueue->get().enqueueEventsWaitWithBarrier(numEventsInWaitList, + phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -88,9 +90,9 @@ ur_result_t urEnqueueMemBufferRead(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueMemBufferRead(hBuffer, blockingRead, offset, size, pDst, - numEventsInWaitList, phEventWaitList, - phEvent); + return hQueue->get().enqueueMemBufferRead(hBuffer, blockingRead, offset, size, + pDst, numEventsInWaitList, + phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -98,9 +100,9 @@ ur_result_t urEnqueueMemBufferWrite( ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite, size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueMemBufferWrite(hBuffer, blockingWrite, offset, size, - pSrc, numEventsInWaitList, - phEventWaitList, phEvent); + return hQueue->get().enqueueMemBufferWrite(hBuffer, blockingWrite, offset, + size, pSrc, numEventsInWaitList, + phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -111,7 +113,7 @@ ur_result_t urEnqueueMemBufferReadRect( size_t hostRowPitch, size_t hostSlicePitch, void *pDst, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueMemBufferReadRect( + return hQueue->get().enqueueMemBufferReadRect( hBuffer, blockingRead, bufferOrigin, hostOrigin, region, bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, pDst, numEventsInWaitList, phEventWaitList, phEvent); @@ -125,7 +127,7 @@ ur_result_t urEnqueueMemBufferWriteRect( size_t hostRowPitch, size_t hostSlicePitch, void *pSrc, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueMemBufferWriteRect( + return hQueue->get().enqueueMemBufferWriteRect( hBuffer, blockingWrite, bufferOrigin, hostOrigin, region, bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, pSrc, numEventsInWaitList, phEventWaitList, phEvent); @@ -139,9 +141,9 @@ ur_result_t urEnqueueMemBufferCopy(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueMemBufferCopy(hBufferSrc, hBufferDst, srcOffset, - dstOffset, size, numEventsInWaitList, - phEventWaitList, phEvent); + return hQueue->get().enqueueMemBufferCopy( + hBufferSrc, hBufferDst, srcOffset, dstOffset, size, numEventsInWaitList, + phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -152,7 +154,7 @@ ur_result_t urEnqueueMemBufferCopyRect( size_t srcSlicePitch, size_t dstRowPitch, size_t dstSlicePitch, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueMemBufferCopyRect( + return hQueue->get().enqueueMemBufferCopyRect( hBufferSrc, hBufferDst, srcOrigin, dstOrigin, region, srcRowPitch, srcSlicePitch, dstRowPitch, dstSlicePitch, numEventsInWaitList, phEventWaitList, phEvent); @@ -166,9 +168,9 @@ ur_result_t urEnqueueMemBufferFill(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueMemBufferFill(hBuffer, pPattern, patternSize, offset, - size, numEventsInWaitList, - phEventWaitList, phEvent); + return hQueue->get().enqueueMemBufferFill(hBuffer, pPattern, patternSize, + offset, size, numEventsInWaitList, + phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -177,7 +179,7 @@ ur_result_t urEnqueueMemImageRead( ur_rect_offset_t origin, ur_rect_region_t region, size_t rowPitch, size_t slicePitch, void *pDst, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueMemImageRead( + return hQueue->get().enqueueMemImageRead( hImage, blockingRead, origin, region, rowPitch, slicePitch, pDst, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { @@ -188,7 +190,7 @@ ur_result_t urEnqueueMemImageWrite( ur_rect_offset_t origin, ur_rect_region_t region, size_t rowPitch, size_t slicePitch, void *pSrc, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueMemImageWrite( + return hQueue->get().enqueueMemImageWrite( hImage, blockingWrite, origin, region, rowPitch, slicePitch, pSrc, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { @@ -201,9 +203,9 @@ urEnqueueMemImageCopy(ur_queue_handle_t hQueue, ur_mem_handle_t hImageSrc, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueMemImageCopy(hImageSrc, hImageDst, srcOrigin, dstOrigin, - region, numEventsInWaitList, - phEventWaitList, phEvent); + return hQueue->get().enqueueMemImageCopy( + hImageSrc, hImageDst, srcOrigin, dstOrigin, region, numEventsInWaitList, + phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -214,9 +216,9 @@ ur_result_t urEnqueueMemBufferMap(ur_queue_handle_t hQueue, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent, void **ppRetMap) try { - return hQueue->enqueueMemBufferMap(hBuffer, blockingMap, mapFlags, offset, - size, numEventsInWaitList, phEventWaitList, - phEvent, ppRetMap); + return hQueue->get().enqueueMemBufferMap(hBuffer, blockingMap, mapFlags, + offset, size, numEventsInWaitList, + phEventWaitList, phEvent, ppRetMap); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -224,8 +226,8 @@ ur_result_t urEnqueueMemUnmap(ur_queue_handle_t hQueue, ur_mem_handle_t hMem, void *pMappedPtr, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueMemUnmap(hMem, pMappedPtr, numEventsInWaitList, - phEventWaitList, phEvent); + return hQueue->get().enqueueMemUnmap(hMem, pMappedPtr, numEventsInWaitList, + phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -234,8 +236,9 @@ ur_result_t urEnqueueUSMFill(ur_queue_handle_t hQueue, void *pMem, size_t size, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueUSMFill(pMem, patternSize, pPattern, size, - numEventsInWaitList, phEventWaitList, phEvent); + return hQueue->get().enqueueUSMFill(pMem, patternSize, pPattern, size, + numEventsInWaitList, phEventWaitList, + phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -244,9 +247,9 @@ ur_result_t urEnqueueUSMMemcpy(ur_queue_handle_t hQueue, bool blocking, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueUSMMemcpy(blocking, pDst, pSrc, size, - numEventsInWaitList, phEventWaitList, - phEvent); + return hQueue->get().enqueueUSMMemcpy(blocking, pDst, pSrc, size, + numEventsInWaitList, phEventWaitList, + phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -255,15 +258,15 @@ ur_result_t urEnqueueUSMPrefetch(ur_queue_handle_t hQueue, const void *pMem, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueUSMPrefetch(pMem, size, flags, numEventsInWaitList, - phEventWaitList, phEvent); + return hQueue->get().enqueueUSMPrefetch( + pMem, size, flags, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } ur_result_t urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size, ur_usm_advice_flags_t advice, ur_event_handle_t *phEvent) try { - return hQueue->enqueueUSMAdvise(pMem, size, advice, phEvent); + return hQueue->get().enqueueUSMAdvise(pMem, size, advice, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -273,9 +276,9 @@ ur_result_t urEnqueueUSMFill2D(ur_queue_handle_t hQueue, void *pMem, size_t height, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueUSMFill2D(pMem, pitch, patternSize, pPattern, width, - height, numEventsInWaitList, phEventWaitList, - phEvent); + return hQueue->get().enqueueUSMFill2D(pMem, pitch, patternSize, pPattern, + width, height, numEventsInWaitList, + phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -285,9 +288,9 @@ ur_result_t urEnqueueUSMMemcpy2D(ur_queue_handle_t hQueue, bool blocking, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueUSMMemcpy2D(blocking, pDst, dstPitch, pSrc, srcPitch, - width, height, numEventsInWaitList, - phEventWaitList, phEvent); + return hQueue->get().enqueueUSMMemcpy2D( + blocking, pDst, dstPitch, pSrc, srcPitch, width, height, + numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -296,7 +299,7 @@ ur_result_t urEnqueueDeviceGlobalVariableWrite( bool blockingWrite, size_t count, size_t offset, const void *pSrc, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueDeviceGlobalVariableWrite( + return hQueue->get().enqueueDeviceGlobalVariableWrite( hProgram, name, blockingWrite, count, offset, pSrc, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { @@ -307,7 +310,7 @@ ur_result_t urEnqueueDeviceGlobalVariableRead( bool blockingRead, size_t count, size_t offset, void *pDst, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueDeviceGlobalVariableRead( + return hQueue->get().enqueueDeviceGlobalVariableRead( hProgram, name, blockingRead, count, offset, pDst, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { @@ -320,9 +323,9 @@ ur_result_t urEnqueueReadHostPipe(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueReadHostPipe(hProgram, pipe_symbol, blocking, pDst, - size, numEventsInWaitList, phEventWaitList, - phEvent); + return hQueue->get().enqueueReadHostPipe(hProgram, pipe_symbol, blocking, + pDst, size, numEventsInWaitList, + phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -333,9 +336,9 @@ ur_result_t urEnqueueWriteHostPipe(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueWriteHostPipe(hProgram, pipe_symbol, blocking, pSrc, - size, numEventsInWaitList, - phEventWaitList, phEvent); + return hQueue->get().enqueueWriteHostPipe(hProgram, pipe_symbol, blocking, + pSrc, size, numEventsInWaitList, + phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -347,7 +350,7 @@ ur_result_t urBindlessImagesImageCopyExp( ur_exp_image_copy_region_t *pCopyRegion, ur_exp_image_copy_flags_t imageCopyFlags, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->bindlessImagesImageCopyExp( + return hQueue->get().bindlessImagesImageCopyExp( pSrc, pDst, pSrcImageDesc, pDstImageDesc, pSrcImageFormat, pDstImageFormat, pCopyRegion, imageCopyFlags, numEventsInWaitList, phEventWaitList, phEvent); @@ -358,7 +361,7 @@ ur_result_t urBindlessImagesWaitExternalSemaphoreExp( ur_queue_handle_t hQueue, ur_exp_external_semaphore_handle_t hSemaphore, bool hasWaitValue, uint64_t waitValue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->bindlessImagesWaitExternalSemaphoreExp( + return hQueue->get().bindlessImagesWaitExternalSemaphoreExp( hSemaphore, hasWaitValue, waitValue, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { @@ -368,7 +371,7 @@ ur_result_t urBindlessImagesSignalExternalSemaphoreExp( ur_queue_handle_t hQueue, ur_exp_external_semaphore_handle_t hSemaphore, bool hasSignalValue, uint64_t signalValue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->bindlessImagesSignalExternalSemaphoreExp( + return hQueue->get().bindlessImagesSignalExternalSemaphoreExp( hSemaphore, hasSignalValue, signalValue, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { @@ -379,7 +382,7 @@ ur_result_t urEnqueueCooperativeKernelLaunchExp( const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueCooperativeKernelLaunchExp( + return hQueue->get().enqueueCooperativeKernelLaunchExp( hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { @@ -388,8 +391,8 @@ ur_result_t urEnqueueCooperativeKernelLaunchExp( ur_result_t urEnqueueTimestampRecordingExp( ur_queue_handle_t hQueue, bool blocking, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueTimestampRecordingExp(blocking, numEventsInWaitList, - phEventWaitList, phEvent); + return hQueue->get().enqueueTimestampRecordingExp( + blocking, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -400,7 +403,7 @@ ur_result_t urEnqueueKernelLaunchCustomExp( const ur_exp_launch_property_t *launchPropList, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueKernelLaunchCustomExp( + return hQueue->get().enqueueKernelLaunchCustomExp( hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, numPropsInLaunchPropList, launchPropList, numEventsInWaitList, phEventWaitList, phEvent); @@ -412,7 +415,7 @@ ur_result_t urEnqueueEventsWaitWithBarrierExt( const ur_exp_enqueue_ext_properties_t *pProperties, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueEventsWaitWithBarrierExt( + return hQueue->get().enqueueEventsWaitWithBarrierExt( pProperties, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { return exceptionToResult(std::current_exception()); @@ -424,7 +427,7 @@ ur_result_t urEnqueueNativeCommandExp( const ur_exp_enqueue_native_command_properties_t *pProperties, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) try { - return hQueue->enqueueNativeCommandExp( + return hQueue->get().enqueueNativeCommandExp( pfnNativeEnqueue, data, numMemsInMemList, phMemList, pProperties, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { diff --git a/source/adapters/level_zero/v2/queue_api.hpp b/source/adapters/level_zero/v2/queue_api.hpp index 88d812bbba..bb2a282490 100644 --- a/source/adapters/level_zero/v2/queue_api.hpp +++ b/source/adapters/level_zero/v2/queue_api.hpp @@ -19,8 +19,8 @@ #include #include -struct ur_queue_handle_t_ { - virtual ~ur_queue_handle_t_(); +struct ur_queue_t_ { + virtual ~ur_queue_t_(); virtual void deferEventFree(ur_event_handle_t hEvent) = 0; diff --git a/source/adapters/level_zero/v2/queue_create.cpp b/source/adapters/level_zero/v2/queue_create.cpp index f397cd8747..7e2de8b1b6 100644 --- a/source/adapters/level_zero/v2/queue_create.cpp +++ b/source/adapters/level_zero/v2/queue_create.cpp @@ -12,6 +12,7 @@ #include "logger/ur_logger.hpp" #include "queue_api.hpp" +#include "queue_handle.hpp" #include "queue_immediate_in_order.hpp" #include @@ -27,8 +28,8 @@ ur_result_t urQueueCreate(ur_context_handle_t hContext, } // TODO: For now, always use immediate, in-order - *phQueue = - new v2::ur_queue_immediate_in_order_t(hContext, hDevice, pProperties); + *phQueue = ur_queue_handle_t_::create( + hContext, hDevice, pProperties); return UR_RESULT_SUCCESS; } catch (...) { return exceptionToResult(std::current_exception()); @@ -57,7 +58,7 @@ ur_result_t urQueueCreateWithNativeHandle( } } - *phQueue = new v2::ur_queue_immediate_in_order_t( + *phQueue = ur_queue_handle_t_::create( hContext, hDevice, hNativeQueue, flags, ownNativeHandle); return UR_RESULT_SUCCESS; diff --git a/source/adapters/level_zero/v2/queue_handle.hpp b/source/adapters/level_zero/v2/queue_handle.hpp new file mode 100644 index 0000000000..caeb7f30a1 --- /dev/null +++ b/source/adapters/level_zero/v2/queue_handle.hpp @@ -0,0 +1,32 @@ +/* + * + * Copyright (C) 2024 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM + * Exceptions. See LICENSE.TXT + * + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file queue_handle.hpp + * + */ + +#pragma once + +#include "queue_immediate_in_order.hpp" +#include +#include + +struct ur_queue_handle_t_ { + using data_variant = std::variant; + data_variant queue_data; + + template + static ur_queue_handle_t_ *create(Args &&...args) { + return new ur_queue_handle_t_{data_variant{std::in_place_type, args...}}; + } + + ur_queue_t_ &get() { + return std::visit([&](auto &q) -> ur_queue_t_ & { return q; }, queue_data); + } +}; diff --git a/source/adapters/level_zero/v2/queue_immediate_in_order.hpp b/source/adapters/level_zero/v2/queue_immediate_in_order.hpp index 6cf8b0c51c..16ec86e22f 100644 --- a/source/adapters/level_zero/v2/queue_immediate_in_order.hpp +++ b/source/adapters/level_zero/v2/queue_immediate_in_order.hpp @@ -25,7 +25,7 @@ namespace v2 { using queue_group_type = ur_device_handle_t_::queue_group_info_t::type; -struct ur_queue_immediate_in_order_t : _ur_object, public ur_queue_handle_t_ { +struct ur_queue_immediate_in_order_t : _ur_object, public ur_queue_t_ { private: ur_context_handle_t hContext; ur_device_handle_t hDevice; diff --git a/test/adapters/level_zero/v2/event_pool_test.cpp b/test/adapters/level_zero/v2/event_pool_test.cpp index 2d04975f0c..d6719dc3ed 100644 --- a/test/adapters/level_zero/v2/event_pool_test.cpp +++ b/test/adapters/level_zero/v2/event_pool_test.cpp @@ -16,6 +16,7 @@ #include "event_provider.hpp" #include "event_provider_counter.hpp" #include "event_provider_normal.hpp" +#include "queue_handle.hpp" #include "uur/fixtures.h" #include "ze_api.h" @@ -164,7 +165,7 @@ TEST_P(EventPoolTest, Basic) { auto pool = cache->borrow(device->Id.value(), getParam().flags); first = pool->allocate(); - first->resetQueueAndCommand(queue, UR_COMMAND_KERNEL_LAUNCH); + first->resetQueueAndCommand(&queue->get(), UR_COMMAND_KERNEL_LAUNCH); zeFirst = first->getZeEvent(); urEventRelease(first); @@ -175,7 +176,7 @@ TEST_P(EventPoolTest, Basic) { auto pool = cache->borrow(device->Id.value(), getParam().flags); second = pool->allocate(); - first->resetQueueAndCommand(queue, UR_COMMAND_KERNEL_LAUNCH); + first->resetQueueAndCommand(&queue->get(), UR_COMMAND_KERNEL_LAUNCH); zeSecond = second->getZeEvent(); urEventRelease(second); @@ -195,7 +196,8 @@ TEST_P(EventPoolTest, Threaded) { std::vector events; for (int i = 0; i < 100; ++i) { events.push_back(pool->allocate()); - events.back()->resetQueueAndCommand(queue, UR_COMMAND_KERNEL_LAUNCH); + events.back()->resetQueueAndCommand(&queue->get(), + UR_COMMAND_KERNEL_LAUNCH); } for (int i = 0; i < 100; ++i) { urEventRelease(events[i]); @@ -214,7 +216,7 @@ TEST_P(EventPoolTest, ProviderNormalUseMostFreePool) { std::list events; for (int i = 0; i < 128; ++i) { auto event = pool->allocate(); - event->resetQueueAndCommand(queue, UR_COMMAND_KERNEL_LAUNCH); + event->resetQueueAndCommand(&queue->get(), UR_COMMAND_KERNEL_LAUNCH); events.push_back(event); } auto frontZeHandle = events.front()->getZeEvent(); @@ -224,7 +226,7 @@ TEST_P(EventPoolTest, ProviderNormalUseMostFreePool) { } for (int i = 0; i < 8; ++i) { auto e = pool->allocate(); - e->resetQueueAndCommand(queue, UR_COMMAND_KERNEL_LAUNCH); + e->resetQueueAndCommand(&queue->get(), UR_COMMAND_KERNEL_LAUNCH); events.push_back(e); }