Skip to content

Commit

Permalink
Merge branch 'metalium-support' into metalium-mesh
Browse files Browse the repository at this point in the history
  • Loading branch information
marty1885 committed Nov 21, 2024
2 parents 3d8b673 + 9ff79cb commit fb59355
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
28 changes: 16 additions & 12 deletions ggml/src/ggml-metalium/ggml-metalium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
#include "impl/dispatch/dispatch_core_manager.hpp"
#include "ttnn/distributed/types.hpp"
#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp"
#include "ttnn/operations/data_movement/reshape_on_device/reshape.hpp"
#include "ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.hpp"
#include "ttnn/operations/eltwise/binary/binary_composite.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.hpp"
Expand Down Expand Up @@ -62,7 +60,6 @@

#include <memory>
#include <type_traits>
#include <unordered_map>
#include <variant>
#include <vector>

Expand Down Expand Up @@ -427,7 +424,8 @@ void tensor2ggml(const tt::tt_metal::Tensor& tensor, void* dst, ggml_type dst_gg
}
// Just putting the integer types here to remind me TT tensors can have integer types
// But not supported on Grayskull.
else if ((std::is_same_v<SrcType, bfloat16> && dst_ggtype == GGML_TYPE_BF16) ||
else if ((std::is_same_v<SrcType, float> && dst_ggtype == GGML_TYPE_F32) ||
(std::is_same_v<SrcType, bfloat16> && dst_ggtype == GGML_TYPE_BF16) ||
(std::is_same_v<SrcType, int32_t> && dst_ggtype == GGML_TYPE_I32) ||
(std::is_same_v<SrcType, int16_t> && dst_ggtype == GGML_TYPE_I16) ||
(std::is_same_v<SrcType, int8_t> && dst_ggtype == GGML_TYPE_I8)) {
Expand Down Expand Up @@ -641,11 +639,21 @@ static std::shared_ptr<tt::tt_metal::Tensor> realize_ggml_view_impl(const ggml_t
else if(ggml_n_dims(src0) == 1 && ggml_n_dims(tensor) > 1) {
// slow: grab the source tensor and unpad it
uint32_t offset_elements = offset / ggml_type_size(src0->type);
ttnn::SimpleShape start{0, 0, 0, offset_elements};

auto dst_volume = ggml_nelements(tensor);
ttnn::SimpleShape end({1, 1, 1, uint32_t(dst_volume) + offset_elements});
auto t = ttnn::untilize(*parent).cpu().unpad(start, end);
res = reshape_host_tt_tensor_into_ggml(t, parent->device(), tensor);
if(offset_elements % tt::constants::TILE_WIDTH == 0 && dst_volume % tt::constants::TILE_HEIGHT == 0) {
std::array<uint32_t, GGML_MAX_DIMS> step = {1, 1, 1, 1};
auto t = ttnn::slice(*parent, start, end, step, tt::tt_metal::MemoryConfig());
res = reshape_tt_tensor_into_ggml(t, tensor);
}
else {
// THIS is EXTREMELY SLOW. But it works
ttnn::SimpleShape start{0, 0, 0, offset_elements};
ttnn::SimpleShape end({1, 1, 1, uint32_t(dst_volume) + offset_elements});
tt::tt_metal::Tensor tmp = ttnn::untilize(*parent).cpu().unpad(start, end);
tmp = ttnn::tilize_with_zero_padding(tmp.to(bufctx->device));
res = reshape_host_tt_tensor_into_ggml(tmp, parent->device(), tensor);
}
}
// The fast path, this is what TTNN is designed for
else if(dst_size[0] % tt::constants::TILE_WIDTH == 0 && dst_size[1] % tt::constants::TILE_HEIGHT == 0 &&
Expand Down Expand Up @@ -1472,8 +1480,6 @@ static void ggml_backend_metalium_outer_product(ggml_backend_metalium_context *
auto src0 = realize_ggml_view(dst->src[0]);
auto src1 = realize_ggml_view(dst->src[1]);

std::cout << "src0: " << src0->shape() << " src1: " << src1->shape() << std::endl;

auto res = ttnn::outer(*src0, *src1);
*dst_meta = {
.tensor = std::make_shared<tt::tt_metal::Tensor>(res),
Expand Down Expand Up @@ -1645,8 +1651,6 @@ static void ggml_backend_metalium_buffer_set_tensor(ggml_backend_buffer_t buffer
tt::tt_metal::Tensor t(std::move(storage), ttnn::Shape(shape)
, tt::tt_metal::DataType::BFLOAT16, tt::tt_metal::Layout::ROW_MAJOR);

// I think we can allow this.. right?
// GGML_ASSERT(!bufctx->tensors.contains(offset));
tt::ARCH processor_class = bufctx->device->arch();
t = ttnn::tilize_with_zero_padding(t.to(bufctx->device));
tt::tt_metal::DataType final_type = ggml2tt_type(ggtype, processor_class);
Expand Down
16 changes: 12 additions & 4 deletions ggml/src/ggml-metalium/metalium-pch.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <unistd.h>
#ifdef __cplusplus
#include "common/base_types.hpp"
#include "common/bfloat16.hpp"
#include "common/constants.hpp"
#include "common/logger.hpp"
#include "device/tt_arch_types.h"
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
Expand All @@ -12,18 +14,24 @@

#include "host_api.hpp"
#include "impl/dispatch/command_queue.hpp"
#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp"
#include "ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.hpp"
#include "ttnn/operations/eltwise/binary/binary_composite.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/experimental/auto_format/auto_format.hpp"
#include "ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.hpp"
#include "ttnn/operations/normalization/softmax/device/softmax_op.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/types.hpp"
#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <mutex>
#include <optional>
#include <string_view>
#include <ttnn/core.hpp>
#include <ttnn/device.hpp>
#include <ttnn/operations/eltwise/binary/binary.hpp>
Expand All @@ -36,13 +44,12 @@
#include <ttnn/operations/normalization/rmsnorm/rmsnorm.hpp>
#include <ttnn/operations/data_movement/untilize/untilize.hpp>
#include <ttnn/operations/experimental/transformer/nlp_kv_cache_load_slice/nlp_kv_cache_load_slice.hpp>
#include <ttnn/operations/creation.hpp>
#include <ttnn/operations/eltwise/unary/unary_composite.hpp>
#include <ttnn/operations/data_movement/transpose/transpose.hpp>
#include <ttnn/operations/data_movement/permute/permute.hpp>
#include <ttnn/operations/data_movement/concat/concat.hpp>
#include <ttnn/operations/data_movement/repeat/repeat.hpp>
#include <ttnn/operations/eltwise/unary/unary.hpp>
#include <ttnn/operations/eltwise/unary/unary_composite.hpp>
#include <ttnn/operations/data_movement/concat/concat.hpp>
#include <ttnn/operations/experimental/copy/typecast/typecast.hpp>
#include <tt_metal/detail/persistent_kernel_cache.hpp>
#include <ttnn/operations/normalization/softmax/softmax.hpp>
Expand All @@ -51,4 +58,5 @@
#include <memory>
#include <type_traits>
#include <variant>
#include <vector>
#endif

0 comments on commit fb59355

Please sign in to comment.