From 11fcfc9417f7a2d344aca364ccfc0a95f359065c Mon Sep 17 00:00:00 2001 From: Jiyoung Yun Date: Thu, 11 Jan 2024 17:52:43 +0900 Subject: [PATCH 1/3] [cker] Introduce DepthwiseConv2D gradient kernel This commit introduces DepthwiseConv2D gradient kernel. The depthwise_conv_op.h file came from TensorFlow code. This commit transforms the depthwise_conv related code to fit onert and execute this kernel function. ONE-DCO-1.0-Signed-off-by: Jiyoung Yun --- .../include/cker/eigen/depthwise_conv_op.h | 956 ++++++++++++++++++ .../cker/train/operation/DepthwiseConv.h | 112 ++ compute/cker/src/train/DepthwiseConv.test.cc | 410 ++++++++ 3 files changed, 1478 insertions(+) create mode 100644 compute/cker/include/cker/eigen/depthwise_conv_op.h create mode 100644 compute/cker/include/cker/train/operation/DepthwiseConv.h create mode 100644 compute/cker/src/train/DepthwiseConv.test.cc diff --git a/compute/cker/include/cker/eigen/depthwise_conv_op.h b/compute/cker/include/cker/eigen/depthwise_conv_op.h new file mode 100644 index 00000000000..bf029972a34 --- /dev/null +++ b/compute/cker/include/cker/eigen/depthwise_conv_op.h @@ -0,0 +1,956 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2015 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_CKER_EIGEN_DEPTHWISE_CONV_OP_H__ +#define __NNFW_CKER_EIGEN_DEPTHWISE_CONV_OP_H__ + +// From tensorflow/core/kernels/depthwise_conv_grad_op.cc +#define EIGEN_USE_THREADS + +#include +#include "unsupported/Eigen/CXX11/Tensor" +#include "cker/operation/Helper/Tensor.h" + +// From tensorflow/core/kernels/depthwise_conv_op.h +namespace nnfw +{ +namespace cker +{ +namespace depthwise_conv_op +{ + +template struct LaunchDepthwiseConvBackpropInputOp +{ + void operator()(int batch, int in_rows, int in_cols, int in_depth, int filter_rows, + int filter_cols, int depth_multiplier, int stride, int pad_rows, int pad_cols, + int out_rows, int out_cols, int out_depth, const T *out_backprop, const T *filter, + T *in_backprop); +}; + +template struct LaunchDepthwiseConvBackpropFilterOp +{ + void operator()(int batch, int in_rows, int in_cols, int in_depth, int filter_rows, + int filter_cols, int depth_multiplier, int stride, int pad_rows, int pad_cols, + int out_rows, int out_cols, int out_depth, const T *out_backprop, const T *input, + T *filter_backprop); +}; + +namespace functor +{ + +// Pads 'filter' to vector-register boundary along its inner dimension: +// filter_inner_dim_size = in_depth * depth_multiplier +// Requires 'filter' to have the following storage order: +// [filter_rows, filter_cols, in_depth, depth_multiplier] +// Returns zero-padded filter in 'padded_filter'. +// +// EX: +// in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4 +// So we have a total of 3 * 2 = 6 filters, each of spatial size 2 x 2. +// +// filter [rows, cols, in_depth, depth_multiplier] +// [u0, v0, w0, x0] [y0, z0, u1, v1] [w1, x1, y1, z1] +// [u2, v2, w2, x2] [y2, z2, u3, v3] [w3, x3, y3, z3] +// +// padded_filter [rows, cols, in_depth, depth_multiplier] +// [u0, v0, w0, x0] [y0, z0, 0, 0] [u1, v1, w1, x1] [y1, z1, 0, 0] +// [u2, v2, w2, x2] [y2, z2, 0, 0] [u3, v3, w3, x3] [y3, z3, 0, 0] + +template struct DepthwiseFilterPadOp +{ + void operator()(int, int, int, int, int filter_rows, int filter_cols, int, int, int, int, int, + int, int out_depth, const T *filter, T *padded_filter) + { + typedef typename Eigen::internal::packet_traits::type Packet; + static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); + + // Calculate vectorized and scalar lengths of filter's inner dimension. + const int64_t filter_inner_dim_size = out_depth; + const int64_t vectorized_size = (filter_inner_dim_size / kPacketSize) * kPacketSize; + const int64_t scalar_size = filter_inner_dim_size - vectorized_size; + // Calculate required padding and padded output buffer stride. + const int64_t pad_size = scalar_size > 0 ? kPacketSize - scalar_size : 0; + const int64_t padded_filter_stride = vectorized_size + kPacketSize; + + const int64_t filter_spatial_size = filter_rows * filter_cols; + for (int64_t i = 0; i < filter_spatial_size; ++i) + { + const int64_t input_base = i * filter_inner_dim_size; + const int64_t output_base = i * padded_filter_stride; + // Write vectorized length of filter's inner dimension to output. + for (int64_t j = 0; j < vectorized_size; j += kPacketSize) + { + const auto v = Eigen::internal::ploadu(filter + input_base + j); + Eigen::internal::pstoreu(padded_filter + output_base + j, v); + } + // Write scalar length of filter's inner dimension to output. + for (int64_t j = 0; j < scalar_size; ++j) + { + padded_filter[output_base + vectorized_size + j] = filter[input_base + vectorized_size + j]; + } + // Pad the remainder of output to vector-register boundary. + for (int64_t j = 0; j < pad_size; ++j) + { + padded_filter[output_base + vectorized_size + scalar_size + j] = static_cast(0); + } + } + } +}; + +// Copies data from local region in 'input' specified by 'out_r' and 'out_'c' +// to 'input_buffer'. The copied data is replicated by factor +// 'depth_multiplier', and padded to vector register-width boundaries so +// that it is aligned for efficient traversal and vector multiply-add by the +// depthwise kernel. +// +// EX: +// in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4 +// +// input: [batch, in_rows, in_cols, in_depth] +// +// [a0, a1, a2, b0, b1, b2, ..., e0, e1, e2, f0, f1, f2, ...] +// +// input_buffer (register boundaries shown): +// [a0, a0, a1, a1] [a2, a2, 0, 0] in_row = 0, in_col = 0 +// [b0, b0, b1, b1] [b2, b2, 0, 0] in_row = 0, in_col = 1 +// [e0, e0, e1, e1] [e2, e2, 0, 0] in_row = 1, in_col = 0 +// [f0, f0, f1, f1] [f2, f2, 0, 0] in_row = 1, in_col = 1 +// +// Returns replicated and padded data from specified input region in +// 'input_buffer'. + +template struct DepthwiseInputCopyOp +{ + void operator()(int, int in_rows, int in_cols, int in_depth, int filter_rows, int filter_cols, + int depth_multiplier, int stride, int pad_rows, int pad_cols, int, int, + int out_depth, const int64_t padded_filter_inner_dim_size, const int64_t out_r, + const int64_t out_c, const T *input, T *input_buffer) + { + typedef typename Eigen::internal::packet_traits::type Packet; + static const int64_t kPacketSize = Eigen::internal::packet_traits::size; + + const int64_t kDepth = depth_multiplier; + // Calculate vectorized and scalar (residual) lengths for 'in_depth'. + const int64_t input_vectorized_size = (in_depth / kPacketSize) * kPacketSize; + const int64_t input_scalar_size = in_depth - input_vectorized_size; + + // Calculate output padding length. + const int64_t output_scalar_size = out_depth % kPacketSize; + const int64_t output_pad_size = output_scalar_size > 0 ? kPacketSize - output_scalar_size : 0; + + // Iterate through all rows x cols reading 'in_depth' from 'input' and + // replicating by 'depth_multiplier' into 'input_buffer' (otherwise + // zero-padding input buffer as needed). + auto *in_buf = input_buffer; + const int64_t in_r_start = out_r * stride - pad_rows; + const int64_t in_c_start = out_c * stride - pad_cols; + + // TODO: add a ploaddup variant for depth == 2 if needed. + if (kDepth > 1 && kDepth <= kPacketSize) + { + for (int64_t f_r = 0; f_r < filter_rows; ++f_r) + { + const int64_t in_r = in_r_start + f_r; + + for (int64_t f_c = 0; f_c < filter_cols; ++f_c) + { + const int64_t in_c = in_c_start + f_c; + + if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) + { + const auto *in = input + (in_r * in_cols + in_c) * in_depth; + int64_t limit = in_depth; + // This will overwrite up to kPacketSize next elements, + // this is ok on all iterations except the last one, since + // we will write correct values on a next iteration. + if (f_c == filter_cols - 1) + { + limit -= (kPacketSize - kDepth) / kDepth + 1; + if (limit < 0) + { + limit = 0; + } + } + // Copy vectorized portion of inner dimension. + for (int64_t d = 0; d < limit; d++) + { + const auto p = Eigen::internal::pset1(in[d]); + Eigen::internal::pstoreu(in_buf, p); + in_buf += kDepth; + } + + // Copy the scalar portion. + for (int64_t d = limit; d < in_depth; d++) + { + const auto value = in[d]; + for (int64_t dm = 0; dm < kDepth; dm++) + { + in_buf[dm] = value; + } + in_buf += kDepth; + } + + // Pad the remainder of the output to vector register boundary. + for (int64_t d = 0; d < output_pad_size; ++d) + { + in_buf[d] = static_cast(0); + } + in_buf += output_pad_size; + } + else + { + // Zero pad. + memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size); + in_buf += padded_filter_inner_dim_size; + } + } + } + } + else if (kDepth > kPacketSize) + { + // Calculate vectorized and scalar (residual) lengths for + // 'depth_multiplier'. This is used to efficiently replicate data for + // when 'depth_multiplier' > kPacketSize. + const int64_t dm_vectorized_size = (kDepth / kPacketSize) * kPacketSize; + + for (int64_t f_r = 0; f_r < filter_rows; ++f_r) + { + const int64_t in_r = in_r_start + f_r; + + for (int64_t f_c = 0; f_c < filter_cols; ++f_c) + { + const int64_t in_c = in_c_start + f_c; + + if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) + { + const auto *in = input + (in_r * in_cols + in_c) * in_depth; + // Copy vectorized portion of inner dimension. + for (int64_t d = 0; d < in_depth; d++) + { + const auto p = Eigen::internal::pset1(in[d]); + for (int64_t dm = 0; dm < dm_vectorized_size; dm += kPacketSize) + { + Eigen::internal::pstoreu(in_buf + dm, p); + } + // Overlapping store for the remainder. + Eigen::internal::pstoreu(in_buf + kDepth - kPacketSize, p); + in_buf += kDepth; + } + // Pad the remainder of the output to vector register boundary. + for (int64_t d = 0; d < output_pad_size; ++d) + { + in_buf[d] = static_cast(0); + } + in_buf += output_pad_size; + } + else + { + // Zero pad. + memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size); + in_buf += padded_filter_inner_dim_size; + } + } + } + } + else if (kDepth == 1) + { + for (int64_t f_r = 0; f_r < filter_rows; ++f_r) + { + const int64_t in_r = in_r_start + f_r; + + for (int64_t f_c = 0; f_c < filter_cols; ++f_c) + { + const int64_t in_c = in_c_start + f_c; + + if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) + { + const auto *in = input + (in_r * in_cols + in_c) * in_depth; + for (int64_t d = 0; d < input_vectorized_size; d += kPacketSize) + { + const auto p = Eigen::internal::ploadu(in + d); + Eigen::internal::pstoreu(in_buf, p); + in_buf += kPacketSize; + } + for (int64_t d = 0; d < input_scalar_size; ++d) + { + T v = in[input_vectorized_size + d]; + in_buf[d] = v; + } + in_buf += input_scalar_size; + + // Pad the remainder of the output to vector register boundary. + for (int64_t d = 0; d < output_pad_size; ++d) + { + in_buf[d] = static_cast(0); + } + in_buf += output_pad_size; + } + else + { + // Zero pad. + memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size); + in_buf += padded_filter_inner_dim_size; + } + } + } + } + } +}; + +} // namespace functor +} // namespace depthwise_conv_op +} // namespace cker +} // namespace nnfw + +// From tensorflow/core/kernels/depthwise_conv_grad_op.cc +namespace nnfw +{ +namespace cker +{ +namespace depthwise_conv_op +{ + +// Enable CPUDevice only for depthwise_conv_op +using CPUDevice = Eigen::ThreadPoolDevice; + +// Copies data from local region in 'out_backprop' into 'buffer'. +// The local region coordinates are calculated as the set of output points which +// used the input point ('in_r', 'in_'c') as input during the forward pass. +// Rather than spatially reversing the filter, the input is reversed during +// the copy. The copied data is padded to vector register-width boundaries so +// that it is aligned for efficient traversal and vector multiply-add by the +// depthwise input kernel. +// +// EX: +// in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4 +// +// 'out_backprop': [batch, out_rows, out_cols, out_depth] +// +// [a00, a01, a10, a11] [a20, a21, b00, b01] +// [b10, b11, b20, b21] [...] +// [e00, e01, e10, e11] [e20, e21, f00, f01] +// [f10, f11, f20, f21] [...] +// +// 'buffer' (register boundaries shown): +// +// [f00, f01, f10, f11] [f20, f21, 0, 0] in_row = 0, in_col = 0 +// [e00, e01, e10, e11] [e20, e21, 0, 0] in_row = 0, in_col = 1 +// [b00, b01, b10, b11] [b20, b21, 0, 0] in_row = 1, in_col = 0 +// [a00, a01, a10, a11] [a20, a21, 0, 0] in_row = 1, in_col = 1 +// +template +static void CopyOutputBackpropRegion(int, int, int, int, int filter_rows_, int filter_cols_, int, + int stride_, int pad_rows_, int pad_cols_, int out_rows_, + int out_cols_, int out_depth, + const int64_t padded_filter_inner_dim_size, const int64_t in_r, + const int64_t in_c, const T *out_backprop, T *buffer) +{ + typedef typename Eigen::internal::packet_traits::type Packet; + static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); + + const int64_t stride = stride_; + const int64_t filter_rows = filter_rows_; + const int64_t filter_cols = filter_cols_; + const int64_t pad_rows = pad_rows_; + const int64_t pad_cols = pad_cols_; + const int64_t out_rows = out_rows_; + const int64_t out_cols = out_cols_; + + // Calculate the output spatial region which used point (in_r, in_c) as input. + const int64_t out_r_start = + std::max(static_cast(0), (in_r - filter_rows + pad_rows + stride) / stride); + const int64_t out_r_end = std::min(out_rows - 1, (in_r + pad_rows) / stride); + const int64_t out_c_start = + std::max(static_cast(0), (in_c - filter_cols + pad_cols + stride) / stride); + const int64_t out_c_end = std::min(out_cols - 1, (in_c + pad_cols) / stride); + + // Zero-pad 'buffer' if output region is smaller than filter spatial size. + const int64_t filter_spatial_size = filter_rows * filter_cols; + if ((out_r_end - out_r_start + 1) < filter_rows || (out_c_end - out_c_start + 1) < filter_cols) + { + memset(buffer, 0, filter_spatial_size * padded_filter_inner_dim_size * sizeof(T)); + } + + // Calculate vectorized and scalar (residual) lengths for 'in_depth'. + const int64_t vectorized_size = (out_depth / kPacketSize) * kPacketSize; + const int64_t scalar_size = out_depth % kPacketSize; + const int64_t pad_size = scalar_size > 0 ? kPacketSize - scalar_size : 0; + + for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) + { + const int64_t f_r = in_r + pad_rows - out_r * stride; + for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) + { + const int64_t f_c = in_c + pad_cols - out_c * stride; + const int64_t buf_base = (f_r * filter_cols + f_c) * padded_filter_inner_dim_size; + // Calculate index into 'out_backprop' for coordinate (out_r, out_c). + auto *out_bprop = out_backprop + (out_r * out_cols + out_c) * out_depth; + + // Copy vectorized portion of inner dimension into 'buffer'. + for (int64_t d = 0; d < vectorized_size; d += kPacketSize) + { + auto v = Eigen::internal::ploadu(out_bprop + d); + Eigen::internal::pstoreu(buffer + buf_base + d, v); + } + // Copy scalar portion of out_bprop to 'buffer' + for (int64_t d = 0; d < scalar_size; ++d) + { + buffer[buf_base + vectorized_size + d] = out_bprop[vectorized_size + d]; + } + // Pad to vector-register width (if needed). + for (int64_t d = 0; d < pad_size; ++d) + { + buffer[buf_base + vectorized_size + scalar_size + d] = static_cast(0); + } + } + } +} + +// Computes the vectorized product of 'buffer' and 'filter' and stores +// result in 'output' at location computed from 'in_r' and 'in_c'. +// If depth_multiplier is > 1, the intermediate output is reduced along +// the depth_multiplier dimension. +// +// EX: +// in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4 +// Both 'input_buffer' and 'filter' are padded to register-width boundaries. +// +// 'buffer' [rows, cols, in_depth, depth_multiplier] +// +// [f00, f01, f10, f11] [f20, f21, 0, 0] in_row = 0, in_col = 0 +// [e00, e01, e10, e11] [e20, e21, 0, 0] in_row = 0, in_col = 1 +// [b00, b01, b10, b11] [b20, b21, 0, 0] in_row = 1, in_col = 0 +// [a00, a01, a10, a11] [a20, a21, 0, 0] in_row = 1, in_col = 1 +// +// filter [rows, cols, in_depth, depth_multiplier] +// [u0, v0, w0, x0] [y0, z0, 0, 0] [u1, v1, w1, x1] [y1, z1, 0, 0] +// [u2, v2, w2, x2] [y2, z2, 0, 0] [u3, v3, w3, x3] [y3, z3, 0, 0] +// +// First output register [in_depth, depth_multiplier] +// [q00, q01, q10, q11] = ([f00, f01, f10, f11] x [u0, v0, w0, x0]) + +// ([e00, e01, e10, e11] x [u1, v1, w1, x1]) + +// ([b00, b01, b10, b11] x [u2, v2, w2, x2]) + +// ([a00, a01, a10, a11] x [u3, v3, w3, x3]) +// +// Reduction step along depth-multiplier dimension: +// +// [q00, q01, q10, q11] [q20, q21, 0, 0] -> [r0, r1, r2, 0] +// + +template +static void ComputeBackpropInput(int, int, int in_cols, int in_depth_, int filter_rows, + int filter_cols, int depth_multiplier_, int, int, int, int, int, + int out_depth_, const int64_t padded_filter_inner_dim_size, + const int64_t in_r, const int64_t in_c, const T *filter, + const T *buffer, T *out_buffer, T *output) +{ + typedef typename Eigen::internal::packet_traits::type Packet; + static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); + + const int64_t in_depth = in_depth_; + const int64_t depth_multiplier = depth_multiplier_; + const int64_t out_depth = out_depth_; + const int64_t filter_spatial_size = filter_rows * filter_cols; + + // Calculate vectorized and scalar lengths of 'out_depth'. + const int64_t output_vectorized_size = (out_depth / kPacketSize) * kPacketSize; + const int64_t output_scalar_size = out_depth % kPacketSize; + + // Calculate base index at which to begin writing output. + const int64_t base_output_index = (in_r * in_cols + in_c) * in_depth; + + // Calculate vectorized and scalar lengths for 'depth_multiplier'. This is + // used to efficiently reduce output when 'depth_multiplier' > kPacketSize. + const int64_t dm_vectorized_size = (depth_multiplier / kPacketSize) * kPacketSize; + const int64_t dm_scalar_size = depth_multiplier % kPacketSize; + + for (int i = 0; i < output_vectorized_size; i += kPacketSize) + { + // Reset accumulator. + auto vaccum = Eigen::internal::pset1(static_cast(0)); + for (int j = 0; j < filter_spatial_size; ++j) + { + // Calculate index. + const int64_t index = i + j * padded_filter_inner_dim_size; + // Load filter. + const auto filter_block = Eigen::internal::ploadu(filter + index); + // Load input. + const auto data_block = Eigen::internal::ploadu(buffer + index); + // Vector multiply-add. + vaccum = Eigen::internal::pmadd(filter_block, data_block, vaccum); + } + if (depth_multiplier == 1) + { + // Write directly to the output. + Eigen::internal::pstoreu(output + base_output_index + i, vaccum); + } + else + { + // Buffer output for subsequent reduction step. + Eigen::internal::pstoreu(out_buffer + i, vaccum); + } + } + + if (output_scalar_size > 0) + { + auto vaccum = Eigen::internal::pset1(static_cast(0)); + for (int j = 0; j < filter_spatial_size; ++j) + { + const int64_t index = output_vectorized_size + j * padded_filter_inner_dim_size; + const auto filter_block = Eigen::internal::ploadu(filter + index); + const auto data_block = Eigen::internal::ploadu(buffer + index); + vaccum = Eigen::internal::pmadd(filter_block, data_block, vaccum); + } + // Load accumulator into an array and loop through output. + T out_buf[kPacketSize]; + Eigen::internal::pstoreu(out_buf, vaccum); + if (depth_multiplier == 1) + { + // Write directly to the output. + for (int j = 0; j < output_scalar_size; ++j) + { + output[base_output_index + output_vectorized_size + j] = out_buf[j]; + } + } + else + { + // Buffer output for subsequent reduction step. + for (int j = 0; j < output_scalar_size; ++j) + { + out_buffer[output_vectorized_size + j] = out_buf[j]; + } + } + } + + // Iterate over 'in_depth', reduce over 'depth_multiplier', write 'output'. + if (depth_multiplier > 1) + { + for (int64_t d = 0; d < in_depth; ++d) + { + const int64_t index = d * depth_multiplier; + T accum = static_cast(0); + for (int64_t dm = 0; dm < dm_vectorized_size; dm += kPacketSize) + { + const auto v = Eigen::internal::ploadu(out_buffer + index + dm); + accum += Eigen::internal::predux(v); + } + // Copy scalar portion of replicated output. + for (int64_t dm = 0; dm < dm_scalar_size; ++dm) + { + accum += out_buffer[index + dm_vectorized_size + dm]; + } + // Copy to output. + output[base_output_index + d] = accum; + } + } +} + +// Computes the depthwise conv2d backprop input of 'out_backprop' by +// 'depthwise_filter' and stores the result in 'in_backprop'. +template struct LaunchDepthwiseConvBackpropInputOp +{ + typedef typename Eigen::internal::packet_traits::type Packet; + + void operator()(int batch, int in_rows, int in_cols, int in_depth, int filter_rows, + int filter_cols, int depth_multiplier, int stride, int pad_rows, int pad_cols, + int out_rows, int out_cols, int out_depth, const T *out_backprop, + const T *depthwise_filter, T *padded_filter_data, T *in_backprop, bool pad_filter, + T *out_bprop, T *in_bprop) + { + const Eigen::ThreadPoolDevice &d = *eigen_support::GetThreadPoolDevice(); + + // Pad 'depthwise_filter' to vector register width (if needed). + if (pad_filter) + { + // Write out padded filter. + functor::DepthwiseFilterPadOp()( + batch, in_rows, in_cols, in_depth, filter_rows, filter_cols, depth_multiplier, stride, + pad_rows, pad_cols, out_rows, out_cols, out_depth, depthwise_filter, padded_filter_data); + } + const T *filter_data = pad_filter ? padded_filter_data : depthwise_filter; + + // Computes one shard of depthwise conv2d backprop input. + auto shard = [d, in_rows, in_cols, in_depth, out_rows, out_cols, out_depth, batch, filter_rows, + filter_cols, depth_multiplier, stride, pad_rows, pad_cols, out_backprop, + filter_data, in_backprop, out_bprop, in_bprop](int64_t start, int64_t limit) { + static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); + + const int64_t input_image_size = in_rows * in_cols * in_depth; + const int64_t output_image_size = out_rows * out_cols * out_depth; + const int64_t filter_spatial_size = filter_rows * filter_cols; + const int64_t padded_filter_inner_dim_size = + ((out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize; + const int64_t out_bprop_size = filter_spatial_size * padded_filter_inner_dim_size; + + int cur_id = d.currentThreadId() + 1; + assert(cur_id >= 0 && cur_id < d.numThreads() + 1); + + // Use out_bprop buffer to copy regions from 'out_backprop'. + T *out_bprop_buf = out_bprop + cur_id * out_bprop_size; + + // Use in_bprop buffer for intermediate results. + T *in_bprop_buf = in_bprop + cur_id * padded_filter_inner_dim_size; + + for (int64_t b = start; b < limit; ++b) + { + for (int64_t in_r = 0; in_r < in_rows; ++in_r) + { + for (int64_t in_c = 0; in_c < in_cols; ++in_c) + { + // Populate 'out_bprop_buf' from local 'out_backprop' region. + CopyOutputBackpropRegion(batch, in_rows, in_cols, in_depth, filter_rows, filter_cols, + depth_multiplier, stride, pad_rows, pad_cols, out_rows, + out_cols, out_depth, padded_filter_inner_dim_size, in_r, + in_c, out_backprop + b * output_image_size, out_bprop_buf); + + // Compute depthwise backprop input. + ComputeBackpropInput( + batch, in_rows, in_cols, in_depth, filter_rows, filter_cols, depth_multiplier, stride, + pad_rows, pad_cols, out_rows, out_cols, out_depth, padded_filter_inner_dim_size, in_r, + in_c, filter_data, out_bprop_buf, in_bprop_buf, in_backprop + b * input_image_size); + } + } + } + }; + + const int64_t input_bytes = out_rows * out_cols * out_depth * sizeof(T); + const int64_t output_bytes = in_rows * in_cols * in_depth * sizeof(T); + const int64_t compute_cycles = in_rows * in_cols * out_depth * batch; + const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles); + d.parallelFor(batch, cost, shard); + } +}; + +template +static void +DepthwiseConvBackpropInputReference(int batch, int in_rows, int in_cols, int in_depth, int out_rows, + int out_cols, int out_depth, int stride, int depth_multiplier, + int filter_rows, int filter_cols, int pad_rows, int pad_cols, + const T *out_backprop, const T *filter, T *in_backprop) +{ + // Naive for loop as a reference point without concerns about performance. + for (int b = 0; b < batch; ++b) + { + for (int in_r = 0; in_r < in_rows; ++in_r) + { + for (int in_c = 0; in_c < in_cols; ++in_c) + { + for (int in_d = 0; in_d < in_depth; ++in_d) + { + T sum = 0; + const int out_d_start = in_d * depth_multiplier; + const int out_d_end = out_d_start + depth_multiplier; + + for (int out_d = out_d_start; out_d < out_d_end; ++out_d) + { + const int out_r_start = std::max(0, (in_r - filter_rows + pad_rows + stride) / stride); + const int out_r_end = std::min(out_rows - 1, (in_r + pad_rows) / stride); + + for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) + { + const int out_c_start = + std::max(0, (in_c - filter_cols + pad_cols + stride) / stride); + const int out_c_end = std::min(out_cols - 1, (in_c + pad_cols) / stride); + + for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) + { + int f_r = in_r + pad_rows - out_r * stride; + int f_c = in_c + pad_cols - out_c * stride; + int filter_dm = out_d - out_d_start; + int out_backprop_offset = + out_d + out_depth * (out_c + out_cols * (out_r + out_rows * b)); + int filter_offset = + filter_dm + depth_multiplier * (in_d + in_depth * (f_c + filter_cols * f_r)); + sum += out_backprop[out_backprop_offset] * filter[filter_offset]; + } + } + } + + int in_backprop_offset = in_d + in_depth * (in_c + in_cols * (in_r + in_rows * b)); + in_backprop[in_backprop_offset] = sum; + } + } + } + } +} + +// Kernels to compute the gradients of the filters for depthwise convolution. + +// Computes filter backprop using 'out_backprop' and 'input_buffer', storing the +// result in 'output_buffer' at an index computed from 'out_r' and 'out_c'. +// +// EX: +// in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4 +// Both 'input_buffer' and 'filter' are padded to register-width boundaries. +// +// 'input_buffer' [rows, cols, in_depth, depth_multiplier] +// +// [f00, f01, f10, f11] [f20, f21, 0, 0] in_row = 0, in_col = 0 +// [e00, e01, e10, e11] [e20, e21, 0, 0] in_row = 0, in_col = 1 +// [b00, b01, b10, b11] [b20, b21, 0, 0] in_row = 1, in_col = 0 +// [a00, a01, a10, a11] [a20, a21, 0, 0] in_row = 1, in_col = 1 +// +// 'out_backprop' [out_rows, out_cols, in_depth, depth_multiplier] +// +// [q00, q01, q10, q11] [q20, q21, r00, r01] +// [r10, r11, r20, r21] [s00, s01, s10, s11] +// [s20, s21, t00, t01] [t10, t11, t20, a21] +// +// First output register of 'filter_backprop' +// [u0, v0, w0, x0] += ([f00, f01, f10, f11] x [q00, q01, q10, q11]) +// +template +static void ComputeBackpropFilter(int, int, int, int, int filter_rows, int filter_cols, int, int, + int, int, int out_rows, int out_cols, int out_depth_, + const int64_t padded_out_depth_size, const int64_t out_r, + const int64_t out_c, const T *out_backprop, const T *input_buffer, + T *output_buffer) +{ + typedef typename Eigen::internal::packet_traits::type Packet; + static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); + // Calculate vectorized size of 'padded_out_depth_size'. + const int64_t out_depth = out_depth_; + const int64_t filter_spatial_size = filter_rows * filter_cols; + const int64_t output_vectorized_size = (padded_out_depth_size / kPacketSize) * kPacketSize; + const int64_t base_output_index = (out_r * out_cols + out_c) * out_depth; + // Determine whether we can execute fast or slow code path. + const int64_t output_image_size = out_rows * out_cols * out_depth; + const int64_t output_last_vector_index = + output_image_size - (filter_spatial_size * padded_out_depth_size); + const bool fast_path = base_output_index <= output_last_vector_index; + + if (fast_path) + { + // TODO(andydavis) Process multiple inputs in 'input_buffer' so we can + // amortize the cost of 'output_buffer' load store in the loop below. + for (int i = 0; i < output_vectorized_size; i += kPacketSize) + { + // Load vector register from 'out_backprop'. + const auto out_bprop_block = + Eigen::internal::ploadu(out_backprop + base_output_index + i); + for (int j = 0; j < filter_spatial_size; ++j) + { + const int64_t index = i + j * padded_out_depth_size; + // Load vector register from 'input_buffer'. + const auto input_block = Eigen::internal::ploadu(input_buffer + index); + // Load output block into vector register. + auto out_block_data = output_buffer + index; + auto out_block = Eigen::internal::ploadu(out_block_data); + // Vector multiply-add. + out_block = Eigen::internal::pmadd(out_bprop_block, input_block, out_block); + // Store 'out_block' back to memory. + Eigen::internal::pstoreu(out_block_data, out_block); + } + } + } + else + { + // Slow path (cant do vector reads from non-padded 'out_backprop'. + for (int i = 0; i < output_vectorized_size; i += kPacketSize) + { + // Calculate safe read size from 'out_backprop'. + const int64_t out_bprop_index = base_output_index + i; + const int64_t out_bprop_limit = std::min(output_image_size, out_bprop_index + kPacketSize); + T out_buf[kPacketSize]; + memset(&out_buf, 0, kPacketSize * sizeof(T)); + const int64_t scalar_size = out_bprop_limit - out_bprop_index; + for (int64_t j = 0; j < scalar_size; ++j) + { + out_buf[j] = out_backprop[out_bprop_index + j]; + } + // Load vector register from 'out_buf'. + const auto out_bprop_block = Eigen::internal::ploadu(out_buf); + for (int j = 0; j < filter_spatial_size; ++j) + { + const int64_t index = i + j * padded_out_depth_size; + // Load vector register from 'input_buffer'. + const auto input_block = Eigen::internal::ploadu(input_buffer + index); + // Load output block into vector register. + auto out_block_data = output_buffer + index; + auto out_block = Eigen::internal::ploadu(out_block_data); + // Vector multiply-add. + out_block = Eigen::internal::pmadd(out_bprop_block, input_block, out_block); + // Store 'out_block' back to memory. + Eigen::internal::pstoreu(out_block_data, out_block); + } + } + } +} + +template struct LaunchDepthwiseConvBackpropFilterOp +{ + typedef typename Eigen::internal::packet_traits::type Packet; + + void operator()(int batch, int in_rows, int in_cols, int in_depth, int filter_rows, + int filter_cols, int depth_multiplier, int stride, int pad_rows, int pad_cols, + int out_rows, int out_cols, int out_depth, const T *out_backprop, const T *input, + T *filter_backprop, T *padded_filter_data, T *in_bprop) + { + const Eigen::ThreadPoolDevice &d = *eigen_support::GetThreadPoolDevice(); + + static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); + + const int64_t filter_spatial_size = filter_rows * filter_cols; + const int64_t padded_out_depth_size = + ((out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize; + + T *output_buffer_data = padded_filter_data; + + // Computes one shard of depthwise conv2d backprop filter. + // auto shard = [&ctx, &args, &out_backprop, &input, &output_buffer_data]( + auto shard = [&](int64_t start, int64_t limit) { + static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); + const int64_t filter_spatial_size = filter_rows * filter_cols; + const int64_t padded_out_depth_size = + ((out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize; + + int cur_id = d.currentThreadId() + 1; + assert(cur_id >= 0 && cur_id < d.numThreads() + 1); + + const int64_t input_image_size = in_rows * in_cols * in_depth; + const int64_t output_image_size = out_rows * out_cols * out_depth; + const int64_t padded_filter_size = filter_spatial_size * padded_out_depth_size; + + T *input_buffer_data = in_bprop + cur_id * padded_filter_size; + + for (int b = start; b < limit; ++b) + { + // Initialize 'output_buffer' for 'b'. + auto *output_buffer = output_buffer_data + b * padded_filter_size; + memset(output_buffer, 0, padded_filter_size * sizeof(T)); + + for (int out_r = 0; out_r < out_rows; ++out_r) + { + for (int out_c = 0; out_c < out_cols; ++out_c) + { + // Populate 'input_buffer_data' with data from local input region. + functor::DepthwiseInputCopyOp()( + batch, in_rows, in_cols, in_depth, filter_rows, filter_cols, depth_multiplier, stride, + pad_rows, pad_cols, out_rows, out_cols, out_depth, padded_out_depth_size, out_r, + out_c, input + b * input_image_size, input_buffer_data); + // Compute depthwise backprop filter. + ComputeBackpropFilter( + batch, in_rows, in_cols, in_depth, filter_rows, filter_cols, depth_multiplier, stride, + pad_rows, pad_cols, out_rows, out_cols, out_depth, padded_out_depth_size, out_r, + out_c, out_backprop + b * output_image_size, input_buffer_data, output_buffer); + } + } + } + }; + + const int64_t input_bytes = in_rows * in_cols * in_depth * sizeof(T); + const int64_t output_bytes = out_rows * out_cols * out_depth * sizeof(T); + const int64_t compute_cycles = out_rows * out_cols * out_depth * batch; + const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles); + d.parallelFor(batch, cost, shard); + + // Accumulate 'output_buffer' from each shard into 'output'. + // const int64_t out_depth = out_depth; + const int64_t vectorized_size = (out_depth / kPacketSize) * kPacketSize; + const int64_t scalar_size = out_depth - vectorized_size; + const int64_t padded_filter_size = filter_spatial_size * padded_out_depth_size; + memset(filter_backprop, 0, filter_spatial_size * out_depth * sizeof(T)); + + for (int64_t i = 0; i < filter_spatial_size; ++i) + { + const int64_t buffer_base = i * padded_out_depth_size; + const int64_t output_base = i * out_depth; + // Write vectorized length of filter's inner dimension to output. + for (int64_t j = 0; j < vectorized_size; j += kPacketSize) + { + // Load data from 'filter_backprop' into vector register. + auto out_block_data = filter_backprop + output_base + j; + auto out_block = Eigen::internal::ploadu(out_block_data); + for (int b = 0; b < batch; ++b) + { + // Load data from 'output_buffer' for 'b'. + const auto *output_buffer = output_buffer_data + b * padded_filter_size; + const auto v = Eigen::internal::ploadu(output_buffer + buffer_base + j); + // Add 'v' to 'out_block'. + out_block = Eigen::internal::padd(out_block, v); + } + // Store 'out_block' back to memory. + Eigen::internal::pstoreu(out_block_data, out_block); + } + // Write scalar length of filter's inner dimension to output. + for (int64_t j = 0; j < scalar_size; ++j) + { + for (int b = 0; b < batch; ++b) + { + const auto *output_buffer = output_buffer_data + b * padded_filter_size; + filter_backprop[output_base + vectorized_size + j] += + output_buffer[buffer_base + vectorized_size + j]; + } + } + } + } +}; + +template +static void DepthwiseConvBackpropFilterReference(int batch, int in_rows, int in_cols, int in_depth, + int out_rows, int out_cols, int out_depth, + int stride, int depth_multiplier, int filter_rows, + int filter_cols, int pad_rows, int pad_cols, + const T *out_backprop, const T *input, + T *filter_backprop) +{ + int num_filter_backprop = filter_rows * filter_cols * in_depth * depth_multiplier; + memset(filter_backprop, 0, num_filter_backprop * sizeof(T)); + // Naive for loop as a reference point without concerns about performance. + for (int b = 0; b < batch; ++b) + { + for (int out_r = 0; out_r < out_rows; ++out_r) + { + for (int out_c = 0; out_c < out_cols; ++out_c) + { + for (int out_d = 0; out_d < out_depth; ++out_d) + { + const int in_d = out_d / depth_multiplier; + const int dm = out_d % depth_multiplier; + const int in_r_start = out_r * stride - pad_rows; + const int in_c_start = out_c * stride - pad_cols; + + for (int f_r = 0; f_r < filter_rows; ++f_r) + { + for (int f_c = 0; f_c < filter_cols; ++f_c) + { + const int in_r = in_r_start + f_r; + const int in_c = in_c_start + f_c; + + if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) + { + int out_backprop_offset = + out_d + out_depth * (out_c + out_cols * (out_r + out_rows * b)); + int input_offset = in_d + in_depth * (in_c + in_cols * (in_r + in_rows * b)); + int filter_backprop_offset = + dm + depth_multiplier * (in_d + in_depth * (f_c + filter_cols * f_r)); + filter_backprop[filter_backprop_offset] += + input[input_offset] * out_backprop[out_backprop_offset]; + } + } + } + } + } + } + } +} + +} // namespace depthwise_conv_op +} // namespace cker +} // namespace nnfw + +#endif // __NNFW_CKER_EIGEN_DEPTHWISE_CONV_OP_H__ diff --git a/compute/cker/include/cker/train/operation/DepthwiseConv.h b/compute/cker/include/cker/train/operation/DepthwiseConv.h new file mode 100644 index 00000000000..6546253f70a --- /dev/null +++ b/compute/cker/include/cker/train/operation/DepthwiseConv.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_CKER_TRAIN_OPERATION_DEPTHWISECONV_H__ +#define __NNFW_CKER_TRAIN_OPERATION_DEPTHWISECONV_H__ + +#include "cker/eigen/depthwise_conv_op.h" +#include "cker/Shape.h" +#include "cker/Types.h" + +namespace nnfw +{ +namespace cker +{ +namespace train +{ + +class DepthwiseConv +{ +public: + DepthwiseConv() = default; + + template int64_t kPacketSize() const + { + typedef typename Eigen::internal::packet_traits::type Packet; + return sizeof(Packet) / sizeof(T); + } + + int getThreadCount() const + { + const Eigen::ThreadPoolDevice &d = *eigen_support::GetThreadPoolDevice(); + return d.numThreads(); + } + + template + void backpropInput(const DepthwiseConvParams ¶ms, const Shape &incoming_shape, + const T *incoming_data, const Shape &filter_shape, const T *filter_data, + T *padded_filter_data, const Shape &grad_shape, T *grad_data, bool pad_filter, + T *filter_buffers_data, T *filter_dim_buffers_data) + { + if (params.stride_height != params.stride_width) + throw std::runtime_error("Not support different length strides"); + + const int batch = MatchingDim(incoming_shape, 0, grad_shape, 0); + const int input_depth = grad_shape.Dims(3); + const int output_depth = incoming_shape.Dims(3); + const int incoming_height = incoming_shape.Dims(1); + const int incoming_width = incoming_shape.Dims(2); + const int grad_height = grad_shape.Dims(1); + const int grad_width = grad_shape.Dims(2); + const int stride = params.stride_height; + const int depth_multiplier = params.depth_multiplier; + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int pad_height = params.padding_values.height; + const int pad_width = params.padding_values.width; + + depthwise_conv_op::LaunchDepthwiseConvBackpropInputOp()( + batch, grad_height, grad_width, input_depth, filter_height, filter_width, depth_multiplier, + stride, pad_height, pad_width, incoming_height, incoming_width, output_depth, incoming_data, + filter_data, padded_filter_data, grad_data, pad_filter, filter_buffers_data, + filter_dim_buffers_data); + } + + template + void backpropFilter(const DepthwiseConvParams ¶ms, const Shape &incoming_shape, + const T *incoming_data, const Shape &input_shape, const T *input_data, + const Shape &filter_grad_shape, T *filter_grad_data, T *padded_filter_data, + T *filter_buffers_data) + { + if (params.stride_height != params.stride_width) + throw std::runtime_error("Not support different length strides"); + + const int batch = MatchingDim(incoming_shape, 0, input_shape, 0); + const int input_depth = input_shape.Dims(3); + const int output_depth = incoming_shape.Dims(3); + const int incoming_height = incoming_shape.Dims(1); + const int incoming_width = incoming_shape.Dims(2); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int stride = params.stride_height; + const int depth_multiplier = params.depth_multiplier; + const int filter_height = filter_grad_shape.Dims(1); + const int filter_width = filter_grad_shape.Dims(2); + const int pad_height = params.padding_values.height; + const int pad_width = params.padding_values.width; + + depthwise_conv_op::LaunchDepthwiseConvBackpropFilterOp()( + batch, input_width, input_height, input_depth, filter_width, filter_height, depth_multiplier, + stride, pad_width, pad_height, incoming_width, incoming_height, output_depth, incoming_data, + input_data, filter_grad_data, padded_filter_data, filter_buffers_data); + } +}; + +} // namespace train +} // namespace cker +} // namespace nnfw + +#endif // __NNFW_CKER_TRAIN_OPERATION_DEPTHWISECONV_H__ diff --git a/compute/cker/src/train/DepthwiseConv.test.cc b/compute/cker/src/train/DepthwiseConv.test.cc new file mode 100644 index 00000000000..a85ff43eda8 --- /dev/null +++ b/compute/cker/src/train/DepthwiseConv.test.cc @@ -0,0 +1,410 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +namespace +{ + +template class DepthwiseConvVerifier +{ +public: + DepthwiseConvVerifier() : _dconv_kernel{new nnfw::cker::train::DepthwiseConv()} + { + _dconv_kernel = std::make_unique(); + } + + void prepare(const nnfw::cker::Shape &incoming_shape, const nnfw::cker::Shape &filter_shape) + { + const int kPacketSize = _dconv_kernel->kPacketSize(); + const int batch = incoming_shape.Dims(0); + const int out_depth = incoming_shape.Dims(3); + const int filter_rows = filter_shape.Dims(1); + const int filter_cols = filter_shape.Dims(2); + const int filter_spatial_size = filter_rows * filter_cols; + const int padded_filter_inner_dim_size = + ((out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize; + + _use_padded_filter = (out_depth % kPacketSize) == 0 ? false : true; + { + nnfw::cker::Shape padded_filter_shape( + {batch, filter_spatial_size, padded_filter_inner_dim_size}); + _padded_filter.resize(padded_filter_shape.FlatSize()); + } + + { + const int thread_count = _dconv_kernel->getThreadCount() + 1; + + nnfw::cker::Shape filter_buffer_shape( + {thread_count, filter_spatial_size, padded_filter_inner_dim_size}); + _filter_buffers.resize(filter_buffer_shape.FlatSize()); + + nnfw::cker::Shape filter_dim_buffer_shape({thread_count, padded_filter_inner_dim_size}); + _filter_dim_buffers.resize(filter_dim_buffer_shape.FlatSize()); + } + } + + void verifyInputGradExpected(const nnfw::cker::DepthwiseConvParams ¶ms, + const nnfw::cker::Shape &incoming_shape, const T *incoming_data, + const nnfw::cker::Shape &filter_shape, const T *filter_data, + const nnfw::cker::Shape &grad_shape) + { + std::vector gradient(grad_shape.FlatSize(), static_cast(0)); + std::vector expected(grad_shape.FlatSize(), static_cast(0)); + + calculateInputGradExpected(params, incoming_shape, incoming_data, filter_shape, filter_data, + grad_shape, expected.data()); + + _dconv_kernel->backpropInput(params, incoming_shape, incoming_data, filter_shape, filter_data, + _padded_filter.data(), grad_shape, gradient.data(), + _use_padded_filter, _filter_buffers.data(), + _filter_dim_buffers.data()); + + for (size_t i = 0; i < gradient.size(); ++i) + EXPECT_NEAR(gradient[i], expected[i], 1e-3f); + } + + void throwInputGradExpected(const nnfw::cker::DepthwiseConvParams ¶ms, + const nnfw::cker::Shape &incoming_shape, const T *incoming_data, + const nnfw::cker::Shape &filter_shape, const T *filter_data, + const nnfw::cker::Shape &grad_shape) + { + std::vector gradient(grad_shape.FlatSize(), static_cast(0)); + + EXPECT_ANY_THROW(_dconv_kernel->backpropInput( + params, incoming_shape, incoming_data, filter_shape, filter_data, _padded_filter.data(), + grad_shape, gradient.data(), _use_padded_filter, _filter_buffers.data(), + _filter_dim_buffers.data())); + } + + void verifyFilterGradExpected(const nnfw::cker::DepthwiseConvParams ¶ms, + const nnfw::cker::Shape &incoming_shape, const T *incoming_data, + const nnfw::cker::Shape &input_shape, const T *input_data, + const nnfw::cker::Shape &filter_grad_shape) + { + std::vector gradient(filter_grad_shape.FlatSize(), static_cast(0)); + std::vector expected(filter_grad_shape.FlatSize(), static_cast(0)); + + calculateFilterGradExpected(params, incoming_shape, incoming_data, input_shape, input_data, + filter_grad_shape, expected.data()); + + _dconv_kernel->backpropFilter(params, incoming_shape, incoming_data, input_shape, input_data, + filter_grad_shape, gradient.data(), _padded_filter.data(), + _filter_buffers.data()); + + for (size_t i = 0; i < gradient.size(); ++i) + EXPECT_NEAR(gradient[i], expected[i], 1e-3f); + } + + void throwFilterGradExpected(const nnfw::cker::DepthwiseConvParams ¶ms, + const nnfw::cker::Shape &incoming_shape, const T *incoming_data, + const nnfw::cker::Shape &input_shape, const T *input_data, + const nnfw::cker::Shape &filter_grad_shape) + { + std::vector gradient(filter_grad_shape.FlatSize(), static_cast(0)); + + EXPECT_ANY_THROW(_dconv_kernel->backpropFilter( + params, incoming_shape, incoming_data, input_shape, input_data, filter_grad_shape, + gradient.data(), _padded_filter.data(), _filter_buffers.data())); + } + +private: + void calculateInputGradExpected(const nnfw::cker::DepthwiseConvParams ¶ms, + const nnfw::cker::Shape &incoming_shape, const T *incoming_data, + const nnfw::cker::Shape &filter_shape, const T *filter_data, + const nnfw::cker::Shape &grad_shape, T *expected) + { + assert(incoming_shape.DimensionsCount() == 4); + assert(filter_shape.DimensionsCount() == 4); + assert(grad_shape.DimensionsCount() == 4); + assert(params.stride_height == params.stride_width); + + const int batch = MatchingDim(incoming_shape, 0, grad_shape, 0); + const int input_depth = grad_shape.Dims(3); + const int output_depth = incoming_shape.Dims(3); + const int incoming_height = incoming_shape.Dims(1); + const int incoming_width = incoming_shape.Dims(2); + const int grad_height = grad_shape.Dims(1); + const int grad_width = grad_shape.Dims(2); + const int stride = params.stride_height; + const int depth_multiplier = params.depth_multiplier; + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int pad_height = params.padding_values.height; + const int pad_width = params.padding_values.width; + + nnfw::cker::depthwise_conv_op::DepthwiseConvBackpropInputReference( + batch, grad_height, grad_width, input_depth, incoming_height, incoming_width, output_depth, + stride, depth_multiplier, filter_height, filter_width, pad_height, pad_width, incoming_data, + filter_data, expected); + } + + void calculateFilterGradExpected(const nnfw::cker::DepthwiseConvParams ¶ms, + const nnfw::cker::Shape &incoming_shape, const T *incoming_data, + const nnfw::cker::Shape &input_shape, const T *input_data, + const nnfw::cker::Shape &filter_grad_shape, T *expected) + { + assert(incoming_shape.DimensionsCount() == 4); + assert(input_shape.DimensionsCount() == 4); + assert(filter_grad_shape.DimensionsCount() == 4); + assert(params.stride_height == params.stride_width); + + const int batch = MatchingDim(incoming_shape, 0, input_shape, 0); + const int input_depth = input_shape.Dims(3); + const int output_depth = incoming_shape.Dims(3); + const int incoming_height = incoming_shape.Dims(1); + const int incoming_width = incoming_shape.Dims(2); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int stride = params.stride_height; + const int depth_multiplier = params.depth_multiplier; + const int filter_height = filter_grad_shape.Dims(1); + const int filter_width = filter_grad_shape.Dims(2); + const int pad_height = params.padding_values.height; + const int pad_width = params.padding_values.width; + + nnfw::cker::depthwise_conv_op::DepthwiseConvBackpropFilterReference( + batch, input_height, input_width, input_depth, incoming_height, incoming_width, output_depth, + stride, depth_multiplier, filter_height, filter_width, pad_height, pad_width, incoming_data, + input_data, expected); + } + +private: + std::unique_ptr _dconv_kernel; + bool _use_padded_filter; + std::vector _padded_filter; + std::vector _filter_buffers; + std::vector _filter_dim_buffers; +}; + +} // namespace + +TEST(CKer_Operation, DepthwiseConvGrad) +{ + // No pad, No stride + { + nnfw::cker::DepthwiseConvParams params; + params.padding_type = nnfw::cker::PaddingType::kNone; + params.padding_values.width = 0; + params.padding_values.height = 0; + params.stride_width = 1; + params.stride_height = 1; + params.dilation_width_factor = 1; + params.dilation_height_factor = 1; + params.depth_multiplier = 1; + + nnfw::cker::Shape incoming_shape{1, 2, 2, 3}; // n, h, w, c + std::vector incoming = {-0.1, 0.2, -0.3, 0.4, 0.5, -0.6, + -0.7, 0.8, 0.9, -1.0, 1.1, -1.2}; + nnfw::cker::Shape filter_shape{1, 2, 2, 3}; // 1, h, w, c + std::vector filter = {-1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 12}; + nnfw::cker::Shape input_shape{1, 3, 3, 3}; // n, h, w, c + std::vector input = {-1, 2, -3, 4, 5, -6, -7, 8, -9, -10, 11, -12, 13, -14, + 15, -16, -17, 18, 19, -20, 21, -22, -23, 24, -25, -26, -27}; + + DepthwiseConvVerifier verifier; + verifier.prepare(incoming_shape, filter_shape); + verifier.verifyInputGradExpected(params, incoming_shape, incoming.data(), filter_shape, + filter.data(), input_shape); + verifier.verifyFilterGradExpected(params, incoming_shape, incoming.data(), input_shape, + input.data(), filter_shape); + } + + // 2 depth_multiplier, use_padded_filter false + { + nnfw::cker::DepthwiseConvParams params; + params.padding_type = nnfw::cker::PaddingType::kNone; + params.padding_values.width = 0; + params.padding_values.height = 0; + params.stride_width = 1; + params.stride_height = 1; + params.dilation_width_factor = 1; + params.dilation_height_factor = 1; + params.depth_multiplier = 2; + + nnfw::cker::Shape incoming_shape{1, 2, 2, 4}; // n, h, w, c + std::vector incoming = {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, + -0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8}; + nnfw::cker::Shape filter_shape{1, 2, 2, 4}; // 1, h, w, c * depth_multiplier + std::vector filter = {-1, 2, -3, 4, 5, -6, 7, -8, 9, -10, -11, 12, -13, 14, -15, 16}; + nnfw::cker::Shape input_shape{1, 3, 3, 2}; // n, h, w, c + std::vector input = {-1, 2, -3, 4, 5, -6, -7, 8, -9, + -10, 11, -12, 13, -14, 15, -16, -17, 18}; + + DepthwiseConvVerifier verifier; + verifier.prepare(incoming_shape, filter_shape); + verifier.verifyInputGradExpected(params, incoming_shape, incoming.data(), filter_shape, + filter.data(), input_shape); + verifier.verifyFilterGradExpected(params, incoming_shape, incoming.data(), input_shape, + input.data(), filter_shape); + } + + // pad valid, stride 2 + { + nnfw::cker::DepthwiseConvParams params; + params.padding_type = nnfw::cker::PaddingType::kValid; + params.stride_width = 2; + params.stride_height = 2; + params.dilation_width_factor = 1; + params.dilation_height_factor = 1; + params.depth_multiplier = 1; + + nnfw::cker::Shape incoming_shape{1, 3, 3, 2}; // n, h, w, c + std::vector incoming = {-0.1, 0.2, -0.3, 0.4, 0.5, -0.6, -0.7, 0.8, 0.9, + -1.0, 1.1, -1.2, 1.3, -1.4, -1.5, 1.6, 1.7, -1.8}; + nnfw::cker::Shape filter_shape{1, 3, 3, 2}; // 1, h, w, c + std::vector filter = {-1, 2, -3, 4, 5, -6, -7, 8, 9, + -10, -11, 12, -13, 14, -15, 16, 17, -18}; + nnfw::cker::Shape input_shape{1, 3, 3, 2}; // n, h, w, c + std::vector input = {-1, 2, -3, 4, 5, -6, -7, 8, -9, + -10, 11, -12, 13, -14, 15, -16, -17, 18}; + + DepthwiseConvVerifier verifier; + verifier.prepare(incoming_shape, filter_shape); + verifier.verifyInputGradExpected(params, incoming_shape, incoming.data(), filter_shape, + filter.data(), input_shape); + verifier.verifyFilterGradExpected(params, incoming_shape, incoming.data(), input_shape, + input.data(), filter_shape); + } + + // pad same, stride 2 + { + nnfw::cker::DepthwiseConvParams params; + params.padding_type = nnfw::cker::PaddingType::kSame; + params.stride_width = 2; + params.stride_height = 2; + params.dilation_width_factor = 1; + params.dilation_height_factor = 1; + params.depth_multiplier = 1; + + nnfw::cker::Shape incoming_shape{1, 1, 1, 2}; // n, h, w, c + std::vector incoming = {-0.1, 0.2}; + nnfw::cker::Shape filter_shape{1, 2, 2, 2}; // 1, h, w, c + std::vector filter = {-1, 2, -3, 4, 5, -6, -7, 8}; + nnfw::cker::Shape input_shape{1, 3, 3, 2}; // n, h, w, c + std::vector input = {-1, 2, -3, 4, 5, -6, -7, 8, -9, + -10, 11, -12, 13, -14, 15, -16, -17, 18}; + + DepthwiseConvVerifier verifier; + verifier.prepare(incoming_shape, filter_shape); + verifier.verifyInputGradExpected(params, incoming_shape, incoming.data(), filter_shape, + filter.data(), input_shape); + verifier.verifyFilterGradExpected(params, incoming_shape, incoming.data(), input_shape, + input.data(), filter_shape); + } + + // multi thread case + { + nnfw::cker::DepthwiseConvParams params; + params.padding_type = nnfw::cker::PaddingType::kNone; + params.padding_values.width = 0; + params.padding_values.height = 0; + params.stride_width = 1; + params.stride_height = 1; + params.dilation_width_factor = 1; + params.dilation_height_factor = 1; + params.depth_multiplier = 1; + + nnfw::cker::Shape incoming_shape{10, 112, 112, 32}; // n, h, w, c + std::vector incoming; + for (int i = 0; i < incoming_shape.FlatSize(); ++i) + { + incoming.push_back(static_cast(i) / static_cast(RAND_MAX)); + } + nnfw::cker::Shape filter_shape{1, 3, 3, 32}; // 1, h, w, c + std::vector filter; + for (int i = 0; i < filter_shape.FlatSize(); ++i) + { + filter.push_back(static_cast(i) / static_cast(RAND_MAX)); + } + nnfw::cker::Shape input_shape{10, 112, 112, 32}; // n, h, w, c + std::vector input; + const int input_size = input_shape.FlatSize(); + for (int i = 0; i < input_size; ++i) + { + input.push_back(static_cast(input_size - i) * 0.001f / static_cast(RAND_MAX)); + } + + DepthwiseConvVerifier verifier; + verifier.prepare(incoming_shape, filter_shape); + verifier.verifyInputGradExpected(params, incoming_shape, incoming.data(), filter_shape, + filter.data(), input_shape); + verifier.verifyFilterGradExpected(params, incoming_shape, incoming.data(), input_shape, + input.data(), filter_shape); + } +} + +TEST(CKer_Operation, neg_DepthwiseConvGrad) +{ + // Not matched stride, InputGrad test case + { + nnfw::cker::DepthwiseConvParams params; + params.padding_type = nnfw::cker::PaddingType::kNone; + params.padding_values.width = 0; + params.padding_values.height = 0; + params.stride_width = 1; + params.stride_height = 2; + params.dilation_width_factor = 1; + params.dilation_height_factor = 1; + params.depth_multiplier = 1; + + nnfw::cker::Shape incoming_shape{1, 2, 2, 4}; // n, h, w, c + std::vector incoming = {-0.1, 0.2, -0.3, 0.4, 0.5, -0.6, -0.7, 0.8, + 0.9, -1.0, 1.1, -1.2, -1.3, 1.4, 1.5, -1.6}; + nnfw::cker::Shape filter_shape{1, 2, 2, 4}; // 1, h, w, c + std::vector filter = {-1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 12}; + nnfw::cker::Shape input_shape{1, 3, 3, 4}; // n, h, w, c + std::vector input = {-1, 2, -3, 4, 5, -6, -7, 8, -9, -10, 11, -12, + 13, -14, 15, -16, -17, 18, 19, -20, 21, -22, -23, 24, + -25, -26, -27, 28, -29, -30, 31, 32, -33, -34, -35, 36}; + + DepthwiseConvVerifier verifier; + verifier.prepare(incoming_shape, filter_shape); + verifier.throwInputGradExpected(params, incoming_shape, incoming.data(), filter_shape, + filter.data(), input_shape); + } + + // Not matched stride, FilterGrad test case + { + nnfw::cker::DepthwiseConvParams params; + params.padding_type = nnfw::cker::PaddingType::kNone; + params.padding_values.width = 0; + params.padding_values.height = 0; + params.stride_width = 1; + params.stride_height = 2; + params.dilation_width_factor = 1; + params.dilation_height_factor = 1; + params.depth_multiplier = 1; + + nnfw::cker::Shape incoming_shape{1, 2, 2, 4}; // n, h, w, c + std::vector incoming = {-0.1, 0.2, -0.3, 0.4, 0.5, -0.6, -0.7, 0.8, + 0.9, -1.0, 1.1, -1.2, -1.3, 1.4, 1.5, -1.6}; + nnfw::cker::Shape filter_shape{1, 2, 2, 4}; // 1, h, w, c + std::vector filter = {-1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 12}; + nnfw::cker::Shape input_shape{1, 3, 3, 4}; // n, h, w, c + std::vector input = {-1, 2, -3, 4, 5, -6, -7, 8, -9, -10, 11, -12, + 13, -14, 15, -16, -17, 18, 19, -20, 21, -22, -23, 24, + -25, -26, -27, 28, -29, -30, 31, 32, -33, -34, -35, 36}; + + DepthwiseConvVerifier verifier; + verifier.prepare(incoming_shape, filter_shape); + verifier.throwFilterGradExpected(params, incoming_shape, incoming.data(), filter_shape, + filter.data(), input_shape); + } +} From 3ff29260acc5e67148d87df8e27bd41d2d92ac5d Mon Sep 17 00:00:00 2001 From: Jiyoung Giuliana Yun Date: Mon, 15 Jan 2024 11:28:26 +0900 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Jang Jiseob --- compute/cker/src/train/DepthwiseConv.test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/compute/cker/src/train/DepthwiseConv.test.cc b/compute/cker/src/train/DepthwiseConv.test.cc index a85ff43eda8..414398c5c04 100644 --- a/compute/cker/src/train/DepthwiseConv.test.cc +++ b/compute/cker/src/train/DepthwiseConv.test.cc @@ -32,16 +32,16 @@ template class DepthwiseConvVerifier void prepare(const nnfw::cker::Shape &incoming_shape, const nnfw::cker::Shape &filter_shape) { - const int kPacketSize = _dconv_kernel->kPacketSize(); + const int k_packet_size = _dconv_kernel->kPacketSize(); const int batch = incoming_shape.Dims(0); const int out_depth = incoming_shape.Dims(3); const int filter_rows = filter_shape.Dims(1); const int filter_cols = filter_shape.Dims(2); const int filter_spatial_size = filter_rows * filter_cols; const int padded_filter_inner_dim_size = - ((out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize; + ((out_depth + k_packet_size - 1) / k_packet_size) * k_packet_size; - _use_padded_filter = (out_depth % kPacketSize) == 0 ? false : true; + _use_padded_filter = (out_depth % k_packet_size) == 0 ? false : true; { nnfw::cker::Shape padded_filter_shape( {batch, filter_spatial_size, padded_filter_inner_dim_size}); From f6af2f78ff9af6460ea217b2dd46094345e4c8ae Mon Sep 17 00:00:00 2001 From: Jiyoung Yun Date: Mon, 15 Jan 2024 17:05:25 +0900 Subject: [PATCH 3/3] Remove static keyword of functions in hearder file Add comments for thread count ONE-DCO-1.0-Signed-off-by: Jiyoung Yun --- .../include/cker/eigen/depthwise_conv_op.h | 51 +++++++++---------- .../cker/train/operation/DepthwiseConv.h | 4 +- compute/cker/src/train/DepthwiseConv.test.cc | 2 +- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/compute/cker/include/cker/eigen/depthwise_conv_op.h b/compute/cker/include/cker/eigen/depthwise_conv_op.h index bf029972a34..d99ace07170 100644 --- a/compute/cker/include/cker/eigen/depthwise_conv_op.h +++ b/compute/cker/include/cker/eigen/depthwise_conv_op.h @@ -353,11 +353,11 @@ using CPUDevice = Eigen::ThreadPoolDevice; // [a00, a01, a10, a11] [a20, a21, 0, 0] in_row = 1, in_col = 1 // template -static void CopyOutputBackpropRegion(int, int, int, int, int filter_rows_, int filter_cols_, int, - int stride_, int pad_rows_, int pad_cols_, int out_rows_, - int out_cols_, int out_depth, - const int64_t padded_filter_inner_dim_size, const int64_t in_r, - const int64_t in_c, const T *out_backprop, T *buffer) +void CopyOutputBackpropRegion(int, int, int, int, int filter_rows_, int filter_cols_, int, + int stride_, int pad_rows_, int pad_cols_, int out_rows_, + int out_cols_, int out_depth, + const int64_t padded_filter_inner_dim_size, const int64_t in_r, + const int64_t in_c, const T *out_backprop, T *buffer) { typedef typename Eigen::internal::packet_traits::type Packet; static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); @@ -452,11 +452,11 @@ static void CopyOutputBackpropRegion(int, int, int, int, int filter_rows_, int f // template -static void ComputeBackpropInput(int, int, int in_cols, int in_depth_, int filter_rows, - int filter_cols, int depth_multiplier_, int, int, int, int, int, - int out_depth_, const int64_t padded_filter_inner_dim_size, - const int64_t in_r, const int64_t in_c, const T *filter, - const T *buffer, T *out_buffer, T *output) +void ComputeBackpropInput(int, int, int in_cols, int in_depth_, int filter_rows, int filter_cols, + int depth_multiplier_, int, int, int, int, int, int out_depth_, + const int64_t padded_filter_inner_dim_size, const int64_t in_r, + const int64_t in_c, const T *filter, const T *buffer, T *out_buffer, + T *output) { typedef typename Eigen::internal::packet_traits::type Packet; static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); @@ -636,11 +636,11 @@ template struct LaunchDepthwiseConvBackpropInputOp }; template -static void -DepthwiseConvBackpropInputReference(int batch, int in_rows, int in_cols, int in_depth, int out_rows, - int out_cols, int out_depth, int stride, int depth_multiplier, - int filter_rows, int filter_cols, int pad_rows, int pad_cols, - const T *out_backprop, const T *filter, T *in_backprop) +void DepthwiseConvBackpropInputReference(int batch, int in_rows, int in_cols, int in_depth, + int out_rows, int out_cols, int out_depth, int stride, + int depth_multiplier, int filter_rows, int filter_cols, + int pad_rows, int pad_cols, const T *out_backprop, + const T *filter, T *in_backprop) { // Naive for loop as a reference point without concerns about performance. for (int b = 0; b < batch; ++b) @@ -714,11 +714,11 @@ DepthwiseConvBackpropInputReference(int batch, int in_rows, int in_cols, int in_ // [u0, v0, w0, x0] += ([f00, f01, f10, f11] x [q00, q01, q10, q11]) // template -static void ComputeBackpropFilter(int, int, int, int, int filter_rows, int filter_cols, int, int, - int, int, int out_rows, int out_cols, int out_depth_, - const int64_t padded_out_depth_size, const int64_t out_r, - const int64_t out_c, const T *out_backprop, const T *input_buffer, - T *output_buffer) +void ComputeBackpropFilter(int, int, int, int, int filter_rows, int filter_cols, int, int, int, int, + int out_rows, int out_cols, int out_depth_, + const int64_t padded_out_depth_size, const int64_t out_r, + const int64_t out_c, const T *out_backprop, const T *input_buffer, + T *output_buffer) { typedef typename Eigen::internal::packet_traits::type Packet; static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); @@ -901,12 +901,11 @@ template struct LaunchDepthwiseConvBackpropFilterOp }; template -static void DepthwiseConvBackpropFilterReference(int batch, int in_rows, int in_cols, int in_depth, - int out_rows, int out_cols, int out_depth, - int stride, int depth_multiplier, int filter_rows, - int filter_cols, int pad_rows, int pad_cols, - const T *out_backprop, const T *input, - T *filter_backprop) +void DepthwiseConvBackpropFilterReference(int batch, int in_rows, int in_cols, int in_depth, + int out_rows, int out_cols, int out_depth, int stride, + int depth_multiplier, int filter_rows, int filter_cols, + int pad_rows, int pad_cols, const T *out_backprop, + const T *input, T *filter_backprop) { int num_filter_backprop = filter_rows * filter_cols * in_depth * depth_multiplier; memset(filter_backprop, 0, num_filter_backprop * sizeof(T)); diff --git a/compute/cker/include/cker/train/operation/DepthwiseConv.h b/compute/cker/include/cker/train/operation/DepthwiseConv.h index 6546253f70a..05a937166ed 100644 --- a/compute/cker/include/cker/train/operation/DepthwiseConv.h +++ b/compute/cker/include/cker/train/operation/DepthwiseConv.h @@ -41,8 +41,10 @@ class DepthwiseConv int getThreadCount() const { + // NOTE The Eigen library uses both main thread as well as a thread pool. + // Therefore, it needs to add an additional memory buffer for main thread. const Eigen::ThreadPoolDevice &d = *eigen_support::GetThreadPoolDevice(); - return d.numThreads(); + return d.numThreads() + 1; } template diff --git a/compute/cker/src/train/DepthwiseConv.test.cc b/compute/cker/src/train/DepthwiseConv.test.cc index 414398c5c04..34513734f45 100644 --- a/compute/cker/src/train/DepthwiseConv.test.cc +++ b/compute/cker/src/train/DepthwiseConv.test.cc @@ -49,7 +49,7 @@ template class DepthwiseConvVerifier } { - const int thread_count = _dconv_kernel->getThreadCount() + 1; + const int thread_count = _dconv_kernel->getThreadCount(); nnfw::cker::Shape filter_buffer_shape( {thread_count, filter_spatial_size, padded_filter_inner_dim_size});