Skip to content

Commit

Permalink
introduce _to_dim_order_copy op to runtime (pytorch#1970)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1970

Differential Revision: D53747744
  • Loading branch information
Gasoonjia authored and facebook-github-bot committed Mar 1, 2024
1 parent 31c4837 commit 0238c2c
Show file tree
Hide file tree
Showing 11 changed files with 910 additions and 17 deletions.
26 changes: 10 additions & 16 deletions exir/passes/dim_order_ops_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@

from torch.library import impl, Library

lib = Library("dim_order_ops", "DEF")
lib.define(
"_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
)

# Out variant drops TensorOptions
lib = Library("dim_order_ops", "FRAGMENT")
lib.define(
"_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
"_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
)


Expand All @@ -34,14 +30,12 @@ def _to_dim_order_copy_impl(*args, **kwargs):
return _op_impl(torch.ops.aten._to_copy, *args, **kwargs)


@impl(lib, "_to_dim_order_copy.out", "CompositeImplicitAutograd")
def _to_dim_order_copy_out_impl(*args, **kwargs):
return _op_impl(torch.ops.aten._to_copy.out, *args, **kwargs)

def get_dim_order_ops_map():
"""
Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup
"""

"""
Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup
"""
DimOrderOpsMap = {
"aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
}
DimOrderOpsMap = {
"aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
}
return DimOrderOpsMap
3 changes: 2 additions & 1 deletion exir/passes/memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.dim_order_utils import get_dim_order
from executorch.exir.pass_base import ExportPass, ProxyValue
from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap
from executorch.exir.passes.dim_order_ops_registry import get_dim_order_ops_map

logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
Expand All @@ -26,6 +26,7 @@ class MemoryFormatOpsPass(ExportPass):
"""

def call_operator(self, op, args, kwargs, meta):
DimOrderOpsMap = get_dim_order_ops_map()
if not (isinstance(op, EdgeOpOverload) and op.__name__ in DimOrderOpsMap):
return super().call_operator(
op,
Expand Down
3 changes: 3 additions & 0 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class MemoryFormatTestSet:

class TestMemoryFormatOpsPass(unittest.TestCase):
def memory_format_test_runner(self, test_set: MemoryFormatTestSet):
# TODO(T180746545): automatic load dim order operator
torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib")

aten_op_str = "torch.ops.aten._to_copy.default"
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"

Expand Down
177 changes: 177 additions & 0 deletions kernels/portable/cpu/op__to_dim_order_copy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

#ifdef USE_ATEN_LIB
// #include <executorch/kernels/aten/Functions.h>
#endif

namespace torch {
namespace executor {
namespace native {

using Tensor = exec_aten::Tensor;
using SizesArrayRef = exec_aten::ArrayRef<exec_aten::SizesType>;
using DimOrderArrayRef = exec_aten::ArrayRef<exec_aten::DimOrderType>;
using MemoryFormat = exec_aten::MemoryFormat;

template <typename T>
using OptionalArrayRef = exec_aten::OptionalArrayRef<T>;

template <typename T>
using Optional = exec_aten::optional<T>;

#ifdef USE_ATEN_LIB

namespace {
Optional<MemoryFormat> get_memory_format(OptionalArrayRef<int64_t> dim_order) {
if (!dim_order.has_value()) {
return exec_aten::nullopt;
}
if (is_contiguous_dim_order(
dim_order.value().data(), dim_order.value().size())) {
return MemoryFormat::Contiguous;
} else if (is_channels_last_dim_order(
dim_order.value().data(), dim_order.value().size())) {
return MemoryFormat::ChannelsLast;
} else {
ET_ASSERT_UNREACHABLE();
}
}
} // namespace

// TODO(T179434631) : enable aten mode if needed
// _to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]?
// dim_order=None, Tensor(a!) out) -> Tensor(a!)
Tensor& _to_dim_order_copy_out(
RuntimeContext& ctx,
const Tensor& self,
bool non_blocking,
OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
// ET_KERNEL_CHECK(
// ctx,
// check__to_dim_order_copy_args(self, non_blocking, dim_order, out),
// InvalidArgument,
// out);

Optional<MemoryFormat> memory_format = get_memory_format(dim_order);
at::_to_copy_outf(self, non_blocking, memory_format, out);

return out;
}

Tensor& _to_dim_order_copy_out(
const Tensor& self,
bool non_blocking,
OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
exec_aten::RuntimeContext ctx{};
return _to_dim_order_copy_out(ctx, self, non_blocking, dim_order, out);
}

#else

namespace {

// TODO(T179241236): Update core/exec_aten/util/tensor_util.h to support dim
// order other than contiguous.
int64_t coordinateToIndexWithDimOrder(
const Tensor& self,
const size_t* cur_indices) {
int64_t index = 0;
exec_aten::StridesType strides[kTensorDimensionLimit];
SizesArrayRef sizes = self.sizes();
DimOrderArrayRef dim_order = self.dim_order();

dim_order_to_stride_nocheck(
sizes.data(), dim_order.data(), sizes.size(), strides);
for (size_t i = 0; i < self.dim(); ++i) {
index += cur_indices[i] * strides[i];
}
return index;
}

template <typename SELF_CTYPE, typename OUT_CTYPE>
void _to_dim_order_copy_impl(const Tensor& self, Tensor& out) {
auto self_data = self.mutable_data_ptr<SELF_CTYPE>();
auto out_data = out.mutable_data_ptr<OUT_CTYPE>();

size_t coordinate[kTensorDimensionLimit] = {0};

// Copy data from self to out index by index. Same index in self and out
// should have same value, no matter the order of dimensions.
for (ssize_t i = 0; i < self.numel(); i++) {
// Update the current indices.
for (ssize_t j = self.dim() - 1; j >= 0; j--) {
if (coordinate[j] + 1 < self.size(j)) {
coordinate[j]++;
break;
} else {
coordinate[j] = 0;
}
}
// Get the corresponding index of self_data and out_data by stride.
int64_t self_data_index = coordinateToIndexWithDimOrder(self, coordinate);
int64_t out_data_index = coordinateToIndexWithDimOrder(out, coordinate);

out_data[out_data_index] =
static_cast<OUT_CTYPE>(self_data[self_data_index]);
}
}
} // namespace

// _to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]?
// dim_order=None, Tensor(a!) out) -> Tensor(a!)
Tensor& _to_dim_order_copy_out(
RuntimeContext& ctx,
const Tensor& self,
bool non_blocking,
OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
(void)ctx;
ET_KERNEL_CHECK(
ctx,
check__to_dim_order_copy_args(self, non_blocking, dim_order, out),
InvalidArgument,
out);

ET_KERNEL_CHECK(
ctx,
resize_tensor(out, self.sizes()) == torch::executor::Error::Ok,
InvalidArgument,
out);

ET_SWITCH_REALHB_TYPES(
self.scalar_type(), ctx, "_to_dim_order_copy_out", CTYPE_IN, [&] {
ET_SWITCH_REALHB_TYPES(
out.scalar_type(), ctx, "_to_dim_order_copy_out", CTYPE_OUT, [&] {
_to_dim_order_copy_impl<CTYPE_IN, CTYPE_OUT>(self, out);
});
});

return out;
}

Tensor& _to_dim_order_copy_out(
const Tensor& self,
bool non_blocking,
OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
exec_aten::RuntimeContext context{};
return _to_dim_order_copy_out(context, self, non_blocking, dim_order, out);
}

#endif

} // namespace native
} // namespace executor
} // namespace torch
6 changes: 6 additions & 0 deletions kernels/portable/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,12 @@ _ATEN_OPS = (
# ops, and must be split. They can, however, share common code via a library dep
# if necessary.
_CUSTOM_OPS = (
op_target(
name = "op__to_dim_order_copy",
deps = [
"//executorch/kernels/portable/cpu/util:copy_ops_util",
],
),
op_target(
name = "op_allclose",
),
Expand Down
43 changes: 43 additions & 0 deletions kernels/portable/cpu/util/copy_ops_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cstring>

#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>

namespace torch {
Expand Down Expand Up @@ -733,6 +734,48 @@ bool check_to_copy_args(
return true;
}

bool check__to_dim_order_copy_args(
const Tensor& input,
bool non_blocking,
exec_aten::OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
// Right now we only support blocking data transfer
ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false);

if (dim_order.has_value()) {
exec_aten::ArrayRef<int64_t> dim_order_ref = dim_order.value();

// dim order size shall equal to input dim
ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim());

// Right now we only focus on contiguous and channels_last memory format
// TODO(T179248280): rename default_dim_order to contiguous_dim_order.
// Default here is ambiguous and easy to confuse with preserve.
ET_LOG_AND_RETURN_IF_FALSE(
is_channels_last_dim_order(
dim_order.value().data(), dim_order.value().size()) ||
is_contiguous_dim_order(
dim_order.value().data(), dim_order.value().size()));

// Out tensor shall have same dim order as dim_order
auto out_dim_order = out.dim_order();
ET_LOG_AND_RETURN_IF_FALSE(out_dim_order.size() == dim_order_ref.size());
for (size_t i = 0; i < dim_order_ref.size(); i++) {
ET_LOG_AND_RETURN_IF_FALSE(out_dim_order[i] == dim_order_ref[i]);
}
} else { // dim_order is not set, preserve the dim order of input

// Out tensor shall have same dim order as input dim_order
auto out_dim_order = out.dim_order();
auto input_dim_order = input.dim_order();
ET_LOG_AND_RETURN_IF_FALSE(out_dim_order.size() == input_dim_order.size());
for (size_t i = 0; i < input_dim_order.size(); i++) {
ET_LOG_AND_RETURN_IF_FALSE(out_dim_order[i] == input_dim_order[i]);
}
}
return true;
}

bool check_unsqueeze_copy_args(
const Tensor input,
int64_t dim,
Expand Down
6 changes: 6 additions & 0 deletions kernels/portable/cpu/util/copy_ops_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ bool check_to_copy_args(
exec_aten::optional<exec_aten::MemoryFormat> memory_format,
Tensor& out);

bool check__to_dim_order_copy_args(
const Tensor& input,
bool non_blocking,
exec_aten::OptionalArrayRef<int64_t> dim_order,
Tensor& out);

bool check_unsqueeze_copy_args(
const Tensor input,
int64_t dim,
Expand Down
10 changes: 10 additions & 0 deletions kernels/portable/custom_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,13 @@
kernels:
- arg_meta: null
kernel_name: torch::executor::linear_scratch_example

- func: dim_order_ops::_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: torch::executor::_to_dim_order_copy_out

# - func: dim_order_ops::_to_dim_order_copy(Tensor self, bool non_blocking=False, ScalarType? dtype=None, int[]? dim_order=None) -> Tensor
# kernels:
# - arg_meta: null
# kernel_name: torch::executor::_to_dim_order_copy_tensor
Loading

0 comments on commit 0238c2c

Please sign in to comment.