Skip to content

Commit

Permalink
Merge pull request #121 from avast/fix_bool_simplifiers
Browse files Browse the repository at this point in the history
Fix bool simplifiers
  • Loading branch information
metthal authored Aug 13, 2020
2 parents eee73e8 + 148bd7c commit 9297e2c
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 64 deletions.
36 changes: 21 additions & 15 deletions include/yaramod/utils/modifying_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@

namespace yaramod {

/**
* Represents error during modifying by ModifyingVisitor.
*/
class ModifyingVisitorError : public YaramodError
{
public:
ModifyingVisitorError(const std::string& errorMsg) : YaramodError("ModifyingVisitorError error: " + errorMsg) {}
ModifyingVisitorError(const ModifyingVisitorError&) = default;
};

/**
* Class capturing current TokenStream and first and last TokenIi of the expression which it is given
* in it's constructor. This data can be very useful at the end of the visit method of ModifyingVisitors
Expand Down Expand Up @@ -670,24 +680,20 @@ class ModifyingVisitor : public Visitor
* Removes all tokens in the TokenStream that the given expression has been associated with.
* Then all tokens, which are currently associated with the given expression are moved to the TokenStream
* The expression is assigned the TokenStream.
* @return true iff any changes performed (when the expression has different tokenstream)
*/
bool cleanUpTokenStreams(const TokenStreamContext& context, Expression* new_expression)
void cleanUpTokenStreams(const TokenStreamContext& context, Expression* new_expression)
{
auto& oldTokenStream = context.oldTokenStream();
if (oldTokenStream.get() != new_expression->getTokenStream())
{
auto oldBeforeFirst = context.oldBeforeFirst();
auto oldAfterLast = context.oldAfterLast();
// remove old tokens which has not been moved away by builder
oldTokenStream->erase(std::next(oldBeforeFirst), oldAfterLast);
// transfer builded tokens
oldTokenStream->moveAppend(oldAfterLast, new_expression->getTokenStream());
new_expression->setTokenStream(oldTokenStream);
return true;
}
else
return false;
if (oldTokenStream.get() == new_expression->getTokenStream())
throw ModifyingVisitorError("The expressions have the same TokenStreams. Use yaramod::YaraExpressionBuilder to extract new Expression first.");

auto oldBeforeFirst = context.oldBeforeFirst();
auto oldAfterLast = context.oldAfterLast();
// remove old tokens which has not been moved away by builder
oldTokenStream->erase(std::next(oldBeforeFirst), oldAfterLast);
// transfer builded tokens
oldTokenStream->moveAppend(oldAfterLast, new_expression->getTokenStream());
new_expression->setTokenStream(oldTokenStream);
}

protected:
Expand Down
81 changes: 52 additions & 29 deletions src/examples/cpp/simplify_bools/bool_simplifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@ class BoolSimplifier : public yaramod::ModifyingVisitor
auto retLeft = expr->getLeftOperand()->accept(this);
auto retRight = expr->getRightOperand()->accept(this);

auto leftExpr = std::get_if<yaramod::Expression::Ptr>(&retLeft);
auto rightExpr = std::get_if<yaramod::Expression::Ptr>(&retRight);

yaramod::BoolLiteralExpression* leftBool = nullptr;
if (auto leftExpr = std::get_if<yaramod::Expression::Ptr>(&retLeft))
{
if (*leftExpr)
leftBool = (*leftExpr)->as<yaramod::BoolLiteralExpression>();
}
if (leftExpr and *leftExpr)
leftBool = (*leftExpr)->as<yaramod::BoolLiteralExpression>();

yaramod::BoolLiteralExpression* rightBool = nullptr;
if (auto rightExpr = std::get_if<yaramod::Expression::Ptr>(&retRight))
{
if (*rightExpr)
rightBool = (*rightExpr)->as<yaramod::BoolLiteralExpression>();
}
if (rightExpr and *rightExpr)
rightBool = (*rightExpr)->as<yaramod::BoolLiteralExpression>();

std::shared_ptr<yaramod::Expression> output = nullptr;
// If both sides of AND are boolean constants then determine the value based on truth table of AND
Expand All @@ -49,7 +46,14 @@ class BoolSimplifier : public yaramod::ModifyingVisitor
output = std::make_shared<yaramod::BoolLiteralExpression>(false);
// T and X = X
else
output = expr->getRightOperand();
{
if (rightExpr and *rightExpr)
output = *rightExpr;
else
output = expr->getRightOperand();
if (output)
output = yaramod::YaraExpressionBuilder(output).get();
}
}
// Only right-hand side is boolean constant
else if (rightBool)
Expand All @@ -59,12 +63,19 @@ class BoolSimplifier : public yaramod::ModifyingVisitor
output = std::make_shared<yaramod::BoolLiteralExpression>(false);
// X and T = X
else
output = expr->getLeftOperand();
{
if (leftExpr and *leftExpr)
output = *leftExpr;
else
output = expr->getLeftOperand();
if (output)
output = yaramod::YaraExpressionBuilder(output).get();
}
}

