Skip to content

Commit

Permalink
fix slice op
Browse files Browse the repository at this point in the history
  • Loading branch information
marty1885 committed Sep 29, 2024
1 parent 9542273 commit 588000d
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions ggml/src/ggml-metalium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
#include <ttnn/operations/data_movement/transpose/transpose.hpp>
#include <ttnn/operations/data_movement/permute/permute.hpp>
#include <ttnn/operations/data_movement/concat/concat.hpp>
#include <ttnn/operations/eltwise/unary/unary.hpp>
#include <ttnn/operations/eltwise/unary/unary_composite.hpp>
#include <ttnn/operations/experimental/copy/typecast/typecast.hpp>
#include <tt_metal/detail/persistent_kernel_cache.hpp>
#include <ttnn/operations/normalization/softmax/softmax.hpp>
Expand Down Expand Up @@ -375,7 +373,7 @@ static tt::tt_metal::Tensor reshape_tt_tensor_into_ggml(const tt::tt_metal::Tens
tensor.shape()[2] < tt::constants::TILE_HEIGHT || tensor.shape()[3] < tt::constants::TILE_WIDTH) {
// This path is SLOW. Reshape on a tilized tensor only works when the last two dimensions are tile aligned
tt::tt_metal::LegacyShape begin({0, 0, 0, 0});
tt::tt_metal::LegacyShape end({tensor.shape()[0]-1, tensor.shape()[1]-1, tensor.shape()[2]-1, tensor.shape()[3]-1});
tt::tt_metal::LegacyShape end({tensor.shape()[0], tensor.shape()[1], tensor.shape()[2], tensor.shape()[3]});

tt::tt_metal::Tensor row_major_tensor = ttnn::untilize(tensor).cpu().unpad(begin, end);
tt::tt_metal::Tensor reshaped = row_major_tensor.reshape(target_shape);
Expand Down Expand Up @@ -455,7 +453,7 @@ static std::shared_ptr<tt::tt_metal::Tensor> realize_ggml_view_impl(const ggml_t
size_t remaining_offset = offset;
for(size_t i = GGML_MAX_DIMS - 1; i < GGML_MAX_DIMS; i--) {
start[i] = remaining_offset / src_stride[i];
end[i] = dst_size[i] + start[i] - 1;
end[i] = dst_size[i] + start[i];
remaining_offset = remaining_offset % src_stride[i];
}
std::reverse(start.begin(), start.end());
Expand Down Expand Up @@ -498,7 +496,7 @@ static std::shared_ptr<tt::tt_metal::Tensor> realize_ggml_view_impl(const ggml_t
// slow: grab the source tensor and unpad it
tt::tt_metal::LegacyShape start{0, 0, 0, uint32_t(offset / ggml_type_size(src0->type))};
auto dst_volume = ggml_nelements(tensor);
tt::tt_metal::LegacyShape end({0, 0, 0, uint32_t(dst_volume - 1) + start[3]});
tt::tt_metal::LegacyShape end({1, 1, 1, uint32_t(dst_volume) + start[3]});
auto t = parent->cpu().to(tt::tt_metal::Layout::ROW_MAJOR).unpad(start, end);
// TODO: I'm lazy and this copy is completely unnecessary. Only here because reshape_tt_tensor_into_ggml() needs a device tensor
t = ttnn::tilize_with_zero_padding(t.to(bufctx->device));
Expand Down

0 comments on commit 588000d

Please sign in to comment.