diff --git a/kernels/portable/cpu/op__to_dim_order_copy.cpp b/kernels/portable/cpu/op__to_dim_order_copy.cpp index 31dd4fbb9d..bcbf6cc132 100644 --- a/kernels/portable/cpu/op__to_dim_order_copy.cpp +++ b/kernels/portable/cpu/op__to_dim_order_copy.cpp @@ -96,13 +96,17 @@ Tensor& _to_dim_order_copy_out( InvalidArgument, out); - ET_SWITCH_REALHB_TYPES( + if (self.numel() == 0) { + return out; + } + + ET_SWITCH_REALHBBF16_TYPES( self.scalar_type(), ctx, "dim_order_ops::_to_dim_order_copy.out", CTYPE_IN, [&] { - ET_SWITCH_REALHB_TYPES( + ET_SWITCH_REALHBBF16_TYPES( out.scalar_type(), ctx, "dim_order_ops::_to_dim_order_copy.out", diff --git a/kernels/test/op__to_dim_order_copy_test.cpp b/kernels/test/op__to_dim_order_copy_test.cpp index e888e0fc7f..073225a7d6 100644 --- a/kernels/test/op__to_dim_order_copy_test.cpp +++ b/kernels/test/op__to_dim_order_copy_test.cpp @@ -36,7 +36,9 @@ typedef std::map< std::type_index, std::variant< std::vector, - std::vector>> + std::vector, + std::vector, + std::vector>> FloatingTypeToDataMap; typedef std::map< @@ -381,9 +383,9 @@ TEST_F(OpToDimOrderCopyTest, NanInfSupported) { ScalarType::OUTPUT_DTYPE>(test_cases); #define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ - ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY #undef TEST_KERNEL @@ -413,6 +415,13 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) { -0.30919688936285893988}; // clang-format on + std::vector half_data; + std::vector bf16_data; + for (auto d : double_data) { + half_data.emplace_back(d); + bf16_data.emplace_back(d); + } + std::vector int64_data = { -1, -4, 2, -2, 3, 3, -3, -4, 3, 3, 0, 2, 0, -1, 0}; std::vector int32_data = { @@ -426,6 +435,8 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) { FloatingTypeToDataMap floating_point_data; floating_point_data[typeid(float)] = float_data; floating_point_data[typeid(double)] = double_data; + floating_point_data[typeid(exec_aten::Half)] = half_data; + floating_point_data[typeid(exec_aten::BFloat16)] = bf16_data; // Gathering all int data together for better traversial IntTypeToDataMap int_data; @@ -444,7 +455,7 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) { #define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ ET_FORALL_INT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); } TEST_F(OpToDimOrderCopyTest, MismatchedSizesDie) {