Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CMDBUF] Implement kernel binary update for L0 adapter #2369

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions scripts/core/EXP-COMMAND-BUFFER.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ ${x}CommandBufferAppendKernelLaunchExp. The command can then be updated
to use the new kernel handle by passing it to
${x}CommandBufferUpdateKernelLaunchExp.

.. important::
When updating the kernel handle of a command all required arguments to the
new kernel must be provided in the update descriptor. Failure to do so will
result in undefined behavior.

.. parsed-literal::

// Create a command-buffer with update enabled.
Expand Down
143 changes: 110 additions & 33 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,21 +476,14 @@ void ur_exp_command_buffer_handle_t_::cleanupCommandBufferResources() {

ur_exp_command_buffer_command_handle_t_::
ur_exp_command_buffer_command_handle_t_(
ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId,
uint32_t WorkDim, bool UserDefinedLocalSize,
ur_kernel_handle_t Kernel = nullptr)
: CommandBuffer(CommandBuffer), CommandId(CommandId), WorkDim(WorkDim),
UserDefinedLocalSize(UserDefinedLocalSize), Kernel(Kernel) {
ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId)
: CommandBuffer(CommandBuffer), CommandId(CommandId) {
ur::level_zero::urCommandBufferRetainExp(CommandBuffer);
if (Kernel)
ur::level_zero::urKernelRetain(Kernel);
}

ur_exp_command_buffer_command_handle_t_::
~ur_exp_command_buffer_command_handle_t_() {
ur::level_zero::urCommandBufferReleaseExp(CommandBuffer);
if (Kernel)
ur::level_zero::urKernelRelease(Kernel);
}

void ur_exp_command_buffer_handle_t_::registerSyncPoint(
Expand Down Expand Up @@ -527,6 +520,31 @@ ur_result_t ur_exp_command_buffer_handle_t_::getFenceForQueue(
return UR_RESULT_SUCCESS;
}

kernel_command_handle::kernel_command_handle(
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
uint64_t CommandId, uint32_t WorkDim, bool UserDefinedLocalSize,
uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives)
: ur_exp_command_buffer_command_handle_t_(CommandBuffer, CommandId),
WorkDim(WorkDim), UserDefinedLocalSize(UserDefinedLocalSize),
Kernel(Kernel) {
// Add the default kernel to the list of valid kernels
ur::level_zero::urKernelRetain(Kernel);
ValidKernelHandles.insert(Kernel);
// Add alternative kernels if provided
if (KernelAlternatives) {
for (size_t i = 0; i < NumKernelAlternatives; i++) {
ur::level_zero::urKernelRetain(KernelAlternatives[i]);
ValidKernelHandles.insert(KernelAlternatives[i]);
}
}
}

kernel_command_handle::~kernel_command_handle() {
for (const ur_kernel_handle_t &KernelHandle : ValidKernelHandles) {
ur::level_zero::urKernelRelease(KernelHandle);
}
}

namespace ur::level_zero {

/**
Expand Down Expand Up @@ -906,7 +924,8 @@ setKernelPendingArguments(ur_exp_command_buffer_handle_t CommandBuffer,
ur_result_t
createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
ur_kernel_handle_t Kernel, uint32_t WorkDim,
const size_t *LocalWorkSize,
const size_t *LocalWorkSize, uint32_t NumKernelAlternatives,
ur_kernel_handle_t *KernelAlternatives,
ur_exp_command_buffer_command_handle_t &Command) {

assert(CommandBuffer->IsUpdatable);
Expand All @@ -923,14 +942,41 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;

auto Platform = CommandBuffer->Context->getPlatform();
ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp,
(CommandBuffer->ZeComputeCommandListTranslated,
&ZeMutableCommandDesc, &CommandId));
if (NumKernelAlternatives > 0) {
ZeMutableCommandDesc.flags |=
ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;

std::vector<ze_kernel_handle_t> TranslatedKernelHandles(
NumKernelAlternatives + 1, nullptr);

// Translate main kernel first
ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, Kernel->ZeKernel,
(void **)&TranslatedKernelHandles[0]));

for (size_t i = 0; i < NumKernelAlternatives; i++) {
ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, KernelAlternatives[i]->ZeKernel,
(void **)&TranslatedKernelHandles[i + 1]));
}

ZE2UR_CALL(Platform->ZeMutableCmdListExt
.zexCommandListGetNextCommandIdWithKernelsExp,
(CommandBuffer->ZeComputeCommandListTranslated,
&ZeMutableCommandDesc, NumKernelAlternatives + 1,
TranslatedKernelHandles.data(), &CommandId));

} else {
ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp,
(CommandBuffer->ZeComputeCommandListTranslated,
&ZeMutableCommandDesc, &CommandId));
}
DEBUG_LOG(CommandId);

try {
Command = new ur_exp_command_buffer_command_handle_t_(
CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr, Kernel);
Command = new kernel_command_handle(
CommandBuffer, Kernel, CommandId, WorkDim, LocalWorkSize != nullptr,
NumKernelAlternatives, KernelAlternatives);
} catch (const std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
} catch (...) {
Expand All @@ -944,8 +990,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
uint32_t WorkDim, const size_t *GlobalWorkOffset,
const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
uint32_t /*numKernelAlternatives*/,
ur_kernel_handle_t * /*phKernelAlternatives*/,
uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives,
uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
Expand All @@ -960,6 +1005,10 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
UR_ASSERT(!(Command && !CommandBuffer->IsUpdatable),
UR_RESULT_ERROR_INVALID_OPERATION);

for (uint32_t i = 0; i < NumKernelAlternatives; ++i) {
UR_ASSERT(KernelAlternatives[i] != Kernel, UR_RESULT_ERROR_INVALID_VALUE);
}

// Lock automatically releases when this goes out of scope.
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
Kernel->Mutex, Kernel->Program->Mutex, CommandBuffer->Mutex);
Expand All @@ -983,18 +1032,21 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));

