From 99263f14d1a9285e3e2dd7c0f0e8e94a9d7b4f1f Mon Sep 17 00:00:00 2001 From: marty1885 Date: Sat, 20 Jul 2024 04:55:44 -0400 Subject: [PATCH] update for latest ttnn --- ggml/src/CMakeLists.txt | 1 + ggml/src/ggml-metalium.cpp | 23 +++++++++++------------ ggml/src/metalium-pch.hpp | 18 ++++++------------ 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 735274db25ccd..60965ac36828c 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -782,6 +782,7 @@ if(GGML_METALIUM) # TTNN $ENV{TT_METAL_HOME}/ttnn/cpp + $ENV{TT_METAL_HOME}/ttnn/cpp/ttnn/experimental/ $ENV{TT_METAL_HOME}/tt_eager $ENV{TT_METAL_HOME}/tt_metal/third_party/magic_enum ) diff --git a/ggml/src/ggml-metalium.cpp b/ggml/src/ggml-metalium.cpp index a40fb90f3e213..fb87654445d5f 100644 --- a/ggml/src/ggml-metalium.cpp +++ b/ggml/src/ggml-metalium.cpp @@ -8,12 +8,6 @@ #include "host_api.hpp" #include "impl/dispatch/command_queue.hpp" -#include "tensor/host_buffer/functions.hpp" -#include "tensor/host_buffer/types.hpp" -#include "tensor/types.hpp" -#include "tt_dnn/op_library/auto_format.hpp" -#include "tt_dnn/op_library/composite/composite_ops.hpp" -#include "tt_dnn/op_library/untilize/untilize_op.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" #include "ttnn/operations/normalization/softmax/device/softmax_op.hpp" #include @@ -24,20 +18,19 @@ #include #include #include -#include #include -#include #include -#include -#include -#include -#include #include #include #include +#include #include #include #include +#include +#include +#include +#include #include #include #include @@ -646,6 +639,10 @@ static bool ggml_backend_metalium_activations(ggml_backend_metalium_context * ct case GGML_UNARY_OP_HARDSIGMOID: ret = tt::tt_metal::hardsigmoid(*src_tensor); break; + case GGML_UNARY_OP_STEP: + // TODO: Make sure the resulting data type matches the input + ret = tt::tt_metal::where(ttnn::gtz(*src_tensor), 1.f, 0.f); + break; default: return false; } @@ -1416,6 +1413,7 @@ GGML_CALL static enum ggml_status ggml_backend_metalium_graph_compute(ggml_backe case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_STEP: ok = ggml_backend_metalium_activations(ctx, node, unary_op); break; default: @@ -1572,6 +1570,7 @@ GGML_CALL static bool ggml_backend_metalium_supports_op(ggml_backend_t backend, case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_STEP: return true; default: return false; diff --git a/ggml/src/metalium-pch.hpp b/ggml/src/metalium-pch.hpp index f4f4c5f77bdcb..9abffe036d05e 100644 --- a/ggml/src/metalium-pch.hpp +++ b/ggml/src/metalium-pch.hpp @@ -10,12 +10,6 @@ #include "host_api.hpp" #include "impl/dispatch/command_queue.hpp" -#include "tensor/host_buffer/functions.hpp" -#include "tensor/host_buffer/types.hpp" -#include "tensor/types.hpp" -#include "tt_dnn/op_library/auto_format.hpp" -#include "tt_dnn/op_library/composite/composite_ops.hpp" -#include "tt_dnn/op_library/untilize/untilize_op.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" #include "ttnn/operations/normalization/softmax/device/softmax_op.hpp" #include @@ -26,20 +20,19 @@ #include #include #include -#include #include -#include #include -#include -#include -#include -#include #include #include #include +#include #include #include #include +#include +#include +#include +#include #include #include #include @@ -48,4 +41,5 @@ #include #include #include + #endif