Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove virtual methods from ur_mem_handle_t_
Browse files Browse the repository at this point in the history
We want to transition to handle pointers containing the ddi table as the
first element. For this to work, handle object must not have a vtable.

Since ur_mem_handle_t_ is relatively simple, it's easy enough to roll
out our own version of dynamic dispatch.
RossBrunton committed Jan 27, 2025
1 parent 0bb6789 commit aea4883
Showing 2 changed files with 77 additions and 49 deletions.
52 changes: 40 additions & 12 deletions source/adapters/level_zero/memory.cpp
Original file line number Diff line number Diff line change
@@ -2081,10 +2081,11 @@ static ur_result_t ZeDeviceMemAllocHelper(void **ResultPtr,
return UR_RESULT_SUCCESS;
}

ur_result_t _ur_buffer::getZeHandle(char *&ZeHandle, access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
ur_result_t _ur_buffer::getBufferZeHandle(char *&ZeHandle,
access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {

// NOTE: There might be no valid allocation at all yet and we get
// here from piEnqueueKernelLaunch that would be doing the buffer
@@ -2393,7 +2394,7 @@ ur_result_t _ur_buffer::free() {
// Buffer constructor
_ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, char *HostPtr,
bool ImportedHostPtr = false)
: ur_mem_handle_t_(Context), Size(Size) {
: ur_mem_handle_t_(mem_type_t::buffer, Context), Size(Size) {

// We treat integrated devices (physical memory shared with the CPU)
// differently from discrete devices (those with distinct memories).
@@ -2422,13 +2423,13 @@ _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, char *HostPtr,

_ur_buffer::_ur_buffer(ur_context_handle_t Context, ur_device_handle_t Device,
size_t Size)
: ur_mem_handle_t_(Context, Device), Size(Size) {}
: ur_mem_handle_t_(mem_type_t::buffer, Context, Device), Size(Size) {}

// Interop-buffer constructor
_ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size,
ur_device_handle_t Device, char *ZeMemHandle,
bool OwnZeMemHandle)
: ur_mem_handle_t_(Context, Device), Size(Size) {
: ur_mem_handle_t_(mem_type_t::buffer, Context, Device), Size(Size) {

// Device == nullptr means host allocation
Allocations[Device].ZeHandle = ZeMemHandle;
@@ -2449,11 +2450,38 @@ _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size,
LastDeviceWithValidAllocation = Device;
}

ur_result_t _ur_buffer::getZeHandlePtr(char **&ZeHandlePtr,
access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
ur_result_t ur_mem_handle_t_::getZeHandle(char *&ZeHandle, access_mode_t mode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
switch (mem_type) {
case ur_mem_handle_t_::image:
return reinterpret_cast<_ur_image *>(this)->getImageZeHandle(
ZeHandle, mode, Device, phWaitEvents, numWaitEvents);
case ur_mem_handle_t_::buffer:
return reinterpret_cast<_ur_buffer *>(this)->getBufferZeHandle(
ZeHandle, mode, Device, phWaitEvents, numWaitEvents);
}
abort();
}

ur_result_t ur_mem_handle_t_::getZeHandlePtr(
char **&ZeHandlePtr, access_mode_t mode, ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents, uint32_t numWaitEvents) {
switch (mem_type) {
case ur_mem_handle_t_::image:
return reinterpret_cast<_ur_image *>(this)->getImageZeHandlePtr(
ZeHandlePtr, mode, Device, phWaitEvents, numWaitEvents);
case ur_mem_handle_t_::buffer:
return reinterpret_cast<_ur_buffer *>(this)->getBufferZeHandlePtr(
ZeHandlePtr, mode, Device, phWaitEvents, numWaitEvents);
}
abort();
}

ur_result_t _ur_buffer::getBufferZeHandlePtr(
char **&ZeHandlePtr, access_mode_t AccessMode, ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents, uint32_t numWaitEvents) {
char *ZeHandle;
UR_CALL(
getZeHandle(ZeHandle, AccessMode, Device, phWaitEvents, numWaitEvents));
74 changes: 37 additions & 37 deletions source/adapters/level_zero/memory.hpp
Original file line number Diff line number Diff line change
@@ -70,34 +70,37 @@ struct ur_mem_handle_t_ : _ur_object {
// Keeps device of this memory handle
ur_device_handle_t UrDevice;

// Whether this is an image or buffer
enum mem_type_t { image, buffer };
mem_type_t mem_type;

// Enumerates all possible types of accesses.
enum access_mode_t { unknown, read_write, read_only, write_only };

// Interface of the _ur_mem object

// Get the Level Zero handle of the current memory object
virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) = 0;
ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

// Get a pointer to the Level Zero handle of the current memory object
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) = 0;
ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

// Method to get type of the derived object (image or buffer)
virtual bool isImage() const = 0;

virtual ~ur_mem_handle_t_() = default;
bool isImage() const { return mem_type == mem_type_t::image; }

protected:
ur_mem_handle_t_(ur_context_handle_t Context)
: UrContext{Context}, UrDevice{nullptr} {}
ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context)
: UrContext{Context}, UrDevice{nullptr}, mem_type(type) {}

ur_mem_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device)
: UrContext{Context}, UrDevice(Device) {}
ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context,
ur_device_handle_t Device)
: UrContext{Context}, UrDevice(Device), mem_type(type) {}
};

struct _ur_buffer final : ur_mem_handle_t_ {
@@ -110,7 +113,7 @@ struct _ur_buffer final : ur_mem_handle_t_ {

// Sub-buffer constructor
_ur_buffer(_ur_buffer *Parent, size_t Origin, size_t Size)
: ur_mem_handle_t_(Parent->UrContext), Size(Size),
: ur_mem_handle_t_(mem_type_t::buffer, Parent->UrContext), Size(Size),
SubBuffer{{Parent, Origin}} {
// Retain the Parent Buffer due to the Creation of the SubBuffer.
Parent->RefCount.increment();
@@ -127,16 +130,15 @@ struct _ur_buffer final : ur_mem_handle_t_ {
// up-to-date and any data copies needed for that are performed under
// the hood.
//
virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override;
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override;
ur_result_t getBufferZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);
ur_result_t getBufferZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

bool isImage() const override { return false; }
bool isSubBuffer() const { return SubBuffer != std::nullopt; }

// Frees all allocations made for the buffer.
@@ -206,35 +208,33 @@ struct _ur_buffer final : ur_mem_handle_t_ {
struct _ur_image final : ur_mem_handle_t_ {
// Image constructor
_ur_image(ur_context_handle_t UrContext, ze_image_handle_t ZeImage)
: ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {}
: ur_mem_handle_t_(mem_type_t::image, UrContext), ZeImage{ZeImage} {}

_ur_image(ur_context_handle_t UrContext, ze_image_handle_t ZeImage,
bool OwnZeMemHandle)
: ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {
: ur_mem_handle_t_(mem_type_t::image, UrContext), ZeImage{ZeImage} {
OwnNativeHandle = OwnZeMemHandle;
}

virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override {
ur_result_t getImageZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
std::ignore = phWaitEvents;
std::ignore = numWaitEvents;
ZeHandle = reinterpret_cast<char *>(ZeImage);
return UR_RESULT_SUCCESS;
}
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override {
ur_result_t getImageZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
std::ignore = phWaitEvents;
std::ignore = numWaitEvents;
ZeHandlePtr = reinterpret_cast<char **>(&ZeImage);
return UR_RESULT_SUCCESS;
}

bool isImage() const override { return true; }

// Keep the descriptor of the image
ZeStruct<ze_image_desc_t> ZeImageDesc;

0 comments on commit aea4883

Please sign in to comment.