Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[one-optimize] Remove unnecessary Cast Op #13760

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1c035b0
[draft][luci/pass] Introduce RemoveUnnecessaryCastPass
jiwaszki Aug 23, 2024
525eb61
[draft][circle2circle] Add an option for RemoveUnnecessaryCastPass
jiwaszki Aug 23, 2024
3f4c8be
[draft][one-cmds] Add an option for RemoveUnnecessaryCastPass
jiwaszki Aug 26, 2024
92b1fb9
[draft][res/tfl_recipes] Add Net_Unnecessary_Cast
jiwaszki Aug 26, 2024
558d6f4
[draft][luci/pass] Value test for RemoveUnnecessaryCastPass
jiwaszki Aug 26, 2024
11a6140
[draft][circle2circle] Dredd test for RemoveUnnecessaryCastPass
jiwaszki Aug 26, 2024
a749c2f
Cleanup for [luci/pass]
jiwaszki Aug 26, 2024
ea1287e
Cleanup for [circle2circle]
jiwaszki Aug 26, 2024
d5e2d16
[circle2circle] Add an option for RemoveUnnecessaryCastPass
jiwaszki Aug 23, 2024
41f573a
Update to new C2C
jiwaszki Sep 2, 2024
8f4c4a9
[luci/pass] Introduce RemoveUnnecessaryCastPass
jiwaszki Aug 23, 2024
6678149
Cleanup for [luci/pass]
jiwaszki Aug 26, 2024
5b0116a
Update to match with new CircleOptimizer
jiwaszki Sep 2, 2024
36b932f
[one-cmds] Add an option for RemoveUnnecessaryCastPass
jiwaszki Aug 26, 2024
21a8efb
[res/tfl_recipes] Add Net_Unnecessary_Cast
jiwaszki Aug 26, 2024
bb636b6
[luci/pass] Value test for RemoveUnnecessaryCastPass
jiwaszki Aug 26, 2024
c56d50c
[circle2circle] Dredd test for RemoveUnnecessaryCastPass
jiwaszki Aug 26, 2024
a2a0fd9
Merge branch 'jiwaszki/remove_cast_c2c_dredd' into jiwaszki/remove_cast
jiwaszki Sep 2, 2024
69e285e
Merge branch 'jiwaszki/remove_cast_c2c_opt' into jiwaszki/remove_cast
jiwaszki Sep 2, 2024
7fd1d97
Merge branch 'jiwaszki/remove_cast_luci_pass' into jiwaszki/remove_cast
jiwaszki Sep 2, 2024
7af1569
Merge branch 'jiwaszki/remove_cast_luci_test' into jiwaszki/remove_cast
jiwaszki Sep 2, 2024
cd9bd0a
Merge branch 'jiwaszki/remove_cast_one_cmds' into jiwaszki/remove_cast
jiwaszki Sep 2, 2024
b157df3
Merge branch 'jiwaszki/remove_cast_tfl_recipes' into jiwaszki/remove_…
jiwaszki Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/circle2circle-dredd-recipe-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Add(Net_TConv_Slice_001 PASS fuse_slice_with_tconv)
Add(Net_TConv_Slice_002 PASS fuse_slice_with_tconv)
Add(Net_TConv_Slice_003 PASS fuse_slice_with_tconv)
Add(Net_Trans_Reshape_Trans_000 PASS remove_unnecessary_transpose)
Add(Net_Unnecessary_Cast_000 PASS remove_unnecessary_cast)
Add(PadV2_001 PASS substitute_padv2_to_pad)
Add(Softmax_001 PASS decompose_softmax)
Add(Softmax_002 PASS decompose_softmax)
Expand Down
4 changes: 4 additions & 0 deletions compiler/circle2circle/src/Circle2Circle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ int entry(int argc, char **argv)
"This will fuse or remove subsequent Transpose operators");
add_switch(arser, "--remove_unnecessary_add",
"This will remove unnecessary add of zero constant");
add_switch(arser, "--remove_unnecessary_cast",
"This will remove unnecessary cast with the same input and output type.");
add_switch(arser, "--remove_unnecessary_reshape",
"This will remove unnecessary reshape operators");
add_switch(arser, "--remove_unnecessary_slice", "This will remove unnecessary slice operators");
Expand Down Expand Up @@ -310,6 +312,7 @@ int entry(int argc, char **argv)
option_str_to_enum["remove_redundant_reshape"] = Algorithms::RemoveRedundantReshape;
option_str_to_enum["remove_redundant_transpose"] = Algorithms::RemoveRedundantTranspose;
option_str_to_enum["remove_unnecessary_add"] = Algorithms::RemoveUnnecessaryAdd;
option_str_to_enum["remove_unnecessary_cast"] = Algorithms::RemoveUnnecessaryCast;
option_str_to_enum["remove_unnecessary_reshape"] = Algorithms::RemoveUnnecessaryReshape;
option_str_to_enum["remove_unnecessary_slice"] = Algorithms::RemoveUnnecessarySlice;
option_str_to_enum["remove_unnecessary_strided_slice"] = Algorithms::RemoveUnnecessaryStridedSlice;
Expand Down Expand Up @@ -348,6 +351,7 @@ int entry(int argc, char **argv)
// If REPLACE is zero, it does not overwrite an existing value.
setenv("LUCI_LOG", "100", 0);
}

