Skip to content

Commit

Permalink
DRAFT CFE fix S64 paddings in Pad
Browse files Browse the repository at this point in the history
on-going draft to fix S64 paddings in Pad.

Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark committed Jul 25, 2024
1 parent 8a76a55 commit 9f4c80e
Show file tree
Hide file tree
Showing 9 changed files with 361 additions and 6 deletions.
1 change: 1 addition & 0 deletions compiler/luci-interpreter/src/kernels/Pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void Pad::configure()
const int32_t padding_after = paddings_data[i * 2 + 1];
assert(padding_before >= 0 && padding_after >= 0);
output_shape.dim(i) = input_shape.dim(i) + padding_before + padding_after;
printf("!!! Pad %d %d %d\r\n", i, input_shape.dim(i), output_shape.dim(i));
}

output()->resize(output_shape);
Expand Down
4 changes: 4 additions & 0 deletions compiler/luci-pass-value-py-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,7 @@ eval(Net_Dequantize_Add_000 fold_dequantize)
# test for common subexpression elimination
eval(CSE_Quantize_000 common_subexpression_elimination)
eval(CSE_Transpose_000 common_subexpression_elimination)

# test for canonicalization, with any optimization
# TODO enable Pad_001 when TF version up supports INT4 paddings
# eval(Pad_001 fuse_instnorm) --> tflite(v2.12.1) does not support INT64 paddings
21 changes: 18 additions & 3 deletions compiler/luci-pass-value-py-test/test_luci_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def luci_eval_verify(test_name,
atolint = int(atolf32)

# Build TFLite interpreter.
interpreter = tf.lite.Interpreter(tflite_model)
interpreter = tf.lite.Interpreter(
tflite_model, experimental_preserve_all_tensors=True)
interpreter.allocate_tensors()

# Read SignatureDef and get output tensor id orders for remapping
Expand Down Expand Up @@ -87,16 +88,30 @@ def luci_eval_verify(test_name,
output_shape = [int(i) for i in shape_file.read().split(',')]
luci_output_data = np.reshape(output_data, output_shape)
output_tensor = output_details["index"]
print("!!! output_tensor 1", output_tensor)
if full_signatures_outputs_remap != None:
output_tensor = full_signatures_outputs_remap[idx]
print("!!! output_tensor 2", idx, output_tensor)
intp_output_data = interpreter.get_tensor(output_tensor)

print("!!! ", tflite_model, ":", output_tensor, intp_output_data.shape)
print("!!! ", circle_model, ":", output_shape)

err_msg = "Execution result of " + tflite_model + " does not match with " + circle_model
if output_details["dtype"] == np.uint8:
assert np.allclose(
luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg
elif output_details["dtype"] == np.float32:
assert np.allclose(
luci_output_data, intp_output_data, rtol=rtolf32, atol=atolf32), err_msg
print("!!! float32")
print(intp_output_data.shape)
print(luci_output_data.shape)
res = np.allclose(
luci_output_data, intp_output_data, rtol=rtolf32, atol=atolf32)
if not res:
diff = np.isclose(
luci_output_data, intp_output_data, rtol=rtolf32, atol=atolf32)
print(diff)
assert res, err_msg
elif output_details["dtype"] == np.int64:
assert np.allclose(
luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg
Expand Down
38 changes: 38 additions & 0 deletions compiler/luci/pass/include/luci/Pass/CanonicalizePass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_CANONICALIZE_PASS_H__
#define __LUCI_CANONICALIZE_PASS_H__

#include <logo/Pass.h>

namespace luci
{

/**
* @brief Class to canoncalize CircleNodes
*
*/
struct CanonicalizePass final : public logo::Pass
{
const char *name(void) const final { return "luci::CanonicalizePass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif // __LUCI_CANONICALIZE_PASS_H__
101 changes: 101 additions & 0 deletions compiler/luci/pass/src/CanonicalizePass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/CanonicalizePass.h"

#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>

#include <loco/IR/DataType.h>

#include <limits>

#define CHECK_OR_FALSE(condition) \
if (not(condition)) \
return false;

namespace
{

/**
* Convert S64 CircleConst paddings to S32
*/
bool paddings_to_s32(luci::CirclePad *pad)
{
// check conditions
auto paddings = dynamic_cast<luci::CircleConst *>(pad->paddings());
CHECK_OR_FALSE(paddings);
CHECK_OR_FALSE(paddings->dtype() == loco::DataType::S64);

// TODO relocate to helpers/CreateCircleConst.h when necessary
auto num_elements = paddings->size<loco::DataType::S64>();
for (uint32_t i = 0; i < num_elements; i++)
{
auto v64 = paddings->at<loco::DataType::S64>(i);
auto hval = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
auto lval = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest());
CHECK_OR_FALSE(v64 < hval);
CHECK_OR_FALSE(v64 > lval);
}

auto paddings_s32 = pad->graph()->nodes()->create<luci::CircleConst>();
paddings_s32->name(paddings->name() + "_S32");
paddings_s32->dtype(loco::DataType::S32);
paddings_s32->rank(paddings->rank());
for (uint32_t i = 0; i < paddings->rank(); i++)
paddings_s32->dim(i).set(paddings->dim(i).value());
paddings_s32->shape_status(luci::ShapeStatus::VALID);
luci::add_origin(paddings_s32, luci::get_origin(paddings));

paddings_s32->size<loco::DataType::S32>(num_elements);
for (uint32_t i = 0; i < num_elements; i++)
{
auto v64 = paddings->at<loco::DataType::S64>(i);
paddings_s32->at<loco::DataType::S32>(i) = static_cast<int32_t>(v64);
}

// replace paddings with S32 dtype
pad->paddings(paddings_s32);

return true;
}

} // namespace

namespace luci
{

/**
* Canonicalize circle nodes
*/
bool CanonicalizePass::run(loco::Graph *g)
{
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
if (auto pad = dynamic_cast<luci::CirclePad *>(node))
{
if (paddings_to_s32(pad))
changed = true;
}

// TODO add more canonicalization
}

return changed;
}

} // namespace luci
152 changes: 152 additions & 0 deletions compiler/luci/pass/src/CanonicalizePass.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/CanonicalizePass.h"

#include <luci/IR/CircleNodes.h>

#include <luci/test/TestIOGraph.h>

#include <vector>

#include <gtest/gtest.h>

TEST(CanonicalizePassTest, name)
{
luci::CanonicalizePass pass;
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}

namespace
{

using namespace luci::test;

struct PadGraphlet
{
PadGraphlet() = default;

void init(loco::Graph *g)
{
_pad = g->nodes()->create<luci::CirclePad>();
_paddings_s32 = g->nodes()->create<luci::CircleConst>();
_paddings_s64 = g->nodes()->create<luci::CircleConst>();

_pad->name("pad");
_paddings_s32->name("paddings_s32");
_paddings_s64->name("paddings_s64");

_paddings_s64->dtype(loco::DataType::S64);
_paddings_s64->rank(2);
_paddings_s64->dim(0).set(4);
_paddings_s64->dim(1).set(2);
_paddings_s64->shape_status(luci::ShapeStatus::VALID);

_paddings_s32->dtype(loco::DataType::S32);
_paddings_s32->rank(2);
_paddings_s32->dim(0).set(4);
_paddings_s32->dim(1).set(2);
_paddings_s32->shape_status(luci::ShapeStatus::VALID);

std::vector<int64_t> ps = {0, 0, 1, 1, 1, 1, 0, 0};

uint32_t num_elements = static_cast<uint32_t>(ps.size());
_paddings_s64->size<loco::DataType::S64>(num_elements);
for (uint32_t i = 0; i < num_elements; i++)
_paddings_s64->at<loco::DataType::S64>(i) = ps[i];

_paddings_s32->size<loco::DataType::S32>(num_elements);
for (uint32_t i = 0; i < num_elements; i++)
_paddings_s32->at<loco::DataType::S32>(i) = static_cast<int32_t>(ps[i]);
}

luci::CirclePad *_pad = nullptr;
luci::CircleConst *_paddings_s32 = nullptr;
luci::CircleConst *_paddings_s64 = nullptr;
};

class CanonicalizePadTestGraph : public TestIOGraph, public PadGraphlet
{
public:
CanonicalizePadTestGraph() = default;

void init(void)
{
TestIOGraph::init({1, 3, 3, 2}, {1, 5, 5, 2});
PadGraphlet::init(g());

_pad->input(input());
_pad->paddings(_paddings_s64);

output()->from(_pad);
}
};

} // namespace

TEST(CanonicalizePassPadTest, paddings_64_to_32)
{
CanonicalizePadTestGraph g;
luci::CanonicalizePass pass;

g.init();

luci::CircleConst *paddings = dynamic_cast<luci::CircleConst *>(g._pad->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S64);

EXPECT_TRUE(pass.run(g.g()));

paddings = dynamic_cast<luci::CircleConst *>(g._pad->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S32);
}

TEST(CanonicalizePassPadTest, paddings_32_NEG)
{
CanonicalizePadTestGraph g;
luci::CanonicalizePass pass;

g.init();
g._pad->paddings(g._paddings_s32);

luci::CircleConst *paddings = dynamic_cast<luci::CircleConst *>(g._pad->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S32);

EXPECT_FALSE(pass.run(g.g()));

paddings = dynamic_cast<luci::CircleConst *>(g._pad->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S32);
}

TEST(CanonicalizePassPadTest, paddings_32_over_NEG)
{
CanonicalizePadTestGraph g;
luci::CanonicalizePass pass;

g.init();
g._paddings_s64->at<loco::DataType::S64>(2) =
static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 100;

EXPECT_FALSE(pass.run(g.g()));

luci::CircleConst *paddings = dynamic_cast<luci::CircleConst *>(g._pad->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S64);
}
4 changes: 4 additions & 0 deletions compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "luci/CircleOptimizer.h"

#include "luci/Pass/CanonicalizePass.h"
#include "luci/Pass/ConvertNCHWToNHWCPass.h"
#include "luci/Pass/CommonSubExpressionEliminationPass.h"
#include "luci/Pass/ExpandBroadcastConstPass.h"
Expand Down Expand Up @@ -260,6 +261,9 @@ void CircleOptimizer::optimize(loco::Graph *g) const
phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());

// Run canonicalization
phase.emplace_back(std::make_unique<luci::CanonicalizePass>());

if (_options->query(Options::Algorithm::CommonSubExpressionElimination))
{
phase.emplace_back(std::make_unique<luci::CommonSubExpressionEliminationPass>());
Expand Down
Loading

0 comments on commit 9f4c80e

Please sign in to comment.