Skip to content

Commit

Permalink
Cadence fusiong3 operators m2 (pytorch#7490)
Browse files Browse the repository at this point in the history
Summary:
Added new operators sub, div, exp, permute, slice, mean in backends/cadence/fusion_g3
For cycle reduction, disabled error checks in operators using macro "OPT_ARG_CHECK"

Pull Request resolved: pytorch#7490

Differential Revision: D67870337

Pulled By: zonglinpeng
  • Loading branch information
ckmadhira authored and facebook-github-bot committed Jan 6, 2025
1 parent 68c0208 commit 200c48a
Show file tree
Hide file tree
Showing 18 changed files with 2,136 additions and 287 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
url = https://github.com/pybind/pybind11.git
[submodule "backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3"]
path = backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3
url = https://github.com/foss-xtensa/nnlib-FusionG3/
url = https://github.com/foss-xtensa/nnlib-FusionG3.git
[submodule "third-party/ao"]
path = third-party/ao
url = https://github.com/pytorch/ao.git
25 changes: 20 additions & 5 deletions backends/cadence/aot/functions_fusion_g3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@
- op: div.out
kernels:
- arg_meta: null
kernel_name: torch::executor::div_out
kernel_name: cadence::impl::G3::div_out

- op: div.out_mode
kernels:
- arg_meta: null
kernel_name: torch::executor::div_out_mode
kernel_name: cadence::impl::G3::div_out_mode

- op: embedding.out
kernels:
Expand All @@ -80,7 +80,7 @@
- op: permute_copy.out
kernels:
- arg_meta: null
kernel_name: torch::executor::permute_copy_out
kernel_name: cadence::impl::G3::permute_copy_out

- op: sigmoid.out
kernels:
Expand All @@ -90,7 +90,7 @@
- op: slice_copy.Tensor_out
kernels:
- arg_meta: null
kernel_name: torch::executor::slice_copy_Tensor_out
kernel_name: cadence::impl::G3::slice_copy_Tensor_out

- op: split_with_sizes_copy.out
kernels:
Expand All @@ -100,7 +100,12 @@
- op: sub.out
kernels:
- arg_meta: null
kernel_name: torch::executor::sub_out
kernel_name: cadence::impl::G3::sub_out

- op: sub.Scalar_out
kernels:
- arg_meta: null
kernel_name: cadence::impl::G3::sub_scalar_out

- op: view_copy.out
kernels:
Expand All @@ -117,6 +122,16 @@
- arg_meta: null
kernel_name: cadence::impl::G3::native_layer_norm_out

- op: mean.out
kernels:
- arg_meta: null
kernel_name: cadence::impl::G3::mean_dim_out

- op: exp.out
kernels:
- arg_meta: null
kernel_name: cadence::impl::G3::exp_out

# custom ops
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
variants: function
Expand Down
7 changes: 7 additions & 0 deletions backends/cadence/fusion_g3/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ set(_aten_ops__srcs
"${CMAKE_CURRENT_SOURCE_DIR}/op_native_layer_norm.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_quantize.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_dequantize.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_sub.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_div.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_mean.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_slice_copy.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_permute_copy.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_exp.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp"
Expand All @@ -51,6 +57,7 @@ set(_aten_ops__srcs
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp"
)
add_library(aten_ops_cadence ${_aten_ops__srcs})
target_link_libraries(aten_ops_cadence PUBLIC executorch)
Expand Down
21 changes: 12 additions & 9 deletions backends/cadence/fusion_g3/operators/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@

#include <xa_nnlib_kernels_api.h>

#include <executorch/backends/cadence/fusion_g3/operators/xt_macros.h>
#include <executorch/backends/cadence/fusion_g3/operators/tensor_util.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>
#include <xa_nnlib_kernels_api.h>

using ::executorch::aten::Scalar;
using ::executorch::aten::ScalarType;
using ::executorch::aten::Tensor;
using ::executorch::runtime::canCast;
using ::executorch::runtime::Error;
using ::executorch::runtime::KernelRuntimeContext;
using exec_aten::Scalar;
using exec_aten::ScalarType;
using exec_aten::Tensor;
using executorch::runtime::canCast;
using torch::executor::Error;
using torch::executor::KernelRuntimeContext;

namespace cadence {
namespace impl {
Expand All @@ -39,6 +40,7 @@ Tensor& add_out(
ScalarType common_type =
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());

#ifdef OP_ARG_CHECK
// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
Expand All @@ -62,12 +64,12 @@ Tensor& add_out(
torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok,
InvalidArgument,
out);
#endif

// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "add.out";

int kTensorDimensionLimit = 5;
Expand Down Expand Up @@ -253,6 +255,7 @@ Tensor& add_scalar_out(
torch::executor::native::utils::promote_type_with_scalar(
a.scalar_type(), b);

#ifdef OP_ARG_CHECK
// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
Expand All @@ -276,7 +279,7 @@ Tensor& add_scalar_out(
executorch::runtime::resize_tensor(out, a.sizes()) == Error::Ok,
InvalidArgument,
out);

#endif
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);
Expand Down
77 changes: 49 additions & 28 deletions backends/cadence/fusion_g3/operators/op_cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cstring>

#include <xa_nnlib_kernels_api.h>

#include <executorch/backends/cadence/fusion_g3/operators/tensor_util.h>
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <xa_nnlib_kernels_api.h>
#include <cstring>

using ::executorch::aten::ScalarType;
using ::executorch::aten::Tensor;
using ::executorch::runtime::Error;
using ::executorch::runtime::KernelRuntimeContext;
using exec_aten::Scalar;
using exec_aten::ScalarType;
using exec_aten::Tensor;
using torch::executor::Error;
using torch::executor::KernelRuntimeContext;

/* ScalarType in Executorch do not have support for below data types.
* So, creating a placeholder for these data types. Once, ScalarTypes is
Expand All @@ -39,13 +39,15 @@ Tensor& cat_out(
dim += out.dim();
}

int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit;

#ifdef OP_ARG_CHECK
ET_KERNEL_CHECK(
ctx,
torch::executor::check_cat_args(tensors, dim, out),
InvalidArgument,
out);

int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit;
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
size_t expected_out_dim = 0;
torch::executor::get_cat_out_target_size(
Expand All @@ -57,6 +59,20 @@ Tensor& cat_out(
out, {expected_out_size, expected_out_dim}) == Error::Ok,
InvalidArgument,
out);
#endif
// Special handling when all inputs are 1D-empty tensors for aten
// consistency In that case, just return an 1D-empty tensor without checking
// dim
bool all_1d_empty = true;
for (size_t i = 0; i < tensors.size(); ++i) {
if (tensors[i].numel() != 0 || tensors[i].dim() != 1) {
all_1d_empty = false;
break;
}
}
if (all_1d_empty) {
return out;
}

const signed char* inp_tensors[tensors.size()];
const int* inp_tensors_shapes[tensors.size()];
Expand Down Expand Up @@ -87,7 +103,10 @@ Tensor& cat_out(
}

if (out.scalar_type() == ScalarType::Int) {
xa_nn_cat(
XT_KERNEL_CHECK(
ctx,
out,
xa_nn_cat,
out_data,
out_shapes,
inp_tensors,
Expand All @@ -97,7 +116,10 @@ Tensor& cat_out(
(int)dim,
sizeof(int));
} else if (out.scalar_type() == ScalarType::Short) {
xa_nn_cat(
XT_KERNEL_CHECK(
ctx,
out,
xa_nn_cat,
out_data,
out_shapes,
inp_tensors,
Expand All @@ -107,7 +129,10 @@ Tensor& cat_out(
(int)dim,
sizeof(short));
} else if (out.scalar_type() == ScalarType::Char) {
xa_nn_cat(
XT_KERNEL_CHECK(
ctx,
out,
xa_nn_cat,
out_data,
out_shapes,
inp_tensors,
Expand All @@ -117,7 +142,10 @@ Tensor& cat_out(
(int)dim,
sizeof(char));
} else if (out.scalar_type() == (ScalarType)Uint) {
xa_nn_cat(
XT_KERNEL_CHECK(
ctx,
out,
xa_nn_cat,
out_data,
out_shapes,
inp_tensors,
Expand All @@ -127,7 +155,10 @@ Tensor& cat_out(
(int)dim,
sizeof(int));
} else if (out.scalar_type() == (ScalarType)Ushort) {
xa_nn_cat(
XT_KERNEL_CHECK(
ctx,
out,
xa_nn_cat,
out_data,
out_shapes,
inp_tensors,
Expand All @@ -137,7 +168,10 @@ Tensor& cat_out(
(int)dim,
sizeof(short));
} else if (out.scalar_type() == ScalarType::Byte) {
xa_nn_cat(
XT_KERNEL_CHECK(
ctx,
out,
xa_nn_cat,
out_data,
out_shapes,
inp_tensors,
Expand All @@ -148,19 +182,6 @@ Tensor& cat_out(
sizeof(char));

} else {
// Special handling when all inputs are 1D-empty tensors for aten
// consistency In that case, just return an 1D-empty tensor without checking
// dim
bool all_1d_empty = true;
for (size_t i = 0; i < tensors.size(); ++i) {
if (tensors[i].numel() != 0 || tensors[i].dim() != 1) {
all_1d_empty = false;
break;
}
}
if (all_1d_empty) {
return out;
}
const size_t outer = executorch::runtime::getLeadingDims(out, dim);
const size_t dim_stride = executorch::runtime::getTrailingDims(out, dim);
const size_t ninputs = tensors.size();
Expand Down
Loading

0 comments on commit 200c48a

Please sign in to comment.