diff --git a/src/tim/vx/ops/max_pool_test.cc b/src/tim/vx/ops/max_pool_test.cc index 5010ff85..2ed8a9d8 100644 --- a/src/tim/vx/ops/max_pool_test.cc +++ b/src/tim/vx/ops/max_pool_test.cc @@ -29,6 +29,66 @@ #include "gtest/gtest.h" #include "test_utils.h" +#include "third_party/half/half.hpp" + +TEST(MAX, shape_8_8_1_1_float16_kernel_3x3_stride_2x2) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + using namespace half_float::literal; + + tim::vx::ShapeType input_shape({8, 8, 1, 1}); //whcn + tim::vx::ShapeType output_shape( + {3, 3, 1, 1}); //whcn + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT16, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT16, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + // Input data nchw + std::vector input_data = { + 1.0_h, 2.0_h, 3.0_h, 4.0_h, 5.0_h, 6.0_h, 7.0_h, 8.0_h, + 11.0_h, 12.0_h, 13.0_h, 14.0_h, 15.0_h, 16.0_h, 17.0_h, 18.0_h, + 21.0_h, 22.0_h, 23.0_h, 24.0_h, 25.0_h, 26.0_h, 27.0_h, 28.0_h, + 31.0_h, 32.0_h, 33.0_h, 34.0_h, 35.0_h, 36.0_h, 37.0_h, 38.0_h, + 41.0_h, 42.0_h, 43.0_h, 44.0_h, 45.0_h, 46.0_h, 47.0_h, 48.0_h, + 51.0_h, 52.0_h, 53.0_h, 54.0_h, 55.0_h, 56.0_h, 57.0_h, 58.0_h, + 61.0_h, 62.0_h, 63.0_h, 64.0_h, 65.0_h, 66.0_h, 67.0_h, 68.0_h, + 71.0_h, 72.0_h, 73.0_h, 74.0_h, 75.0_h, 76.0_h, 77.0_h, 78.0_h, + }; + + std::vector golden = { + 23.0_h, 25.0_h, 27.0_h, + 43.0_h, 45.0_h, 47.0_h, + 63.0_h, 65.0_h, 67.0_h, + }; + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::array ksize = {3, 3}; + std::array stride = {2, 2}; + auto round_type = tim::vx::RoundType::FLOOR; + + auto op = graph->CreateOperation( + tim::vx::PoolType::MAX, tim::vx::PadType::VALID, ksize, stride, round_type); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + + input_tensor->CopyDataToTensor(input_data.data()); + + EXPECT_TRUE(graph->Run()); + + uint32_t output_size = 1; + for (auto i : output_tensor->GetShape()) { + output_size *= i; + } + std::vector output(output_size); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(ArraysMatch(golden, output, (half_float::half)0.1)); +} + TEST(MAX, shape_32_3_1_fp32_kernel_2_stride_1) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph();