Skip to content

Commit

Permalink
[luci-interpreter] Relax StrideSlice rank limitations (#14507)
Browse files Browse the repository at this point in the history
This commit removes assert that limit StrideSlice rank to 4. TF 2.8 allows to inference this operator with rank 5.

ONE-DCO-1.0-Signed-off-by: Mateusz Bencer <[email protected]>
  • Loading branch information
mbencer authored Jan 2, 2025
1 parent 518bd72 commit c93d12c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
2 changes: 1 addition & 1 deletion compiler/luci-interpreter/src/kernels/StridedSlice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void StridedSlice::configure()
assert(begin()->element_type() == DataType::S32);
assert(end()->element_type() == DataType::S32);
assert(strides()->element_type() == DataType::S32);
assert(input()->shape().num_dims() <= 4);
assert(input()->shape().num_dims() <= 5);
if (params().ellipsis_mask != 0)
{
throw std::runtime_error("ellipsis_mask is not implemented yet.");
Expand Down
44 changes: 44 additions & 0 deletions compiler/luci-interpreter/src/kernels/StridedSlice.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "kernels/TestUtils.h"
#include "luci_interpreter/TestMemoryManager.h"

#include <numeric>

namespace luci_interpreter
{
namespace kernels
Expand Down Expand Up @@ -107,6 +109,48 @@ TEST(StridedSliceTest, Uint8)
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape));
}

TEST(StridedSliceTest, 5DCase)
{
std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();

Shape input_shape{2, 3, 2, 2, 3};
std::vector<float> input_data(input_shape.num_elements());
std::iota(std::begin(input_data), std::end(input_data), 0);
Shape begin_shape{5};
std::vector<int32_t> begin_data{0, 0, 0, 0, 0};
Shape end_shape{5};
std::vector<int32_t> end_data{2, 3, 2, 2, 1};
Shape strides_shape{5};
std::vector<int32_t> strides_data{1, 1, 1, 1, 1};
Tensor input_tensor =
makeInputTensor<DataType::U8>(input_shape, 1.0f, 0, input_data, memory_manager.get());
Tensor begin_tensor =
makeInputTensor<DataType::S32>(begin_shape, begin_data, memory_manager.get());
Tensor end_tensor = makeInputTensor<DataType::S32>(end_shape, end_data, memory_manager.get());
Tensor strides_tensor =
makeInputTensor<DataType::S32>(strides_shape, strides_data, memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::U8, 1.0f, 0);

StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.ellipsis_mask = 0;
params.new_axis_mask = 0;
params.shrink_axis_mask = 0;

StridedSlice kernel(&input_tensor, &begin_tensor, &end_tensor, &strides_tensor, &output_tensor,
params);
kernel.configure();
memory_manager->allocate_memory(output_tensor);
kernel.execute();

std::vector<int32_t> output_shape{2, 3, 2, 2, 1};
std::vector<float> output_data{0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33,
36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69};
EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear(output_data));
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape));
}

} // namespace
} // namespace kernels
} // namespace luci_interpreter

0 comments on commit c93d12c

Please sign in to comment.