if (output)
{
expr->exchangeTokens(output.get());
cleanUpTokenStreams(context, output.get());
return output;
}
else
Expand All @@ -77,19 +88,16 @@ class BoolSimplifier : public yaramod::ModifyingVisitor
auto retLeft = expr->getLeftOperand()->accept(this);
auto retRight = expr->getRightOperand()->accept(this);

auto leftExpr = std::get_if<yaramod::Expression::Ptr>(&retLeft);
auto rightExpr = std::get_if<yaramod::Expression::Ptr>(&retRight);

yaramod::BoolLiteralExpression* leftBool = nullptr;
if (auto leftExpr = std::get_if<yaramod::Expression::Ptr>(&retLeft))
{
if (*leftExpr)
leftBool = (*leftExpr)->as<yaramod::BoolLiteralExpression>();
}
if (leftExpr and *leftExpr)
leftBool = (*leftExpr)->as<yaramod::BoolLiteralExpression>();

yaramod::BoolLiteralExpression* rightBool = nullptr;
if (auto rightExpr = std::get_if<yaramod::Expression::Ptr>(&retRight))
{
if (*rightExpr)
rightBool = (*rightExpr)->as<yaramod::BoolLiteralExpression>();
}
if (rightExpr and *rightExpr)
rightBool = (*rightExpr)->as<yaramod::BoolLiteralExpression>();

std::shared_ptr<yaramod::Expression> output = nullptr;
// If both sides of OR are boolean constants then determine the value based on truth table of OR
Expand All @@ -109,7 +117,14 @@ class BoolSimplifier : public yaramod::ModifyingVisitor
output = std::make_shared<yaramod::BoolLiteralExpression>(true);
// F or X = X
else
output = expr->getRightOperand();
{
if (rightExpr and *rightExpr)
output = *rightExpr;
else
output = expr->getRightOperand();
if(output)
output = yaramod::YaraExpressionBuilder(output).get();
}
}
// Only right-hand side is boolean constant
else if (rightBool)
Expand All @@ -119,13 +134,20 @@ class BoolSimplifier : public yaramod::ModifyingVisitor
output = std::make_shared<yaramod::BoolLiteralExpression>(true);
// X or F = X
else
output = expr->getLeftOperand();
{
if (leftExpr and *leftExpr)
output = *leftExpr;
else
output = expr->getLeftOperand();
if(output)
output = yaramod::YaraExpressionBuilder(output).get();
}
}


