From b24eb57d99e5180ce851dc726984c7dc117e36b9 Mon Sep 17 00:00:00 2001 From: Ben Tracy Date: Wed, 20 Nov 2024 13:27:14 +0000 Subject: [PATCH] [CMDBUF] Implement kernel binary update for L0 adapter - Implement binary update in L0 adapter - Platform and driver query support for new extension functions - Update command buffer documentation to specify arg requirements for binary update --- scripts/core/EXP-COMMAND-BUFFER.rst | 5 + source/adapters/level_zero/command_buffer.cpp | 143 ++++++++++++++---- source/adapters/level_zero/command_buffer.hpp | 21 ++- source/adapters/level_zero/device.cpp | 4 + source/adapters/level_zero/platform.cpp | 16 ++ source/adapters/level_zero/platform.hpp | 6 + 6 files changed, 158 insertions(+), 37 deletions(-) diff --git a/scripts/core/EXP-COMMAND-BUFFER.rst b/scripts/core/EXP-COMMAND-BUFFER.rst index d6ef76c7bc..0baec66a66 100644 --- a/scripts/core/EXP-COMMAND-BUFFER.rst +++ b/scripts/core/EXP-COMMAND-BUFFER.rst @@ -256,6 +256,11 @@ ${x}CommandBufferAppendKernelLaunchExp. The command can then be updated to use the new kernel handle by passing it to ${x}CommandBufferUpdateKernelLaunchExp. +.. important:: + When updating the kernel handle of a command you must also provide all + required arguments to the new kernel in the update descriptor. Failure to + do so will result in undefined behavior. + .. parsed-literal:: // Create a command-buffer with update enabled. diff --git a/source/adapters/level_zero/command_buffer.cpp b/source/adapters/level_zero/command_buffer.cpp index 56c53b5331..eccdc5e4d2 100644 --- a/source/adapters/level_zero/command_buffer.cpp +++ b/source/adapters/level_zero/command_buffer.cpp @@ -476,21 +476,14 @@ void ur_exp_command_buffer_handle_t_::cleanupCommandBufferResources() { ur_exp_command_buffer_command_handle_t_:: ur_exp_command_buffer_command_handle_t_( - ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId, - uint32_t WorkDim, bool UserDefinedLocalSize, - ur_kernel_handle_t Kernel = nullptr) - : CommandBuffer(CommandBuffer), CommandId(CommandId), WorkDim(WorkDim), - UserDefinedLocalSize(UserDefinedLocalSize), Kernel(Kernel) { + ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId) + : CommandBuffer(CommandBuffer), CommandId(CommandId) { ur::level_zero::urCommandBufferRetainExp(CommandBuffer); - if (Kernel) - ur::level_zero::urKernelRetain(Kernel); } ur_exp_command_buffer_command_handle_t_:: ~ur_exp_command_buffer_command_handle_t_() { ur::level_zero::urCommandBufferReleaseExp(CommandBuffer); - if (Kernel) - ur::level_zero::urKernelRelease(Kernel); } void ur_exp_command_buffer_handle_t_::registerSyncPoint( @@ -527,6 +520,31 @@ ur_result_t ur_exp_command_buffer_handle_t_::getFenceForQueue( return UR_RESULT_SUCCESS; } +kernel_command_handle::kernel_command_handle( + ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel, + uint64_t CommandId, uint32_t WorkDim, bool UserDefinedLocalSize, + uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives) + : ur_exp_command_buffer_command_handle_t_(CommandBuffer, CommandId), + WorkDim(WorkDim), UserDefinedLocalSize(UserDefinedLocalSize), + Kernel(Kernel) { + // Add the default kernel to the list of valid kernels + ur::level_zero::urKernelRetain(Kernel); + ValidKernelHandles.insert(Kernel); + // Add alternative kernels if provided + if (KernelAlternatives) { + for (size_t i = 0; i < NumKernelAlternatives; i++) { + ur::level_zero::urKernelRetain(KernelAlternatives[i]); + ValidKernelHandles.insert(KernelAlternatives[i]); + } + } +} + +kernel_command_handle::~kernel_command_handle() { + for (const ur_kernel_handle_t &KernelHandle : ValidKernelHandles) { + ur::level_zero::urKernelRelease(KernelHandle); + } +} + namespace ur::level_zero { /** @@ -906,7 +924,8 @@ setKernelPendingArguments(ur_exp_command_buffer_handle_t CommandBuffer, ur_result_t createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel, uint32_t WorkDim, - const size_t *LocalWorkSize, + const size_t *LocalWorkSize, uint32_t NumKernelAlternatives, + ur_kernel_handle_t *KernelAlternatives, ur_exp_command_buffer_command_handle_t &Command) { assert(CommandBuffer->IsUpdatable); @@ -923,14 +942,41 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer, ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET; auto Platform = CommandBuffer->Context->getPlatform(); - ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp, - (CommandBuffer->ZeComputeCommandListTranslated, - &ZeMutableCommandDesc, &CommandId)); + if (NumKernelAlternatives > 0) { + ZeMutableCommandDesc.flags |= + ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION; + + std::vector TranslatedKernelHandles( + NumKernelAlternatives + 1, nullptr); + + // Translate main kernel first + ZE2UR_CALL(zelLoaderTranslateHandle, + (ZEL_HANDLE_KERNEL, Kernel->ZeKernel, + (void **)&TranslatedKernelHandles[0])); + + for (size_t i = 0; i < NumKernelAlternatives; i++) { + ZE2UR_CALL(zelLoaderTranslateHandle, + (ZEL_HANDLE_KERNEL, KernelAlternatives[i]->ZeKernel, + (void **)&TranslatedKernelHandles[i + 1])); + } + + ZE2UR_CALL(Platform->ZeMutableCmdListExt + .zexCommandListGetNextCommandIdWithKernelsExp, + (CommandBuffer->ZeComputeCommandListTranslated, + &ZeMutableCommandDesc, NumKernelAlternatives + 1, + TranslatedKernelHandles.data(), &CommandId)); + + } else { + ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp, + (CommandBuffer->ZeComputeCommandListTranslated, + &ZeMutableCommandDesc, &CommandId)); + } DEBUG_LOG(CommandId); try { - Command = new ur_exp_command_buffer_command_handle_t_( - CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr, Kernel); + Command = new kernel_command_handle( + CommandBuffer, Kernel, CommandId, WorkDim, LocalWorkSize != nullptr, + NumKernelAlternatives, KernelAlternatives); } catch (const std::bad_alloc &) { return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; } catch (...) { @@ -944,8 +990,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp( ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel, uint32_t WorkDim, const size_t *GlobalWorkOffset, const size_t *GlobalWorkSize, const size_t *LocalWorkSize, - uint32_t /*numKernelAlternatives*/, - ur_kernel_handle_t * /*phKernelAlternatives*/, + uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives, uint32_t NumSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *SyncPointWaitList, uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList, @@ -960,6 +1005,10 @@ ur_result_t urCommandBufferAppendKernelLaunchExp( UR_ASSERT(!(Command && !CommandBuffer->IsUpdatable), UR_RESULT_ERROR_INVALID_OPERATION); + for (uint32_t i = 0; i < NumKernelAlternatives; ++i) { + UR_ASSERT(KernelAlternatives[i] != Kernel, UR_RESULT_ERROR_INVALID_VALUE); + } + // Lock automatically releases when this goes out of scope. std::scoped_lock Lock( Kernel->Mutex, Kernel->Program->Mutex, CommandBuffer->Mutex); @@ -983,18 +1032,21 @@ ur_result_t urCommandBufferAppendKernelLaunchExp( ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2])); CommandBuffer->KernelsList.push_back(Kernel); + for (size_t i = 0; i < NumKernelAlternatives; i++) { + CommandBuffer->KernelsList.push_back(KernelAlternatives[i]); + } - // Increment the reference count of the Kernel and indicate that the Kernel - // is in use. Once the event has been signaled, the code in - // CleanupCompletedEvent(Event) will do a urKernelRelease to update the - // reference count on the kernel, using the kernel saved in CommandData. - UR_CALL(ur::level_zero::urKernelRetain(Kernel)); + ur::level_zero::urKernelRetain(Kernel); + // Retain alternative kernels if provided + for (size_t i = 0; i < NumKernelAlternatives; i++) { + ur::level_zero::urKernelRetain(KernelAlternatives[i]); + } if (Command) { UR_CALL(createCommandHandle(CommandBuffer, Kernel, WorkDim, LocalWorkSize, + NumKernelAlternatives, KernelAlternatives, *Command)); } - std::vector ZeEventList; ze_event_handle_t ZeLaunchEvent = nullptr; UR_CALL(createSyncPointAndGetZeEvents( @@ -1690,7 +1742,7 @@ ur_result_t urCommandBufferReleaseCommandExp( * @return UR_RESULT_SUCCESS or an error code on failure */ ur_result_t validateCommandDesc( - ur_exp_command_buffer_command_handle_t Command, + kernel_command_handle *Command, const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) { auto CommandBuffer = Command->CommandBuffer; @@ -1699,9 +1751,14 @@ ur_result_t validateCommandDesc( ->mutableCommandFlags; logger::debug("Mutable features supported by device {}", SupportedFeatures); - // Kernel handle updates are not yet supported. - if (CommandDesc->hNewKernel && CommandDesc->hNewKernel != Command->Kernel) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + UR_ASSERT( + !CommandDesc->hNewKernel || + (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION), + UR_RESULT_ERROR_UNSUPPORTED_FEATURE); + // Check if the provided new kernel is in the list of valid alternatives. + if (CommandDesc->hNewKernel && + !Command->ValidKernelHandles.count(CommandDesc->hNewKernel)) { + return UR_RESULT_ERROR_INVALID_VALUE; } if (CommandDesc->newWorkDim != Command->WorkDim && @@ -1754,7 +1811,7 @@ ur_result_t validateCommandDesc( * @return UR_RESULT_SUCCESS or an error code on failure */ ur_result_t updateKernelCommand( - ur_exp_command_buffer_command_handle_t Command, + kernel_command_handle *Command, const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) { // We need the created descriptors to live till the point when @@ -1769,12 +1826,29 @@ ur_result_t updateKernelCommand( const auto CommandBuffer = Command->CommandBuffer; const void *NextDesc = nullptr; + auto Platform = CommandBuffer->Context->getPlatform(); uint32_t Dim = CommandDesc->newWorkDim; size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset; size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize; size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize; + // Kernel handle must be updated first for a given CommandId if required + ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel; + if (NewKernel && Command->Kernel != NewKernel) { + ze_kernel_handle_t ZeKernelTranslated = nullptr; + ZE2UR_CALL( + zelLoaderTranslateHandle, + (ZEL_HANDLE_KERNEL, NewKernel->ZeKernel, (void **)&ZeKernelTranslated)); + + ZE2UR_CALL(Platform->ZeMutableCmdListExt + .zexCommandListUpdateMutableCommandKernelsExp, + (CommandBuffer->ZeComputeCommandListTranslated, 1, + &Command->CommandId, &ZeKernelTranslated)); + // Set current kernel to be the new kernel + Command->Kernel = NewKernel; + } + // Check if a new global offset is provided. if (NewGlobalWorkOffset && Dim > 0) { auto MutableGroupOffestDesc = @@ -1973,7 +2047,6 @@ ur_result_t updateKernelCommand( MutableCommandDesc.pNext = NextDesc; MutableCommandDesc.flags = 0; - auto Platform = CommandBuffer->Context->getPlatform(); ZE2UR_CALL( Platform->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp, (CommandBuffer->ZeComputeCommandListTranslated, &MutableCommandDesc)); @@ -2009,18 +2082,22 @@ ur_result_t urCommandBufferUpdateKernelLaunchExp( const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) { UR_ASSERT(Command->CommandBuffer->IsUpdatable, UR_RESULT_ERROR_INVALID_OPERATION); - UR_ASSERT(Command->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE); + + auto KernelCommandHandle = static_cast(Command); + + UR_ASSERT(KernelCommandHandle->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE); // Lock command, kernel and command buffer for update. std::scoped_lock Guard( - Command->Mutex, Command->CommandBuffer->Mutex, Command->Kernel->Mutex); + Command->Mutex, Command->CommandBuffer->Mutex, + KernelCommandHandle->Kernel->Mutex); UR_ASSERT(Command->CommandBuffer->IsFinalized, UR_RESULT_ERROR_INVALID_OPERATION); - UR_CALL(validateCommandDesc(Command, CommandDesc)); + UR_CALL(validateCommandDesc(KernelCommandHandle, CommandDesc)); UR_CALL(waitForOngoingExecution(Command->CommandBuffer)); - UR_CALL(updateKernelCommand(Command, CommandDesc)); + UR_CALL(updateKernelCommand(KernelCommandHandle, CommandDesc)); ZE2UR_CALL(zeCommandListClose, (Command->CommandBuffer->ZeComputeCommandList)); diff --git a/source/adapters/level_zero/command_buffer.hpp b/source/adapters/level_zero/command_buffer.hpp index 156e0e5c24..d069f301fb 100644 --- a/source/adapters/level_zero/command_buffer.hpp +++ b/source/adapters/level_zero/command_buffer.hpp @@ -145,18 +145,31 @@ struct ur_exp_command_buffer_handle_t_ : public _ur_object { struct ur_exp_command_buffer_command_handle_t_ : public _ur_object { ur_exp_command_buffer_command_handle_t_(ur_exp_command_buffer_handle_t, - uint64_t, uint32_t, bool, - ur_kernel_handle_t); + uint64_t); - ~ur_exp_command_buffer_command_handle_t_(); + virtual ~ur_exp_command_buffer_command_handle_t_(); // Command-buffer of this command. ur_exp_command_buffer_handle_t CommandBuffer; - + // L0 command ID identifying this command uint64_t CommandId; +}; + +struct kernel_command_handle : public ur_exp_command_buffer_command_handle_t_ { + kernel_command_handle(ur_exp_command_buffer_handle_t CommandBuffer, + ur_kernel_handle_t Kernel, uint64_t CommandId, + uint32_t WorkDim, bool UserDefinedLocalSize, + uint32_t NumKernelAlternatives, + ur_kernel_handle_t *KernelAlternatives); + + ~kernel_command_handle(); + // Work-dimension the command was originally created with. uint32_t WorkDim; // Set to true if the user set the local work size on command creation. bool UserDefinedLocalSize; + // Currently active kernel handle ur_kernel_handle_t Kernel; + // Storage for valid kernel alternatives for this command. + std::unordered_set ValidKernelHandles; }; diff --git a/source/adapters/level_zero/device.cpp b/source/adapters/level_zero/device.cpp index 865edebc08..3997c837f8 100644 --- a/source/adapters/level_zero/device.cpp +++ b/source/adapters/level_zero/device.cpp @@ -1048,6 +1048,10 @@ ur_result_t urDeviceGetInfo( UpdateCapabilities |= UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_OFFSET; } + if (supportsFlags(ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION)) { + UpdateCapabilities |= + UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_HANDLE; + } return ReturnValue(UpdateCapabilities); } case UR_DEVICE_INFO_COMMAND_BUFFER_EVENT_SUPPORT_EXP: diff --git a/source/adapters/level_zero/platform.cpp b/source/adapters/level_zero/platform.cpp index 1e65e55048..dbcd755c43 100644 --- a/source/adapters/level_zero/platform.cpp +++ b/source/adapters/level_zero/platform.cpp @@ -320,6 +320,22 @@ ur_result_t ur_platform_handle_t_::initialize() { &ZeMutableCmdListExt .zexCommandListUpdateMutableCommandWaitEventsExp))) == 0); + ZeMutableCmdListExt.Supported &= + (ZE_CALL_NOCHECK( + zeDriverGetExtensionFunctionAddress, + (ZeDriver, "zeCommandListUpdateMutableCommandKernelsExp", + reinterpret_cast( + &ZeMutableCmdListExt + .zexCommandListUpdateMutableCommandKernelsExp))) == 0); + + ZeMutableCmdListExt.Supported &= + (ZE_CALL_NOCHECK( + zeDriverGetExtensionFunctionAddress, + (ZeDriver, "zeCommandListGetNextCommandIdWithKernelsExp", + reinterpret_cast( + &ZeMutableCmdListExt + .zexCommandListGetNextCommandIdWithKernelsExp))) == 0); + return UR_RESULT_SUCCESS; } diff --git a/source/adapters/level_zero/platform.hpp b/source/adapters/level_zero/platform.hpp index 515f71b8c4..a3087e69f7 100644 --- a/source/adapters/level_zero/platform.hpp +++ b/source/adapters/level_zero/platform.hpp @@ -106,5 +106,11 @@ struct ur_platform_handle_t_ : public _ur_platform { ze_result_t (*zexCommandListUpdateMutableCommandWaitEventsExp)( ze_command_list_handle_t, uint64_t, uint32_t, ze_event_handle_t *) = nullptr; + ze_result_t (*zexCommandListUpdateMutableCommandKernelsExp)( + ze_command_list_handle_t, uint32_t, uint64_t *, + ze_kernel_handle_t *) = nullptr; + ze_result_t (*zexCommandListGetNextCommandIdWithKernelsExp)( + ze_command_list_handle_t, const ze_mutable_command_id_exp_desc_t *, + uint32_t, ze_kernel_handle_t *, uint64_t *) = nullptr; } ZeMutableCmdListExt; };