diff --git a/compiler/circle2circle-dredd-recipe-test/test.lst b/compiler/circle2circle-dredd-recipe-test/test.lst index 1e7b3de58cf..0a8d893c43b 100644 --- a/compiler/circle2circle-dredd-recipe-test/test.lst +++ b/compiler/circle2circle-dredd-recipe-test/test.lst @@ -91,6 +91,7 @@ Add(Net_TConv_Slice_001 PASS fuse_slice_with_tconv) Add(Net_TConv_Slice_002 PASS fuse_slice_with_tconv) Add(Net_TConv_Slice_003 PASS fuse_slice_with_tconv) Add(Net_Trans_Reshape_Trans_000 PASS remove_unnecessary_transpose) +Add(Net_Unnecessary_Cast_000 PASS remove_unnecessary_cast) Add(PadV2_001 PASS substitute_padv2_to_pad) Add(Softmax_001 PASS decompose_softmax) Add(Softmax_002 PASS decompose_softmax) diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp index c32060bd8af..9100fc3a38f 100644 --- a/compiler/circle2circle/src/Circle2Circle.cpp +++ b/compiler/circle2circle/src/Circle2Circle.cpp @@ -148,6 +148,8 @@ int entry(int argc, char **argv) "This will fuse or remove subsequent Transpose operators"); add_switch(arser, "--remove_unnecessary_add", "This will remove unnecessary add of zero constant"); + add_switch(arser, "--remove_unnecessary_cast", + "This will remove unnecessary cast with the same input and output type."); add_switch(arser, "--remove_unnecessary_reshape", "This will remove unnecessary reshape operators"); add_switch(arser, "--remove_unnecessary_slice", "This will remove unnecessary slice operators"); @@ -310,6 +312,7 @@ int entry(int argc, char **argv) option_str_to_enum["remove_redundant_reshape"] = Algorithms::RemoveRedundantReshape; option_str_to_enum["remove_redundant_transpose"] = Algorithms::RemoveRedundantTranspose; option_str_to_enum["remove_unnecessary_add"] = Algorithms::RemoveUnnecessaryAdd; + option_str_to_enum["remove_unnecessary_cast"] = Algorithms::RemoveUnnecessaryCast; option_str_to_enum["remove_unnecessary_reshape"] = Algorithms::RemoveUnnecessaryReshape; option_str_to_enum["remove_unnecessary_slice"] = Algorithms::RemoveUnnecessarySlice; option_str_to_enum["remove_unnecessary_strided_slice"] = Algorithms::RemoveUnnecessaryStridedSlice; @@ -348,6 +351,7 @@ int entry(int argc, char **argv) // If REPLACE is zero, it does not overwrite an existing value. setenv("LUCI_LOG", "100", 0); } + for (auto const &x : option_str_to_enum) { if (arser.get("--" + x.first)) diff --git a/compiler/luci-pass-value-py-test/test.lst b/compiler/luci-pass-value-py-test/test.lst index ebf84e02660..8328948b937 100644 --- a/compiler/luci-pass-value-py-test/test.lst +++ b/compiler/luci-pass-value-py-test/test.lst @@ -74,6 +74,7 @@ eval(Net_TConv_Slice_003 fuse_slice_with_tconv) eval(Net_Trans_Reshape_Trans_000 remove_unnecessary_transpose) eval(Net_Transpose_Add_000 forward_transpose_op) eval(Net_Transpose_Abs_000 forward_transpose_op) +eval(Net_Unnecessary_Cast_000 remove_unnecessary_cast) eval(Softmax_001 decompose_softmax) eval(Softmax_002 decompose_softmax) eval(UnidirectionalSequenceLSTM_003 unroll_unidirseqlstm) diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index 8a1eb6d4f78..d4f675f36fe 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -92,6 +92,7 @@ class CircleOptimizer final ConvertNCHWToNHWC, CommonSubExpressionElimination, RemoveUnnecessaryAdd, + RemoveUnnecessaryCast, RemoveUnnecessarySlice, RemoveUnnecessaryStridedSlice, RemoveUnnecessarySplit, diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryCastPass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryCastPass.h new file mode 100644 index 00000000000..c603239ac54 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryCastPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_UNNECESSARY_CAST_PASS_H__ +#define __LUCI_REMOVE_UNNECESSARY_CAST_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to remove unnecessary Cast nodes. + * @details This class will remove unnecessary Cast nodes. + * See https://github.com/Samsung/ONE/issues/13623 for more details. + */ +struct RemoveUnnecessaryCastPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveUnnecessaryCastPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_UNNECESSARY_CAST_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 90060253080..bf18b973d6d 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -66,6 +66,7 @@ #include "luci/Pass/RemoveRedundantTransposePass.h" #include "luci/Pass/RemoveRedundantQuantizePass.h" #include "luci/Pass/RemoveUnnecessaryAddPass.h" +#include "luci/Pass/RemoveUnnecessaryCastPass.h" #include "luci/Pass/RemoveUnnecessaryReshapePass.h" #include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h" #include "luci/Pass/RemoveUnnecessarySlicePass.h" @@ -365,6 +366,7 @@ void CircleOptimizer::optimize(loco::Graph *g) const option_to_pass[Options::Algorithm::RemoveQDQForMixedPrecisionOp] = &createPassInstance; option_to_pass[Options::Algorithm::RemoveQuantDequantSeq] = &createPassInstance; option_to_pass[Options::Algorithm::RemoveUnnecessaryAdd] = &createPassInstance; + option_to_pass[Options::Algorithm::RemoveUnnecessaryCast] = &createPassInstance; option_to_pass[Options::Algorithm::RemoveUnnecessarySlice] = &createPassInstance; option_to_pass[Options::Algorithm::RemoveUnnecessaryStridedSlice] = &createPassInstance; option_to_pass[Options::Algorithm::RemoveUnnecessarySplit] = &createPassInstance; diff --git a/compiler/luci/pass/src/RemoveUnnecessaryCastPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryCastPass.cpp new file mode 100644 index 00000000000..82056a96c77 --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryCastPass.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveUnnecessaryCastPass.h" + +#include + +namespace +{ + +#define RETURN_FALSE_UNLESS(cond) \ + if (not(cond)) \ + return false; + +bool remove_unnecessary_cast(luci::CircleCast *cast) +{ + RETURN_FALSE_UNLESS(cast->in_data_type() == cast->out_data_type()); + + loco::replace(cast).with(cast->x()); + + return true; +} + +} // namespace + +namespace luci +{ + +bool RemoveUnnecessaryCastPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto cast_node = dynamic_cast(node)) + { + if (remove_unnecessary_cast(cast_node)) + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveUnnecessaryCastPass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessaryCastPass.test.cpp new file mode 100644 index 00000000000..d2765d47986 --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryCastPass.test.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveUnnecessaryCastPass.h" +#include "helpers/CreateCircleConst.h" + +#include +#include + +#include + +namespace +{ + +using namespace luci::test; + +template +luci::CircleConst *const_node_of_dtype(loco::Graph *g, const loco::DataType dtype, + const std::vector &shape, T value) +{ + switch (dtype) + { + case loco::DataType::S32: + return luci::create_const_node(g, dtype, shape, static_cast(value)); + case loco::DataType::FLOAT32: + return luci::create_const_node(g, dtype, shape, static_cast(value)); + default: + throw std::runtime_error("Unsupported dtype!"); + } +} + +/** + * Graph for this test + * + * BEFORE + * + * | + * [CircleAdd] + * | + * [CircleCast] + * | + * [CircleAdd] + * | + * + * AFTER + * + * | + * [CircleAdd] + * | [CircleCast removed] + * [CircleAdd] + * | + * + */ +class CastGraphlet +{ +public: + void init(loco::Graph *g, loco::DataType in_type, loco::DataType out_type) + { + _const_a = const_node_of_dtype(g, in_type, {1}, 1); + + _add_a = g->nodes()->create(); + // _add_a->x(input_of_the_net); + _add_a->y(_const_a); + _add_a->dtype(in_type); + _add_a->shape({1}); + _add_a->name("add_a"); + + _cast = g->nodes()->create(); + _cast->in_data_type(in_type); + _cast->out_data_type(out_type); + _cast->x(_add_a); + + _const_a = const_node_of_dtype(g, out_type, {1}, 2); + + _add_b = g->nodes()->create(); + _add_b->x(_cast); + _add_b->y(_const_b); + _add_b->dtype(out_type); + _add_b->shape({1}); + _add_b->name("add_b"); + } + +protected: + luci::CircleCast *_cast = nullptr; + luci::CircleAdd *_add_a = nullptr; + luci::CircleConst *_const_a = nullptr; + luci::CircleAdd *_add_b = nullptr; + luci::CircleConst *_const_b = nullptr; +}; + +class RemoveUnnecessaryCastTestGraph : public TestIOGraph, public CastGraphlet +{ +public: + void init(loco::DataType in_type, loco::DataType out_type) + { + TestIOGraph::init({1}, {1}); + CastGraphlet::init(g(), in_type, out_type); + + _add_a->x(input()); + + output()->from(_add_b); + } +}; + +class RemoveUnnecessaryCastPassTest : public ::testing::Test +{ +public: + RemoveUnnecessaryCastTestGraph g; + luci::RemoveUnnecessaryCastPass pass; +}; + +} // namespace + +TEST_F(RemoveUnnecessaryCastPassTest, cast_remove) +{ + g.init(loco::DataType::FLOAT32 /* in_type */, loco::DataType::FLOAT32 /* out_type */); + + EXPECT_EQ(true, pass.run(g.g())); + + auto last_add = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, last_add); + // Check if the cast was removed: + auto first_add = dynamic_cast(last_add->x()); + EXPECT_NE(nullptr, first_add); +} + +TEST_F(RemoveUnnecessaryCastPassTest, different_data_types_NEG) +{ + g.init(loco::DataType::FLOAT32 /* in_type */, loco::DataType::S32 /* out_type */); + + EXPECT_EQ(false, pass.run(g.g())); +} diff --git a/compiler/one-cmds/how-to-use-one-commands.txt b/compiler/one-cmds/how-to-use-one-commands.txt index d6656545ff8..817b0c76ebc 100644 --- a/compiler/one-cmds/how-to-use-one-commands.txt +++ b/compiler/one-cmds/how-to-use-one-commands.txt @@ -200,6 +200,7 @@ Current transformation options are - remove_redundant_reshape : This fuses or removes redundant reshape operators. - remove_redundant_transpose : This fuses or removes redundant transpose operators. - remove_unnecessary_add : This removes unnecessary add operators. +- remove_unnecessary_cast : This will remove unnecessary cast with the same input and output type. - remove_unnecessary_reshape : This removes unnecessary reshape operators. - remove_unnecessary_slice : This removes unnecessary slice operators. - remove_unnecessary_strided_slice : This removes unnecessary strided slice operators. diff --git a/compiler/one-cmds/onelib/constant.py b/compiler/one-cmds/onelib/constant.py index a8dabf139d0..97b207488f3 100644 --- a/compiler/one-cmds/onelib/constant.py +++ b/compiler/one-cmds/onelib/constant.py @@ -63,6 +63,7 @@ class CONSTANT: 'remove_redundant_reshape', 'remove_redundant_transpose', 'remove_unnecessary_add', + 'remove_unnecessary_cast' 'remove_unnecessary_reshape', 'remove_unnecessary_slice', 'remove_unnecessary_strided_slice', @@ -157,6 +158,7 @@ class CONSTANT: ('remove_redundant_reshape', 'fuse or remove subsequent Reshape ops'), ('remove_redundant_transpose', 'fuse or remove subsequent Transpose ops'), ('remove_unnecessary_add', 'remove unnecessary add ops'), + ('remove_unnecessary_cast', 'remove unnecessary cast ops'), ('remove_unnecessary_reshape', 'remove unnecessary reshape ops'), ('remove_unnecessary_slice', 'remove unnecessary slice ops'), ('remove_unnecessary_strided_slice', 'remove unnecessary strided slice ops'), diff --git a/res/TensorFlowLiteRecipes/Net_Unnecessary_Cast_000/test.recipe b/res/TensorFlowLiteRecipes/Net_Unnecessary_Cast_000/test.recipe new file mode 100644 index 00000000000..8db629c20d1 --- /dev/null +++ b/res/TensorFlowLiteRecipes/Net_Unnecessary_Cast_000/test.recipe @@ -0,0 +1,69 @@ +operand { + name: "ifm" + type: FLOAT32 + shape { dim: 1 dim: 2 dim: 4 dim: 4 } +} +operand { + name: "add_const_0" + type: FLOAT32 + shape { dim: 1 dim: 2 dim: 4 dim: 4 } + filler { + tag: "gaussian" + arg: "0.0" + arg: "1.0" + } +} +operand { + name: "add_output_0" + type: FLOAT32 + shape { dim: 1 dim: 2 dim: 4 dim: 4 } +} +operation { + type: "Add" + input: "ifm" + input: "add_const_0" + output: "add_output_0" + add_options { + activation: RELU + } +} +operand { + name: "cast_output" + type: FLOAT32 + shape { dim: 1 dim: 2 dim: 4 dim: 4 } +} +operation { + type: "Cast" + cast_options { + in_data_type: FLOAT32 + out_data_type: FLOAT32 + } + input: "add_output_0" + output: "cast_output" +} +operand { + name: "add_const_1" + type: FLOAT32 + shape { dim: 1 dim: 2 dim: 4 dim: 4 } + filler { + tag: "gaussian" + arg: "0.0" + arg: "1.0" + } +} +operand { + name: "ofm" + type: FLOAT32 + shape { dim: 1 dim: 2 dim: 4 dim: 4 } +} +operation { + type: "Add" + input: "cast_output" + input: "add_const_1" + output: "ofm" + add_options { + activation: NONE + } +} +input: "ifm" +output: "ofm" diff --git a/res/TensorFlowLiteRecipes/Net_Unnecessary_Cast_000/test.rule b/res/TensorFlowLiteRecipes/Net_Unnecessary_Cast_000/test.rule new file mode 100644 index 00000000000..897490fe0b4 --- /dev/null +++ b/res/TensorFlowLiteRecipes/Net_Unnecessary_Cast_000/test.rule @@ -0,0 +1,9 @@ +# This checks if: +# Add -> Cast(input_type == output_type) -> Add +# is converted to: +# Add -> Add + +RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1 + +RULE "NO_CAST" $(op_count CAST) '=' 0 +RULE "ADD_EXIST" $(op_count ADD) '=' 2