From a13a394cd3a8f9df95a670fa48ffdd1a1d8c09b6 Mon Sep 17 00:00:00 2001 From: "jing.tang" Date: Thu, 21 Mar 2024 03:01:48 +0000 Subject: [PATCH] Add some trigonometric ops Add cos/tan/arctan/arctanh/arccosh Type: New Feature Signed-off-by: Tang Jing --- docs/Operators.md | 225 +++++++++++++++++- include/tim/experimental/trace/tvx/ops.h | 5 + include/tim/vx/ops/simple_operations.h | 23 ++ src/tim/transform/layout_inference.cc | 7 +- .../ops/simple_ops_layout_inference.h | 7 +- src/tim/vx/ops/simple_operations.cc | 3 + 6 files changed, 258 insertions(+), 12 deletions(-) diff --git a/docs/Operators.md b/docs/Operators.md index b38233838..949d93127 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -7,11 +7,14 @@ - [ArgMin/ArgMax](#argminargmax) - [Batch2Space](#batch2space) - [BatchNorm](#batchnorm) + - [bidirectional sequence rnn](#bidirectional-sequence-rnn) + - [Bidirectional sequence rnn for onnx](#bidirectional-sequence-rnn-for-onnx) - [Broadcast](#broadcast) - [Clip](#clip) - [Concat](#concat) - [Conv2d](#conv2d) - [Conv3d](#conv3d) + - [Cumsum](#cumsum) - [DeConv2d](#deconv2d) - [DeConv1d](#deconv1d) - [DepthToSpace](#depthtospace) @@ -24,6 +27,7 @@ - [Minimum](#minimum) - [Maximum](#maximum) - [FloorDiv](#floordiv) + - [EmbeddingLookup](#embeddinglookup) - [Erf](#erf) - [FullyConnected](#fullyconnected) - [Gather](#gather) @@ -31,20 +35,29 @@ - [GatherNd](#gathernd) - [GroupedConv1d](#groupedconv1d) - [GroupedConv2d](#groupedconv2d) + - [GRUCell](#grucell) + - [HashtableLookup](#hashtablelookup) - [L2Normalization](#l2normalization) - [LocalResponseNormalization](#localresponsenormalization) - [And](#and) - [Or](#or) - [LogSoftmax](#logsoftmax) - [Matmul](#matmul) + - [Max_pool3d](#max_pool3d) - [MaxpooGrad](#maxpoograd) - [MaxpoolWithArgmax](#maxpoolwithargmax) - [MaxpoolWithArgmax2](#maxpoolwithargmax2) - [MaxUnpool2d](#maxunpool2d) + - [Mod](#mod) - [Moments](#moments) - [NBG](#nbg) - [OneHot](#onehot) - [Pad](#pad) + - [PadV2](#padv2) + - [Pool1d](#pool1d) + - [Classic Pool1d](#classic-pool1d) + - [Global Pool1d](#global-pool1d) + - [Adaptive Pool1d](#adaptive-pool1d) - [Pool2d](#pool2d) - [Classic Pool2d](#classic-pool2d) - [Global Pool2d](#global-pool2d) @@ -70,11 +83,17 @@ - [RoiAlign](#roialign) - [RoiPool](#roipool) - [ScatterND](#scatternd) + - [ScatterND_ONNX_V16](#scatternd_onnx_v16) - [Select](#select) - [DataConvert](#dataconvert) - [Neg](#neg) - [Abs](#abs) - [Sin](#sin) + - [Cos](#cos) + - [Tan](#tan) + - [ATan](#atan) + - [ACosh](#acosh) + - [ATanh](#atanh) - [Exp](#exp) - [Log](#log) - [Sqrt](#sqrt) @@ -84,6 +103,7 @@ - [Floor](#floor) - [Ceil](#ceil) - [Cast](#cast) + - [Rcp](#rcp) - [Slice](#slice) - [Softmax](#softmax) - [Space2Batch](#space2batch) @@ -96,7 +116,10 @@ - [Tile](#tile) - [Topk](#topk) - [Transpose](#transpose) + - [UnidirectionalSequenceGRU](#unidirectionalsequencegru) - [Unidirectional sequence lstm](#unidirectional-sequence-lstm) + - [Unidirectional sequence rnn](#unidirectional-sequence-rnn) + - [Unidirectional sequence rnn for onnx](#unidirectional-sequence-rnn-for-onnx) - [Unstack](#unstack) @@ -177,6 +200,14 @@ $$\hat x_i\leftarrow \frac{x_i-\mu_\mathcal{B}}{\sqrt{\sigma_\mathcal{B}^2+\epsi $$y_i=\gamma\hat x_i+\beta\equiv BN_{\gamma,\beta}(x_i)$$ + +## bidirectional sequence rnn +how to bind input/output: take bidirectional_sequence_rnn_test.cc + + +## Bidirectional sequence rnn for onnx +how to bind input/output: take unidirectional_sequence_rnn_ext_test.cc + ## Broadcast @@ -187,7 +218,7 @@ Input: Attribute: - shape: the shape which broadcast to. -- dimensions(optional): Which dimension in the target shape each dimension +- dimensions(optional): Which dimension in the target shape each dimension of the operand shape corresponds to. For BroadcastInDim. @@ -210,7 +241,8 @@ Depthwise Conv2D / Group Conv2D / Dilation Conv2D. Input: - input [WHCN or CWHN]. -- kernel [ WHIcOc ] (Ic: Input Channels. Oc: Output Channels). +- kernel [ WHIcOc ] (Ic: Input Channels. Oc: Output Channels) normally, +[WHIc(Oc)1] for Depthwise Conv. - bias [ O ]. Optional. Attribute: @@ -246,6 +278,19 @@ but the value is different. multiplier = weights / group. - input_layout : WHDCN or WHCDN. - kernel_layout : WHDIcOc + +## Cumsum + +Compute the cumulative sum of the tensor along the giveb axis. By default, it +will do the sum inclusively meaning the first element is copied as is. Through +an exclusive attribute, this behavior can change to exclude the first element. +It can also perform summation in the opposite direction of the axis by setting +reverse atrribution to 1. +All the attributes can be combined. +- axis : Specify the cumsum eperforming along which axis.Default = 0. +- exclusive : If exclusive = 1, perform exclusive cumsum. +- reverse : If reverse = 1, the cumsum is performed in the opposite direction. + ## DeConv2d @@ -276,11 +321,12 @@ but is actually the transpose (gradient) of Conv2D rather than an actual deconvo - weights : the channel number for weight tensor. - ksize : the length for weight tensor. -- padding : AUTO, VALID or SAME. +- padtype : AUTO, VALID or SAME.** - pad : pad value for each spatial axis. - stride : stride along each spatial axis. -- output_padding : specifying the amount of padding along the height and width of -the output tensor. +- output_padding : additional padding lines added to the output tensor, default is zero + +Caution**: PadType is not really supported yet, will be supported in future. ## DepthToSpace @@ -349,6 +395,11 @@ Maximum(x, y) : max(x, y). This operation supports broadcasting. FloorDiv(x, y): floor( x / y ). This operation supports broadcasting. + +## EmbeddingLookup + +Looks up sub-tensors in the input tensor with specific indices(idx) + ## Erf @@ -360,7 +411,7 @@ Computes the Gauss error function of x element-wise. ## FullyConnected Denotes a fully (densely) connected layer, which connects all elements in the -input tensor with each element in the output tensor. +input tensor with each element in the output tensor. - axis: Describes the axis of the inputs when coerced to 2D. - weights: the output channel number for weight tensor. @@ -369,6 +420,7 @@ input tensor with each element in the output tensor. ## Gather Gather slices from input, **axis** according to **indices**. +batch_dims means in which dimension to repeat the value according to indices. ## GatherElements @@ -424,6 +476,20 @@ Attribute: - group_number: Split conv to n group. - layout : WHCN or CWHN. + +## GRUCell + +- num_units : dimensionality of the output space. +- activation : Activation function to use. +- recurrent_activation : Activation function to use for the recurrent step. +- reset_after : whether to apply reset gate after or before matrix multiplication. +False = "before", True = "after". + + +## HashtableLookup + +Looks up sub-tensors in the input tensor using a key-value map. + ## L2Normalization @@ -444,6 +510,11 @@ Applies Local Response Normalization along the depth dimension: sqr_sum[a, b, c, d] = sum( pow(input[a, b, c, d - depth_radius : d + depth_radius + 1], 2)) output = input / pow((bias + alpha * sqr_sum), beta) +output = input / pow((bias + alpha * sqr_sum), beta) +size : width of the 1-D normalization window. +bias : An offset (usually positive to avoid dividing by 0). +alpha : A scale factor. +beta : An exponent. ``` @@ -475,11 +546,29 @@ Multiplies matrix a by matrix b, producing a * b. - adjoint_a: If True, a is conjugated and transposed before multiplication. - adjoint_b: If True, b is conjugated and transposed before multiplication. + +## Max_pool3d + +Applies a 3D max pooling over an input Tensor which can be regarded as a composition of 3D planes. + +Input: +- input [WHDCN] +- kernel [ WHD ] + +Attribute: +- round_type : CEILING or FLOOR +- ksize : the height and width for kernel tensor. +- stride : stride along each spatial axis. +- pad : pad value for each spatial axis. (left, right, top, bottom, front, rear). +- pad_type : AUTO, VALID or SAME. + + ## MaxpooGrad Acquire the gradient of 2-D Max pooling operation's input tensor. \ -Like the tensorflow_XLA op SelectAndScatter, see https://tensorflow.google.cn/xla/operation_semantics?hl=en#selectandscatter. +Like the tensorflow_XLA op SelectAndScatter, see \ +https://tensorflow.google.cn/xla/operation_semantics?hl=en#selectandscatter. - padding : AUTO, VALID or SAME. - ksize : filter size. @@ -491,6 +580,10 @@ Like the tensorflow_XLA op SelectAndScatter, see https://tensorflow.google.cn/xl - 0 : input tensor of 2-D Max pooling. - 1 : gradient of 2-D Max pooling output tensor. +* Outputs: + +- 0 : updated tensor of 2-D Max pooling input. + ## MaxpoolWithArgmax @@ -519,6 +612,19 @@ Performs an 2-D Max pooling operation upsample - stride : stride along each spatial axis. - ksize : filter size. + +## Mod + +Mod performs element-wise binary modulus. +The sign of the remainder is the same as that of the Divisor as default. + +Mod operator can also behave like C fmod() or numpy.fmod when input type is floating +point. The sign of the remainder however, will be the same as the Dividend. Attribute +fmod is set to decide the mod behivior. + +- fmod : If the input type is floating point, then fmod must be set to 1.Default = 0 +means integer mod. + ## Moments @@ -549,11 +655,54 @@ Create a one-hot tensor. Pads a tensor. -- const_val : the value to pad. +- const_val : the int32 value to pad. +- pad_mode : the mode of pad. +- front_size : Add pad values to the left and top. +- back_size : Add pad values to the right and bottom. + + +## PadV2 + +Pads a tensor. + +- const_val : the float value to pad. - pad_mode : the mode of pad. - front_size : Add pad values to the left and top. - back_size : Add pad values to the right and bottom. + +## Pool1d + + +### Classic Pool1d + +Performs an 1-D pooling operation. + +- type : MAX, AVG, L2 or AVG_ANDROID. +- padding : AUTO, VALID or SAME. +- pad : Specify the number of pad values for left, right. +- ksize : filter size. +- stride : stride along each spatial axis. +- round_type : CEILING or FLOOR. + + +### Global Pool1d + +- type : MAX, AVG, L2 or AVG_ANDROID. +- input_size : input size(only [W]) +- round_type : CEILING or FLOOR. + + +### Adaptive Pool1d + +Same as torch.nn.AdaptiveXXXPool1d. + +- type : MAX, AVG, L2 or AVG_ANDROID. +- input_size : input size(only [W]) +- output_size : output size(only [W]) +- round_type : CEILING or FLOOR. + + ## Pool2d @@ -758,7 +907,7 @@ Select and scale the feature map of each region of interest to a unified output size by max-pooling. pool_type : only support max-pooling (MAX) -scale : The ratio of image to feature map (Range: 0 < scale <= 1) +scale : The ratio of image to feature map (Range: 0 < scale <= 1) size : The size of roi pooling (height/width) @@ -769,6 +918,13 @@ Scatter updates into a new tensor according to indices. - shape : The shape of the resulting tensor. + +## ScatterND_ONNX_V16 + +Scatter updates into a new tensor according to indices. + +- reduction: Type of reduction to apply: none (default), add, mul, max, min. + ## Select @@ -795,6 +951,31 @@ Abs(x) : x if x >= 0; -x if x < 0. Sin(x) : sin(x) + +## Cos + +Cos(x) : cos(x) + + +## Tan + +Tan(x) : tan(x) + + +## ATan + +ATan(x) : arctan(x) + + +## ACosh + +ACosh(x) : arccosh(x) + + +## ATanh + +Tan(x) : arctanh(x) + ## Exp @@ -841,6 +1022,10 @@ returns the largest integer more than or equal to a given number. Change the format from input tensor to output tensor. This operation ignores the scale and zeroPoint of quanized tensors. + +## Rcp +Computes the reciprocal of input element-wise. + ## Slice @@ -948,6 +1133,7 @@ Length must be the same as the number of dimensions in input. Finds values and indices of the k largest entries for the last dimension. - k : Number of top elements to look for along the last dimension. +-axis : Dimension on which to do th sort. Default is 0. ## Transpose @@ -960,10 +1146,31 @@ If perm is not given, it is set to (n-1...0), where n is the rank of the input tensor. Hence by default, this operation performs a regular matrix transpose on 2-D input Tensors. + +## UnidirectionalSequenceGRU + +- num_units : dimensionality of the output space. +- activation : Activation function to use. +- recurrent_activation : Activation function to use for the recurrent step. +- reset_after : whether to apply reset gate after or before matrix multiplication. +False = "before", True = "after". +- return_sequences : Whether to return the last output in the output sequence, +or the full sequence. Default: False. +- time_major : If True, the inputs and outputs will be in shape [feature, batch, timesteps], +in the False case, it will be [feature, timesteps, batch]. + ## Unidirectional sequence lstm how to bind input/output: take unidirectional_sequence_lstm_test.cc + +## Unidirectional sequence rnn +how to bind input/output: take unidirectional_sequence_rnn_test.cc + + +## Unidirectional sequence rnn for onnx +how to bind input/output: take unidirectional_sequence_rnn_ext_test.cc + ## Unstack diff --git a/include/tim/experimental/trace/tvx/ops.h b/include/tim/experimental/trace/tvx/ops.h index 9972cef11..35a84eff2 100755 --- a/include/tim/experimental/trace/tvx/ops.h +++ b/include/tim/experimental/trace/tvx/ops.h @@ -124,6 +124,11 @@ (Neg) \ (Abs) \ (Sin) \ + (Cos) \ + (Tan) \ + (ACosh) \ + (ATan) \ + (ATanh) \ (Exp) \ (Log) \ (Sqrt) \ diff --git a/include/tim/vx/ops/simple_operations.h b/include/tim/vx/ops/simple_operations.h index 7f8cc98a2..ea26dcfbf 100644 --- a/include/tim/vx/ops/simple_operations.h +++ b/include/tim/vx/ops/simple_operations.h @@ -54,6 +54,26 @@ namespace ops { * * Sin(x) : sin(x) * + * ## Cos + * + * Cos(x) : cos(x) + * + * ## Tan + * + * Tan(x) : tan(x) + * + * ## ATan + * + * ATan(x) : arctan(x) + * + * ## ACosh + * + * ACosh(x) : arccosh(x) + * + * ## ATanh + * + * Tan(x) : arctanh(x) + * * ## Exp * * Exp(x) : e^x @@ -101,6 +121,9 @@ DECLARE_SIMPLE_OP(Abs) DECLARE_SIMPLE_OP(Sin) DECLARE_SIMPLE_OP(Cos) DECLARE_SIMPLE_OP(Tan) +DECLARE_SIMPLE_OP(ATan) +DECLARE_SIMPLE_OP(ATanh) +DECLARE_SIMPLE_OP(ACosh) DECLARE_SIMPLE_OP(Exp) DECLARE_SIMPLE_OP(Log) DECLARE_SIMPLE_OP(Sqrt) diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index 5bb11d5ba..a464c4d29 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -249,7 +249,6 @@ std::vector> HandleLayoutInfer( REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_HARD_SIGMOID, HardSigmoid); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SOFTRELU, SoftRelu); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SWISH, HardSwish); - REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TANH, Tanh); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_LEAKY_RELU, LeakyRelu); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_CONCAT, Concat); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ADD, Add); @@ -263,6 +262,12 @@ std::vector> HandleLayoutInfer( REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_NEG, Neg); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ABS, Abs); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SIN, Sin); + REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_COS, Cos); + REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TAN, Tan); + REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TANH, Tanh); + REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ATAN, ATan); + REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ATANH, ATanh); + REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ACOSH, ACosh); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_EXP, Exp); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_LOG, Log); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SQRT, Sqrt); diff --git a/src/tim/transform/ops/simple_ops_layout_inference.h b/src/tim/transform/ops/simple_ops_layout_inference.h index 28ae75983..7100fc6d0 100644 --- a/src/tim/transform/ops/simple_ops_layout_inference.h +++ b/src/tim/transform/ops/simple_ops_layout_inference.h @@ -60,8 +60,11 @@ using DataConvertLayoutInfer = SimpleOpsLayoutInfer; using NegLayoutInfer = SimpleOpsLayoutInfer; using AbsLayoutInfer = SimpleOpsLayoutInfer; using SinLayoutInfer = SimpleOpsLayoutInfer; -// TODO(yzw): enable it when TIM-VX support 'Cos' -// using CosLayoutInfer = SimpleOpsLayoutInfer; +using CosLayoutInfer = SimpleOpsLayoutInfer; +using TanLayoutInfer = SimpleOpsLayoutInfer; +using ATanLayoutInfer = SimpleOpsLayoutInfer; +using ATanhLayoutInfer = SimpleOpsLayoutInfer; +using ACoshLayoutInfer = SimpleOpsLayoutInfer; using ExpLayoutInfer = SimpleOpsLayoutInfer; using LogLayoutInfer = SimpleOpsLayoutInfer; using SqrtLayoutInfer = SimpleOpsLayoutInfer; diff --git a/src/tim/vx/ops/simple_operations.cc b/src/tim/vx/ops/simple_operations.cc index e2250151f..d6b881aae 100644 --- a/src/tim/vx/ops/simple_operations.cc +++ b/src/tim/vx/ops/simple_operations.cc @@ -42,6 +42,9 @@ DEFINE_SIMPLE_OP(Abs, VSI_NN_OP_ABS) DEFINE_SIMPLE_OP(Sin, VSI_NN_OP_SIN) DEFINE_SIMPLE_OP(Cos, VSI_NN_OP_COS) DEFINE_SIMPLE_OP(Tan, VSI_NN_OP_TAN) +DEFINE_SIMPLE_OP(ATan, VSI_NN_OP_ATAN) +DEFINE_SIMPLE_OP(ATanh, VSI_NN_OP_ATANH) +DEFINE_SIMPLE_OP(ACosh, VSI_NN_OP_ACOSH) DEFINE_SIMPLE_OP(Exp, VSI_NN_OP_EXP) DEFINE_SIMPLE_OP(Log, VSI_NN_OP_LOG) DEFINE_SIMPLE_OP(Sqrt, VSI_NN_OP_SQRT)