Skip to content

Commit

Permalink
reduce use of g_device_map
Browse files Browse the repository at this point in the history
  • Loading branch information
marty1885 committed Nov 17, 2024
1 parent b8f902a commit 96273dd
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions ggml/src/ggml-metalium/ggml-metalium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1838,37 +1838,33 @@ static ggml_backend_buffer_type_i ggml_backend_metalium_buffer_type_interface =
};

static ggml_backend_buffer_type_t ggml_backend_metalium_buffer_type(ggml_backend_dev_t dev, ggml_backend_metalium_device_context* dev_ctx) {
auto device = dev_ctx->device_id;
GGML_ASSERT((size_t)device < tt::tt_metal::GetNumAvailableDevices());
static std::map<int, ggml_backend_buffer_type> buffer_type_map;
static std::set<std::unique_ptr<ggml_backend_metalium_buffer_type_context>> buffer_type_context_deleter;
auto device_id = dev_ctx->device_id;
ggml_backend_metalium_reg_context* regctx = (ggml_backend_metalium_reg_context*)(dev->reg->context);

if(!g_device_map.contains(device)) {
ggml_backend_metalium_init(device);
GGML_ASSERT(g_device_map.contains(device));
}
GGML_ASSERT((size_t)device_id < tt::tt_metal::GetNumAvailableDevices());
GGML_ASSERT((size_t)device_id < regctx->devices.size());

if(buffer_type_map.contains(device)) {
return &buffer_type_map[device];
static std::map<int, ggml_backend_buffer_type> buffer_type_map;
static std::set<std::unique_ptr<ggml_backend_metalium_buffer_type_context>> buffer_type_context_deleter;
auto it = buffer_type_map.find(device_id);
if(it != buffer_type_map.end()) {
return &it->second;
}


auto bufctx = std::make_unique<ggml_backend_metalium_buffer_type_context>(
ggml_backend_metalium_buffer_type_context{
.device = g_device_map[device],
.name = "Metalium " + std::to_string(device),
.device = dev_ctx->device,
.name = "Metalium " + std::to_string(device_id),
});
auto* bufctx_ptr = bufctx.get();
buffer_type_context_deleter.insert(std::move(bufctx));

ggml_backend_metalium_reg_context* regctx = (ggml_backend_metalium_reg_context*)(dev->reg->context);
GGML_ASSERT((size_t)device < regctx->devices.size());
buffer_type_map[device] = {
buffer_type_map[device_id] = {
/* .iface = */ ggml_backend_metalium_buffer_type_interface,
/* .device = */ regctx->devices[device],
/* .device = */ regctx->devices[device_id],
/* .context = */ bufctx_ptr,
};
return &buffer_type_map[device];
return &buffer_type_map[device_id];
}

static enum ggml_status ggml_backend_metalium_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
Expand Down

0 comments on commit 96273dd

Please sign in to comment.