Skip to content

Commit

Permalink
Change urPlatformGet to retrieve the static adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
omarahmed1111 committed Oct 21, 2024
1 parent 250aab9 commit f8810c6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 47 deletions.
6 changes: 3 additions & 3 deletions source/adapters/opencl/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1131,16 +1131,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t hNativeDevice, ur_adapter_handle_t hAdapter,
ur_native_handle_t hNativeDevice, ur_adapter_handle_t,
const ur_device_native_properties_t *pProperties,
ur_device_handle_t *phDevice) {
cl_device_id NativeHandle = reinterpret_cast<cl_device_id>(hNativeDevice);

uint32_t NumPlatforms = 0;
UR_RETURN_ON_FAILURE(urPlatformGet(&hAdapter, 1, 0, nullptr, &NumPlatforms));
UR_RETURN_ON_FAILURE(urPlatformGet(nullptr, 0, 0, nullptr, &NumPlatforms));
std::vector<ur_platform_handle_t> Platforms(NumPlatforms);
UR_RETURN_ON_FAILURE(
urPlatformGet(&hAdapter, 1, NumPlatforms, Platforms.data(), nullptr));
urPlatformGet(nullptr, 0, NumPlatforms, Platforms.data(), nullptr));

for (uint32_t i = 0; i < NumPlatforms; i++) {
uint32_t NumDevices = 0;
Expand Down
84 changes: 40 additions & 44 deletions source/adapters/opencl/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,52 +63,48 @@ urPlatformGetApiVersion([[maybe_unused]] ur_platform_handle_t hPlatform,
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(
ur_adapter_handle_t *phAdapters, uint32_t NumAdapters, uint32_t NumEntries,
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {
for (uint32_t idx = 0; idx < NumAdapters; idx++) {
if (!(phAdapters[idx]->NumPlatforms)) {
uint32_t NumPlatforms = 0;
cl_int Res = clGetPlatformIDs(0, nullptr, &NumPlatforms);

std::vector<cl_platform_id> CLPlatforms(NumPlatforms);
Res = clGetPlatformIDs(static_cast<cl_uint>(NumPlatforms),
CLPlatforms.data(), nullptr);

/* Absorb the CL_PLATFORM_NOT_FOUND_KHR and just return 0 in num_platforms
*/
if (Res == CL_PLATFORM_NOT_FOUND_KHR) {
if (pNumPlatforms) {
*pNumPlatforms = 0;
return UR_RESULT_SUCCESS;
}
UR_APIEXPORT ur_result_t UR_APICALL
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {

ur_adapter_handle_t Adapter = nullptr;
UR_RETURN_ON_FAILURE(urAdapterGet(1, &Adapter, nullptr));
if (Adapter && !(Adapter->NumPlatforms)) {
uint32_t NumPlatforms = 0;
cl_int Res = clGetPlatformIDs(0, nullptr, &NumPlatforms);

std::vector<cl_platform_id> CLPlatforms(NumPlatforms);
Res = clGetPlatformIDs(static_cast<cl_uint>(NumPlatforms),
CLPlatforms.data(), nullptr);

/* Absorb the CL_PLATFORM_NOT_FOUND_KHR and just return 0 in num_platforms
*/
if (Res == CL_PLATFORM_NOT_FOUND_KHR) {
if (pNumPlatforms) {
*pNumPlatforms = 0;
return UR_RESULT_SUCCESS;
}
CL_RETURN_ON_FAILURE(Res);
try {
for (uint32_t i = 0; i < NumPlatforms; i++) {
auto URPlatform =
std::make_unique<ur_platform_handle_t_>(CLPlatforms[i]);
phAdapters[idx]->URPlatforms.emplace_back(URPlatform.release());
}
phAdapters[idx]->NumPlatforms = NumPlatforms;
} catch (std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
} catch (...) {
return UR_RESULT_ERROR_INVALID_PLATFORM;
}
CL_RETURN_ON_FAILURE(Res);
try {
for (uint32_t i = 0; i < NumPlatforms; i++) {
auto URPlatform =
std::make_unique<ur_platform_handle_t_>(CLPlatforms[i]);
Adapter->URPlatforms.emplace_back(URPlatform.release());
}
Adapter->NumPlatforms = NumPlatforms;
} catch (std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
} catch (...) {
return UR_RESULT_ERROR_INVALID_PLATFORM;
}
}
if (pNumPlatforms) {
*pNumPlatforms = 0;
if (pNumPlatforms != nullptr) {
*pNumPlatforms = Adapter->NumPlatforms;
}
for (uint32_t idx = 0; idx < NumAdapters; idx++) {
if (pNumPlatforms != nullptr) {
*pNumPlatforms += phAdapters[idx]->NumPlatforms;
}
if (NumEntries && phPlatforms) {
for (uint32_t i = 0; i < NumEntries; i++) {
phPlatforms[i] = phAdapters[idx]->URPlatforms[i].get();
}
if (NumEntries && phPlatforms) {
for (uint32_t i = 0; i < NumEntries; i++) {
phPlatforms[i] = Adapter->URPlatforms[i].get();
}
}
return UR_RESULT_SUCCESS;
Expand All @@ -122,16 +118,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urPlatformCreateWithNativeHandle(
ur_native_handle_t hNativePlatform, ur_adapter_handle_t hAdapter,
ur_native_handle_t hNativePlatform, ur_adapter_handle_t,
const ur_platform_native_properties_t *, ur_platform_handle_t *phPlatform) {
cl_platform_id NativeHandle =
reinterpret_cast<cl_platform_id>(hNativePlatform);

uint32_t NumPlatforms = 0;
UR_RETURN_ON_FAILURE(urPlatformGet(&hAdapter, 1, 0, nullptr, &NumPlatforms));
UR_RETURN_ON_FAILURE(urPlatformGet(nullptr, 0, 0, nullptr, &NumPlatforms));
std::vector<ur_platform_handle_t> Platforms(NumPlatforms);
UR_RETURN_ON_FAILURE(
urPlatformGet(&hAdapter, 1, NumPlatforms, Platforms.data(), nullptr));
urPlatformGet(nullptr, 0, NumPlatforms, Platforms.data(), nullptr));

for (uint32_t i = 0; i < NumPlatforms; i++) {
if (Platforms[i]->CLPlatform == NativeHandle) {
Expand Down

0 comments on commit f8810c6

Please sign in to comment.