Skip to content

Commit

Permalink
[luci/logex] Support RoPE operation (Samsung#14102)
Browse files Browse the repository at this point in the history
This commit supports RoPE for luci logex

ONE-DCO-1.0-Signed-off-by: youngsik kim [email protected]
  • Loading branch information
ys44kim authored Sep 26, 2024
1 parent bb52b8b commit dffeb51
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ CircleNodeSummaryBuilder::create_builder(const luci::CircleNode *node)
CIRCLE_NODE(REVERSE_SEQUENCE, CircleReverseSequenceSummaryBuilder)
CIRCLE_NODE(REVERSE_V2, CircleReverseV2SummaryBuilder)
CIRCLE_NODE(RMS_NORM, CircleRmsNormSummaryBuilder)
CIRCLE_NODE(ROPE, CircleRoPESummaryBuilder)
CIRCLE_NODE(ROUND, CircleRoundSummaryBuilder)
CIRCLE_NODE(RSQRT, CircleRsqrtSummaryBuilder)
CIRCLE_NODE(SCATTER_ND, CircleScatterNdSummaryBuilder)
Expand Down
14 changes: 14 additions & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,20 @@ TEST_F(CircleNodeSummaryBuilderTest, Mul_validate_fused_NEG)
EXPECT_FALSE(mock_build(&node));
}

TEST_F(CircleNodeSummaryBuilderTest, RoPE_validate)
{
luci::CircleRoPE node;
node.mode(luci::RoPEMode::GPT_NEOX);
EXPECT_TRUE(mock_build(&node));
}

TEST_F(CircleNodeSummaryBuilderTest, RoPE_validate_NEG)
{
luci::CircleRoPE node;
node.mode(luci::RoPEMode::UNDEFINED);
EXPECT_FALSE(mock_build(&node));
}

TEST_F(CircleNodeSummaryBuilderTest, SVDF_validate)
{
luci::CircleSVDF node;
Expand Down
33 changes: 33 additions & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,19 @@ std::string to_str(luci::MirrorPadMode mode)
}
}

std::string to_str(luci::RoPEMode mode)
{
switch (mode)
{
case luci::RoPEMode::GPT_NEOX:
return "GPT_NEOX";
case luci::RoPEMode::GPT_J:
return "GPT_J";
default:
return "Error";
}
}

} // namespace

namespace luci
Expand Down Expand Up @@ -902,6 +915,26 @@ void CircleRmsNormSummaryBuilder::build_attributes(const luci::CircleNode *node,
s.args().append("epsilon", std::to_string(rmsnorm->epsilon()));
}

bool CircleRoPESummaryBuilder::validate(const luci::CircleNode *node)
{
auto rope = loco::must_cast<const luci::CircleRoPE *>(node);
if (rope->mode() == luci::RoPEMode::UNDEFINED)
return false;

return true;
}

std::vector<std::string> CircleRoPESummaryBuilder::get_input_names(const luci::CircleNode *)
{
return {"input", "sin_table", "cos_table"};
}

void CircleRoPESummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
{
auto rope = loco::must_cast<const luci::CircleRoPE *>(node);
s.args().append("mode", to_str(rope->mode()));
}

std::vector<std::string> CircleScatterNdSummaryBuilder::get_input_names(const luci::CircleNode *)
{
return {"indices", "updates", "shape"};
Expand Down
8 changes: 8 additions & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilders.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,14 @@ class CircleRmsNormSummaryBuilder final : public CircleNodeSummaryBuilder
void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
};

class CircleRoPESummaryBuilder final : public CircleNodeSummaryBuilder
{
private:
bool validate(const luci::CircleNode *node);
std::vector<std::string> get_input_names(const luci::CircleNode *);
void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
};

class CircleRoundSummaryBuilder final : public CircleNodeWithXSummaryBuilder
{
};
Expand Down

0 comments on commit dffeb51

Please sign in to comment.