Skip to content

Commit

Permalink
get rid of global device storage
Browse files Browse the repository at this point in the history
  • Loading branch information
marty1885 committed Nov 17, 2024
1 parent 37f9c52 commit 655cd6b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 40 deletions.
2 changes: 0 additions & 2 deletions ggml/include/ggml-metalium.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ extern "C" {
#endif

// backend API
// TODO: Need a way to specify we want a meshed device (TT has native support for combining multiple devices)
GGML_API ggml_backend_t ggml_backend_metalium_init(int device_id);

GGML_API bool ggml_backend_is_metalium(ggml_backend_t backend);

Expand Down
46 changes: 9 additions & 37 deletions ggml/src/ggml-metalium/ggml-metalium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,6 @@ static const ggml_backend_metalium_debug_flags g_debug_flags = []() {
// Backend internal state tracking because GGML API does not allow
///////////////////////////////////////////////////////////////////////////////////////////////////////

// maps device id to device
static std::map<int, ttnn::Device*> g_device_map;
static std::map<int, ggml_backend_t> g_backend_map;

// Maintain all base addresses are unique
// TODO: Do we still need this since we already removed the virtual address mapping hack?
static size_t g_metalium_base_offset = 0;
Expand Down Expand Up @@ -2204,33 +2200,16 @@ static ggml_guid_t ggml_backend_metalium_guid(void) {
return &guid;
}

ggml_backend_t ggml_backend_metalium_init(int device_id) {
static ggml_backend_t ggml_backend_metalium_init(ggml_backend_metalium_device_context* dev_ctx) {
int device_id = dev_ctx->device_id;
ttnn::Device* device = dev_ctx->device;
GGML_ASSERT(device_id >= 0 && (size_t)device_id < tt::tt_metal::GetNumAvailableDevices());
// TODO: Support multiple devices (do we even need to? TT supports merging diverse devices into a single device, at least the API suggests that)
static std::once_flag once;
std::call_once(once, [](){
tt::tt_metal::detail::EnablePersistentKernelCache();
});

auto it = g_backend_map.find(device_id);
if (it != g_backend_map.end()) {
return it->second;
}
GGML_ASSERT(device != nullptr);

ttnn::Device* device = nullptr;
if(g_device_map.contains(device_id)) {
device = g_device_map[device_id];
}
else {
device = &ttnn::device::open_device(device_id);
ttnn::enable_program_cache(*device);
// store the device in the global map because tensor creation uses device ID but Metalium disallows opening the same device twice
g_device_map[device_id] = device;
}
ggml_backend_metalium_context * ctx = new ggml_backend_metalium_context {
/* device = */ device,
/* device_id = */ device_id,
/* name = */ "METALIUM" + std::to_string(device_id),
/* name = */ dev_ctx->name,
};

ggml_backend_t backend = new ggml_backend {
Expand All @@ -2239,7 +2218,6 @@ ggml_backend_t ggml_backend_metalium_init(int device_id) {
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_metalium_reg(), device_id),
/* .context = */ ctx
};
g_backend_map[device_id] = backend;
return backend;
}

Expand Down Expand Up @@ -2297,7 +2275,7 @@ static enum ggml_backend_dev_type ggml_backend_metalium_get_type(ggml_backend_de
static ggml_backend_t ggml_backend_metalium_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
ggml_backend_metalium_device_context * ctx = (ggml_backend_metalium_device_context *)dev->context;
ggml_backend_t backend = ggml_backend_metalium_init(ctx->device_id);
ggml_backend_t backend = ggml_backend_metalium_init(ctx);
GGML_ASSERT(backend != NULL);
return backend;
}
Expand Down Expand Up @@ -2377,6 +2355,7 @@ GGML_API ggml_backend_reg_t ggml_backend_metalium_reg()
static ggml_backend_reg reg;
static std::once_flag once;
std::call_once(once, [&]() {
tt::tt_metal::detail::EnablePersistentKernelCache();
// TODO: Support multiple devices (TT supports mesh configuration so it's going to be tricky)
// but for now we just work on 1 device at a time
static std::unique_ptr<ggml_backend_metalium_reg_context> ctx = std::make_unique<ggml_backend_metalium_reg_context>();
Expand All @@ -2388,15 +2367,8 @@ GGML_API ggml_backend_reg_t ggml_backend_metalium_reg()
ctx->devices.reserve(num_devices);
for(size_t device_id = 0; device_id < num_devices; device_id++) {
ggml_backend_metalium_device_context * dev_ctx = new ggml_backend_metalium_device_context;
ttnn::Device* device = nullptr;
if(g_device_map.contains(device_id)) {
device = g_device_map[device_id];
GGML_ASSERT(device != nullptr);
} else {
device = &ttnn::device::open_device(device_id);
ttnn::enable_program_cache(*device);
g_device_map[device_id] = device;
}
ttnn::Device* device = &ttnn::device::open_device(device_id);
ttnn::enable_program_cache(*device);
// Limit device support to the ones I own
GGML_ASSERT(device->arch() == tt::ARCH::GRAYSKULL || device->arch() == tt::ARCH::WORMHOLE_B0);

Expand Down
15 changes: 14 additions & 1 deletion tests/test-metalium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,20 @@ std::string type_name(ggml_type type)
int main()
{
ggml_backend_t cpu = ggml_backend_cpu_init();
ggml_backend_t metalium = ggml_backend_metalium_init(0);

ggml_backend_t metalium = NULL;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
ggml_backend_reg_t reg = ggml_backend_reg_get(i);
if (std::strcmp(ggml_backend_reg_name(reg), "Metalium") == 0) {
ggml_backend_reg_t reg = ggml_backend_reg_get(i);
metalium = ggml_backend_dev_init(ggml_backend_reg_dev_get(reg, 0), NULL);
break;
}
}
if(metalium == NULL) {
printf("Cannot find Metalium backend\n");
return 1;
}

std::vector<std::unique_ptr<test_case>> tests;

Expand Down

0 comments on commit 655cd6b

Please sign in to comment.