Skip to content

Commit

Permalink
[CMDBUF] Implement kernel binary update for L0 adapter
Browse files Browse the repository at this point in the history
- Implement binary update in L0 adapter
- Platform and driver query support for new extension functions
- Update command buffer documentation to specify arg requirements for binary update
  • Loading branch information
Bensuo committed Nov 21, 2024
1 parent 3609cd6 commit b24eb57
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 37 deletions.
5 changes: 5 additions & 0 deletions scripts/core/EXP-COMMAND-BUFFER.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ ${x}CommandBufferAppendKernelLaunchExp. The command can then be updated
to use the new kernel handle by passing it to
${x}CommandBufferUpdateKernelLaunchExp.

.. important::
When updating the kernel handle of a command you must also provide all
required arguments to the new kernel in the update descriptor. Failure to
do so will result in undefined behavior.

.. parsed-literal::
// Create a command-buffer with update enabled.
Expand Down
143 changes: 110 additions & 33 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,21 +476,14 @@ void ur_exp_command_buffer_handle_t_::cleanupCommandBufferResources() {

ur_exp_command_buffer_command_handle_t_::
ur_exp_command_buffer_command_handle_t_(
ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId,
uint32_t WorkDim, bool UserDefinedLocalSize,
ur_kernel_handle_t Kernel = nullptr)
: CommandBuffer(CommandBuffer), CommandId(CommandId), WorkDim(WorkDim),
UserDefinedLocalSize(UserDefinedLocalSize), Kernel(Kernel) {
ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId)
: CommandBuffer(CommandBuffer), CommandId(CommandId) {
ur::level_zero::urCommandBufferRetainExp(CommandBuffer);
if (Kernel)
ur::level_zero::urKernelRetain(Kernel);
}

ur_exp_command_buffer_command_handle_t_::
~ur_exp_command_buffer_command_handle_t_() {
ur::level_zero::urCommandBufferReleaseExp(CommandBuffer);
if (Kernel)
ur::level_zero::urKernelRelease(Kernel);
}

void ur_exp_command_buffer_handle_t_::registerSyncPoint(
Expand Down Expand Up @@ -527,6 +520,31 @@ ur_result_t ur_exp_command_buffer_handle_t_::getFenceForQueue(
return UR_RESULT_SUCCESS;
}

kernel_command_handle::kernel_command_handle(
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
uint64_t CommandId, uint32_t WorkDim, bool UserDefinedLocalSize,
uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives)
: ur_exp_command_buffer_command_handle_t_(CommandBuffer, CommandId),
WorkDim(WorkDim), UserDefinedLocalSize(UserDefinedLocalSize),
Kernel(Kernel) {
// Add the default kernel to the list of valid kernels
ur::level_zero::urKernelRetain(Kernel);
ValidKernelHandles.insert(Kernel);
// Add alternative kernels if provided
if (KernelAlternatives) {
for (size_t i = 0; i < NumKernelAlternatives; i++) {
ur::level_zero::urKernelRetain(KernelAlternatives[i]);
ValidKernelHandles.insert(KernelAlternatives[i]);
}
}
}

kernel_command_handle::~kernel_command_handle() {
for (const ur_kernel_handle_t &KernelHandle : ValidKernelHandles) {
ur::level_zero::urKernelRelease(KernelHandle);
}
}

namespace ur::level_zero {

/**
Expand Down Expand Up @@ -906,7 +924,8 @@ setKernelPendingArguments(ur_exp_command_buffer_handle_t CommandBuffer,
ur_result_t
createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
ur_kernel_handle_t Kernel, uint32_t WorkDim,
const size_t *LocalWorkSize,
const size_t *LocalWorkSize, uint32_t NumKernelAlternatives,
ur_kernel_handle_t *KernelAlternatives,
ur_exp_command_buffer_command_handle_t &Command) {

assert(CommandBuffer->IsUpdatable);
Expand All @@ -923,14 +942,41 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;

auto Platform = CommandBuffer->Context->getPlatform();
ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp,
(CommandBuffer->ZeComputeCommandListTranslated,
&ZeMutableCommandDesc, &CommandId));
if (NumKernelAlternatives > 0) {
ZeMutableCommandDesc.flags |=
ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;

std::vector<ze_kernel_handle_t> TranslatedKernelHandles(
NumKernelAlternatives + 1, nullptr);

// Translate main kernel first
ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, Kernel->ZeKernel,
(void **)&TranslatedKernelHandles[0]));

for (size_t i = 0; i < NumKernelAlternatives; i++) {
ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, KernelAlternatives[i]->ZeKernel,
(void **)&TranslatedKernelHandles[i + 1]));
}

ZE2UR_CALL(Platform->ZeMutableCmdListExt
.zexCommandListGetNextCommandIdWithKernelsExp,
(CommandBuffer->ZeComputeCommandListTranslated,
&ZeMutableCommandDesc, NumKernelAlternatives + 1,
TranslatedKernelHandles.data(), &CommandId));

} else {
ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp,
(CommandBuffer->ZeComputeCommandListTranslated,
&ZeMutableCommandDesc, &CommandId));
}
DEBUG_LOG(CommandId);

try {
Command = new ur_exp_command_buffer_command_handle_t_(
CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr, Kernel);
Command = new kernel_command_handle(
CommandBuffer, Kernel, CommandId, WorkDim, LocalWorkSize != nullptr,
NumKernelAlternatives, KernelAlternatives);
} catch (const std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
} catch (...) {
Expand All @@ -944,8 +990,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
uint32_t WorkDim, const size_t *GlobalWorkOffset,
const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
uint32_t /*numKernelAlternatives*/,
ur_kernel_handle_t * /*phKernelAlternatives*/,
uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives,
uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
Expand All @@ -960,6 +1005,10 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
UR_ASSERT(!(Command && !CommandBuffer->IsUpdatable),
UR_RESULT_ERROR_INVALID_OPERATION);

for (uint32_t i = 0; i < NumKernelAlternatives; ++i) {
UR_ASSERT(KernelAlternatives[i] != Kernel, UR_RESULT_ERROR_INVALID_VALUE);
}

// Lock automatically releases when this goes out of scope.
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
Kernel->Mutex, Kernel->Program->Mutex, CommandBuffer->Mutex);
Expand All @@ -983,18 +1032,21 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));

