Skip to content

Commit

Permalink
Remove virtual methods from ur_mem_handle_t_
Browse files Browse the repository at this point in the history
We want to transition to handle pointers containing the ddi table as the
first element. For this to work, handle object must not have a vtable.

Since ur_mem_handle_t_ is relatively simple, it's easy enough to roll
out our own version of dynamic dispatch.
  • Loading branch information
RossBrunton committed Jan 27, 2025
1 parent 0bb6789 commit 893743b
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 50 deletions.
52 changes: 40 additions & 12 deletions source/adapters/level_zero/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2081,10 +2081,11 @@ static ur_result_t ZeDeviceMemAllocHelper(void **ResultPtr,
return UR_RESULT_SUCCESS;
}

ur_result_t _ur_buffer::getZeHandle(char *&ZeHandle, access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
ur_result_t _ur_buffer::getBufferZeHandle(char *&ZeHandle,
access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {

// NOTE: There might be no valid allocation at all yet and we get
// here from piEnqueueKernelLaunch that would be doing the buffer
Expand Down Expand Up @@ -2393,7 +2394,7 @@ ur_result_t _ur_buffer::free() {
// Buffer constructor
_ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, char *HostPtr,
bool ImportedHostPtr = false)
: ur_mem_handle_t_(Context), Size(Size) {
: ur_mem_handle_t_(mem_type_t::buffer, Context), Size(Size) {

// We treat integrated devices (physical memory shared with the CPU)
// differently from discrete devices (those with distinct memories).
Expand Down Expand Up @@ -2422,13 +2423,13 @@ _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, char *HostPtr,

_ur_buffer::_ur_buffer(ur_context_handle_t Context, ur_device_handle_t Device,
size_t Size)
: ur_mem_handle_t_(Context, Device), Size(Size) {}
: ur_mem_handle_t_(mem_type_t::buffer, Context, Device), Size(Size) {}

// Interop-buffer constructor
_ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size,
ur_device_handle_t Device, char *ZeMemHandle,
bool OwnZeMemHandle)
: ur_mem_handle_t_(Context, Device), Size(Size) {
: ur_mem_handle_t_(mem_type_t::buffer, Context, Device), Size(Size) {

// Device == nullptr means host allocation
Allocations[Device].ZeHandle = ZeMemHandle;
Expand All @@ -2449,11 +2450,38 @@ _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size,
LastDeviceWithValidAllocation = Device;
}

ur_result_t _ur_buffer::getZeHandlePtr(char **&ZeHandlePtr,
access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
ur_result_t ur_mem_handle_t_::getZeHandle(char *&ZeHandle, access_mode_t mode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
switch (mem_type) {
case ur_mem_handle_t_::image:
return reinterpret_cast<_ur_image *>(this)->getImageZeHandle(
ZeHandle, mode, Device, phWaitEvents, numWaitEvents);
case ur_mem_handle_t_::buffer:
return reinterpret_cast<_ur_buffer *>(this)->getBufferZeHandle(
ZeHandle, mode, Device, phWaitEvents, numWaitEvents);
}
abort();
}

ur_result_t ur_mem_handle_t_::getZeHandlePtr(
char **&ZeHandlePtr, access_mode_t mode, ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents, uint32_t numWaitEvents) {
switch (mem_type) {
case ur_mem_handle_t_::image:
return reinterpret_cast<_ur_image *>(this)->getImageZeHandlePtr(
ZeHandlePtr, mode, Device, phWaitEvents, numWaitEvents);
case ur_mem_handle_t_::buffer:
return reinterpret_cast<_ur_buffer *>(this)->getBufferZeHandlePtr(
ZeHandlePtr, mode, Device, phWaitEvents, numWaitEvents);
}
abort();
}

ur_result_t _ur_buffer::getBufferZeHandlePtr(
char **&ZeHandlePtr, access_mode_t AccessMode, ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents, uint32_t numWaitEvents) {
char *ZeHandle;
UR_CALL(
getZeHandle(ZeHandle, AccessMode, Device, phWaitEvents, numWaitEvents));
Expand Down
76 changes: 38 additions & 38 deletions source/adapters/level_zero/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,34 +70,37 @@ struct ur_mem_handle_t_ : _ur_object {
// Keeps device of this memory handle
ur_device_handle_t UrDevice;

// Whether this is an image or buffer
enum mem_type_t { image, buffer };
mem_type_t mem_type;

// Enumerates all possible types of accesses.
enum access_mode_t { unknown, read_write, read_only, write_only };

// Interface of the _ur_mem object

// Get the Level Zero handle of the current memory object
virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) = 0;
ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

// Get a pointer to the Level Zero handle of the current memory object
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) = 0;
ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

// Method to get type of the derived object (image or buffer)
virtual bool isImage() const = 0;

virtual ~ur_mem_handle_t_() = default;
bool isImage() const { return mem_type == mem_type_t::image; }

protected:
ur_mem_handle_t_(ur_context_handle_t Context)
: UrContext{Context}, UrDevice{nullptr} {}
ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context)
: UrContext{Context}, UrDevice{nullptr}, mem_type(type) {}

ur_mem_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device)
: UrContext{Context}, UrDevice(Device) {}
ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context,
ur_device_handle_t Device)
: UrContext{Context}, UrDevice(Device), mem_type(type) {}
};

struct _ur_buffer final : ur_mem_handle_t_ {
Expand All @@ -110,8 +113,8 @@ struct _ur_buffer final : ur_mem_handle_t_ {

// Sub-buffer constructor
_ur_buffer(_ur_buffer *Parent, size_t Origin, size_t Size)
: ur_mem_handle_t_(Parent->UrContext), Size(Size),
SubBuffer{{Parent, Origin}} {
: ur_mem_handle_t_(mem_type_t::buffer, Parent->UrContext),
Size(Size), SubBuffer{{Parent, Origin}} {
// Retain the Parent Buffer due to the Creation of the SubBuffer.
Parent->RefCount.increment();
}
Expand All @@ -127,16 +130,15 @@ struct _ur_buffer final : ur_mem_handle_t_ {
// up-to-date and any data copies needed for that are performed under
// the hood.
//
virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override;
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override;
ur_result_t getBufferZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);
ur_result_t getBufferZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

bool isImage() const override { return false; }
bool isSubBuffer() const { return SubBuffer != std::nullopt; }

// Frees all allocations made for the buffer.
Expand Down Expand Up @@ -206,35 +208,33 @@ struct _ur_buffer final : ur_mem_handle_t_ {
struct _ur_image final : ur_mem_handle_t_ {
// Image constructor
_ur_image(ur_context_handle_t UrContext, ze_image_handle_t ZeImage)
: ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {}
: ur_mem_handle_t_(mem_type_t::image, UrContext), ZeImage{ZeImage} {}

_ur_image(ur_context_handle_t UrContext, ze_image_handle_t ZeImage,
bool OwnZeMemHandle)
: ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {
: ur_mem_handle_t_(mem_type_t::image, UrContext), ZeImage{ZeImage} {
OwnNativeHandle = OwnZeMemHandle;
}

virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override {
ur_result_t getImageZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
std::ignore = phWaitEvents;
std::ignore = numWaitEvents;
ZeHandle = reinterpret_cast<char *>(ZeImage);
return UR_RESULT_SUCCESS;
}
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override {
ur_result_t getImageZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
std::ignore = phWaitEvents;
std::ignore = numWaitEvents;
ZeHandlePtr = reinterpret_cast<char **>(&ZeImage);
return UR_RESULT_SUCCESS;
}

bool isImage() const override { return true; }

// Keep the descriptor of the image
ZeStruct<ze_image_desc_t> ZeImageDesc;

Expand Down

0 comments on commit 893743b

Please sign in to comment.