if (output)
{
expr->exchangeTokens(output.get());
cleanUpTokenStreams(context, output.get());
return output;
}
else
Expand All @@ -144,7 +166,7 @@ class BoolSimplifier : public yaramod::ModifyingVisitor
if (boolVal)
{
auto output = std::make_shared<yaramod::BoolLiteralExpression>(!boolVal->getValue());
expr->exchangeTokens(output.get());
cleanUpTokenStreams(context, output.get());
return output;
}
}
Expand All @@ -164,7 +186,7 @@ class BoolSimplifier : public yaramod::ModifyingVisitor
if (boolVal)
{
auto output = std::make_shared<yaramod::BoolLiteralExpression>(boolVal->getValue());
expr->exchangeTokens(output.get());
cleanUpTokenStreams(context, output.get());
return output;
}
}
Expand All @@ -175,8 +197,9 @@ class BoolSimplifier : public yaramod::ModifyingVisitor
virtual yaramod::VisitResult visit(yaramod::BoolLiteralExpression* expr) override
{
// Lift up boolean value
yaramod::TokenStreamContext context{expr};
auto output = std::make_shared<yaramod::BoolLiteralExpression>(expr->getValue());
expr->exchangeTokens(output.get());
cleanUpTokenStreams(context, output.get());
return output;
}
};
25 changes: 13 additions & 12 deletions src/examples/python/simplify_bools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@ def visit_AndExpression(self, expr):
# F and F = F
if left_bool and right_bool:
output = yaramod.bool_val(left_bool.value and right_bool.value).get()
expr.exchange_tokens(output)
self.cleanUpTokenStreams(context, output)
return output
# Only left-hand side is boolean constant
elif left_bool:
# F and X = F
# T and X = X
output = yaramod.bool_val(False).get() if not left_bool.value else expr.right_operand
expr.exchange_tokens(output)
output = yaramod.bool_val(False).get() if not left_bool.value else yaramod.YaraExpressionBuilder(right_expr if right_expr else expr.right_operand).get()
self.cleanUpTokenStreams(context, output)
return output
# Only right-hand side is boolean constant
elif right_bool:
# X and F = F
# X and T = X
output yaramod.bool_val(False).get() if not right_bool.value else expr.left_operand
output = yaramod.bool_val(False).get() if not right_bool.value else yaramod.YaraExpressionBuilder(left_expr if left_expr else expr.left_operand).get()
self.cleanUpTokenStreams(context, output)
return output

Expand All @@ -55,21 +55,21 @@ def visit_OrExpression(self, expr):
# F or F = F
if left_bool and right_bool:
output = yaramod.bool_val(left_bool.value or right_bool.value).get()
expr.exchange_tokens(output)
self.cleanUpTokenStreams(context, output)
return output
# Only left-hand side is boolean constant
elif left_bool:
# T or X = T
# F or X = X
output = yaramod.bool_val(True).get() if left_bool.value else expr.right_operand
expr.exchange_tokens(output)
output = yaramod.bool_val(True).get() if left_bool.value else yaramod.YaraExpressionBuilder(right_expr if right_expr else expr.right_operand).get()
self.cleanUpTokenStreams(context, output)
return output
# Only right-hand side is boolean constant
elif right_bool:
# X or T = T
# X or F = X
output = yaramod.bool_val(True).get() if right_bool.value else expr.left_operand
expr.exchange_tokens(output)
output = yaramod.bool_val(True).get() if right_bool.value else yaramod.YaraExpressionBuilder(left_expr if left_expr else expr.left_operand).get()
self.cleanUpTokenStreams(context, output)
return output

return self.default_handler(context, expr, left_expr, right_expr)
Expand All @@ -82,7 +82,7 @@ def visit_NotExpression(self, expr):
bool_val = new_expr if (new_expr and isinstance(new_expr, yaramod.BoolLiteralExpression)) else None
if bool_val:
output = yaramod.bool_val(not bool_val.value).get()
expr.exchange_tokens(output)
self.cleanUpTokenStreams(context, output)
return output

return self.default_handler(context, expr, new_expr)
Expand All @@ -95,15 +95,16 @@ def visit_ParenthesesExpression(self, expr):
bool_val = new_expr if (new_expr and isinstance(new_expr, yaramod.BoolLiteralExpression)) else None
if bool_val:
output = yaramod.bool_val(bool_val.value).get()
expr.exchange_tokens(output)
self.cleanUpTokenStreams(context, output)
return output

