Skip to content

Commit

Permalink
introduce _to_dim_order_copy op to runtime (pytorch#1970)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1970

This diff creates a new and special operator, `_to_dim_order_copy`. This new operator introduces two critical attributes to runtime system:
1. Extract memory_format information from tensor based on dim_order, instead of stride.
2. Support both channal_last and contiguous memory_format in runtime. Please note that memory format here is a parallel concept with memory layout, and supporting new format does not violate our contract on only supporting contiguous memory layout tensor. Details can be found in [here](https://discuss.pytorch.org/t/contigious-vs-non-contigious-tensor/30107) and [here](https://pytorch.org/blog/tensor-memory-format-matters/).

Furthermore, dim order is a specifial operator, which does not have a native aten variant but is needed by most model in edge dialect, so it can not directly be put in `kernels/portable/custom_ops.yaml` (need manually registered everytime, not work for an operator needed by many models), or `kernels/portable/functions.yaml` (should have native aten variant). To overcome that, this diff puts   `_to_dim_order_copy`'s aten mode under `kernels/aten`, while lean mode under `kernels/portable/functions.yaml`.

Also update dependencies and utils.

Differential Revision: https://internalfb.com/D53747744
  • Loading branch information
Songhao Jia authored and facebook-github-bot committed May 3, 2024
1 parent b9488fe commit 6c698cf
Show file tree
Hide file tree
Showing 19 changed files with 1,190 additions and 12 deletions.
8 changes: 8 additions & 0 deletions kernels/aten/cpu/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
124 changes: 124 additions & 0 deletions kernels/aten/cpu/op__to_dim_order_copy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {

using Tensor = exec_aten::Tensor;
using SizesArrayRef = exec_aten::ArrayRef<exec_aten::SizesType>;
using DimOrderArrayRef = exec_aten::ArrayRef<exec_aten::DimOrderType>;
using MemoryFormat = exec_aten::MemoryFormat;

template <typename T>
using OptionalArrayRef = exec_aten::OptionalArrayRef<T>;

template <typename T>
using Optional = exec_aten::optional<T>;

namespace {
Optional<MemoryFormat> get_memory_format(OptionalArrayRef<int64_t> dim_order) {
if (!dim_order.has_value()) {
return exec_aten::nullopt;
}
if (is_contiguous_dim_order(
dim_order.value().data(), dim_order.value().size())) {
return MemoryFormat::Contiguous;
} else if (is_channels_last_dim_order(
dim_order.value().data(), dim_order.value().size())) {
return MemoryFormat::ChannelsLast;
} else {
ET_ASSERT_UNREACHABLE();
}
}

bool check__to_dim_order_copy_args(
const Tensor& input,
bool non_blocking,
exec_aten::OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
// Right now we only support blocking data transfer
ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false);

// dim_order is set, the target dim_order will be either contiguous or
// channels_last memory format
if (dim_order.has_value()) {
exec_aten::ArrayRef<int64_t> dim_order_ref = dim_order.value();

// dim order size shall equal to input dim
ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim());

ET_LOG_AND_RETURN_IF_FALSE(
is_channels_last_dim_order(
dim_order.value().data(), dim_order.value().size()) ||
is_contiguous_dim_order(
dim_order.value().data(), dim_order.value().size()));

// Out Aten tensor shall have same memory format stride as dim_order
const size_t kMaxNumOfDimensions = 16;
ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim());
exec_aten::StridesType target_strides[kMaxNumOfDimensions];
dim_order_to_stride_nocheck(
out.sizes().data(),
dim_order_ref.data(),
dim_order_ref.size(),
target_strides);
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == dim_order_ref.size());
for (size_t i = 0; i < dim_order_ref.size(); i++) {
ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]);
}

} else { // dim_order is not set, preserve the dim order of input

auto out_strides = out.strides();
auto input_strides = input.strides();
ET_LOG_AND_RETURN_IF_FALSE(input_strides.size() == out_strides.size());
for (size_t i = 0; i < input_strides.size(); i++) {
ET_LOG_AND_RETURN_IF_FALSE(input_strides[i] == out_strides[i]);
}
}
return true;
}
} // namespace

// _to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]?
// dim_order=None, Tensor(a!) out) -> Tensor(a!)
Tensor& _to_dim_order_copy_out(
RuntimeContext& ctx,
const Tensor& self,
bool non_blocking,
OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
// TODO(T181345875): enable sanity check in aten mode
ET_KERNEL_CHECK(
ctx,
check__to_dim_order_copy_args(self, non_blocking, dim_order, out),
InvalidArgument,
out);

Optional<MemoryFormat> memory_format = get_memory_format(dim_order);
at::_to_copy_outf(self, non_blocking, memory_format, out);

return out;
}

