Skip to content

Commit

Permalink
hook up matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
marty1885 committed Jun 26, 2024
1 parent f3faa9c commit bc3b051
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions ggml-metalium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <tt_eager/tensor/tensor.hpp>
#include <ttnn/core.hpp>
#include <ttnn/operations/eltwise/binary/binary.hpp>
#include <tt_eager/tt_dnn/op_library/transpose/transpose_op.hpp>
#include <ttnn/device.hpp>
#include <tt_dnn/op_library/fully_connected/fully_connected_op.hpp>
#include <tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp>
Expand Down Expand Up @@ -81,10 +82,11 @@ static void ggml_backend_metalium_mul_mat(ggml_backend_metalium_context * ctx, s
GGML_ASSERT(a.storage_type() == tt::tt_metal::StorageType::DEVICE || a.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE);
GGML_ASSERT(b.storage_type() == tt::tt_metal::StorageType::DEVICE || b.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE);

auto t = tt::tt_metal::fully_connected(a, b);
fprintf(stderr, "Metalium: %s starting\n", __func__);
tt::tt_metal::Finish(ctx->device->command_queue());
fprintf(stderr, "Metalium: %s done\n", __func__);
printf("A shape = %u %u %u %u\n", a.shape()[0], a.shape()[1], a.shape()[2], a.shape()[3]);
printf("B shape = %u %u %u %u\n", b.shape()[0], b.shape()[1], b.shape()[2], b.shape()[3]);

// TODO: Support matmul of pre-transposed tensors. Calling transpose here is slow
c = tt::tt_metal::fully_connected(a, tt::tt_metal::transpose(b, 2, 3));
}

static void ggml_backend_metalium_out_prod(ggml_backend_metalium_context * ctx, struct ggml_tensor * dst) {
Expand Down Expand Up @@ -213,6 +215,7 @@ static void ggml_backend_metalium_buffer_set_tensor(ggml_backend_buffer_t buffer

// I think we can allow this.. right?
// GGML_ASSERT(!bufctx->tensors.contains(offset));
// TODO: Make sure this is the correct tilize we want to use
*meta = TensorWithMetadata {
.tensor = std::make_shared<tt::tt_metal::Tensor>(tt::tt_metal::tilize_with_zero_padding(t.to(bufctx->device))),
.ggtype = ggtype,
Expand Down

0 comments on commit bc3b051

Please sign in to comment.