CommandBuffer->KernelsList.push_back(Kernel);
for (size_t i = 0; i < NumKernelAlternatives; i++) {
CommandBuffer->KernelsList.push_back(KernelAlternatives[i]);
}

// Increment the reference count of the Kernel and indicate that the Kernel
// is in use. Once the event has been signaled, the code in
// CleanupCompletedEvent(Event) will do a urKernelRelease to update the
// reference count on the kernel, using the kernel saved in CommandData.
UR_CALL(ur::level_zero::urKernelRetain(Kernel));
ur::level_zero::urKernelRetain(Kernel);
// Retain alternative kernels if provided
for (size_t i = 0; i < NumKernelAlternatives; i++) {
ur::level_zero::urKernelRetain(KernelAlternatives[i]);
}

if (Command) {
UR_CALL(createCommandHandle(CommandBuffer, Kernel, WorkDim, LocalWorkSize,
NumKernelAlternatives, KernelAlternatives,
*Command));
}

std::vector<ze_event_handle_t> ZeEventList;
ze_event_handle_t ZeLaunchEvent = nullptr;
UR_CALL(createSyncPointAndGetZeEvents(
Expand Down Expand Up @@ -1690,7 +1742,7 @@ ur_result_t urCommandBufferReleaseCommandExp(
* @return UR_RESULT_SUCCESS or an error code on failure
*/
ur_result_t validateCommandDesc(
ur_exp_command_buffer_command_handle_t Command,
kernel_command_handle *Command,
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {

auto CommandBuffer = Command->CommandBuffer;
Expand All @@ -1699,9 +1751,14 @@ ur_result_t validateCommandDesc(
->mutableCommandFlags;
logger::debug("Mutable features supported by device {}", SupportedFeatures);

// Kernel handle updates are not yet supported.
if (CommandDesc->hNewKernel && CommandDesc->hNewKernel != Command->Kernel) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
UR_ASSERT(
!CommandDesc->hNewKernel ||
(SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION),
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
// Check if the provided new kernel is in the list of valid alternatives.
if (CommandDesc->hNewKernel &&
!Command->ValidKernelHandles.count(CommandDesc->hNewKernel)) {
return UR_RESULT_ERROR_INVALID_VALUE;
}

if (CommandDesc->newWorkDim != Command->WorkDim &&
Expand Down Expand Up @@ -1754,7 +1811,7 @@ ur_result_t validateCommandDesc(
* @return UR_RESULT_SUCCESS or an error code on failure
*/
ur_result_t updateKernelCommand(
ur_exp_command_buffer_command_handle_t Command,
kernel_command_handle *Command,
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {

// We need the created descriptors to live till the point when
Expand All @@ -1769,12 +1826,29 @@ ur_result_t updateKernelCommand(

const auto CommandBuffer = Command->CommandBuffer;
const void *NextDesc = nullptr;
auto Platform = CommandBuffer->Context->getPlatform();

uint32_t Dim = CommandDesc->newWorkDim;
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset;
size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize;
size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize;

// Kernel handle must be updated first for a given CommandId if required
ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel;
if (NewKernel && Command->Kernel != NewKernel) {
ze_kernel_handle_t ZeKernelTranslated = nullptr;
ZE2UR_CALL(
zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, NewKernel->ZeKernel, (void **)&ZeKernelTranslated));

ZE2UR_CALL(Platform->ZeMutableCmdListExt
.zexCommandListUpdateMutableCommandKernelsExp,
(CommandBuffer->ZeComputeCommandListTranslated, 1,
&Command->CommandId, &ZeKernelTranslated));
// Set current kernel to be the new kernel
Command->Kernel = NewKernel;
}

// Check if a new global offset is provided.
if (NewGlobalWorkOffset && Dim > 0) {
auto MutableGroupOffestDesc =
Expand Down Expand Up @@ -1973,7 +2047,6 @@ ur_result_t updateKernelCommand(
MutableCommandDesc.pNext = NextDesc;
MutableCommandDesc.flags = 0;

auto Platform = CommandBuffer->Context->getPlatform();
ZE2UR_CALL(
Platform->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp,
(CommandBuffer->ZeComputeCommandListTranslated, &MutableCommandDesc));
Expand Down Expand Up @@ -2009,18 +2082,22 @@ ur_result_t urCommandBufferUpdateKernelLaunchExp(
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
UR_ASSERT(Command->CommandBuffer->IsUpdatable,
UR_RESULT_ERROR_INVALID_OPERATION);
UR_ASSERT(Command->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);

auto KernelCommandHandle = static_cast<kernel_command_handle *>(Command);

UR_ASSERT(KernelCommandHandle->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);

// Lock command, kernel and command buffer for update.
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Guard(
Command->Mutex, Command->CommandBuffer->Mutex, Command->Kernel->Mutex);
Command->Mutex, Command->CommandBuffer->Mutex,
KernelCommandHandle->Kernel->Mutex);

UR_ASSERT(Command->CommandBuffer->IsFinalized,
UR_RESULT_ERROR_INVALID_OPERATION);

UR_CALL(validateCommandDesc(Command, CommandDesc));
UR_CALL(validateCommandDesc(KernelCommandHandle, CommandDesc));
UR_CALL(waitForOngoingExecution(Command->CommandBuffer));
UR_CALL(updateKernelCommand(Command, CommandDesc));
UR_CALL(updateKernelCommand(KernelCommandHandle, CommandDesc));

ZE2UR_CALL(zeCommandListClose,
(Command->CommandBuffer->ZeComputeCommandList));
Expand Down
21 changes: 17 additions & 4 deletions source/adapters/level_zero/command_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,31 @@ struct ur_exp_command_buffer_handle_t_ : public _ur_object {

struct ur_exp_command_buffer_command_handle_t_ : public _ur_object {
ur_exp_command_buffer_command_handle_t_(ur_exp_command_buffer_handle_t,
uint64_t, uint32_t, bool,
ur_kernel_handle_t);
uint64_t);

~ur_exp_command_buffer_command_handle_t_();
virtual ~ur_exp_command_buffer_command_handle_t_();

// Command-buffer of this command.
ur_exp_command_buffer_handle_t CommandBuffer;

// L0 command ID identifying this command
uint64_t CommandId;
};

struct kernel_command_handle : public ur_exp_command_buffer_command_handle_t_ {
kernel_command_handle(ur_exp_command_buffer_handle_t CommandBuffer,
ur_kernel_handle_t Kernel, uint64_t CommandId,
uint32_t WorkDim, bool UserDefinedLocalSize,
uint32_t NumKernelAlternatives,
ur_kernel_handle_t *KernelAlternatives);

~kernel_command_handle();

// Work-dimension the command was originally created with.
uint32_t WorkDim;
// Set to true if the user set the local work size on command creation.
bool UserDefinedLocalSize;
// Currently active kernel handle
ur_kernel_handle_t Kernel;
// Storage for valid kernel alternatives for this command.
std::unordered_set<ur_kernel_handle_t> ValidKernelHandles;
};
4 changes: 4 additions & 0 deletions source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,10 @@ ur_result_t urDeviceGetInfo(
UpdateCapabilities |=
UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_OFFSET;
}
if (supportsFlags(ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION)) {
UpdateCapabilities |=
UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_HANDLE;
}
return ReturnValue(UpdateCapabilities);
}
case UR_DEVICE_INFO_COMMAND_BUFFER_EVENT_SUPPORT_EXP:
Expand Down
31 changes: 31 additions & 0 deletions source/adapters/level_zero/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,22 @@ ur_result_t ur_platform_handle_t_::initialize() {
ZeMutableCmdListExt.Supported |=
ZeMutableCmdListExt.zexCommandListUpdateMutableCommandWaitEventsExp !=
nullptr;
ZeMutableCmdListExt.zexCommandListUpdateMutableCommandKernelsExp =
(ze_pfnCommandListUpdateMutableCommandKernelsExp_t)
ur_loader::LibLoader::getFunctionPtr(
GlobalAdapter->processHandle,
"zeCommandListUpdateMutableCommandKernelsExp");
ZeMutableCmdListExt.Supported |=
ZeMutableCmdListExt.zexCommandListUpdateMutableCommandKernelsExp !=
nullptr;
ZeMutableCmdListExt.zexCommandListGetNextCommandIdWithKernelsExp =
(ze_pfnCommandListGetNextCommandIdWithKernelsExp_t)
ur_loader::LibLoader::getFunctionPtr(
GlobalAdapter->processHandle,
"zeCommandListGetNextCommandIdWithKernelsExp");
ZeMutableCmdListExt.Supported |=
ZeMutableCmdListExt.zexCommandListGetNextCommandIdWithKernelsExp !=
nullptr;
} else {
ZeMutableCmdListExt.Supported |=
(ZE_CALL_NOCHECK(
Expand Down Expand Up @@ -353,6 +369,21 @@ ur_result_t ur_platform_handle_t_::initialize() {
&ZeMutableCmdListExt
.zexCommandListUpdateMutableCommandWaitEventsExp))) ==
0);
ZeMutableCmdListExt.Supported &=
(ZE_CALL_NOCHECK(
zeDriverGetExtensionFunctionAddress,
(ZeDriver, "zeCommandListUpdateMutableCommandKernelsExp",
reinterpret_cast<void **>(
&ZeMutableCmdListExt
.zexCommandListUpdateMutableCommandKernelsExp))) == 0);

ZeMutableCmdListExt.Supported &=
(ZE_CALL_NOCHECK(
zeDriverGetExtensionFunctionAddress,
(ZeDriver, "zeCommandListGetNextCommandIdWithKernelsExp",
reinterpret_cast<void **>(
&ZeMutableCmdListExt
.zexCommandListGetNextCommandIdWithKernelsExp))) == 0);
}
return UR_RESULT_SUCCESS;
}
Expand Down
6 changes: 6 additions & 0 deletions source/adapters/level_zero/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,11 @@ struct ur_platform_handle_t_ : public _ur_platform {
ze_result_t (*zexCommandListUpdateMutableCommandWaitEventsExp)(
ze_command_list_handle_t, uint64_t, uint32_t,
ze_event_handle_t *) = nullptr;
ze_result_t (*zexCommandListUpdateMutableCommandKernelsExp)(
ze_command_list_handle_t, uint32_t, uint64_t *,
ze_kernel_handle_t *) = nullptr;
ze_result_t (*zexCommandListGetNextCommandIdWithKernelsExp)(
ze_command_list_handle_t, const ze_mutable_command_id_exp_desc_t *,
uint32_t, ze_kernel_handle_t *, uint64_t *) = nullptr;
} ZeMutableCmdListExt;
};
Loading