Skip to content

Commit

Permalink
register Metalium into backend
Browse files Browse the repository at this point in the history
  • Loading branch information
marty1885 committed Jun 25, 2024
1 parent f702a90 commit 8851d6d
Show file tree
Hide file tree
Showing 5 changed files with 470 additions and 1 deletion.
65 changes: 65 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ option(LLAMA_SYCL_F16 "llama: use 16 bit floats for sycl
set(LLAMA_SYCL_TARGET "INTEL" CACHE STRING "llama: sycl target device")
option(LLAMA_CPU_HBM "llama: use memkind for CPU HBM" OFF)
set(LLAMA_SCHED_MAX_COPIES "4" CACHE STRING "llama: max input copies for pipeline parallelism")
option(LLAMA_METALIUM "llama: use Metalium" OFF)

option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
Expand Down Expand Up @@ -864,6 +865,69 @@ if (LLAMA_KOMPUTE)
endif()
endif()

if(LLAMA_METALIUM)
if("$ENV{TT_METAL_HOME}" STREQUAL "")
message(FATAL_ERROR "TT_METAL_HOME is not set")
endif()
if("$ENV{ARCH_NAME}" STREQUAL "")
message(FATAL_ERROR "ARCH_NAME is not set")
endif()

set(METALIUM_INCLUDE_DIRS
# Metalium
$ENV{TT_METAL_HOME}
$ENV{TT_METAL_HOME}/tt_metal
$ENV{TT_METAL_HOME}/tt_metal/third_party/umd
$ENV{TT_METAL_HOME}/tt_metal/third_party/fmt
$ENV{TT_METAL_HOME}/tt_metal/hw/inc/$ENV{ARCH_NAME}
$ENV{TT_METAL_HOME}/tt_metal/hw/inc/
$ENV{TT_METAL_HOME}/tt_metal/third_party/umd/src/firmware/riscv/$ENV{ARCH_NAME}

# TTNN
$ENV{TT_METAL_HOME}/ttnn/cpp
$ENV{TT_METAL_HOME}/tt_eager
$ENV{TT_METAL_HOME}/tt_metal/third_party/magic_enum
)
set(METALIUM_LIB_DIRS
$ENV{TT_METAL_HOME}/build/lib
)

# TODO: In the future TTNN can optionally not depend on Python
find_package(Python REQUIRED Development)

set(METALIUM_LIBRARIES PRIVATE
# Metalium
tt_metal
yaml-cpp
c++abi
c++

# TTNN
tt_eager
Python::Python
$ENV{TT_METAL_HOME}/build/lib/_ttnn.so
)


set(METALIUM_DEDINES
FMT_HEADER_ONLY
)

set(METALIUM_OPTIONS
-stdlib=libc++
)

set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${METALIUM_LIBRARIES})
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${METALIUM_INCLUDE_DIRS})
link_directories(ggml ${METALIUM_LIB_DIRS})
add_compile_definitions(${METALIUM_DEDINES})
#target_compile_options(ggml PUBLIC ${METALIUM_OPTIONS})
add_compile_definitions(GGML_USE_METALIUM)
add_compile_definitions(LLAMA_USE_METALIUM)
set(GGML_HEADERS_METALIUM ggml-metalium.h)
set(GGML_SOURCES_METALIUM ggml-metalium.cpp)
endif()

if (LLAMA_CPU_HBM)
find_library(memkind memkind REQUIRED)

Expand Down Expand Up @@ -1273,6 +1337,7 @@ add_library(ggml OBJECT
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS}
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
${GGML_SOURCES_METALIUM} ${GGML_HEADERS_METALIUM}
)

target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})
Expand Down
6 changes: 6 additions & 0 deletions ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,12 @@ GGML_CALL static void ggml_backend_registry_init(void) {
extern GGML_CALL void ggml_backend_kompute_reg_devices(void);
ggml_backend_kompute_reg_devices();
#endif

#ifdef GGML_USE_METALIUM
extern GGML_CALL ggml_backend_t ggml_backend_reg_metalium_init(const char * params, void * user_data);
extern GGML_CALL ggml_backend_buffer_type_t ggml_backend_metalium_buffer_type(void);
ggml_backend_register("Metalium", ggml_backend_reg_metalium_init, ggml_backend_cpu_buffer_type(), NULL);
#endif
}

GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
Expand Down
Loading

0 comments on commit 8851d6d

Please sign in to comment.