CommandBuffer->KernelsList.push_back(Kernel);
for (size_t i = 0; i < NumKernelAlternatives; i++) {
CommandBuffer->KernelsList.push_back(KernelAlternatives[i]);
}

// Increment the reference count of the Kernel and indicate that the Kernel
// is in use. Once the event has been signaled, the code in
// CleanupCompletedEvent(Event) will do a urKernelRelease to update the
// reference count on the kernel, using the kernel saved in CommandData.
UR_CALL(ur::level_zero::urKernelRetain(Kernel));
ur::level_zero::urKernelRetain(Kernel);
// Retain alternative kernels if provided
for (size_t i = 0; i < NumKernelAlternatives; i++) {
ur::level_zero::urKernelRetain(KernelAlternatives[i]);
}

if (Command) {
UR_CALL(createCommandHandle(CommandBuffer, Kernel, WorkDim, LocalWorkSize,
NumKernelAlternatives, KernelAlternatives,
*Command));
}

std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
Expand Down Expand Up @@ -1690,7 +1742,7 @@ ur_result_t urCommandBufferReleaseCommandExp(
* @return UR_RESULT_SUCCESS or an error code on failure
*/
ur_result_t validateCommandDesc(
ur_exp_command_buffer_command_handle_t Command,
kernel_command_handle *Command,
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {

auto CommandBuffer = Command->CommandBuffer;
Expand All @@ -1699,9 +1751,14 @@ ur_result_t validateCommandDesc(
->mutableCommandFlags;
logger::debug("Mutable features supported by device {}", SupportedFeatures);

// Kernel handle updates are not yet supported.
if (CommandDesc->hNewKernel && CommandDesc->hNewKernel != Command->Kernel) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
UR_ASSERT(
!CommandDesc->hNewKernel ||
(SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION),
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
// Check if the provided new kernel is in the list of valid alternatives.
if (CommandDesc->hNewKernel &&
!Command->ValidKernelHandles.count(CommandDesc->hNewKernel)) {
return UR_RESULT_ERROR_INVALID_VALUE;
}

if (CommandDesc->newWorkDim != Command->WorkDim &&
Expand Down Expand Up @@ -1754,7 +1811,7 @@ ur_result_t validateCommandDesc(
* @return UR_RESULT_SUCCESS or an error code on failure
*/
ur_result_t updateKernelCommand(
ur_exp_command_buffer_command_handle_t Command,
kernel_command_handle *Command,
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {

// We need the created descriptors to live till the point when
Expand All @@ -1769,12 +1826,29 @@ ur_result_t updateKernelCommand(

const auto CommandBuffer = Command->CommandBuffer;
const void *NextDesc = nullptr;
auto Platform = CommandBuffer->Context->getPlatform();

uint32_t Dim = CommandDesc->newWorkDim;
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset;
size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize;
size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize;

// Kernel handle must be updated first for a given CommandId if required
ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel;
if (NewKernel && Command->Kernel != NewKernel) {
ze_kernel_handle_t ZeKernelTranslated = nullptr;
ZE2UR_CALL(
zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, NewKernel->ZeKernel, (void **)&ZeKernelTranslated));

ZE2UR_CALL(Platform->ZeMutableCmdListExt
.zexCommandListUpdateMutableCommandKernelsExp,
(CommandBuffer->ZeComputeCommandListTranslated, 1,
&Command->CommandId, &ZeKernelTranslated));
// Set current kernel to be the new kernel
Command->Kernel = NewKernel;
}

// Check if a new global offset is provided.
if (NewGlobalWorkOffset && Dim > 0) {
auto MutableGroupOffestDesc =
Expand Down Expand Up @@ -1973,7 +2047,6 @@ ur_result_t updateKernelCommand(
MutableCommandDesc.pNext = NextDesc;
MutableCommandDesc.flags = 0;

auto Platform = CommandBuffer->Context->getPlatform();
ZE2UR_CALL(
Platform->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp,
(CommandBuffer->ZeComputeCommandListTranslated, &MutableCommandDesc));
Expand Down Expand Up @@ -2009,18 +2082,22 @@ ur_result_t urCommandBufferUpdateKernelLaunchExp(
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
UR_ASSERT(Command->CommandBuffer->IsUpdatable,
UR_RESULT_ERROR_INVALID_OPERATION);
UR_ASSERT(Command->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);

auto KernelCommandHandle = static_cast<kernel_command_handle *>(Command);

UR_ASSERT(KernelCommandHandle->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);

// Lock command, kernel and command buffer for update.
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Guard(
Command->Mutex, Command->CommandBuffer->Mutex, Command->Kernel->Mutex);
Command->Mutex, Command->CommandBuffer->Mutex,
KernelCommandHandle->Kernel->Mutex);

UR_ASSERT(Command->CommandBuffer->IsFinalized,
UR_RESULT_ERROR_INVALID_OPERATION);

UR_CALL(validateCommandDesc(Command, CommandDesc));
UR_CALL(validateCommandDesc(KernelCommandHandle, CommandDesc));
UR_CALL(waitForOngoingExecution(Command->CommandBuffer));
UR_CALL(updateKernelCommand(Command, CommandDesc));
UR_CALL(updateKernelCommand(KernelCommandHandle, CommandDesc));

ZE2UR_CALL(zeCommandListClose,
(Command->CommandBuffer->ZeComputeCommandList));
Expand Down
21 changes: 17 additions & 4 deletions source/adapters/level_zero/command_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,31 @@ struct ur_exp_command_buffer_handle_t_ : public _ur_object {

struct ur_exp_command_buffer_command_handle_t_ : public _ur_object {
ur_exp_command_buffer_command_handle_t_(ur_exp_command_buffer_handle_t,
uint64_t, uint32_t, bool,
ur_kernel_handle_t);
uint64_t);

~ur_exp_command_buffer_command_handle_t_();
virtual ~ur_exp_command_buffer_command_handle_t_();

// Command-buffer of this command.
ur_exp_command_buffer_handle_t CommandBuffer;

// L0 command ID identifying this command
uint64_t CommandId;
};

struct kernel_command_handle : public ur_exp_command_buffer_command_handle_t_ {
kernel_command_handle(ur_exp_command_buffer_handle_t CommandBuffer,
ur_kernel_handle_t Kernel, uint64_t CommandId,
uint32_t WorkDim, bool UserDefinedLocalSize,
uint32_t NumKernelAlternatives,
ur_kernel_handle_t *KernelAlternatives);

~kernel_command_handle();

// Work-dimension the command was originally created with.
uint32_t WorkDim;
// Set to true if the user set the local work size on command creation.
bool UserDefinedLocalSize;
// Currently active kernel handle
ur_kernel_handle_t Kernel;
// Storage for valid kernel alternatives for this command.
std::unordered_set<ur_kernel_handle_t> ValidKernelHandles;
};
4 changes: 4 additions & 0 deletions source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,10 @@ ur_result_t urDeviceGetInfo(
UpdateCapabilities |=
UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_OFFSET;
}
if (supportsFlags(ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION)) {
UpdateCapabilities |=
UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_HANDLE;
}
return ReturnValue(UpdateCapabilities);
}
case UR_DEVICE_INFO_COMMAND_BUFFER_EVENT_SUPPORT_EXP:
Expand Down
16 changes: 16 additions & 0 deletions source/adapters/level_zero/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,22 @@ ur_result_t ur_platform_handle_t_::initialize() {
&ZeMutableCmdListExt
.zexCommandListUpdateMutableCommandWaitEventsExp))) == 0);

ZeMutableCmdListExt.Supported &=
(ZE_CALL_NOCHECK(
zeDriverGetExtensionFunctionAddress,
(ZeDriver, "zeCommandListUpdateMutableCommandKernelsExp",
reinterpret_cast<void **>(
&ZeMutableCmdListExt
.zexCommandListUpdateMutableCommandKernelsExp))) == 0);

ZeMutableCmdListExt.Supported &=
(ZE_CALL_NOCHECK(
zeDriverGetExtensionFunctionAddress,
(ZeDriver, "zeCommandListGetNextCommandIdWithKernelsExp",
reinterpret_cast<void **>(
&ZeMutableCmdListExt
.zexCommandListGetNextCommandIdWithKernelsExp))) == 0);

return UR_RESULT_SUCCESS;
}

Expand Down
6 changes: 6 additions & 0 deletions source/adapters/level_zero/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,11 @@ struct ur_platform_handle_t_ : public _ur_platform {
ze_result_t (*zexCommandListUpdateMutableCommandWaitEventsExp)(
ze_command_list_handle_t, uint64_t, uint32_t,
ze_event_handle_t *) = nullptr;
ze_result_t (*zexCommandListUpdateMutableCommandKernelsExp)(
ze_command_list_handle_t, uint32_t, uint64_t *,
ze_kernel_handle_t *) = nullptr;
ze_result_t (*zexCommandListGetNextCommandIdWithKernelsExp)(
ze_command_list_handle_t, const ze_mutable_command_id_exp_desc_t *,
uint32_t, ze_kernel_handle_t *, uint64_t *) = nullptr;
} ZeMutableCmdListExt;
};

0 comments on commit b24eb57

Please sign in to comment.