for (auto const &x : option_str_to_enum)
{
if (arser.get<bool>("--" + x.first))
Expand Down
1 change: 1 addition & 0 deletions compiler/luci-pass-value-py-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ eval(Net_TConv_Slice_003 fuse_slice_with_tconv)
eval(Net_Trans_Reshape_Trans_000 remove_unnecessary_transpose)
eval(Net_Transpose_Add_000 forward_transpose_op)
eval(Net_Transpose_Abs_000 forward_transpose_op)
eval(Net_Unnecessary_Cast_000 remove_unnecessary_cast)
eval(Softmax_001 decompose_softmax)
eval(Softmax_002 decompose_softmax)
eval(UnidirectionalSequenceLSTM_003 unroll_unidirseqlstm)
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/pass/include/luci/CircleOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class CircleOptimizer final
ConvertNCHWToNHWC,
CommonSubExpressionElimination,
RemoveUnnecessaryAdd,
RemoveUnnecessaryCast,
RemoveUnnecessarySlice,
RemoveUnnecessaryStridedSlice,
RemoveUnnecessarySplit,
Expand Down
39 changes: 39 additions & 0 deletions compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryCastPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_REMOVE_UNNECESSARY_CAST_PASS_H__
#define __LUCI_REMOVE_UNNECESSARY_CAST_PASS_H__

#include <logo/Pass.h>

namespace luci
{

/**
* @brief Class to remove unnecessary Cast nodes.
* @details This class will remove unnecessary Cast nodes.
* See https://github.com/Samsung/ONE/issues/13623 for more details.
*/
struct RemoveUnnecessaryCastPass final : public logo::Pass
{
const char *name(void) const final { return "luci::RemoveUnnecessaryCastPass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif // __LUCI_REMOVE_UNNECESSARY_CAST_PASS_H__
2 changes: 2 additions & 0 deletions compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
#include "luci/Pass/RemoveRedundantTransposePass.h"
#include "luci/Pass/RemoveRedundantQuantizePass.h"
#include "luci/Pass/RemoveUnnecessaryAddPass.h"
#include "luci/Pass/RemoveUnnecessaryCastPass.h"
#include "luci/Pass/RemoveUnnecessaryReshapePass.h"
#include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h"
#include "luci/Pass/RemoveUnnecessarySlicePass.h"
Expand Down Expand Up @@ -365,6 +366,7 @@ void CircleOptimizer::optimize(loco::Graph *g) const
option_to_pass[Options::Algorithm::RemoveQDQForMixedPrecisionOp] = &createPassInstance<luci::RemoveQDQForMixedPrecisionOpPass>;
option_to_pass[Options::Algorithm::RemoveQuantDequantSeq] = &createPassInstance<luci::RemoveQuantDequantSeqPass>;
option_to_pass[Options::Algorithm::RemoveUnnecessaryAdd] = &createPassInstance<luci::RemoveUnnecessaryAddPass>;
option_to_pass[Options::Algorithm::RemoveUnnecessaryCast] = &createPassInstance<luci::RemoveUnnecessaryCastPass>;
option_to_pass[Options::Algorithm::RemoveUnnecessarySlice] = &createPassInstance<luci::RemoveUnnecessarySlicePass>;
option_to_pass[Options::Algorithm::RemoveUnnecessaryStridedSlice] = &createPassInstance<luci::RemoveUnnecessaryStridedSlicePass>;
option_to_pass[Options::Algorithm::RemoveUnnecessarySplit] = &createPassInstance<luci::RemoveUnnecessarySplitPass>;
Expand Down
56 changes: 56 additions & 0 deletions compiler/luci/pass/src/RemoveUnnecessaryCastPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/RemoveUnnecessaryCastPass.h"

#include <luci/IR/CircleNodes.h>

namespace
{

#define RETURN_FALSE_UNLESS(cond) \
if (not(cond)) \
return false;

bool remove_unnecessary_cast(luci::CircleCast *cast)
{
RETURN_FALSE_UNLESS(cast->in_data_type() == cast->out_data_type());

loco::replace(cast).with(cast->x());

return true;
}

} // namespace

namespace luci
{

bool RemoveUnnecessaryCastPass::run(loco::Graph *g)
{
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
if (auto cast_node = dynamic_cast<luci::CircleCast *>(node))
{
if (remove_unnecessary_cast(cast_node))
changed = true;
}
}
return changed;
}

} // namespace luci
145 changes: 145 additions & 0 deletions compiler/luci/pass/src/RemoveUnnecessaryCastPass.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/RemoveUnnecessaryCastPass.h"
#include "helpers/CreateCircleConst.h"

#include <luci/IR/CircleNodes.h>
#include <luci/test/TestIOGraph.h>

#include <gtest/gtest.h>

namespace
{

using namespace luci::test;

template <typename T>
luci::CircleConst *const_node_of_dtype(loco::Graph *g, const loco::DataType dtype,
const std::vector<uint32_t> &shape, T value)
{
switch (dtype)
{
case loco::DataType::S32:
return luci::create_const_node(g, dtype, shape, static_cast<int32_t>(value));
case loco::DataType::FLOAT32:
return luci::create_const_node(g, dtype, shape, static_cast<float>(value));
default:
throw std::runtime_error("Unsupported dtype!");
}
}

/**
* Graph for this test
*
* BEFORE
*
* |
* [CircleAdd]
* |
* [CircleCast]
* |
* [CircleAdd]
* |
*
* AFTER
*
* |
* [CircleAdd]
* | [CircleCast removed]
* [CircleAdd]
* |
*
*/
class CastGraphlet
{
public:
void init(loco::Graph *g, loco::DataType in_type, loco::DataType out_type)
{
_const_a = const_node_of_dtype(g, in_type, {1}, 1);

_add_a = g->nodes()->create<luci::CircleAdd>();
// _add_a->x(input_of_the_net);
_add_a->y(_const_a);
_add_a->dtype(in_type);
_add_a->shape({1});
_add_a->name("add_a");

_cast = g->nodes()->create<luci::CircleCast>();
_cast->in_data_type(in_type);
_cast->out_data_type(out_type);
_cast->x(_add_a);

_const_a = const_node_of_dtype(g, out_type, {1}, 2);

_add_b = g->nodes()->create<luci::CircleAdd>();
_add_b->x(_cast);
_add_b->y(_const_b);
_add_b->dtype(out_type);
_add_b->shape({1});
_add_b->name("add_b");
}

protected:
luci::CircleCast *_cast = nullptr;
luci::CircleAdd *_add_a = nullptr;
luci::CircleConst *_const_a = nullptr;
luci::CircleAdd *_add_b = nullptr;
luci::CircleConst *_const_b = nullptr;
};

class RemoveUnnecessaryCastTestGraph : public TestIOGraph, public CastGraphlet
{
public:
void init(loco::DataType in_type, loco::DataType out_type)
{
TestIOGraph::init({1}, {1});
CastGraphlet::init(g(), in_type, out_type);

_add_a->x(input());

output()->from(_add_b);
}
};

class RemoveUnnecessaryCastPassTest : public ::testing::Test
{
public:
RemoveUnnecessaryCastTestGraph g;
luci::RemoveUnnecessaryCastPass pass;
};

} // namespace

TEST_F(RemoveUnnecessaryCastPassTest, cast_remove)
{
g.init(loco::DataType::FLOAT32 /* in_type */, loco::DataType::FLOAT32 /* out_type */);

EXPECT_EQ(true, pass.run(g.g()));

auto last_add = dynamic_cast<luci::CircleAdd *>(g.output()->from());
EXPECT_NE(nullptr, last_add);
// Check if the cast was removed:
auto first_add = dynamic_cast<luci::CircleAdd *>(last_add->x());
EXPECT_NE(nullptr, first_add);
}

TEST_F(RemoveUnnecessaryCastPassTest, different_data_types_NEG)
{
g.init(loco::DataType::FLOAT32 /* in_type */, loco::DataType::S32 /* out_type */);

EXPECT_EQ(false, pass.run(g.g()));
}
1 change: 1 addition & 0 deletions compiler/one-cmds/how-to-use-one-commands.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ Current transformation options are
- remove_redundant_reshape : This fuses or removes redundant reshape operators.
- remove_redundant_transpose : This fuses or removes redundant transpose operators.
- remove_unnecessary_add : This removes unnecessary add operators.
- remove_unnecessary_cast : This will remove unnecessary cast with the same input and output type.
- remove_unnecessary_reshape : This removes unnecessary reshape operators.
- remove_unnecessary_slice : This removes unnecessary slice operators.
- remove_unnecessary_strided_slice : This removes unnecessary strided slice operators.
Expand Down
2 changes: 2 additions & 0 deletions compiler/one-cmds/onelib/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class CONSTANT:
'remove_redundant_reshape',
'remove_redundant_transpose',
'remove_unnecessary_add',
'remove_unnecessary_cast'
'remove_unnecessary_reshape',
'remove_unnecessary_slice',
'remove_unnecessary_strided_slice',
Expand Down Expand Up @@ -157,6 +158,7 @@ class CONSTANT:
('remove_redundant_reshape', 'fuse or remove subsequent Reshape ops'),
('remove_redundant_transpose', 'fuse or remove subsequent Transpose ops'),
('remove_unnecessary_add', 'remove unnecessary add ops'),
('remove_unnecessary_cast', 'remove unnecessary cast ops'),
('remove_unnecessary_reshape', 'remove unnecessary reshape ops'),
('remove_unnecessary_slice', 'remove unnecessary slice ops'),
('remove_unnecessary_strided_slice', 'remove unnecessary strided slice ops'),
Expand Down
Loading