Skip to content

Commit

Permalink
fixed that transpose bug
Browse files Browse the repository at this point in the history
  • Loading branch information
marty1885 committed Jun 28, 2024
1 parent 976dfe1 commit d5e4db8
Showing 1 changed file with 10 additions and 23 deletions.
33 changes: 10 additions & 23 deletions ggml/src/ggml-metalium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,17 +348,6 @@ static void ggml_backend_metalium_mul_mat(ggml_backend_metalium_context * ctx, s
GGML_ASSERT(cm != NULL);

auto aT = tt::tt_metal::transpose(a, -2, -1);
#if !defined(NDEBUG) || 1
// TODO: Remove this in the future. TTNN has buggy transpose implementation
std::cout << "a.shape: " << a.shape() << " aT.shape: " << aT.shape() << std::endl;
GGML_ASSERT(aT.shape()[0] == a.shape()[0]);
GGML_ASSERT(aT.shape()[1] == a.shape()[1]);
GGML_ASSERT(aT.shape()[3] == a.shape()[2]);
GGML_ASSERT(aT.shape()[2] == a.shape()[3]);
#endif

std::cout << "a.shape: " << a.shape() << " b.shape: " << b.shape() << std::endl;

// TODO: Ask TT to support multiplication of pre-transposed tensors. Calling transpose here is inefficient
// https://github.com/tenstorrent/tt-metal/issues/9709
cm->tensor = std::make_shared<tt::tt_metal::Tensor>(tt::tt_metal::fully_connected(b, aT));
Expand Down Expand Up @@ -573,7 +562,7 @@ static void ggml_backend_metalium_buffer_set_tensor(ggml_backend_buffer_t buffer
// 3. If the data is quantized, cast down to BFLOAT8_B or BFLOAT4_B
// There's a lot of things to do here.
// TODO: On grayskull the best I can do is BFLOAT16 so the final dimension must be a multiple of 2.
// But on Wormhole we can use FP32 then the final dimension can be anything. But currently it
// But on Wormhole we can use FP32 then the final dimension can be anything. But currently it
// is hard coded to BFLOAT16. Use FP32 as intermidate when the hardware supports it and when
// it makes sense.
// TODO: Handle integer data type for Wormhole
Expand Down Expand Up @@ -616,7 +605,7 @@ static void ggml_backend_metalium_buffer_set_tensor(ggml_backend_buffer_t buffer
else {
GGML_ASSERT(false && "Unsupported data type");
}

// TODO: Make sure this is correct
std::vector<uint32_t> shape(GGML_MAX_DIMS, 1);
for(int i = 0; i < GGML_MAX_DIMS; i++) {
Expand All @@ -626,7 +615,7 @@ static void ggml_backend_metalium_buffer_set_tensor(ggml_backend_buffer_t buffer

tt::tt_metal::Tensor t(std::move(storage), tt::tt_metal::Shape(shape)
, tt::tt_metal::DataType::BFLOAT16, tt::tt_metal::Layout::ROW_MAJOR);

// I think we can allow this.. right?
// GGML_ASSERT(!bufctx->tensors.contains(offset));

Expand Down Expand Up @@ -819,11 +808,11 @@ GGML_CALL static enum ggml_status ggml_backend_metalium_graph_compute(ggml_backe
case GGML_OP_MUL_MAT:
ggml_backend_metalium_mul_mat(ctx, node);
break;

case GGML_OP_CPY:
ggml_backend_metalium_cpy(ctx, node);
break;

case GGML_OP_NONE:
break;

Expand Down Expand Up @@ -869,7 +858,7 @@ GGML_CALL static bool ggml_backend_metalium_supports_op(ggml_backend_t backend,
}
return tensor->ne[0] % 4 == 0;
};

GGML_ASSERT(op != NULL);
if(!output_supported(op)) {
return false;
Expand Down Expand Up @@ -903,10 +892,8 @@ GGML_CALL static bool ggml_backend_metalium_supports_op(ggml_backend_t backend,
// DIV does not support broadcasting on TTNN
return input_supported(src0) && input_supported(src1) &&
(memcmp(src0->ne, src1->ne, sizeof(src0->ne)) == 0 || (numpy_broadcast_rule(src0, src1) && op->op != GGML_OP_DIV));
// FIXME: This crash for most shapes due to a bug in TTNN transpose implementation. Unmask this
// when the bug is fixed
// case GGML_OP_MUL_MAT:
// return op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
case GGML_OP_MUL_MAT:
return true;
case GGML_OP_CPY:
return input_supported(src0);
default:
Expand Down Expand Up @@ -963,7 +950,7 @@ ggml_backend_t ggml_backend_metalium_init(void) {
/* name = */ "Metalium " + std::to_string(device_id),
};
AutoFormat::SetDefaultDevice(ctx->device);


// store the device in the global map because tensor creation uses device ID but Metalium disallows opening the same device twice
g_device_map[device_id] = ctx->device;
Expand All @@ -987,4 +974,4 @@ GGML_CALL ggml_backend_t ggml_backend_reg_metalium_init(const char * params, voi
GGML_UNUSED(params);
GGML_UNUSED(user_data);
return ggml_backend_metalium_init();
}
}

0 comments on commit d5e4db8

Please sign in to comment.