return self.default_handler(context, expr, new_expr)

def visit_BoolLiteralExpression(self, expr):
# Lift up boolean value
context = yaramod.TokenStreamContext(expr)
output = yaramod.bool_val(expr.value).get()
expr.exchange_tokens(output)
self.cleanUpTokenStreams(context, output)
return output


Expand Down
2 changes: 1 addition & 1 deletion src/python/py_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ void addVisitorClasses(py::module& module)
py::class_<ModifyingVisitor, PyModifyingVisitor, Visitor>(module, "ModifyingVisitor")
.def(py::init<>())
.def("modify", &ModifyingVisitor::modify, py::arg("expr"), py::arg("when_deleted") = static_cast<Expression*>(nullptr))
.def("cleanUpTokenStreams", &ModifyingVisitor::cleanUpTokenStreams, py::arg("context"), py::arg("new_expression"))
.def("cleanup_tokenstreams", &ModifyingVisitor::cleanUpTokenStreams, py::arg("context"), py::arg("new_expression"))
.def("visit_StringExpression", py::overload_cast<StringExpression*>(&ModifyingVisitor::visit))
.def("visit_StringWildcardExpression", py::overload_cast<StringWildcardExpression*>(&ModifyingVisitor::visit))
.def("visit_StringAtExpression", py::overload_cast<StringAtExpression*>(&ModifyingVisitor::visit))
Expand Down
28 changes: 27 additions & 1 deletion src/types/yara_file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ void YaraFile::addRule(Rule&& rule)
*/
void YaraFile::addRule(std::unique_ptr<Rule>&& rule)
{
if (rule->getTokenStream() != _tokenStream.get())
_tokenStream->moveAppend(rule->getTokenStream());
_rules.emplace_back(std::move(rule));
_ruleTable.emplace(_rules.back()->getName(), _rules.back().get());
}
Expand All @@ -234,8 +236,10 @@ void YaraFile::addRule(std::unique_ptr<Rule>&& rule)
*/
void YaraFile::addRule(const std::shared_ptr<Rule>& rule)
{
if (rule->getTokenStream() != _tokenStream.get())
_tokenStream->moveAppend(rule->getTokenStream());
_rules.emplace_back(rule);
_ruleTable.emplace(rule->getName(), _rules.back().get());
_ruleTable.emplace(_rules.back()->getName(), _rules.back().get());
}

/**
Expand Down Expand Up @@ -277,6 +281,17 @@ bool YaraFile::addImports(const std::vector<TokenIt>& imports, ModulesPool& modu
void YaraFile::insertRule(std::size_t position, std::unique_ptr<Rule>&& rule)
{
position = std::min(position, _rules.size());
TokenIt before;
if (position == _rules.size())
{
before = _tokenStream->end();
}
else
{
before = _rules[position]->getFirstTokenIt();
}
_tokenStream->moveAppend(before, rule->getTokenStream());

_rules.insert(_rules.begin() + position, std::move(rule));
_ruleTable.emplace(_rules[position]->getName(), _rules[position].get());
}
Expand All @@ -290,6 +305,17 @@ void YaraFile::insertRule(std::size_t position, std::unique_ptr<Rule>&& rule)
void YaraFile::insertRule(std::size_t position, const std::shared_ptr<Rule>& rule)
{
position = std::min(position, _rules.size());
TokenIt before;
if (position == _rules.size())
{
before = _tokenStream->end();
}
else
{
before = _rules[position]->getFirstTokenIt();
}
_tokenStream->moveAppend(before, rule->getTokenStream());

_rules.insert(_rules.begin() + position, rule);
_ruleTable.emplace(_rules[position]->getName(), _rules[position].get());
}
Expand Down
Loading

0 comments on commit 9297e2c

Please sign in to comment.