Tensor& _to_dim_order_copy_out(
const Tensor& self,
bool non_blocking,
OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
exec_aten::RuntimeContext ctx{};
return _to_dim_order_copy_out(ctx, self, non_blocking, dim_order, out);
}

} // namespace native
} // namespace executor
} // namespace torch
41 changes: 41 additions & 0 deletions kernels/aten/cpu/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load("@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "define_op_target", "op_target")

# Operators that are listed in `functions.yaml`, and are thus compatible with
# the core ATen operators. Every entry here will be backed by a cxx_library
# target with the given name and deps.
#
# Note that a single target (or single .cpp file) can't mix ATen and non-ATen
# ops, and must be split. They can, however, share common code via a library dep
# if necessary.
_EDGE_DIALECT_OPS = (
op_target(
name = "op__to_dim_order_copy",
deps = [
"//executorch/kernels/aten/cpu/util:copy_ops_util",
],
),
)

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.
The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""

# Define build targets for all operators registered in the tables above.
for op in _EDGE_DIALECT_OPS:
define_op_target(is_aten_op = False, is_et_op = False, **op)

all_op_targets = [":{}".format(op["name"]) for op in _EDGE_DIALECT_OPS]

runtime.cxx_library(
name = "cpu",
srcs = [],
visibility = [
"//executorch/kernels/aten/...",
"//executorch/kernels/test/...",
],
exported_deps = [t + "_aten" for t in all_op_targets],
)
8 changes: 8 additions & 0 deletions kernels/aten/cpu/util/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
68 changes: 68 additions & 0 deletions kernels/aten/cpu/util/copy_ops_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cstring>

#include <executorch/kernels/aten/cpu/util/copy_ops_util.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>

namespace torch {
namespace executor {

using Tensor = exec_aten::Tensor;

bool check__to_dim_order_copy_args(
const Tensor& input,
bool non_blocking,
exec_aten::OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
// Right now we only support blocking data transfer
ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false);

// dim_order is set, the target dim_order will be either contiguous or
// channels_last memory format
if (dim_order.has_value()) {
exec_aten::ArrayRef<int64_t> dim_order_ref = dim_order.value();

// dim order size shall equal to input dim
ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim());

ET_LOG_AND_RETURN_IF_FALSE(
is_channels_last_dim_order(
dim_order.value().data(), dim_order.value().size()) ||
is_contiguous_dim_order(
dim_order.value().data(), dim_order.value().size()));

// Out Aten tensor shall have same memory format stride as dim_order
const size_t kMaxNumOfDimensions = 16;
ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim());
exec_aten::StridesType target_strides[kMaxNumOfDimensions];
dim_order_to_stride_nocheck(
out.sizes().data(),
dim_order_ref.data(),
dim_order_ref.size(),
target_strides);
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == dim_order_ref.size());
for (size_t i = 0; i < dim_order_ref.size(); i++) {
ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]);
}

} else { // dim_order is not set, preserve the dim order of input

auto out_strides = out.strides();
auto input_strides = input.strides();
ET_LOG_AND_RETURN_IF_FALSE(input_strides.size() == out_strides.size());
for (size_t i = 0; i < input_strides.size(); i++) {
ET_LOG_AND_RETURN_IF_FALSE(input_strides[i] == out_strides[i]);
}
}
return true;
}

} // namespace executor
} // namespace torch
23 changes: 23 additions & 0 deletions kernels/aten/cpu/util/copy_ops_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {

bool check__to_dim_order_copy_args(
const Tensor& input,
bool non_blocking,
exec_aten::OptionalArrayRef<int64_t> dim_order,
Tensor& out);

} // namespace executor
} // namespace torch
28 changes: 28 additions & 0 deletions kernels/aten/cpu/util/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.
The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""

# Utility functions that can be used by operators that perform reduction
runtime.cxx_library(
name = "copy_ops_util",
srcs = ["copy_ops_util.cpp"],
exported_headers = [
"copy_ops_util.h",
],
compiler_flags = ["-Wno-missing-prototypes"],
deps = [
"//executorch/runtime/kernel:kernel_includes_aten",
"//executorch/runtime/core/exec_aten/util:tensor_util_aten",
],
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"],
visibility = [
"//executorch/kernels/aten/cpu/...",
"//executorch/kernels/portable/cpu/...",
"//executorch/kernels/optimized/cpu/...",
],
)
8 changes: 8 additions & 0 deletions kernels/aten/edge_dialect_aten_op.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This yaml file contains operators that are defined by ExecuTorch and used in ATen mode.

- func: dim_order_ops::_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: torch::executor::_to_dim_order_copy_out
33 changes: 32 additions & 1 deletion kernels/aten/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,51 @@ def define_common_targets():
],
)

runtime.export_file(
name = "edge_dialect_aten_op.yaml",
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
)

et_operator_library(
name = "executorch_aten_ops",
ops_schema_yaml_target = ":functions.yaml",
define_static_targets = True,
)

runtime.cxx_library(
name = "operators_edge_dialect_aten",
srcs = [],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
"//executorch/kernels/aten/cpu:cpu",
],
)

et_operator_library(
name = "edge_dialect_aten_ops",
ops_schema_yaml_target = ":edge_dialect_aten_op.yaml",
define_static_targets = True,
)

executorch_generated_lib(
name = "generated_lib",
aten_mode = True,
deps = [
":executorch_aten_ops",
":edge_dialect_aten_ops",
],
kernel_deps = [
":operators_edge_dialect_aten",
],
functions_yaml_target = None,
custom_ops_yaml_target = "//executorch/kernels/aten:edge_dialect_aten_op.yaml",
define_static_targets = True,
custom_ops_requires_aot_registration = False,
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
Expand Down
Loading

0 comments on commit 6c698cf

Please sign in to comment.