Skip to content

Commit

Permalink
reduce memory use during inference and fix some device memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
marty1885 committed Nov 20, 2024
1 parent cb4bfd2 commit 9f3754d
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions ggml/src/ggml-metalium/ggml-metalium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <cstring>
#include <mutex>
#include <optional>
#include <string_view>
#include <ttnn/core.hpp>
#include <ttnn/device.hpp>
#include <ttnn/operations/eltwise/binary/binary.hpp>
Expand Down Expand Up @@ -1733,19 +1734,24 @@ ggml_backend_metalium_buffer_init_tensor(ggml_backend_buffer_t buffer,
ggml_tensor *tensor)
{
ggml_backend_metalium_buffer_context * bufctx = (ggml_backend_metalium_buffer_context *)buffer->context;
bufctx->metadata_to_free.push_back(std::make_unique<TensorWithMetadata>(TensorWithMetadata{

// Tensors can be initialized multiple times. We need overwrite the old metadata (and potentially free the old TTNN tensor)
if(tensor->extra == NULL) {
bufctx->metadata_to_free.push_back(std::make_unique<TensorWithMetadata>());
tensor->extra = bufctx->metadata_to_free.back().get();
}
TensorWithMetadata* meta = (TensorWithMetadata*)tensor->extra;
*meta = {
.tensor = nullptr,
.ggtype = GGML_TYPE_COUNT,
.bufctx = bufctx
}));
tensor->extra = bufctx->metadata_to_free.back().get();
};

// HACK: Make KV cache work
std::string name(tensor->name);
std::string_view name(tensor->name);
if(name.find("cache") != std::string::npos && tensor->op == GGML_OP_NONE) {
TensorWithMetadata* meta = (TensorWithMetadata*)tensor->extra;
std::vector<uint32_t> shape(tensor->ne, tensor->ne + GGML_MAX_DIMS);
std::reverse(shape.begin(), shape.end());
// TODO: Check if we can make TILE tensors and not pad on CPU
auto t = ttnn::zeros(ttnn::Shape(shape), ggml2tt_type(tensor->type, bufctx->device->arch()), tt::tt_metal::Layout::ROW_MAJOR);
t = ttnn::tilize_with_zero_padding(t.to(bufctx->device));
meta->tensor = std::make_shared<tt::tt_metal::Tensor>(std::move(t));
Expand Down

0 comments on commit 9f3754d

Please sign in to comment.