forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
introduce _to_dim_order_copy op to runtime (pytorch#1970)
Summary: Pull Request resolved: pytorch#1970 This diff creates a new and special operator, `_to_dim_order_copy`. This new operator introduces two critical attributes to runtime system: 1. Extract memory_format information from tensor based on dim_order, instead of stride. 2. Support both channal_last and contiguous memory_format in runtime. Please note that memory format here is a parallel concept with memory layout, and supporting new format does not violate our contract on only supporting contiguous memory layout tensor. Details can be found in [here](https://discuss.pytorch.org/t/contigious-vs-non-contigious-tensor/30107) and [here](https://pytorch.org/blog/tensor-memory-format-matters/). Furthermore, dim order is a specifial operator, which does not have a native aten variant but is needed by most model in edge dialect, so it can not directly be put in `kernels/portable/custom_ops.yaml` (need manually registered everytime, not work for an operator needed by many models), or `kernels/portable/functions.yaml` (should have native aten variant). To overcome that, this diff puts `_to_dim_order_copy`'s aten mode under `kernels/aten`, while lean mode under `kernels/portable/functions.yaml`. Also update dependencies and utils. Differential Revision: https://internalfb.com/D53747744
- Loading branch information
1 parent
b9488fe
commit 6c698cf
Showing
19 changed files
with
1,190 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Any targets that should be shared between fbcode and xplat must be defined in | ||
# targets.bzl. This file can contain fbcode-only targets. | ||
|
||
load(":targets.bzl", "define_common_targets") | ||
|
||
oncall("executorch") | ||
|
||
define_common_targets() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h> | ||
#include <executorch/runtime/kernel/kernel_includes.h> | ||
|
||
namespace torch { | ||
namespace executor { | ||
namespace native { | ||
|
||
using Tensor = exec_aten::Tensor; | ||
using SizesArrayRef = exec_aten::ArrayRef<exec_aten::SizesType>; | ||
using DimOrderArrayRef = exec_aten::ArrayRef<exec_aten::DimOrderType>; | ||
using MemoryFormat = exec_aten::MemoryFormat; | ||
|
||
template <typename T> | ||
using OptionalArrayRef = exec_aten::OptionalArrayRef<T>; | ||
|
||
template <typename T> | ||
using Optional = exec_aten::optional<T>; | ||
|
||
namespace { | ||
Optional<MemoryFormat> get_memory_format(OptionalArrayRef<int64_t> dim_order) { | ||
if (!dim_order.has_value()) { | ||
return exec_aten::nullopt; | ||
} | ||
if (is_contiguous_dim_order( | ||
dim_order.value().data(), dim_order.value().size())) { | ||
return MemoryFormat::Contiguous; | ||
} else if (is_channels_last_dim_order( | ||
dim_order.value().data(), dim_order.value().size())) { | ||
return MemoryFormat::ChannelsLast; | ||
} else { | ||
ET_ASSERT_UNREACHABLE(); | ||
} | ||
} | ||
|
||
bool check__to_dim_order_copy_args( | ||
const Tensor& input, | ||
bool non_blocking, | ||
exec_aten::OptionalArrayRef<int64_t> dim_order, | ||
Tensor& out) { | ||
// Right now we only support blocking data transfer | ||
ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false); | ||
|
||
// dim_order is set, the target dim_order will be either contiguous or | ||
// channels_last memory format | ||
if (dim_order.has_value()) { | ||
exec_aten::ArrayRef<int64_t> dim_order_ref = dim_order.value(); | ||
|
||
// dim order size shall equal to input dim | ||
ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim()); | ||
|
||
ET_LOG_AND_RETURN_IF_FALSE( | ||
is_channels_last_dim_order( | ||
dim_order.value().data(), dim_order.value().size()) || | ||
is_contiguous_dim_order( | ||
dim_order.value().data(), dim_order.value().size())); | ||
|
||
// Out Aten tensor shall have same memory format stride as dim_order | ||
const size_t kMaxNumOfDimensions = 16; | ||
ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim()); | ||
exec_aten::StridesType target_strides[kMaxNumOfDimensions]; | ||
dim_order_to_stride_nocheck( | ||
out.sizes().data(), | ||
dim_order_ref.data(), | ||
dim_order_ref.size(), | ||
target_strides); | ||
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == dim_order_ref.size()); | ||
for (size_t i = 0; i < dim_order_ref.size(); i++) { | ||
ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]); | ||
} | ||
|
||
} else { // dim_order is not set, preserve the dim order of input | ||
|
||
auto out_strides = out.strides(); | ||
auto input_strides = input.strides(); | ||
ET_LOG_AND_RETURN_IF_FALSE(input_strides.size() == out_strides.size()); | ||
for (size_t i = 0; i < input_strides.size(); i++) { | ||
ET_LOG_AND_RETURN_IF_FALSE(input_strides[i] == out_strides[i]); | ||
} | ||
} | ||
return true; | ||
} | ||
} // namespace | ||
|
||
// _to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? | ||
// dim_order=None, Tensor(a!) out) -> Tensor(a!) | ||
Tensor& _to_dim_order_copy_out( | ||
RuntimeContext& ctx, | ||
const Tensor& self, | ||
bool non_blocking, | ||
OptionalArrayRef<int64_t> dim_order, | ||
Tensor& out) { | ||
// TODO(T181345875): enable sanity check in aten mode | ||
ET_KERNEL_CHECK( | ||
ctx, | ||
check__to_dim_order_copy_args(self, non_blocking, dim_order, out), | ||
InvalidArgument, | ||
out); | ||
|
||
Optional<MemoryFormat> memory_format = get_memory_format(dim_order); | ||
at::_to_copy_outf(self, non_blocking, memory_format, out); | ||
|
||
return out; | ||
} | ||
|
||
Tensor& _to_dim_order_copy_out( | ||
const Tensor& self, | ||
bool non_blocking, | ||
OptionalArrayRef<int64_t> dim_order, | ||
Tensor& out) { | ||
exec_aten::RuntimeContext ctx{}; | ||
return _to_dim_order_copy_out(ctx, self, non_blocking, dim_order, out); | ||
} | ||
|
||
} // namespace native | ||
} // namespace executor | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") | ||
load("@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "define_op_target", "op_target") | ||
|
||
# Operators that are listed in `functions.yaml`, and are thus compatible with | ||
# the core ATen operators. Every entry here will be backed by a cxx_library | ||
# target with the given name and deps. | ||
# | ||
# Note that a single target (or single .cpp file) can't mix ATen and non-ATen | ||
# ops, and must be split. They can, however, share common code via a library dep | ||
# if necessary. | ||
_EDGE_DIALECT_OPS = ( | ||
op_target( | ||
name = "op__to_dim_order_copy", | ||
deps = [ | ||
"//executorch/kernels/aten/cpu/util:copy_ops_util", | ||
], | ||
), | ||
) | ||
|
||
def define_common_targets(): | ||
"""Defines targets that should be shared between fbcode and xplat. | ||
The directory containing this targets.bzl file should also contain both | ||
TARGETS and BUCK files that call this function. | ||
""" | ||
|
||
# Define build targets for all operators registered in the tables above. | ||
for op in _EDGE_DIALECT_OPS: | ||
define_op_target(is_aten_op = False, is_et_op = False, **op) | ||
|
||
all_op_targets = [":{}".format(op["name"]) for op in _EDGE_DIALECT_OPS] | ||
|
||
runtime.cxx_library( | ||
name = "cpu", | ||
srcs = [], | ||
visibility = [ | ||
"//executorch/kernels/aten/...", | ||
"//executorch/kernels/test/...", | ||
], | ||
exported_deps = [t + "_aten" for t in all_op_targets], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Any targets that should be shared between fbcode and xplat must be defined in | ||
# targets.bzl. This file can contain fbcode-only targets. | ||
|
||
load(":targets.bzl", "define_common_targets") | ||
|
||
oncall("executorch") | ||
|
||
define_common_targets() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <cstring> | ||
|
||
#include <executorch/kernels/aten/cpu/util/copy_ops_util.h> | ||
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h> | ||
|
||
namespace torch { | ||
namespace executor { | ||
|
||
using Tensor = exec_aten::Tensor; | ||
|
||
bool check__to_dim_order_copy_args( | ||
const Tensor& input, | ||
bool non_blocking, | ||
exec_aten::OptionalArrayRef<int64_t> dim_order, | ||
Tensor& out) { | ||
// Right now we only support blocking data transfer | ||
ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false); | ||
|
||
// dim_order is set, the target dim_order will be either contiguous or | ||
// channels_last memory format | ||
if (dim_order.has_value()) { | ||
exec_aten::ArrayRef<int64_t> dim_order_ref = dim_order.value(); | ||
|
||
// dim order size shall equal to input dim | ||
ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim()); | ||
|
||
ET_LOG_AND_RETURN_IF_FALSE( | ||
is_channels_last_dim_order( | ||
dim_order.value().data(), dim_order.value().size()) || | ||
is_contiguous_dim_order( | ||
dim_order.value().data(), dim_order.value().size())); | ||
|
||
// Out Aten tensor shall have same memory format stride as dim_order | ||
const size_t kMaxNumOfDimensions = 16; | ||
ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim()); | ||
exec_aten::StridesType target_strides[kMaxNumOfDimensions]; | ||
dim_order_to_stride_nocheck( | ||
out.sizes().data(), | ||
dim_order_ref.data(), | ||
dim_order_ref.size(), | ||
target_strides); | ||
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == dim_order_ref.size()); | ||
for (size_t i = 0; i < dim_order_ref.size(); i++) { | ||
ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]); | ||
} | ||
|
||
} else { // dim_order is not set, preserve the dim order of input | ||
|
||
auto out_strides = out.strides(); | ||
auto input_strides = input.strides(); | ||
ET_LOG_AND_RETURN_IF_FALSE(input_strides.size() == out_strides.size()); | ||
for (size_t i = 0; i < input_strides.size(); i++) { | ||
ET_LOG_AND_RETURN_IF_FALSE(input_strides[i] == out_strides[i]); | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
} // namespace executor | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <executorch/runtime/kernel/kernel_includes.h> | ||
|
||
namespace torch { | ||
namespace executor { | ||
|
||
bool check__to_dim_order_copy_args( | ||
const Tensor& input, | ||
bool non_blocking, | ||
exec_aten::OptionalArrayRef<int64_t> dim_order, | ||
Tensor& out); | ||
|
||
} // namespace executor | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") | ||
|
||
def define_common_targets(): | ||
"""Defines targets that should be shared between fbcode and xplat. | ||
The directory containing this targets.bzl file should also contain both | ||
TARGETS and BUCK files that call this function. | ||
""" | ||
|
||
# Utility functions that can be used by operators that perform reduction | ||
runtime.cxx_library( | ||
name = "copy_ops_util", | ||
srcs = ["copy_ops_util.cpp"], | ||
exported_headers = [ | ||
"copy_ops_util.h", | ||
], | ||
compiler_flags = ["-Wno-missing-prototypes"], | ||
deps = [ | ||
"//executorch/runtime/kernel:kernel_includes_aten", | ||
"//executorch/runtime/core/exec_aten/util:tensor_util_aten", | ||
], | ||
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"], | ||
visibility = [ | ||
"//executorch/kernels/aten/cpu/...", | ||
"//executorch/kernels/portable/cpu/...", | ||
"//executorch/kernels/optimized/cpu/...", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This yaml file contains operators that are defined by ExecuTorch and used in ATen mode. | ||
|
||
- func: dim_order_ops::_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!) | ||
kernels: | ||
- arg_meta: null | ||
kernel_name: torch::executor::_to_dim_order_copy_out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.