diff --git a/compiler/luci/service/src/Nodes/CircleStridedSlice.cpp b/compiler/luci/service/src/Nodes/CircleStridedSlice.cpp index 3ed9a03dde4..cc594f63ddc 100644 --- a/compiler/luci/service/src/Nodes/CircleStridedSlice.cpp +++ b/compiler/luci/service/src/Nodes/CircleStridedSlice.cpp @@ -93,75 +93,19 @@ struct StridedSliceContext params.new_axis_mask = node->new_axis_mask(); params.shrink_axis_mask = node->shrink_axis_mask(); - auto begin_node = loco::must_cast(node->begin()); - auto end_node = loco::must_cast(node->end()); - auto strides_node = loco::must_cast(node->strides()); - - LUCI_ASSERT(begin_node->rank() == 1, "Only support rank 1 for begin_node"); - LUCI_ASSERT(end_node->rank() == 1, "Only support rank 1 for end_node"); - LUCI_ASSERT(strides_node->rank() == 1, "Only support rank 1 for strides_node"); - input = loco::must_cast(node->input()); - begin = dynamic_cast(node->begin()); - end = dynamic_cast(node->end()); - strides = dynamic_cast(node->strides()); - dummy.dtype(loco::DataType::S32); - dummy.rank(1); - dummy.shape_status(luci::ShapeStatus::VALID); + begin = loco::must_cast(node->begin()); + end = loco::must_cast(node->end()); + strides = loco::must_cast(node->strides()); loco::TensorShape input_shape = circle_shape(input); - - if (begin == nullptr) - { - begin = &dummy; - begin->dim(0).set(begin_node->dim(0).value()); - int32_t unknown_range = begin_node->dim(0).known() ? begin_node->dim(0).value() : input_dims; - - for (int32_t i = 0; i < unknown_range; ++i) - { - input_shape.dim(i).unset(); - } - } - if (end == nullptr) - { - end = &dummy; - end->dim(0).set(end_node->dim(0).value()); - int32_t unknown_range = end_node->dim(0).known() ? end_node->dim(0).value() : input_dims; - - for (int32_t i = 0; i < unknown_range; ++i) - { - input_shape.dim(i).unset(); - } - } - if (strides == nullptr) - { - strides = &dummy; - strides->dim(0).set(strides_node->dim(0).value()); - int32_t unknown_range = - strides_node->dim(0).known() ? strides_node->dim(0).value() : input_dims; - - for (int32_t i = 0; i < unknown_range; ++i) - { - input_shape.dim(i).unset(); - } - } - - LUCI_ASSERT(begin->dtype() == S32, "Only support S32 for begin_node"); - LUCI_ASSERT(end->dtype() == S32, "Only support S32 for end_node"); - LUCI_ASSERT(strides->dtype() == S32, "Only support S32 for strides_node"); - input_dims = input_shape.rank(); - - assert(begin->size() <= input_dims); - assert(end->size() <= input_dims); - assert(strides->size() <= input_dims); } StridedSliceParams params; luci::CircleNode *input = nullptr; luci::CircleConst *begin = nullptr; luci::CircleConst *end = nullptr; luci::CircleConst *strides = nullptr; - luci::CircleConst dummy; // Equivalent input shape after adding axis according to new_axis_mask. loco::TensorShape effective_input_shape; @@ -440,11 +384,35 @@ loco::TensorShape Algorithm::visit(const luci::CircleStridedSlice *node) { loco::TensorShape output_shape; + auto input_node = loco::must_cast(node->input()); + + auto begin_node = dynamic_cast(node->begin()); + auto end_node = dynamic_cast(node->end()); + auto strides_node = dynamic_cast(node->strides()); + // TODO support non-const case + if (begin_node == nullptr || end_node == nullptr || strides_node == nullptr) + { + INTERNAL_EXN("StridedSlice begin/end/strides nodes are not Constant"); + } + + LUCI_ASSERT(begin_node->dtype() == S32, "Only support S32 for begin_node"); + LUCI_ASSERT(end_node->dtype() == S32, "Only support S32 for end_node"); + LUCI_ASSERT(strides_node->dtype() == S32, "Only support S32 for strides_node"); + + LUCI_ASSERT(begin_node->rank() == 1, "Only support rank 1 for begin_node"); + LUCI_ASSERT(end_node->rank() == 1, "Only support rank 1 for end_node"); + LUCI_ASSERT(strides_node->rank() == 1, "Only support rank 1 for strides_node"); + + loco::TensorShape input_shape = circle_shape(input_node); + + assert(begin_node->size() <= input_shape.rank()); + assert(end_node->size() <= input_shape.rank()); + assert(strides_node->size() <= input_shape.rank()); + StridedSliceContext op_context(node); auto op_params = BuildStridedSliceParams(&op_context); auto &effective_input_shape = op_context.effective_input_shape; std::vector output_shape_vector; - std::vector output_known_vector; for (int32_t idx = effective_input_shape.rank() - 1; idx >= 0; --idx) { @@ -470,7 +438,6 @@ loco::TensorShape Algorithm::visit(const luci::CircleStridedSlice *node) if (!shrink_axis) { output_shape_vector.push_back(dim_shape); - output_known_vector.push_back(effective_input_shape.dim(idx).known()); } } @@ -478,9 +445,6 @@ loco::TensorShape Algorithm::visit(const luci::CircleStridedSlice *node) output_shape.rank(shape_size); for (uint32_t idx = 0; idx < shape_size; ++idx) { - bool known = output_known_vector[shape_size - 1u - idx]; - if (not known) - continue; int64_t dim = output_shape_vector.at(shape_size - 1u - idx); LUCI_ASSERT(0 <= dim && dim < 0xfffffffL, "Dimension size exceeds limit"); // reverse copy