diff --git a/ggml-metalium.cpp b/ggml-metalium.cpp index cb0b4d9f0842f..b78feacaee27d 100644 --- a/ggml-metalium.cpp +++ b/ggml-metalium.cpp @@ -1,4 +1,5 @@ #include "common/bfloat16.hpp" +#include "device/tt_arch_types.h" #include "ggml-backend-impl.h" #include "ggml.h" #include "ggml-metalium.h" @@ -8,8 +9,10 @@ #include "tensor/host_buffer/types.hpp" #include "tensor/types.hpp" #include "tt_dnn/op_library/auto_format.hpp" +#include "tt_dnn/op_library/tilize/tilize_op.hpp" #include #include +#include #include #include #include @@ -27,6 +30,11 @@ struct ggml_backend_metalium_context { std::string name; }; +struct TensorWithMetadata +{ + std::shared_ptr tensor; + ggml_type ggtype = (ggml_type)-1; +}; /////////////////////////////////////////////////////////////////////////////////////////////////////// // Backend internal state tracking because GGML API does not allow @@ -40,9 +48,43 @@ static std::map g_device_map; /////////////////////////////////////////////////////////////////////////////////////////////////////// static void ggml_backend_metalium_mul_mat(ggml_backend_metalium_context * ctx, struct ggml_tensor * dst) { - GGML_UNUSED(ctx); - GGML_UNUSED(dst); - abort(); + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const enum ggml_type type = src0->type; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(src0->extra != NULL); + GGML_ASSERT(src1->extra != NULL); + GGML_ASSERT(dst->extra != NULL); + + tt::tt_metal::Tensor& a = *reinterpret_cast(src0->extra)->tensor; + tt::tt_metal::Tensor& b = *reinterpret_cast(src1->extra)->tensor; + tt::tt_metal::Tensor& c = *reinterpret_cast(dst->extra)->tensor; + + GGML_ASSERT(a.storage_type() == tt::tt_metal::StorageType::DEVICE || a.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE); + GGML_ASSERT(b.storage_type() == tt::tt_metal::StorageType::DEVICE || b.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE); + + auto t = tt::tt_metal::fully_connected(a, b); + fprintf(stderr, "Metalium: %s starting\n", __func__); + tt::tt_metal::Finish(ctx->device->command_queue()); + fprintf(stderr, "Metalium: %s done\n", __func__); } static void ggml_backend_metalium_out_prod(ggml_backend_metalium_context * ctx, struct ggml_tensor * dst) { @@ -78,7 +120,8 @@ GGML_CALL static const char * ggml_backend_metalium_buffer_type_name(ggml_backen } GGML_CALL static size_t ggml_backend_metalium_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 4096; // assume the wosre, BFP16 on tile boundary + // Not using this. Metalium's allication model is not compatible with GGML's allocator + return 128; GGML_UNUSED(buft); } @@ -90,15 +133,8 @@ static size_t ggml_backend_metalium_buffer_type_get_max_size(ggml_backend_buffer } GGML_CALL static size_t ggml_backend_metalium_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - // TODO: Make sure this is correct - if(ggml_is_quantized(tensor->type)) { - return ggml_nbytes(tensor); - } - intmax_t nelements = 1; - for(int i = 0; i < 4; i++) { - nelements *= i < 2 ? tensor->ne[i] / 32 + (tensor->ne[i] % 32 != 0) : tensor->ne[i]; - } - return nelements * ggml_type_size(tensor->type); + // Not using this. Metalium's allication model is not compatible with GGML's allocator + return ggml_nbytes(tensor); GGML_UNUSED(buft); } @@ -106,9 +142,10 @@ struct ggml_backend_metalium_buffer_context { size_t ggml_buffer_size_bytes = 0; std::string name; - - // These initializations are deferred due to GGML API limitations - tt::tt_metal::Tensor tensor; + ttnn::device::Device* device = nullptr; + + // Tracking our own allocations because Metalium limitations and GGML assuming them + std::vector> metadata_to_free; }; GGML_CALL static const char * ggml_backend_metalium_buffer_get_name(ggml_backend_buffer_t buffer) { @@ -127,34 +164,82 @@ static void ggml_backend_metalium_buffer_set_tensor(ggml_backend_buffer_t buffer const void *data, size_t offset, size_t size) { - ggml_backend_metalium_buffer_context * ctx = (ggml_backend_metalium_buffer_context *)buffer->context; + // Must be setting the entire tensor at once + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + GGML_ASSERT(tensor->extra != NULL); + + ggml_backend_metalium_buffer_context * bufctx = (ggml_backend_metalium_buffer_context *)buffer->context; ggml_type ggtype = tensor->type; + TensorWithMetadata * meta = (TensorWithMetadata *)tensor->extra; + + tt::ARCH processor_class = bufctx->device->arch(); + // only grayskull is supported for now. + GGML_ASSERT(processor_class == tt::ARCH::GRAYSKULL); - // TODO: Support other types - GGML_ASSERT(ggtype == GGML_TYPE_BF16); - std::vector bfloat16_data(size / sizeof(bfloat16)); - std::memcpy(bfloat16_data.data(), data, size); - auto owned = tt::tt_metal::owned_buffer::create(std::move(bfloat16_data)); + // TODO: See if we can use BorrowedStorage to avoid copying the data + OwnedStorage storage; + + if (ggtype == GGML_TYPE_BF16) { + std::vector bfloat16_data(size / sizeof(bfloat16)); + std::memcpy(bfloat16_data.data(), data, size); + auto owned = tt::tt_metal::owned_buffer::create(std::move(bfloat16_data)); + storage = OwnedStorage{std::move(owned)}; + } + else if (ggtype == GGML_TYPE_F32) { + // For now we cast F32 to BF16. Need a scalable way to handle this as WORMHOLD_B0 have native support for F32 + std::vector bfloat16_data(size / sizeof(float)); + const float* f32_data = (const float*)data; + for(size_t i = 0; i < size / sizeof(float); i++) { + bfloat16_data[i] = bfloat16(f32_data[i]); + } + auto owned = tt::tt_metal::owned_buffer::create(std::move(bfloat16_data)); + storage = OwnedStorage{std::move(owned)}; + } + else { + // TODO: Support other types + GGML_ASSERT(false && "Unsupported data type"); + } + // TODO: Make sure this is correct std::vector shape(GGML_MAX_DIMS); for(int i = 0; i < GGML_MAX_DIMS; i++) { // GGML stores the shape in reverse order shape[i] = tensor->ne[GGML_MAX_DIMS - i - 1]; } - ctx->tensor = tt::tt_metal::Tensor(OwnedStorage{owned}, tt::tt_metal::Shape(shape), tt::tt_metal::DataType::BFLOAT16, tt::tt_metal::Layout::ROW_MAJOR); - // HACK: Need to save device pointer - ctx->tensor = ctx->tensor.to(g_device_map[0]); - ctx->tensor = tilize_with_zero_padding(ctx->tensor); + tt::tt_metal::Tensor t(std::move(storage), tt::tt_metal::Shape(shape) + , tt::tt_metal::DataType::BFLOAT16, tt::tt_metal::Layout::ROW_MAJOR); + + // I think we can allow this.. right? + // GGML_ASSERT(!bufctx->tensors.contains(offset)); + *meta = TensorWithMetadata { + .tensor = std::make_shared(tt::tt_metal::tilize_with_zero_padding(t.to(bufctx->device))), + .ggtype = ggtype, + }; + tt::tt_metal::Finish(bufctx->device->command_queue()); +} - GGML_ASSERT(offset == 0); +static void * ggml_backend_metalium_buffer_get_base(ggml_backend_buffer_t buffer) { + // Not using this. Metalium's allication model is not compatible with GGML's allocator + return (void*)0x10000; +} + +GGML_CALL static void +ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, + ggml_tensor *tensor) +{ + ggml_backend_metalium_buffer_context * bufctx = (ggml_backend_metalium_buffer_context *)buffer->context; + bufctx->metadata_to_free.push_back(std::make_unique()); + tensor->extra = bufctx->metadata_to_free.back().get(); + GGML_UNUSED(buffer); } static struct ggml_backend_buffer_i ggml_backend_metalium_buffer_interface = { /* .get_name = */ ggml_backend_metalium_buffer_get_name, /* .free_buffer = */ ggml_backend_metalium_buffer_free_buffer, - /* .get_base = */ nullptr, //ggml_backend_metalium_buffer_get_base, - /* .init_tensor = */ nullptr, //ggml_backend_metalium_buffer_init_tensor, + /* .get_base = */ ggml_backend_metalium_buffer_get_base, + /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor, /* .set_tensor = */ ggml_backend_metalium_buffer_set_tensor, /* .get_tensor = */ nullptr, //ggml_backend_metalium_buffer_get_tensor, /* .cpy_tensor = */ nullptr, //ggml_backend_metalium_buffer_cpy_tensor, @@ -166,13 +251,14 @@ static struct ggml_backend_buffer_i ggml_backend_metalium_buffer_interface = { GGML_CALL static ggml_backend_buffer_t ggml_backend_metalium_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - // ggml_backend_metalium_buffer_type_context * buft_ctx = (ggml_backend_metalium_buffer_type_context *)buft->context; + ggml_backend_metalium_buffer_type_context * buft_ctx = (ggml_backend_metalium_buffer_type_context *)buft->context; ggml_backend_metalium_buffer_context* ctx = new ggml_backend_metalium_buffer_context; // real allocation is deferred until the first tensor is set because we don't know the underlying tensor type yet // TODO: Use a constructor ctx->ggml_buffer_size_bytes = size; ctx->name = ctx->name; + ctx->device = buft_ctx->device; return ggml_backend_buffer_init(buft, ggml_backend_metalium_buffer_interface, ctx, size); } @@ -212,7 +298,6 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metalium_get_default_bu } GGML_CALL static enum ggml_status ggml_backend_metalium_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - abort(); // nothing supported yet ggml_backend_metalium_context * ctx = (ggml_backend_metalium_context *)backend->context; for (int i = 0; i < cgraph->n_nodes; i++) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cfa7073153486..bfd1474a2f616 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -63,6 +63,8 @@ function(llama_target_and_test source) ${LLAMA_TEST_ARGS}) set_property(TEST ${TEST_TARGET} PROPERTY LABELS ${LLAMA_TEST_LABEL}) + target_compile_options(${TEST_TARGET} PRIVATE -stdlib=libc++) + target_link_libraries(${TEST_TARGET} PRIVATE c++ c++abi) endfunction() # build test-tokenizer-0 target once and add many tests @@ -139,3 +141,7 @@ endif() get_filename_component(TEST_TARGET test-c.c NAME_WE) add_executable(${TEST_TARGET} test-c.c) target_link_libraries(${TEST_TARGET} PRIVATE llama) + +# HACK: Test files so I can debug metalium. +# TODO: Remove these tests when done. +llama_target_and_test(test-mul-mat.cpp) \ No newline at end of file diff --git a/tests/test-mul-mat.cpp b/tests/test-mul-mat.cpp new file mode 100644 index 0000000000000..f8271643c9fa3 --- /dev/null +++ b/tests/test-mul-mat.cpp @@ -0,0 +1,371 @@ +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +//#define GGML_USE_CUBLAS // uncomment this to use cuda backend, make sure build ggml lib with GGML_CUBLAS=ON + +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#ifdef GGML_USE_METALIUM +#include "ggml-metalium.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +struct test_model { + struct ggml_tensor * a; + struct ggml_tensor * b; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + +void load_model(test_model & model, float* a, float* b, int M, int N, int K, bool use_gpu = false) { + size_t buffer_size = 0; + { + buffer_size += (M * N) * ggml_type_size(GGML_TYPE_F32); // tensor a + buffer_size += (N * K) * ggml_type_size(GGML_TYPE_F32); // tensor b + buffer_size += 1024; // overhead + } + + printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + printf("%s: backend buffer size = %d bytes\n", __func__, (int) buffer_size); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUBLAS + if (use_gpu) { + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METALIUM + if (use_gpu) { + fprintf(stderr, "%s: using Metalium backend\n", __func__); + model.backend = ggml_backend_metalium_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metalium_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, K, M); + printf("Matrix A: [%i, %i]\n", K, M); + model.b = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, K, N); + printf("Matrix B: [%i, %i]\n", K, N); + + // create a allocator + struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer); + + // alloc memory + ggml_tallocr_alloc(&alloc, model.a); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.a->data, a, ggml_nbytes(model.a)); + } else { + ggml_backend_tensor_set(model.a, a, 0, ggml_nbytes(model.a)); // cuda requires copy the data directly to device + } + + // alloc memory + ggml_tallocr_alloc(&alloc, model.b); + + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.b->data, b, ggml_nbytes(model.b)); + } else { + ggml_backend_tensor_set(model.b, b, 0, ggml_nbytes(model.b)); // cuda requires copy the data directly to device + } +} + +struct ggml_cgraph * build_graph(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + // zT = x @ yT + /* + struct ggml_tensor * result = ggml_mul_mat(ctx0, model.a, ggml_cont(ctx0, model.b)); + + // z = (zT)T + ggml_build_forward_expand(gf, ggml_cont(ctx0, ggml_transpose(ctx0, result))); + */ + struct ggml_tensor * result = ggml_mul_mat(ctx0, model.a, model.b); + ggml_build_forward_expand(gf, result); + + // delete the temporally context used to build the graph + ggml_free(ctx0); + return gf; +} + +struct ggml_tensor* compute(const test_model & model, ggml_gallocr_t allocr) { + struct ggml_cgraph * gf = build_graph(model); + + // allocate tensors + ggml_gallocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + ggml_backend_graph_compute(model.backend, gf); + + //ggml_graph_print(gf); + + // in this case, the output tensor is the last one in the graph + return gf->nodes[gf->n_nodes - 1]; +} + + +static void ggml_vec_dot_f16(const int n, float * s, float * x, float * y) { + float sumf = 0.0; + for (int i = 0; i < n; ++i) { + sumf += x[i] * y[i]; + } + *s = sumf; +} + +static void gemm_f16_out_f32(int m, int n, int k, + float * A, + float * B, + float * C, + const int ith, const int nth) { + // does not seem to make a difference + int m0, m1, n0, n1; + // patches per thread + if (m > n) { + n0 = 0; + n1 = n; + + // total patches in dst + const int np = m; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + m0 = dp*ith; + m1 = std::min(m0 + dp, np); + } else { + m0 = 0; + m1 = m; + + // total patches in dst + const int np = n; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + n0 = dp*ith; + n1 = std::min(n0 + dp, np); + } + + // block-tiling attempt + int64_t blck_n = 16; + int64_t blck_m = 16; + + for (int j = n0; j < n1; j+=blck_n) { + for (int i = m0; i < m1; i+=blck_m) { + // printf("i j k => %d %d %d\n", i, j, K); + for (int ii = i; ii < i + blck_m && ii < m1; ii++) { + for (int jj = j; jj < j + blck_n && jj < n1; jj++) { + ggml_vec_dot_f16(k, + C + ii*n + jj, + A + ii * k, + B + jj * k); + } + } + } + } +} + + +void perform_gemm_test(float* a, float* b, float* expected, int M, int N, int K) { + printf("\nPerforming gemm_f16_out_f32 test:\n"); + + float* gemm_out = new float[M * N]; + gemm_f16_out_f32(M, N, K, a, b, gemm_out, 0, 1); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + printf("%.1ff,", gemm_out[i * N + j]); + } + printf("\n"); + } + + bool passed = true; + + for(int i = 0; i < M * N; i++) { + if(gemm_out[i] != expected[i]) { + passed = false; + break; + } + } + + printf("gemm_mult (%i): %s\n", (M * N), passed ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); +} + +int main(void) +{ + ggml_time_init(); + const int M = 4, N = 16, K = 36; // a conv2d expected matrix multiplication + + // matrix A (4 X 36) + float matrixA[M * K] = { + 2.0f, 9.0f, 2.0f, 10.0f, 6.0f, 4.0f, 3.0f, 6.0f, 3.0f, 6.0f, 9.0f, 7.0f, 8.0f, 8.0f, 3.0f, 3.0f, 10.0f, 5.0f, 2.0f, 10.0f, 7.0f, 10.0f, 9.0f, 3.0f, 6.0f, 6.0f, 5.0f, 10.0f, 2.0f, 3.0f, 6.0f, 1.0f, 9.0f, 4.0f, 10.0f, 4.0f, + 10.0f, 7.0f, 8.0f, 10.0f, 10.0f, 8.0f, 7.0f, 10.0f, 4.0f, 6.0f, 8.0f, 7.0f, 7.0f, 6.0f, 9.0f, 3.0f, 6.0f, 5.0f, 5.0f, 2.0f, 7.0f, 2.0f, 7.0f, 4.0f, 4.0f, 6.0f, 6.0f, 4.0f, 3.0f, 9.0f, 3.0f, 6.0f, 4.0f, 7.0f, 2.0f, 9.0f, + 7.0f, 3.0f, 2.0f, 5.0f, 7.0f, 3.0f, 10.0f, 2.0f, 6.0f, 1.0f, 4.0f, 7.0f, 5.0f, 10.0f, 3.0f, 10.0f, 4.0f, 5.0f, 5.0f, 1.0f, 6.0f, 10.0f, 7.0f, 4.0f, 5.0f, 3.0f, 9.0f, 9.0f, 8.0f, 6.0f, 9.0f, 2.0f, 3.0f, 6.0f, 8.0f, 5.0f, + 5.0f, 5.0f, 5.0f, 5.0f, 3.0f, 10.0f, 4.0f, 1.0f, 8.0f, 8.0f, 9.0f, 8.0f, 4.0f, 1.0f, 4.0f, 9.0f, 3.0f, 6.0f, 3.0f, 1.0f, 4.0f, 8.0f, 3.0f, 10.0f, 8.0f, 6.0f, 4.0f, 5.0f, 4.0f, 3.0f, 2.0f, 2.0f, 4.0f, 3.0f, 6.0f, 4.0f, + }; + + // matrix B (16 X 36) + float matrixB[N * K] = { + 9.0f, 7.0f, 1.0f, 3.0f, 5.0f, 9.0f, 7.0f, 6.0f, 1.0f, 10.0f, 1.0f, 1.0f, 7.0f, 2.0f, 4.0f, 9.0f, 10.0f, 4.0f, 5.0f, 5.0f, 7.0f, 1.0f, 7.0f, 7.0f, 2.0f, 9.0f, 5.0f, 10.0f, 7.0f, 4.0f, 8.0f, 9.0f, 9.0f, 3.0f, 10.0f, 2.0f, + 4.0f, 6.0f, 10.0f, 9.0f, 5.0f, 1.0f, 8.0f, 7.0f, 4.0f, 7.0f, 2.0f, 6.0f, 5.0f, 3.0f, 1.0f, 10.0f, 8.0f, 4.0f, 8.0f, 3.0f, 7.0f, 1.0f, 2.0f, 7.0f, 6.0f, 8.0f, 6.0f, 5.0f, 2.0f, 3.0f, 1.0f, 1.0f, 2.0f, 5.0f, 7.0f, 1.0f, + 8.0f, 2.0f, 8.0f, 8.0f, 8.0f, 8.0f, 4.0f, 4.0f, 6.0f, 10.0f, 10.0f, 9.0f, 2.0f, 9.0f, 3.0f, 7.0f, 7.0f, 1.0f, 4.0f, 9.0f, 1.0f, 2.0f, 3.0f, 6.0f, 1.0f, 10.0f, 5.0f, 8.0f, 9.0f, 4.0f, 6.0f, 2.0f, 3.0f, 1.0f, 2.0f, 7.0f, + 5.0f, 1.0f, 7.0f, 2.0f, 9.0f, 10.0f, 9.0f, 5.0f, 2.0f, 5.0f, 4.0f, 10.0f, 9.0f, 9.0f, 1.0f, 9.0f, 8.0f, 8.0f, 9.0f, 4.0f, 9.0f, 4.0f, 8.0f, 2.0f, 1.0f, 8.0f, 4.0f, 5.0f, 10.0f, 7.0f, 6.0f, 2.0f, 1.0f, 10.0f, 10.0f, 7.0f, + 9.0f, 4.0f, 5.0f, 9.0f, 5.0f, 10.0f, 10.0f, 3.0f, 6.0f, 6.0f, 4.0f, 4.0f, 4.0f, 8.0f, 5.0f, 4.0f, 9.0f, 1.0f, 9.0f, 9.0f, 1.0f, 7.0f, 9.0f, 2.0f, 10.0f, 9.0f, 10.0f, 8.0f, 3.0f, 3.0f, 9.0f, 3.0f, 9.0f, 10.0f, 1.0f, 8.0f, + 9.0f, 2.0f, 6.0f, 9.0f, 7.0f, 2.0f, 3.0f, 5.0f, 3.0f, 6.0f, 9.0f, 7.0f, 3.0f, 7.0f, 6.0f, 4.0f, 10.0f, 3.0f, 5.0f, 7.0f, 2.0f, 9.0f, 3.0f, 2.0f, 2.0f, 10.0f, 8.0f, 7.0f, 3.0f, 10.0f, 6.0f, 3.0f, 1.0f, 1.0f, 4.0f, 10.0f, + 2.0f, 9.0f, 2.0f, 10.0f, 6.0f, 4.0f, 3.0f, 6.0f, 3.0f, 6.0f, 9.0f, 7.0f, 8.0f, 8.0f, 3.0f, 3.0f, 10.0f, 5.0f, 2.0f, 10.0f, 7.0f, 10.0f, 9.0f, 3.0f, 6.0f, 6.0f, 5.0f, 10.0f, 2.0f, 3.0f, 6.0f, 1.0f, 9.0f, 4.0f, 10.0f, 4.0f, + 10.0f, 7.0f, 8.0f, 10.0f, 10.0f, 8.0f, 7.0f, 10.0f, 4.0f, 6.0f, 8.0f, 7.0f, 7.0f, 6.0f, 9.0f, 3.0f, 6.0f, 5.0f, 5.0f, 2.0f, 7.0f, 2.0f, 7.0f, 4.0f, 4.0f, 6.0f, 6.0f, 4.0f, 3.0f, 9.0f, 3.0f, 6.0f, 4.0f, 7.0f, 2.0f, 9.0f, + 7.0f, 3.0f, 2.0f, 5.0f, 7.0f, 3.0f, 10.0f, 2.0f, 6.0f, 1.0f, 4.0f, 7.0f, 5.0f, 10.0f, 3.0f, 10.0f, 4.0f, 5.0f, 5.0f, 1.0f, 6.0f, 10.0f, 7.0f, 4.0f, 5.0f, 3.0f, 9.0f, 9.0f, 8.0f, 6.0f, 9.0f, 2.0f, 3.0f, 6.0f, 8.0f, 5.0f, + 5.0f, 5.0f, 5.0f, 5.0f, 3.0f, 10.0f, 4.0f, 1.0f, 8.0f, 8.0f, 9.0f, 8.0f, 4.0f, 1.0f, 4.0f, 9.0f, 3.0f, 6.0f, 3.0f, 1.0f, 4.0f, 8.0f, 3.0f, 10.0f, 8.0f, 6.0f, 4.0f, 5.0f, 4.0f, 3.0f, 2.0f, 2.0f, 4.0f, 3.0f, 6.0f, 4.0f, + 6.0f, 2.0f, 3.0f, 3.0f, 3.0f, 7.0f, 5.0f, 1.0f, 8.0f, 1.0f, 4.0f, 5.0f, 1.0f, 1.0f, 6.0f, 4.0f, 2.0f, 1.0f, 7.0f, 8.0f, 6.0f, 1.0f, 1.0f, 5.0f, 6.0f, 5.0f, 10.0f, 6.0f, 7.0f, 5.0f, 9.0f, 3.0f, 2.0f, 7.0f, 9.0f, 4.0f, + 2.0f, 5.0f, 9.0f, 5.0f, 10.0f, 3.0f, 1.0f, 8.0f, 1.0f, 7.0f, 1.0f, 8.0f, 1.0f, 6.0f, 7.0f, 8.0f, 4.0f, 9.0f, 5.0f, 10.0f, 3.0f, 7.0f, 6.0f, 8.0f, 8.0f, 5.0f, 6.0f, 8.0f, 10.0f, 9.0f, 4.0f, 1.0f, 3.0f, 3.0f, 4.0f, 7.0f, + 8.0f, 2.0f, 6.0f, 6.0f, 5.0f, 1.0f, 3.0f, 7.0f, 1.0f, 7.0f, 2.0f, 2.0f, 2.0f, 8.0f, 4.0f, 1.0f, 1.0f, 5.0f, 9.0f, 4.0f, 1.0f, 2.0f, 3.0f, 10.0f, 1.0f, 4.0f, 9.0f, 9.0f, 6.0f, 8.0f, 8.0f, 1.0f, 9.0f, 10.0f, 4.0f, 1.0f, + 8.0f, 5.0f, 8.0f, 9.0f, 4.0f, 8.0f, 2.0f, 1.0f, 1.0f, 9.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 5.0f, 6.0f, 7.0f, 3.0f, 1.0f, 4.0f, 6.0f, 7.0f, 7.0f, 7.0f, 8.0f, 7.0f, 8.0f, 8.0f, 2.0f, 10.0f, 2.0f, 7.0f, 3.0f, 8.0f, 3.0f, + 8.0f, 7.0f, 6.0f, 2.0f, 4.0f, 10.0f, 10.0f, 6.0f, 10.0f, 3.0f, 7.0f, 6.0f, 4.0f, 3.0f, 5.0f, 5.0f, 5.0f, 3.0f, 8.0f, 10.0f, 3.0f, 4.0f, 8.0f, 4.0f, 2.0f, 6.0f, 8.0f, 9.0f, 6.0f, 9.0f, 4.0f, 3.0f, 5.0f, 2.0f, 2.0f, 6.0f, + 10.0f, 6.0f, 2.0f, 1.0f, 7.0f, 5.0f, 6.0f, 4.0f, 1.0f, 9.0f, 10.0f, 2.0f, 4.0f, 5.0f, 8.0f, 5.0f, 7.0f, 4.0f, 7.0f, 6.0f, 3.0f, 9.0f, 2.0f, 1.0f, 4.0f, 2.0f, 6.0f, 6.0f, 3.0f, 3.0f, 2.0f, 8.0f, 5.0f, 9.0f, 3.0f, 4.0f, + }; + + // matrix C (4 x 16) + float expected_result[M * N] = { + 1224.0f, 1023.0f, 1158.0f,1259.0f,1359.0f,1194.0f,1535.0f,1247.0f,1185.0f,1029.0f,889.0f,1182.0f,955.0f,1179.0f,1147.0f,1048.0f, + 1216.0f, 1087.0f, 1239.0f,1361.0f,1392.0f,1260.0f,1247.0f,1563.0f,1167.0f,1052.0f,942.0f,1214.0f,1045.0f,1134.0f,1264.0f,1126.0f, + 1125.0f, 966.0f, 1079.0f,1333.0f,1287.0f,1101.0f,1185.0f,1167.0f,1368.0f,990.0f,967.0f,1121.0f,971.0f,1086.0f,1130.0f,980.0f, + 999.0f, 902.0f, 1020.0f,1056.0f,1076.0f,929.0f,1029.0f,1052.0f,990.0f,1108.0f,823.0f,989.0f,759.0f,1041.0f,1003.0f,870.0f + }; + + bool passed = true; + + perform_gemm_test(matrixA, matrixB, expected_result, M, N, K); + + test_model model; + load_model(model, matrixA, matrixB, M, N, K, true); + + ggml_gallocr_t allocr = NULL; + + { + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0); + fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + } + + struct ggml_tensor * result = compute(model, allocr); + + float* out_data = new float[ggml_nelements(result)]; + + ggml_backend_tensor_get(result, out_data, 0, ggml_nbytes(result)); + + printf("\nPerforming ggml_mul_mat test:\n"); + + passed = true; + for(int i = 0; i < M * N; i++) { + if(out_data[i] != expected_result[i]) { + passed = false; + break; + } + } + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + printf("%.1f ", out_data[i * N + j]); + } + printf("\n"); + } + + printf("ggml_mul_mat (%d): %s\n", (int) ggml_nelements(result), passed && (ggml_nelements(result) == M * N) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + // free memory + ggml_free(model.ctx); + + ggml_backend_buffer_free(model.buffer); + ggml_backend_free(model.backend); + ggml_gallocr_free(allocr); + return 0; +}