diff --git a/ggml/src/ggml-metalium/ggml-metalium.cpp b/ggml/src/ggml-metalium/ggml-metalium.cpp index f762f0709d7f5..d4f300fe0dc95 100644 --- a/ggml/src/ggml-metalium/ggml-metalium.cpp +++ b/ggml/src/ggml-metalium/ggml-metalium.cpp @@ -2347,14 +2347,13 @@ static std::vector> g_backend_device_holder static std::vector> g_backend_device_context_holder; GGML_API ggml_backend_reg_t ggml_backend_metalium_reg() { - if(getenv("TT_METAL_HOME") == NULL || getenv("ARCH_NAME") == NULL) { - tt::log_fatal(tt::LogType::LogAlways, "TT_METAL_HOME and ARCH_NAME environment variables must be set to use the Metalium backend"); - abort(); - } - static ggml_backend_reg reg; static std::once_flag once; std::call_once(once, [&]() { + if(getenv("TT_METAL_HOME") == NULL || getenv("ARCH_NAME") == NULL) { + tt::log_fatal(tt::LogType::LogAlways, "TT_METAL_HOME and ARCH_NAME environment variables must be set to use the Metalium backend"); + abort(); + } tt::tt_metal::detail::EnablePersistentKernelCache(); // TODO: Support multiple devices (TT supports mesh configuration so it's going to be tricky) // but for now we just work on 1 device at a time diff --git a/tests/test-metalium.cpp b/tests/test-metalium.cpp index 1c4c28a3bd591..c7e0ceb0bad4b 100644 --- a/tests/test-metalium.cpp +++ b/tests/test-metalium.cpp @@ -320,12 +320,10 @@ int main() ggml_backend_t metalium = NULL; for (size_t i = 0; i < ggml_backend_reg_count(); i++) { - ggml_backend_reg_t reg = ggml_backend_reg_get(i); - if (std::strcmp(ggml_backend_reg_name(reg), "Metalium") == 0) { - ggml_backend_reg_t reg = ggml_backend_reg_get(i); - metalium = ggml_backend_dev_init(ggml_backend_reg_dev_get(reg, 0), NULL); - break; - } + ggml_backend_reg_t reg = ggml_backend_reg_by_name("Metalium"); + GGML_ASSERT(ggml_backend_reg_dev_count(reg) > 0); + metalium = ggml_backend_dev_init(ggml_backend_reg_dev_get(reg, 0), NULL); + break; } if(metalium == NULL) { printf("Cannot find Metalium backend\n");