Skip to content

Commit

Permalink
[luci/service] Remove Unintended Changes for StridedSlice (#13958)
Browse files Browse the repository at this point in the history
This removes unintended changes for StridedSlice from
#13955

ONE-DCO-1.0-Signed-off-by: sunki <[email protected]>
  • Loading branch information
qsunki authored Sep 9, 2024
1 parent c3f130e commit 12941da
Showing 1 changed file with 28 additions and 64 deletions.
92 changes: 28 additions & 64 deletions compiler/luci/service/src/Nodes/CircleStridedSlice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,75 +93,19 @@ struct StridedSliceContext
params.new_axis_mask = node->new_axis_mask();
params.shrink_axis_mask = node->shrink_axis_mask();

auto begin_node = loco::must_cast<luci::CircleNode *>(node->begin());
auto end_node = loco::must_cast<luci::CircleNode *>(node->end());
auto strides_node = loco::must_cast<luci::CircleNode *>(node->strides());

LUCI_ASSERT(begin_node->rank() == 1, "Only support rank 1 for begin_node");
LUCI_ASSERT(end_node->rank() == 1, "Only support rank 1 for end_node");
LUCI_ASSERT(strides_node->rank() == 1, "Only support rank 1 for strides_node");

input = loco::must_cast<luci::CircleNode *>(node->input());
begin = dynamic_cast<luci::CircleConst *>(node->begin());
end = dynamic_cast<luci::CircleConst *>(node->end());
strides = dynamic_cast<luci::CircleConst *>(node->strides());
dummy.dtype(loco::DataType::S32);
dummy.rank(1);
dummy.shape_status(luci::ShapeStatus::VALID);
begin = loco::must_cast<luci::CircleConst *>(node->begin());
end = loco::must_cast<luci::CircleConst *>(node->end());
strides = loco::must_cast<luci::CircleConst *>(node->strides());

loco::TensorShape input_shape = circle_shape(input);

if (begin == nullptr)
{
begin = &dummy;
begin->dim(0).set(begin_node->dim(0).value());
int32_t unknown_range = begin_node->dim(0).known() ? begin_node->dim(0).value() : input_dims;

for (int32_t i = 0; i < unknown_range; ++i)
{
input_shape.dim(i).unset();
}
}
if (end == nullptr)
{
end = &dummy;
end->dim(0).set(end_node->dim(0).value());
int32_t unknown_range = end_node->dim(0).known() ? end_node->dim(0).value() : input_dims;

for (int32_t i = 0; i < unknown_range; ++i)
{
input_shape.dim(i).unset();
}
}
if (strides == nullptr)
{
strides = &dummy;
strides->dim(0).set(strides_node->dim(0).value());
int32_t unknown_range =
strides_node->dim(0).known() ? strides_node->dim(0).value() : input_dims;

for (int32_t i = 0; i < unknown_range; ++i)
{
input_shape.dim(i).unset();
}
}

LUCI_ASSERT(begin->dtype() == S32, "Only support S32 for begin_node");
LUCI_ASSERT(end->dtype() == S32, "Only support S32 for end_node");
LUCI_ASSERT(strides->dtype() == S32, "Only support S32 for strides_node");

input_dims = input_shape.rank();

assert(begin->size<S32>() <= input_dims);
assert(end->size<S32>() <= input_dims);
assert(strides->size<S32>() <= input_dims);
}
StridedSliceParams params;
luci::CircleNode *input = nullptr;
luci::CircleConst *begin = nullptr;
luci::CircleConst *end = nullptr;
luci::CircleConst *strides = nullptr;
luci::CircleConst dummy;

// Equivalent input shape after adding axis according to new_axis_mask.
loco::TensorShape effective_input_shape;
Expand Down Expand Up @@ -440,11 +384,35 @@ loco::TensorShape Algorithm::visit(const luci::CircleStridedSlice *node)
{
loco::TensorShape output_shape;

auto input_node = loco::must_cast<luci::CircleNode *>(node->input());

auto begin_node = dynamic_cast<luci::CircleConst *>(node->begin());
auto end_node = dynamic_cast<luci::CircleConst *>(node->end());
auto strides_node = dynamic_cast<luci::CircleConst *>(node->strides());
// TODO support non-const case
if (begin_node == nullptr || end_node == nullptr || strides_node == nullptr)
{
INTERNAL_EXN("StridedSlice begin/end/strides nodes are not Constant");
}

LUCI_ASSERT(begin_node->dtype() == S32, "Only support S32 for begin_node");
LUCI_ASSERT(end_node->dtype() == S32, "Only support S32 for end_node");
LUCI_ASSERT(strides_node->dtype() == S32, "Only support S32 for strides_node");

LUCI_ASSERT(begin_node->rank() == 1, "Only support rank 1 for begin_node");
LUCI_ASSERT(end_node->rank() == 1, "Only support rank 1 for end_node");
LUCI_ASSERT(strides_node->rank() == 1, "Only support rank 1 for strides_node");

loco::TensorShape input_shape = circle_shape(input_node);

assert(begin_node->size<S32>() <= input_shape.rank());
assert(end_node->size<S32>() <= input_shape.rank());
assert(strides_node->size<S32>() <= input_shape.rank());

StridedSliceContext op_context(node);
auto op_params = BuildStridedSliceParams(&op_context);
auto &effective_input_shape = op_context.effective_input_shape;
std::vector<int64_t> output_shape_vector;
std::vector<bool> output_known_vector;

for (int32_t idx = effective_input_shape.rank() - 1; idx >= 0; --idx)
{
Expand All @@ -470,17 +438,13 @@ loco::TensorShape Algorithm::visit(const luci::CircleStridedSlice *node)
if (!shrink_axis)
{
output_shape_vector.push_back(dim_shape);
output_known_vector.push_back(effective_input_shape.dim(idx).known());
}
}

auto shape_size = output_shape_vector.size();
output_shape.rank(shape_size);
for (uint32_t idx = 0; idx < shape_size; ++idx)
{
bool known = output_known_vector[shape_size - 1u - idx];
if (not known)
continue;
int64_t dim = output_shape_vector.at(shape_size - 1u - idx);
LUCI_ASSERT(0 <= dim && dim < 0xfffffffL, "Dimension size exceeds limit");
// reverse copy
Expand Down

0 comments on commit 12941da

Please sign in to comment.