-
Notifications
You must be signed in to change notification settings - Fork 128
/
Copy pathdl_module_concat.hpp
163 lines (145 loc) · 5.14 KB
/
dl_module_concat.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#pragma once
#include "dl_module_base.hpp"
namespace dl {
namespace module {
/**
* NOTE:
*
* @tparam feature_t supports int16_t and int8_t,
* - int16_t: stands for operation in int16_t quantize
* - int8_t: stands for operation in int8_t quantize
*/
class Concat : public Module {
private:
int axis; /*<! axis to concat >*/
int n_dims; /*<! num dimensions >*/
int n_inputs; /*<! num inputs >*/
int loop_times;
std::vector<int> copy_nums;
public:
/**
* @brief Construct a new Concat object.
*
* @param name name of module
*/
Concat(const char *name = NULL, int axis = 0, quant_type_t quant_type = QUANT_TYPE_NONE) :
Module(name, MODULE_NON_INPLACE, quant_type), axis(axis)
{
}
/**
* @brief Destroy the Concat object.
*/
~Concat() {}
std::vector<std::vector<int>> get_output_shape(std::vector<std::vector<int>> &input_shapes)
{
this->n_inputs = input_shapes.size();
assert(this->n_inputs > 1);
this->n_dims = input_shapes[0].size();
if (this->axis < 0)
this->axis += this->n_dims;
assert(this->axis >= 0 && this->axis < this->n_dims);
int output_axis_dim = 0;
this->loop_times = 1;
this->copy_nums.assign(this->n_inputs, 1);
for (size_t i = 0; i < this->n_inputs; i++) {
assert(input_shapes[i].size() == this->n_dims);
for (size_t j = 0; j < this->n_dims; j++) {
if (i == 0 && j < this->axis) {
this->loop_times *= input_shapes[0][j];
}
if (i > 0 && j != this->axis) {
assert(input_shapes[i][j] == input_shapes[i - 1][j]);
}
if (j >= this->axis) {
this->copy_nums[i] *= input_shapes[i][j];
}
}
output_axis_dim += input_shapes[i][this->axis];
}
std::vector<int> output_shape(input_shapes[0]);
output_shape[this->axis] = output_axis_dim;
std::vector<std::vector<int>> output_shapes(1, output_shape);
return output_shapes;
}
void forward(std::vector<dl::TensorBase *> &tensors, runtime_mode_t mode)
{
DL_LOG_LAYER_LATENCY_INIT();
DL_LOG_LAYER_LATENCY_START();
if (quant_type == QUANT_TYPE_SYMM_8BIT) {
forward_template<int8_t>(tensors, mode);
} else if (quant_type == QUANT_TYPE_SYMM_16BIT) {
forward_template<int16_t>(tensors, mode);
}
DL_LOG_LAYER_LATENCY_END(this->name, "Concat");
}
void forward_args(void *args) {}
template <typename T>
void forward_template(std::vector<TensorBase *> &tensors, runtime_mode_t mode)
{
TensorBase *output = tensors[m_outputs_index[0]];
T *output_ptr = (T *)output->get_element_ptr();
std::vector<T *> inputs_ptr(this->n_inputs);
for (size_t i = 0; i < this->n_inputs; i++) {
TensorBase *input = tensors[m_inputs_index[i]];
inputs_ptr[i] = (T *)input->get_element_ptr();
}
for (size_t i = 0; i < this->loop_times; i++) {
for (size_t j = 0; j < this->n_inputs; j++) {
tool::copy_memory(output_ptr, inputs_ptr[j], sizeof(T) * this->copy_nums[j]);
output_ptr += copy_nums[j];
inputs_ptr[j] += copy_nums[j];
}
}
}
/**
* @brief deserialize Concat module instance by node serialization information
*/
static Module *deserialize(fbs::FbsModel *fbs_model, std::string node_name)
{
Module *op = nullptr;
int axis;
std::vector<int> output_shape;
quant_type_t quant_type;
fbs_model->get_operation_attribute(node_name, "axis", axis);
fbs_model->get_operation_output_shape(node_name, 0, output_shape);
fbs_model->get_operation_attribute(node_name, "quant_type", quant_type);
// if (output_shape.size() == 4)
// {
// assert (axis > 0 && axis < 4);
// // n c h w => h w c
// switch (axis){
// case 1: axis = 2;
// break;
// case 2: axis = 0;
// break;
// case 3: axis = 1;
// break;
// }
// }
// else if (output_shape.size() == 3)
// {
// assert (axis > 0 && axis < 3);
// // n c w => w c
// switch (axis){
// case 1: axis = 1;
// break;
// case 2: axis = 0;
// break;
// }
// }
// else if (output_shape.size() == 2)
// {
// // n c => c
// assert (axis == 1);
// axis = 0;
// }
// Create module
if (quant_type == QUANT_TYPE_SYMM_8BIT || quant_type == QUANT_TYPE_SYMM_16BIT) {
op = new Concat(node_name.c_str(), axis, quant_type);
}
return op;
}
void print() { ESP_LOGI("Concat", "quant_type: %s.", quant_type_to_string(quant_type)); }
};
} // namespace module
} // namespace dl