diff --git a/exir/passes/dim_order_ops_registry.py b/exir/passes/dim_order_ops_registry.py index 11c26788840..cb3a440926c 100644 --- a/exir/passes/dim_order_ops_registry.py +++ b/exir/passes/dim_order_ops_registry.py @@ -10,14 +10,10 @@ from torch.library import impl, Library -lib = Library("dim_order_ops", "DEF") -lib.define( - "_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor" -) -# Out variant drops TensorOptions +lib = Library("dim_order_ops", "FRAGMENT") lib.define( - "_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" + "_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor" ) @@ -34,14 +30,12 @@ def _to_dim_order_copy_impl(*args, **kwargs): return _op_impl(torch.ops.aten._to_copy, *args, **kwargs) -@impl(lib, "_to_dim_order_copy.out", "CompositeImplicitAutograd") -def _to_dim_order_copy_out_impl(*args, **kwargs): - return _op_impl(torch.ops.aten._to_copy.out, *args, **kwargs) - +def get_dim_order_ops_map(): + """ + Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup + """ -""" -Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup -""" -DimOrderOpsMap = { - "aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default, -} + DimOrderOpsMap = { + "aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + } + return DimOrderOpsMap diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index 5a3c0f3a912..56d22e0afd5 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -11,7 +11,7 @@ from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.dim_order_utils import get_dim_order from executorch.exir.pass_base import ExportPass, ProxyValue -from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap +from executorch.exir.passes.dim_order_ops_registry import get_dim_order_ops_map logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) @@ -26,6 +26,7 @@ class MemoryFormatOpsPass(ExportPass): """ def call_operator(self, op, args, kwargs, meta): + DimOrderOpsMap = get_dim_order_ops_map() if not (isinstance(op, EdgeOpOverload) and op.__name__ in DimOrderOpsMap): return super().call_operator( op, diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 15e73dd413c..da6f3e630c3 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -28,6 +28,9 @@ class MemoryFormatTestSet: class TestMemoryFormatOpsPass(unittest.TestCase): def memory_format_test_runner(self, test_set: MemoryFormatTestSet): + # TODO(T180746545): automatic load dim order operator + torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib") + aten_op_str = "torch.ops.aten._to_copy.default" edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" diff --git a/kernels/portable/cpu/op__to_dim_order_copy.cpp b/kernels/portable/cpu/op__to_dim_order_copy.cpp new file mode 100644 index 00000000000..bd4cdacc459 --- /dev/null +++ b/kernels/portable/cpu/op__to_dim_order_copy.cpp @@ -0,0 +1,177 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#ifdef USE_ATEN_LIB +// #include +#endif + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; +using SizesArrayRef = exec_aten::ArrayRef; +using DimOrderArrayRef = exec_aten::ArrayRef; +using MemoryFormat = exec_aten::MemoryFormat; + +template +using OptionalArrayRef = exec_aten::OptionalArrayRef; + +template +using Optional = exec_aten::optional; + +#ifdef USE_ATEN_LIB + +namespace { +Optional get_memory_format(OptionalArrayRef dim_order) { + if (!dim_order.has_value()) { + return exec_aten::nullopt; + } + if (is_contiguous_dim_order( + dim_order.value().data(), dim_order.value().size())) { + return MemoryFormat::Contiguous; + } else if (is_channels_last_dim_order( + dim_order.value().data(), dim_order.value().size())) { + return MemoryFormat::ChannelsLast; + } else { + ET_ASSERT_UNREACHABLE(); + } +} +} // namespace + +// TODO(T179434631) : enable aten mode if needed +// _to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? +// dim_order=None, Tensor(a!) out) -> Tensor(a!) +Tensor& _to_dim_order_copy_out( + RuntimeContext& ctx, + const Tensor& self, + bool non_blocking, + OptionalArrayRef dim_order, + Tensor& out) { + // ET_KERNEL_CHECK( + // ctx, + // check__to_dim_order_copy_args(self, non_blocking, dim_order, out), + // InvalidArgument, + // out); + + Optional memory_format = get_memory_format(dim_order); + at::_to_copy_outf(self, non_blocking, memory_format, out); + + return out; +} + +Tensor& _to_dim_order_copy_out( + const Tensor& self, + bool non_blocking, + OptionalArrayRef dim_order, + Tensor& out) { + exec_aten::RuntimeContext ctx{}; + return _to_dim_order_copy_out(ctx, self, non_blocking, dim_order, out); +} + +#else + +namespace { + +// TODO(T179241236): Update core/exec_aten/util/tensor_util.h to support dim +// order other than contiguous. +int64_t coordinateToIndexWithDimOrder( + const Tensor& self, + const size_t* cur_indices) { + int64_t index = 0; + exec_aten::StridesType strides[kTensorDimensionLimit]; + SizesArrayRef sizes = self.sizes(); + DimOrderArrayRef dim_order = self.dim_order(); + + dim_order_to_stride_nocheck( + sizes.data(), dim_order.data(), sizes.size(), strides); + for (size_t i = 0; i < self.dim(); ++i) { + index += cur_indices[i] * strides[i]; + } + return index; +} + +template +void _to_dim_order_copy_impl(const Tensor& self, Tensor& out) { + auto self_data = self.mutable_data_ptr(); + auto out_data = out.mutable_data_ptr(); + + size_t coordinate[kTensorDimensionLimit] = {0}; + + // Copy data from self to out index by index. Same index in self and out + // should have same value, no matter the order of dimensions. + for (ssize_t i = 0; i < self.numel(); i++) { + // Update the current indices. + for (ssize_t j = self.dim() - 1; j >= 0; j--) { + if (coordinate[j] + 1 < self.size(j)) { + coordinate[j]++; + break; + } else { + coordinate[j] = 0; + } + } + // Get the corresponding index of self_data and out_data by stride. + int64_t self_data_index = coordinateToIndexWithDimOrder(self, coordinate); + int64_t out_data_index = coordinateToIndexWithDimOrder(out, coordinate); + + out_data[out_data_index] = + static_cast(self_data[self_data_index]); + } +} +} // namespace + +// _to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? +// dim_order=None, Tensor(a!) out) -> Tensor(a!) +Tensor& _to_dim_order_copy_out( + RuntimeContext& ctx, + const Tensor& self, + bool non_blocking, + OptionalArrayRef dim_order, + Tensor& out) { + (void)ctx; + ET_KERNEL_CHECK( + ctx, + check__to_dim_order_copy_args(self, non_blocking, dim_order, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, self.sizes()) == torch::executor::Error::Ok, + InvalidArgument, + out); + + ET_SWITCH_REALHB_TYPES( + self.scalar_type(), ctx, "_to_dim_order_copy_out", CTYPE_IN, [&] { + ET_SWITCH_REALHB_TYPES( + out.scalar_type(), ctx, "_to_dim_order_copy_out", CTYPE_OUT, [&] { + _to_dim_order_copy_impl(self, out); + }); + }); + + return out; +} + +Tensor& _to_dim_order_copy_out( + const Tensor& self, + bool non_blocking, + OptionalArrayRef dim_order, + Tensor& out) { + exec_aten::RuntimeContext context{}; + return _to_dim_order_copy_out(context, self, non_blocking, dim_order, out); +} + +#endif + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/targets.bzl b/kernels/portable/cpu/targets.bzl index 2b8805d2484..904e2ef7e66 100644 --- a/kernels/portable/cpu/targets.bzl +++ b/kernels/portable/cpu/targets.bzl @@ -951,6 +951,12 @@ _ATEN_OPS = ( # ops, and must be split. They can, however, share common code via a library dep # if necessary. _CUSTOM_OPS = ( + op_target( + name = "op__to_dim_order_copy", + deps = [ + "//executorch/kernels/portable/cpu/util:copy_ops_util", + ], + ), op_target( name = "op_allclose", ), diff --git a/kernels/portable/cpu/util/copy_ops_util.cpp b/kernels/portable/cpu/util/copy_ops_util.cpp index 69f7dc94c3a..556d55a3692 100644 --- a/kernels/portable/cpu/util/copy_ops_util.cpp +++ b/kernels/portable/cpu/util/copy_ops_util.cpp @@ -9,6 +9,7 @@ #include #include +#include #include namespace torch { @@ -733,6 +734,48 @@ bool check_to_copy_args( return true; } +bool check__to_dim_order_copy_args( + const Tensor& input, + bool non_blocking, + exec_aten::OptionalArrayRef dim_order, + Tensor& out) { + // Right now we only support blocking data transfer + ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false); + + if (dim_order.has_value()) { + exec_aten::ArrayRef dim_order_ref = dim_order.value(); + + // dim order size shall equal to input dim + ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim()); + + // Right now we only focus on contiguous and channels_last memory format + // TODO(T179248280): rename default_dim_order to contiguous_dim_order. + // Default here is ambiguous and easy to confuse with preserve. + ET_LOG_AND_RETURN_IF_FALSE( + is_channels_last_dim_order( + dim_order.value().data(), dim_order.value().size()) || + is_contiguous_dim_order( + dim_order.value().data(), dim_order.value().size())); + + // Out tensor shall have same dim order as dim_order + auto out_dim_order = out.dim_order(); + ET_LOG_AND_RETURN_IF_FALSE(out_dim_order.size() == dim_order_ref.size()); + for (size_t i = 0; i < dim_order_ref.size(); i++) { + ET_LOG_AND_RETURN_IF_FALSE(out_dim_order[i] == dim_order_ref[i]); + } + } else { // dim_order is not set, preserve the dim order of input + + // Out tensor shall have same dim order as input dim_order + auto out_dim_order = out.dim_order(); + auto input_dim_order = input.dim_order(); + ET_LOG_AND_RETURN_IF_FALSE(out_dim_order.size() == input_dim_order.size()); + for (size_t i = 0; i < input_dim_order.size(); i++) { + ET_LOG_AND_RETURN_IF_FALSE(out_dim_order[i] == input_dim_order[i]); + } + } + return true; +} + bool check_unsqueeze_copy_args( const Tensor input, int64_t dim, diff --git a/kernels/portable/cpu/util/copy_ops_util.h b/kernels/portable/cpu/util/copy_ops_util.h index dc8e5902aec..88e8500bf54 100644 --- a/kernels/portable/cpu/util/copy_ops_util.h +++ b/kernels/portable/cpu/util/copy_ops_util.h @@ -149,6 +149,12 @@ bool check_to_copy_args( exec_aten::optional memory_format, Tensor& out); +bool check__to_dim_order_copy_args( + const Tensor& input, + bool non_blocking, + exec_aten::OptionalArrayRef dim_order, + Tensor& out); + bool check_unsqueeze_copy_args( const Tensor input, int64_t dim, diff --git a/kernels/portable/custom_ops.yaml b/kernels/portable/custom_ops.yaml index e8ae0812674..e5468dd6ea9 100644 --- a/kernels/portable/custom_ops.yaml +++ b/kernels/portable/custom_ops.yaml @@ -35,3 +35,13 @@ kernels: - arg_meta: null kernel_name: torch::executor::linear_scratch_example + +- func: dim_order_ops::_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: torch::executor::_to_dim_order_copy_out + +# - func: dim_order_ops::_to_dim_order_copy(Tensor self, bool non_blocking=False, ScalarType? dtype=None, int[]? dim_order=None) -> Tensor +# kernels: +# - arg_meta: null +# kernel_name: torch::executor::_to_dim_order_copy_tensor diff --git a/kernels/portable/test/op__to_dim_order_copy_test.cpp b/kernels/portable/test/op__to_dim_order_copy_test.cpp new file mode 100644 index 00000000000..96806b72bf8 --- /dev/null +++ b/kernels/portable/test/op__to_dim_order_copy_test.cpp @@ -0,0 +1,644 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; +using exec_aten::ArrayRef; +using exec_aten::optional; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::testing::TensorFactory; + +Tensor& op__to_dim_order_copy_out( + const Tensor& self, + bool non_blocking, + exec_aten::optional> dim_order, + Tensor& out) { + exec_aten::RuntimeContext context{}; + return torch::executor::native::_to_dim_order_copy_out( + context, self, non_blocking, dim_order, out); +} + +/* Here we temporary not try to implement or test the behavior about casting a + * number can not be represented in some type to this type (e.g. inf to int32_t + * nan to int64_t or 2147483648 to int32_t), because + * - a. The result of such kind of cast is undefined according to c++ standard; + * - b. No explicit rules can be found in core pytorch for such transaction (not + * same as static_cast or any other casting function in c++); + * - c. If user tries to cast a unrepresentable value to certain type, they + * should take the risk; + * - d. Even though we can always use if/switch to cover these boundry cases, + * the code will be lengthy and jumbled. I believe using these disordered + * code to meet some undefine behavior is meaningless, and we can not + * cover all such cases. + */ + +namespace { + +// Cast float vector to OUTPUT_CTYPE vector +template +std::vector vector_type_cast(std::vector input) { + std::vector output(input.size()); + std::transform(input.begin(), input.end(), output.begin(), [](INPUT_CTYPE x) { + return static_cast(x); + }); + return output; +} +} // namespace + +template +struct ToTestCase { + const std::vector sizes; + const std::vector data_in; + const std::vector data_out; +}; + +// Each test has different combination of input and output types. Therefore it +// is a little bit mess if create template test case and custom data types for +// both input data and output data. +// We choose another way: for all test cases, their data are all in double. And +// we are gonna cast them into desired type when delievering them into tf.make +// function. +// Based on our experiments, type cast of core PyTorch is same as static_cast +// in c++ in the representable scope, so here we believe using static_cast to +// generate ground truth is reasonable. +template < + typename INPUT_CTYPE, + ScalarType INPUT_DTYPE, + typename OUTPUT_CTYPE, + ScalarType OUTPUT_DTYPE> +void test_runner_static_cast( + std::vector> test_cases) { + TensorFactory tf_in; + TensorFactory tf_out; + + for (auto test_case : test_cases) { + auto data_in = vector_type_cast(test_case.data_in); + auto data_out = vector_type_cast(data_in); + + Tensor input = tf_in.make(test_case.sizes, data_in); + Tensor output = tf_out.zeros_like(input); + + std::vector dim_order_vec; + for (int64_t i = 0; i < input.dim(); i++) { + dim_order_vec.push_back(i); + } + ArrayRef dim_order(dim_order_vec.data(), dim_order_vec.size()); + + Tensor ret = op__to_dim_order_copy_out( + /*self=*/input, + /*non_blocking=*/false, + dim_order, + output); + + Tensor expected = tf_out.make(test_case.sizes, data_out); + + // The original tensor a should share same value with the out variable and + // return variable of to function + EXPECT_TENSOR_EQ(ret, output); + EXPECT_TENSOR_EQ(ret, expected); + } +} + +// Regular test for to_copy.out +// Test if to_copy.out works well under all kinds of data pairs +TEST(OpToDimOrderCopyTest, AllDtypesSupported) { + std::vector> test_cases = { + { + /*sizes=*/{2, 4}, /*data_in=*/ + {2.11, 3.2, 2.3, 4.0, 1.1, 5.2, 1.1, 6.3}, /*data_out=*/ + {}, // data_out shouldn't be used in test_runner_static_cast + }, + { + /*sizes=*/{3, 4, 0, 5}, + /*data_in=*/{}, + /*data_out=*/{}, + }, + { + /*sizes=*/{}, + /*data_in=*/{10.0}, + /*data_out=*/{}, // data_out shouldn't be used in + // test_runner_static_cast + }, + }; + +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \ + test_runner_static_cast< \ + INPUT_CTYPE, \ + ScalarType::INPUT_DTYPE, \ + OUTPUT_CTYPE, \ + ScalarType::OUTPUT_DTYPE>(test_cases); + +#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ + ET_FORALL_REAL_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + + ET_FORALL_REAL_TYPES(TEST_ENTRY); + +#undef TEST_ENTRY +#undef TEST_KERNEL +} + +template +void test_runner_to_bool( + std::vector test_case, + std::vector data_out) { + TensorFactory tf_in; + TensorFactory tf_out; + + auto data_in = vector_type_cast(test_case); + + Tensor input = tf_in.make({(int)test_case.size()}, data_in); + Tensor output = tf_out.zeros_like(input); + + std::vector dim_order_vec; + for (int i = 0; i < input.dim(); i++) { + dim_order_vec.push_back(i); + } + ArrayRef dim_order(dim_order_vec.data(), dim_order_vec.size()); + + Tensor ret = op__to_dim_order_copy_out( + /*self=*/input, + /*non_blocking=*/false, + dim_order, + output); + + Tensor expected = tf_out.make({(int)data_out.size()}, data_out); + + // The return value of op__to_dim_order_copy_out and the values written to + // output should be the same. + EXPECT_TENSOR_EQ(ret, output); + // The return value of op__to_dim_order_copy_out and the values in expected + // which are the reference values should be the same. + EXPECT_TENSOR_EQ(ret, expected); +} + +template +void test_runner_from_bool( + std::vector test_case, + std::vector out) { + TensorFactory tf_in; + TensorFactory tf_out; + + auto data_out = vector_type_cast(out); + + Tensor input = tf_in.make({(int)test_case.size()}, test_case); + Tensor output = tf_out.zeros_like(input); + + std::vector dim_order_vec; + for (int64_t i = 0; i < input.dim(); i++) { + dim_order_vec.push_back(i); + } + ArrayRef dim_order(dim_order_vec.data(), dim_order_vec.size()); + + Tensor ret = op__to_dim_order_copy_out( + /*self=*/input, + /*non_blocking=*/false, + dim_order, + output); + + Tensor expected = tf_out.make({(int)data_out.size()}, data_out); + + // The return value of op__to_dim_order_copy_out and the values written to + // output should be the same. + EXPECT_TENSOR_EQ(ret, output); + // The return value of op__to_dim_order_copy_out and the values in expected + // which are the reference values should be the same. + EXPECT_TENSOR_EQ(ret, expected); +} + +TEST(OpToDimOrderCopyTest, BoolTests) { + std::vector test_case_to_bool = {1.1, 2.2, 0}; + std::vector result_to_bool = {true, true, false}; +#define TEST_TO_BOOL(INPUT_CTYPE, INPUT_DTYPE) \ + test_runner_to_bool( \ + test_case_to_bool, result_to_bool); + ET_FORALL_REAL_TYPES(TEST_TO_BOOL); + + std::vector test_case_from_bool = {true, true, false}; + std::vector result_from_bool = {1.0, 1.0, 0}; +#define TEST_FROM_BOOL(OUTPUT_CTYPE, OUTPUT_DTYPE) \ + test_runner_from_bool( \ + test_case_from_bool, result_from_bool); + ET_FORALL_REAL_TYPES(TEST_FROM_BOOL); +} + +TEST(OpToDimOrderCopyTest, NanInfSupported) { + constexpr auto floatInfinity = std::numeric_limits::infinity(); + std::vector> test_cases = {{ + /*sizes=*/{2, 4}, + /*data_in=*/{2, 3, NAN, 4, floatInfinity, 5, -floatInfinity, 6}, + /*data_out=*/{2, 3, NAN, 4, floatInfinity, 5, -floatInfinity, 6}, + }}; + +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \ + test_runner_static_cast< \ + INPUT_CTYPE, \ + ScalarType::INPUT_DTYPE, \ + OUTPUT_CTYPE, \ + ScalarType::OUTPUT_DTYPE>(test_cases); + +#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ + ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + +#undef TEST_ENTRY +#undef TEST_KERNEL +} + +// To further emphasize the accuracy of our op_to, we test the conversion +// from floating-point types to signed int types directly by the test cases +// generated by core Pytorch directly. Such data is random generated in [-5, 5]. + +// clang-format off +typedef std::map< + std::type_index, + std::variant< + std::vector, + std::vector>> + FloatingTypeToDataMap; + +typedef std::map< + std::type_index, + std::variant< + std::vector, + std::vector, + std::vector, + std::vector, + std::vector>> + IntTypeToDataMap; +// clang-format on + +template < + typename INPUT_CTYPE, + ScalarType INPUT_DTYPE, + typename OUTPUT_CTYPE, + ScalarType OUTPUT_DTYPE> +void test_runner_hardcode_data( + FloatingTypeToDataMap floating_point_data, + IntTypeToDataMap int_data) { + TensorFactory tf_in; + TensorFactory tf_out; + + if (typeid(OUTPUT_CTYPE) == typeid(uint8_t)) { + // Would cause underflow when testing uint8_t. + return; + } + + ToTestCase test_case = { + /*sizes=*/{3, 5}, /*data_in=*/ + std::get>( + floating_point_data[typeid(INPUT_CTYPE)]), + /*data_out=*/ + std::get>(int_data[typeid(OUTPUT_CTYPE)])}; + + Tensor input = tf_in.make(test_case.sizes, test_case.data_in); + Tensor output = tf_out.zeros_like(input); + + std::vector dim_order_vec; + for (int64_t i = 0; i < input.dim(); i++) { + dim_order_vec.push_back(i); + } + ArrayRef dim_order(dim_order_vec.data(), dim_order_vec.size()); + + Tensor ret = op__to_dim_order_copy_out( + /*self=*/input, + /*non_blocking=*/false, + dim_order, + output); + + Tensor expected = tf_out.make(test_case.sizes, test_case.data_out); + + // The original tensor a should share same value with the out variable and + // return variable of to function + EXPECT_TENSOR_EQ(ret, output); + EXPECT_TENSOR_EQ(ret, expected); +} + +TEST(OpToDimOrderCopyTest, HardcodeFloatConvertInt) { + // Hardcode input and output generated from core PyTorch + // clang-format off + std::vector float_data = { + -1.47900056838989257812, -4.59277725219726562500, + 2.15365791320800781250, -2.55494546890258789062, + 3.06999135017395019531, 3.27460670471191406250, + -3.98865103721618652344, -4.81065988540649414062, + 3.67902207374572753906, 3.72226405143737792969, + 0.80567771196365356445, 2.23788332939147949219, + -0.52035576105117797852, -1.58493483066558837891, + -0.30919688940048217773}; + + std::vector double_data = { + -1.47900053955270172068, -4.59277735274143061872, + 2.15365796963871947156, -2.55494554556038755422, + 3.06999137834642255029, 3.27460679459944969949, + -3.98865109243288795682, -4.81065977167646074975, + 3.67902198302105531980, 3.72226414774102742911, + 0.80567768667100203572, 2.23788335717029518435, + -0.52035578832931150828, -1.58493480710766210251, + -0.30919688936285893988}; + // clang-format on + + std::vector int64_data = { + -1, -4, 2, -2, 3, 3, -3, -4, 3, 3, 0, 2, 0, -1, 0}; + std::vector int32_data = { + -1, -4, 2, -2, 3, 3, -3, -4, 3, 3, 0, 2, 0, -1, 0}; + std::vector int16_data = { + -1, -4, 2, -2, 3, 3, -3, -4, 3, 3, 0, 2, 0, -1, 0}; + std::vector int8_data = { + -1, -4, 2, -2, 3, 3, -3, -4, 3, 3, 0, 2, 0, -1, 0}; + + // Gathering all floating point data together for better traversial + FloatingTypeToDataMap floating_point_data; + floating_point_data[typeid(float)] = float_data; + floating_point_data[typeid(double)] = double_data; + + // Gathering all int data together for better traversial + IntTypeToDataMap int_data; + int_data[typeid(int64_t)] = int64_data; + int_data[typeid(int32_t)] = int32_data; + int_data[typeid(int16_t)] = int16_data; + int_data[typeid(int8_t)] = int8_data; + +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \ + test_runner_hardcode_data< \ + INPUT_CTYPE, \ + ScalarType::INPUT_DTYPE, \ + OUTPUT_CTYPE, \ + ScalarType::OUTPUT_DTYPE>(floating_point_data, int_data); + +#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ + ET_FORALL_INT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +} + +TEST(OpToDimOrderCopyTest, MismatchedSizesDie) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen kernel can handle mismatched sizes"; + } + TensorFactory tf; + Tensor input = tf.make(/*sizes=*/{3, 1, 1, 2}, /*data=*/{1, 2, 3, 4, 5, 6}); + Tensor out = tf.zeros({3, 2, 1, 1}); + std::vector dim_order_vec; + for (int64_t i = 0; i < input.dim(); i++) { + dim_order_vec.push_back(i); + } + ArrayRef dim_order(dim_order_vec.data(), dim_order_vec.size()); + + ET_EXPECT_KERNEL_FAILURE(op__to_dim_order_copy_out( + /*self=*/input, + /*non_blocking=*/false, + dim_order, + out)); +} + +// Only contiguous memory is supported, the memory type MemoryFormat::Contiguous +// should not be allowed. The function is expected death if using the illegal +// memory format. +TEST(OpToDimOrderCopyTest, MismatchedMemoryFormatDies) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen kernel can handle non contiguous memory formats"; + } + TensorFactory tf_in; + TensorFactory tf_out; + Tensor input = + tf_in.make(/*sizes=*/{3, 1, 1, 2}, /*data=*/{1, 2, 3, 4, 5, 6}); + Tensor out = tf_out.zeros({3, 1, 1, 2}); + + std::vector dim_order_vec; + for (int64_t i = 0; i < input.dim(); i++) { + dim_order_vec.push_back(i); + } + + // mutate dim_order_vec to create a illegal one. + dim_order_vec[1] = 3; + dim_order_vec[3] = 1; + ArrayRef dim_order(dim_order_vec.data(), dim_order_vec.size()); + + ET_EXPECT_KERNEL_FAILURE(op__to_dim_order_copy_out( + /*self=*/input, + /*non_blocking=*/false, + dim_order, + out)); +} + +// Only blocking data transfer supported +TEST(OpToDimOrderCopyTest, MismatchedBlockingDie) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen kernel can handle non blocking data transfer"; + } + TensorFactory tf; + Tensor input = tf.make(/*sizes=*/{3, 1, 1, 2}, /*data=*/{1, 2, 3, 4, 5, 6}); + Tensor out = tf.zeros(/*sizes=*/{3, 1, 1, 2}); + + std::vector dim_order_vec; + for (int64_t i = 0; i < input.dim(); i++) { + dim_order_vec.push_back(i); + } + ArrayRef dim_order(dim_order_vec.data(), dim_order_vec.size()); + + ET_EXPECT_KERNEL_FAILURE(op__to_dim_order_copy_out( + /*self=*/input, + /*non_blocking=*/true, + dim_order, + out)); +} + +/* %python +import torch +torch.manual_seed(0) +x = torch.rand(2, 3) +res = x.to(non_blocking = False, memory_format = torch.preserve_format) +op = "op__to_dim_order_copy_out" +opt_setup_params = """ + bool non_blocking = false; + optional memory_format; +""" +opt_extra_params = "non_blocking, memory_format," +out_args = "out_shape, dynamism" +dtype = "ScalarType::Float" +check = "EXPECT_TENSOR_EQ" */ + +void test_dynamic_shape( + const std::vector& out_shape, + enum torch::executor::TensorShapeDynamism dynamism) { + /* %python + %rewrite(unary_op) */ + + TensorFactory tf; + + Tensor x = tf.make( + {2, 3}, + {0.49625658988952637, + 0.7682217955589294, + 0.08847743272781372, + 0.13203048706054688, + 0.30742281675338745, + 0.6340786814689636}); + Tensor expected = tf.make( + {2, 3}, + {0.49625658988952637, + 0.7682217955589294, + 0.08847743272781372, + 0.13203048706054688, + 0.30742281675338745, + 0.6340786814689636}); + + bool non_blocking = false; + + Tensor out = tf.zeros(out_shape, dynamism); + + std::vector dim_order_vec; + for (int64_t i = 0; i < x.dim(); i++) { + dim_order_vec.push_back(i); + } + ArrayRef dim_order(dim_order_vec.data(), dim_order_vec.size()); + + Tensor ret = op__to_dim_order_copy_out( + /*self=*/x, non_blocking, dim_order, out); + + EXPECT_TENSOR_EQ(out, expected); + EXPECT_TENSOR_EQ(ret, expected); +} + +TEST(OpToDimOrderCopyTest, DynamicShapeUpperBoundSameAsExpected) { + test_dynamic_shape( + {2, 3}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); +} + +TEST(OpToDimOrderCopyTest, DynamicShapeUpperBoundLargerThanExpected) { + test_dynamic_shape( + {10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); +} + +TEST(OpToDimOrderCopyTest, DynamicShapeUnbound) { + if (!torch::executor::testing::SupportedFeatures::get()->output_resize) { + GTEST_SKIP() << "Dynamic shape unbound not supported"; + } + test_dynamic_shape( + {1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND); +} + +TEST(OpToDimOrderCopyTest, ContiguousToChannelsLast) { + TensorFactory tf; + + Tensor x = tf.make_with_dimorder( + {3, 5, 2, 2}, + {0.2432, 0.5248, 0.5361, 0.8513, 0.8184, 0.8206, 0.7357, 0.9655, 0.6138, + 0.1112, 0.2799, 0.1079, 0.9680, 0.2548, 0.0393, 0.6002, 0.2257, 0.8766, + 0.2715, 0.1595, 0.2029, 0.7026, 0.6982, 0.8529, 0.4405, 0.6560, 0.9217, + 0.6372, 0.2446, 0.6590, 0.3866, 0.7185, 0.4439, 0.5346, 0.3179, 0.4492, + 0.3491, 0.6970, 0.8456, 0.2516, 0.2345, 0.2924, 0.7695, 0.0911, 0.8530, + 0.8560, 0.6909, 0.7719, 0.8923, 0.5546, 0.6978, 0.8151, 0.3007, 0.3961, + 0.8416, 0.4296, 0.7203, 0.8963, 0.3597, 0.5552}); + + Tensor out = tf.full_channels_last({3, 5, 2, 2}, 0.0); + Tensor expected = tf.make_with_dimorder( + {3, 5, 2, 2}, + {0.2432, 0.8184, 0.6138, 0.9680, 0.2257, 0.5248, 0.8206, 0.1112, 0.2548, + 0.8766, 0.5361, 0.7357, 0.2799, 0.0393, 0.2715, 0.8513, 0.9655, 0.1079, + 0.6002, 0.1595, 0.2029, 0.4405, 0.2446, 0.4439, 0.3491, 0.7026, 0.6560, + 0.6590, 0.5346, 0.6970, 0.6982, 0.9217, 0.3866, 0.3179, 0.8456, 0.8529, + 0.6372, 0.7185, 0.4492, 0.2516, 0.2345, 0.8530, 0.8923, 0.3007, 0.7203, + 0.2924, 0.8560, 0.5546, 0.3961, 0.8963, 0.7695, 0.6909, 0.6978, 0.8416, + 0.3597, 0.0911, 0.7719, 0.8151, 0.4296, 0.5552}, + /*dim_order=*/{0, 2, 3, 1}); + + std::vector dim_order_vec = {0, 2, 3, 1}; + exec_aten::ArrayRef dim_order( + dim_order_vec.data(), dim_order_vec.size()); + Tensor ret = op__to_dim_order_copy_out( + /*self*/ x, /*non_blocking*/ false, /*dim_order*/ dim_order, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpToDimOrderCopyTest, ChannelsLastToContiguous) { + TensorFactory tf; + + Tensor out = tf.full({3, 5, 2, 2}, 0.0); + Tensor x = tf.make_with_dimorder( + {3, 5, 2, 2}, + {0.2432, 0.8184, 0.6138, 0.9680, 0.2257, 0.5248, 0.8206, 0.1112, 0.2548, + 0.8766, 0.5361, 0.7357, 0.2799, 0.0393, 0.2715, 0.8513, 0.9655, 0.1079, + 0.6002, 0.1595, 0.2029, 0.4405, 0.2446, 0.4439, 0.3491, 0.7026, 0.6560, + 0.6590, 0.5346, 0.6970, 0.6982, 0.9217, 0.3866, 0.3179, 0.8456, 0.8529, + 0.6372, 0.7185, 0.4492, 0.2516, 0.2345, 0.8530, 0.8923, 0.3007, 0.7203, + 0.2924, 0.8560, 0.5546, 0.3961, 0.8963, 0.7695, 0.6909, 0.6978, 0.8416, + 0.3597, 0.0911, 0.7719, 0.8151, 0.4296, 0.5552}, + /*dim_order=*/{0, 2, 3, 1}); + + Tensor expected = tf.make_with_dimorder( + {3, 5, 2, 2}, + {0.2432, 0.5248, 0.5361, 0.8513, 0.8184, 0.8206, 0.7357, 0.9655, 0.6138, + 0.1112, 0.2799, 0.1079, 0.9680, 0.2548, 0.0393, 0.6002, 0.2257, 0.8766, + 0.2715, 0.1595, 0.2029, 0.7026, 0.6982, 0.8529, 0.4405, 0.6560, 0.9217, + 0.6372, 0.2446, 0.6590, 0.3866, 0.7185, 0.4439, 0.5346, 0.3179, 0.4492, + 0.3491, 0.6970, 0.8456, 0.2516, 0.2345, 0.2924, 0.7695, 0.0911, 0.8530, + 0.8560, 0.6909, 0.7719, 0.8923, 0.5546, 0.6978, 0.8151, 0.3007, 0.3961, + 0.8416, 0.4296, 0.7203, 0.8963, 0.3597, 0.5552}); + + std::vector dim_order_vec = {0, 1, 2, 3}; + exec_aten::ArrayRef dim_order( + dim_order_vec.data(), dim_order_vec.size()); + Tensor ret = op__to_dim_order_copy_out( + /*self*/ x, /*non_blocking*/ false, /*dim_order*/ dim_order, out); + + EXPECT_TENSOR_EQ(out, expected); + EXPECT_TENSOR_EQ(ret, expected); +} + +TEST(OpToDimOrderCopyTest, PreserveChanneslLast) { + TensorFactory tf; + + Tensor out = tf.full_channels_last({3, 5, 2, 2}, 0.0); + Tensor x = tf.make_with_dimorder( + {3, 5, 2, 2}, + {0.2432, 0.8184, 0.6138, 0.9680, 0.2257, 0.5248, 0.8206, 0.1112, 0.2548, + 0.8766, 0.5361, 0.7357, 0.2799, 0.0393, 0.2715, 0.8513, 0.9655, 0.1079, + 0.6002, 0.1595, 0.2029, 0.4405, 0.2446, 0.4439, 0.3491, 0.7026, 0.6560, + 0.6590, 0.5346, 0.6970, 0.6982, 0.9217, 0.3866, 0.3179, 0.8456, 0.8529, + 0.6372, 0.7185, 0.4492, 0.2516, 0.2345, 0.8530, 0.8923, 0.3007, 0.7203, + 0.2924, 0.8560, 0.5546, 0.3961, 0.8963, 0.7695, 0.6909, 0.6978, 0.8416, + 0.3597, 0.0911, 0.7719, 0.8151, 0.4296, 0.5552}, + /*dim_order=*/{0, 2, 3, 1}); + + Tensor expected = tf.make_with_dimorder( + {3, 5, 2, 2}, + {0.2432, 0.8184, 0.6138, 0.9680, 0.2257, 0.5248, 0.8206, 0.1112, 0.2548, + 0.8766, 0.5361, 0.7357, 0.2799, 0.0393, 0.2715, 0.8513, 0.9655, 0.1079, + 0.6002, 0.1595, 0.2029, 0.4405, 0.2446, 0.4439, 0.3491, 0.7026, 0.6560, + 0.6590, 0.5346, 0.6970, 0.6982, 0.9217, 0.3866, 0.3179, 0.8456, 0.8529, + 0.6372, 0.7185, 0.4492, 0.2516, 0.2345, 0.8530, 0.8923, 0.3007, 0.7203, + 0.2924, 0.8560, 0.5546, 0.3961, 0.8963, 0.7695, 0.6909, 0.6978, 0.8416, + 0.3597, 0.0911, 0.7719, 0.8151, 0.4296, 0.5552}, + /*dim_order=*/{0, 2, 3, 1}); + + Tensor ret = op__to_dim_order_copy_out( + /*self*/ x, + /*non_blocking*/ false, + /*dim_order*/ exec_aten::nullopt, + out); + + EXPECT_TENSOR_EQ(out, expected); + EXPECT_TENSOR_EQ(ret, expected); +} diff --git a/kernels/portable/test/targets.bzl b/kernels/portable/test/targets.bzl index ae0dbaef40d..f37a6155a94 100644 --- a/kernels/portable/test/targets.bzl +++ b/kernels/portable/test/targets.bzl @@ -8,6 +8,8 @@ def define_common_targets(): """ define_supported_features_lib() + # TODO(T179434631) : enable aten mode test for op__to_dim_order_copy + op_test(name = "op__to_dim_order_copy_test") op_test(name = "op_allclose_test", aten_compatible = False) op_test(name = "op_div_test") op_test(name = "op_gelu_test") diff --git a/runtime/core/exec_aten/exec_aten.h b/runtime/core/exec_aten/exec_aten.h index 9eb1f12cd18..919b5420b3a 100644 --- a/runtime/core/exec_aten/exec_aten.h +++ b/runtime/core/exec_aten/exec_aten.h @@ -82,6 +82,9 @@ using quint4x2 = c10::quint4x2; using quint2x4 = c10::quint2x4; using IntArrayRef = at::IntArrayRef; +template +using OptionalArrayRef = c10::OptionalArrayRef; + #else // Use executor types using Tensor = torch::executor::Tensor; @@ -118,6 +121,10 @@ using quint2x4 = torch::executor::quint2x4; using IntArrayRef = torch::executor::IntArrayRef; +template +using OptionalArrayRef = + torch::executor::optional>; + #endif // Use executor types } // namespace exec_aten