From bcba028b03aabc64a5a2612183dc2153a43aa609 Mon Sep 17 00:00:00 2001 From: jackwener Date: Wed, 3 Jan 2024 11:32:07 +0800 Subject: [PATCH] [feature](Nereids): InferPredicates support In --- .../rules/rewrite/InferPredicates.java | 3 +- .../rules/rewrite/PredicatePropagation.java | 177 +++++++++++------- .../rules/rewrite/PullUpPredicates.java | 4 +- .../nereids/trees/expressions/EqualTo.java | 4 - .../trees/expressions/InPredicate.java | 11 ++ .../rules/rewrite/InferPredicatesTest.java | 62 +++--- .../rewrite/PredicatePropagationTest.java | 51 +++++ 7 files changed, 206 insertions(+), 106 deletions(-) create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java index 3c4593df54c81d4..477ff905d20f2ac 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java @@ -51,7 +51,6 @@ * round of predicate push-down */ public class InferPredicates extends DefaultPlanRewriter implements CustomRewriter { - private final PredicatePropagation propagation = new PredicatePropagation(); private final PullUpPredicates pollUpPredicates = new PullUpPredicates(); @Override @@ -109,7 +108,7 @@ private Set getAllExpressions(Plan left, Plan right, Optional baseExpressions = pullUpPredicates(left); baseExpressions.addAll(pullUpPredicates(right)); condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on))); - baseExpressions.addAll(propagation.infer(baseExpressions)); + baseExpressions.addAll(PredicatePropagation.infer(baseExpressions)); return baseExpressions; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index 72e9023dc45df52..7788bbb7f06fb3f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.types.DataType; @@ -55,8 +56,7 @@ private enum InferType { INTEGRAL(IntegralType.class), STRING(CharacterType.class), DATE(DateLikeType.class), - OTHER(DataType.class) - ; + OTHER(DataType.class); private final Class superClazz; @@ -65,15 +65,15 @@ private enum InferType { } } - private class ComparisonInferInfo { + private static class EqualInferInfo { public final InferType inferType; - public final Optional left; - public final Optional right; + public final Expression left; + public final Expression right; public final ComparisonPredicate comparisonPredicate; - public ComparisonInferInfo(InferType inferType, - Optional left, Optional right, + public EqualInferInfo(InferType inferType, + Expression left, Expression right, ComparisonPredicate comparisonPredicate) { this.inferType = inferType; this.left = left; @@ -85,26 +85,27 @@ public ComparisonInferInfo(InferType inferType, /** * infer additional predicates. */ - public Set infer(Set predicates) { + public static Set infer(Set predicates) { Set inferred = Sets.newHashSet(); for (Expression predicate : predicates) { // if we support more infer predicate expression type, we should impl withInferred() method. // And should add inferred props in withChildren() method just like ComparisonPredicate, // and it's subclass, to mark the predicate is from infer. - if (!(predicate instanceof ComparisonPredicate)) { + if (!(predicate instanceof ComparisonPredicate + || (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) { continue; } - ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate); + if (predicate instanceof InPredicate) { + continue; + } + EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate); if (equalInfo.inferType == InferType.NONE) { continue; } Set newInferred = predicates.stream() - .filter(ComparisonPredicate.class::isInstance) .filter(p -> !p.equals(predicate)) - .map(ComparisonPredicate.class::cast) - .map(this::inferInferInfo) - .filter(predicateInfo -> predicateInfo.inferType != InferType.NONE) - .map(predicateInfo -> doInfer(equalInfo, predicateInfo)) + .filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate) + .map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo)) .filter(Objects::nonNull) .collect(Collectors.toSet()); inferred.addAll(newInferred); @@ -113,17 +114,64 @@ public Set infer(Set predicates) { return inferred; } + private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) { + Expression equalLeft = equalInfo.left; + Expression equalRight = equalInfo.right; + + DataType leftType = predicate.child(0).getDataType(); + InferType inferType; + if (leftType instanceof CharacterType) { + inferType = InferType.STRING; + } else if (leftType instanceof IntegralType) { + inferType = InferType.INTEGRAL; + } else if (leftType instanceof DateLikeType) { + inferType = InferType.DATE; + } else { + inferType = InferType.OTHER; + } + if (predicate instanceof ComparisonPredicate) { + ComparisonPredicate comparisonPredicate = (ComparisonPredicate) predicate; + Optional left = validForInfer(comparisonPredicate.left(), inferType); + Optional right = validForInfer(comparisonPredicate.right(), inferType); + if (!left.isPresent() || !right.isPresent()) { + return null; + } + } else if (predicate instanceof InPredicate) { + InPredicate inPredicate = (InPredicate) predicate; + Optional left = validForInfer(inPredicate.getCompareExpr(), inferType); + if (!left.isPresent()) { + return null; + } + } + + Expression newPredicate = predicate.rewriteUp(e -> { + if (e.equals(equalLeft)) { + return equalRight; + } else if (e.equals(equalRight)) { + return equalLeft; + } else { + return e; + } + }); + if (predicate instanceof ComparisonPredicate) { + return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate).withInferred(true); + } else { + return TypeCoercionUtils.processInPredicate((InPredicate) newPredicate).withInferred(true); + } + } + /** * Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression` * Now only support infer `ComparisonPredicate`. * TODO: We should determine whether `expression` satisfies the condition for replacement * eg: Satisfy `expression` is non-deterministic */ - private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) { - Expression predicateLeft = predicateInfo.left.get(); - Expression predicateRight = predicateInfo.right.get(); - Expression equalLeft = equalInfo.left.get(); - Expression equalRight = equalInfo.right.get(); + private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) { + Expression equalLeft = equalInfo.left; + Expression equalRight = equalInfo.right; + + Expression predicateLeft = predicateInfo.left; + Expression predicateRight = predicateInfo.right; Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight); Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight); if (newLeft == null || newRight == null) { @@ -136,7 +184,7 @@ private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo pr return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true); } - private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) { + private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) { if (predicateOneSide instanceof SlotReference) { if (predicateOneSide.equals(equalLeft)) { return equalRight; @@ -153,60 +201,55 @@ private Expression inferOneSide(Expression predicateOneSide, Expression equalLef return null; } - private Optional validForInfer(Expression expression, InferType inferType) { + private static Optional validForInfer(Expression expression, InferType inferType) { if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) { return Optional.empty(); } if (expression instanceof SlotReference || expression.isConstant()) { return Optional.of(expression); } + if (!(expression instanceof Cast)) { + return Optional.empty(); + } + Cast cast = (Cast) expression; + Expression child = cast.child(); + DataType dataType = cast.getDataType(); + DataType childType = child.getDataType(); if (inferType == InferType.INTEGRAL) { - if (expression instanceof Cast) { - // avoid cast from wider type to narrower type, such as cast(int as smallint) - // IntegralType dataType = (IntegralType) expression.getDataType(); - // DataType childType = ((Cast) expression).child().getDataType(); - // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) { - // return validForInfer(((Cast) expression).child(), inferType); - // } - return validForInfer(((Cast) expression).child(), inferType); - } + // avoid cast from wider type to narrower type, such as cast(int as smallint) + // IntegralType dataType = (IntegralType) expression.getDataType(); + // DataType childType = ((Cast) expression).child().getDataType(); + // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) { + // return validForInfer(((Cast) expression).child(), inferType); + // } + return validForInfer(child, inferType); } else if (inferType == InferType.DATE) { - if (expression instanceof Cast) { - DataType dataType = expression.getDataType(); - DataType childType = ((Cast) expression).child().getDataType(); - // avoid lost precision - if (dataType instanceof DateType) { - if (childType instanceof DateV2Type || childType instanceof DateType) { - return validForInfer(((Cast) expression).child(), inferType); - } - } else if (dataType instanceof DateV2Type) { - if (childType instanceof DateType || childType instanceof DateV2Type) { - return validForInfer(((Cast) expression).child(), inferType); - } - } else if (dataType instanceof DateTimeType) { - if (!(childType instanceof DateTimeV2Type)) { - return validForInfer(((Cast) expression).child(), inferType); - } - } else if (dataType instanceof DateTimeV2Type) { - return validForInfer(((Cast) expression).child(), inferType); + // avoid lost precision + if (dataType instanceof DateType) { + if (childType instanceof DateV2Type || childType instanceof DateType) { + return validForInfer(child, inferType); + } + } else if (dataType instanceof DateV2Type) { + if (childType instanceof DateType || childType instanceof DateV2Type) { + return validForInfer(child, inferType); } + } else if (dataType instanceof DateTimeType) { + if (!(childType instanceof DateTimeV2Type)) { + return validForInfer(child, inferType); + } + } else if (dataType instanceof DateTimeV2Type) { + return validForInfer(child, inferType); } } else if (inferType == InferType.STRING) { - if (expression instanceof Cast) { - DataType dataType = expression.getDataType(); - DataType childType = ((Cast) expression).child().getDataType(); - // avoid substring cast such as cast(char(3) as char(2)) - if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) { - return validForInfer(((Cast) expression).child(), inferType); - } + // avoid substring cast such as cast(char(3) as char(2)) + if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) { + return validForInfer(child, inferType); } - } else { - return Optional.empty(); } return Optional.empty(); } - private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) { + private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) { DataType leftType = comparisonPredicate.left().getDataType(); InferType inferType; if (leftType instanceof CharacterType) { @@ -223,25 +266,27 @@ private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredica if (!left.isPresent() || !right.isPresent()) { inferType = InferType.NONE; } - return new ComparisonInferInfo(inferType, left, right, comparisonPredicate); + return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()), + right.orElse(comparisonPredicate.right()), comparisonPredicate); } /** * Currently only equivalence derivation is supported * and requires that the left and right sides of an expression must be slot + *

+ * TODO: NullSafeEqual */ - private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) { + private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) { if (!(predicate instanceof EqualTo)) { - return new ComparisonInferInfo(InferType.NONE, - Optional.of(predicate.left()), Optional.of(predicate.right()), predicate); + return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate); } - ComparisonInferInfo info = inferInferInfo(predicate); + EqualInferInfo info = inferInferInfo(predicate); if (info.inferType == InferType.NONE) { return info; } - if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) { + if (info.left instanceof SlotReference && info.right instanceof SlotReference) { return info; } - return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate); + return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 1a198c76ea5d217..26e1358c2e5e11b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -47,7 +47,6 @@ */ public class PullUpPredicates extends PlanVisitor, Void> { - PredicatePropagation propagation = new PredicatePropagation(); Map> cache = new IdentityHashMap<>(); @Override @@ -99,6 +98,7 @@ public ImmutableSet visitLogicalProject(LogicalProject visitLogicalAggregate(LogicalAggregate aggregate, Void context) { return cacheOrElse(aggregate, () -> { ImmutableSet childPredicates = aggregate.child().accept(this, context); + // TODO Map expressionSlotMap = aggregate.getOutputExpressions() .stream() .filter(this::hasAgg) @@ -130,7 +130,7 @@ private ImmutableSet cacheOrElse(Plan plan, Supplier getAvailableExpressions(Collection predicates, Plan plan) { Set expressions = Sets.newHashSet(predicates); - expressions.addAll(propagation.infer(expressions)); + expressions.addAll(PredicatePropagation.infer(expressions)); return expressions.stream() .filter(p -> plan.getOutputSet().containsAll(p.getInputSlots())) .collect(ImmutableSet.toImmutableSet()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java index 2704d446555867c..3e71b3b89a01d8b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java @@ -39,10 +39,6 @@ public EqualTo(Expression left, Expression right, boolean inferred) { super(ImmutableList.of(left, right), "=", inferred); } - private EqualTo(List children) { - this(children, false); - } - private EqualTo(List children, boolean inferred) { super(children, "=", inferred); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java index d839a1e9062b0b2..c86a074dcfd3762 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java @@ -48,6 +48,12 @@ public InPredicate(Expression compareExpr, List options) { this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null")); } + public InPredicate(Expression compareExpr, List options, boolean inferred) { + super(new Builder().add(compareExpr).addAll(options).build(), inferred); + this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null"); + this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null")); + } + public R accept(ExpressionVisitor visitor, C context) { return visitor.visitInPredicate(this, context); } @@ -80,6 +86,11 @@ public void checkLegalityBeforeTypeCoercion() { }); } + @Override + public Expression withInferred(boolean inferred) { + return new InPredicate(children.get(0), ImmutableList.copyOf(children).subList(1, children.size()), true); + } + @Override public String toString() { return compareExpr + " IN " + options.stream() diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java index c910e98fcd50c5e..0708ea3f172f18d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java @@ -25,7 +25,7 @@ import org.junit.jupiter.api.Test; -public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported { +class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported { @Override protected void runBeforeAll() throws Exception { @@ -77,7 +77,7 @@ protected void runBeforeAll() throws Exception { } @Test - public void inferPredicatesTest01() { + void inferPredicatesTest01() { String sql = "select * from student join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -100,7 +100,7 @@ public void inferPredicatesTest01() { } @Test - public void inferPredicatesTest02() { + void inferPredicatesTest02() { String sql = "select * from student join score on student.id = score.sid"; PlanChecker.from(connectContext) @@ -117,7 +117,7 @@ public void inferPredicatesTest02() { } @Test - public void inferPredicatesTest03() { + void inferPredicatesTest03() { String sql = "select * from student join score on student.id = score.sid where student.id in (1,2,3)"; PlanChecker.from(connectContext) @@ -126,18 +126,17 @@ public void inferPredicatesTest03() { .matches( logicalProject( logicalJoin( - logicalFilter( - logicalOlapScan() - ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + logicalFilter(logicalOlapScan()).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) & filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), - logicalOlapScan() + logicalFilter(logicalOlapScan()).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid IN (1, 2, 3)")) ) ) ); } @Test - public void inferPredicatesTest04() { + void inferPredicatesTest04() { String sql = "select * from student join score on student.id = score.sid and student.id in (1,2,3)"; PlanChecker.from(connectContext) @@ -146,18 +145,17 @@ public void inferPredicatesTest04() { .matches( logicalProject( logicalJoin( - logicalFilter( - logicalOlapScan() - ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + logicalFilter(logicalOlapScan()).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) & filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), - logicalOlapScan() + logicalFilter(logicalOlapScan()).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid IN (1, 2, 3)")) ) ) ); } @Test - public void inferPredicatesTest05() { + void inferPredicatesTest05() { String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id where student.id > 1"; PlanChecker.from(connectContext) @@ -185,7 +183,7 @@ public void inferPredicatesTest05() { } @Test - public void inferPredicatesTest06() { + void inferPredicatesTest06() { String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id and score.sid > 1"; PlanChecker.from(connectContext) @@ -213,7 +211,7 @@ public void inferPredicatesTest06() { } @Test - public void inferPredicatesTest07() { + void inferPredicatesTest07() { String sql = "select * from student left join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -236,7 +234,7 @@ public void inferPredicatesTest07() { } @Test - public void inferPredicatesTest08() { + void inferPredicatesTest08() { String sql = "select * from student left join score on student.id = score.sid and student.id > 1"; PlanChecker.from(connectContext) @@ -256,7 +254,7 @@ public void inferPredicatesTest08() { } @Test - public void inferPredicatesTest09() { + void inferPredicatesTest09() { // convert left join to inner join String sql = "select * from student left join score on student.id = score.sid where score.sid > 1"; @@ -280,7 +278,7 @@ public void inferPredicatesTest09() { } @Test - public void inferPredicatesTest10() { + void inferPredicatesTest10() { String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid where t.nid > 1"; PlanChecker.from(connectContext) @@ -305,7 +303,7 @@ public void inferPredicatesTest10() { } @Test - public void inferPredicatesTest11() { + void inferPredicatesTest11() { String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid and t.nid > 1"; PlanChecker.from(connectContext) @@ -327,7 +325,7 @@ public void inferPredicatesTest11() { } @Test - public void inferPredicatesTest12() { + void inferPredicatesTest12() { String sql = "select * from student left join (select sid as nid, sum(grade) from score group by sid) s on s.nid = student.id where student.id > 1"; PlanChecker.from(connectContext) @@ -356,7 +354,7 @@ public void inferPredicatesTest12() { } @Test - public void inferPredicatesTest13() { + void inferPredicatesTest13() { String sql = "select * from (select id, name from student where id = 1) t left join score on t.id = score.sid"; PlanChecker.from(connectContext) @@ -381,7 +379,7 @@ public void inferPredicatesTest13() { } @Test - public void inferPredicatesTest14() { + void inferPredicatesTest14() { String sql = "select * from student left semi join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -406,7 +404,7 @@ public void inferPredicatesTest14() { } @Test - public void inferPredicatesTest15() { + void inferPredicatesTest15() { String sql = "select * from student left semi join score on student.id = score.sid and student.id > 1"; PlanChecker.from(connectContext) @@ -431,7 +429,7 @@ public void inferPredicatesTest15() { } @Test - public void inferPredicatesTest16() { + void inferPredicatesTest16() { String sql = "select * from student left anti join score on student.id = score.sid and student.id > 1"; PlanChecker.from(connectContext) @@ -453,7 +451,7 @@ public void inferPredicatesTest16() { } @Test - public void inferPredicatesTest17() { + void inferPredicatesTest17() { String sql = "select * from student left anti join score on student.id = score.sid and score.sid > 1"; PlanChecker.from(connectContext) @@ -475,7 +473,7 @@ public void inferPredicatesTest17() { } @Test - public void inferPredicatesTest18() { + void inferPredicatesTest18() { String sql = "select * from student left anti join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -500,7 +498,7 @@ public void inferPredicatesTest18() { } @Test - public void inferPredicatesTest19() { + void inferPredicatesTest19() { String sql = "select * from subquery1\n" + "left semi join (\n" + " select t1.k3\n" @@ -564,7 +562,7 @@ public void inferPredicatesTest19() { } @Test - public void inferPredicatesTest20() { + void inferPredicatesTest20() { String sql = "select * from student left join score on student.id = score.sid and score.sid > 1 inner join course on course.id = score.sid"; PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); PlanChecker.from(connectContext) @@ -592,7 +590,7 @@ public void inferPredicatesTest20() { } @Test - public void inferPredicatesTest21() { + void inferPredicatesTest21() { String sql = "select * from student,score,course where student.id = score.sid and score.sid = course.id and score.sid > 1"; PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); PlanChecker.from(connectContext) @@ -623,7 +621,7 @@ public void inferPredicatesTest21() { * test for #15310 */ @Test - public void inferPredicatesTest22() { + void inferPredicatesTest22() { String sql = "select * from student join (select sid as id1, sid as id2, grade from score) s on student.id = s.id1 where s.id1 > 1"; PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); PlanChecker.from(connectContext) @@ -651,7 +649,7 @@ public void inferPredicatesTest22() { * in this case, filter on relation s1 should not contain s1.id = 1. */ @Test - public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() { + void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() { String sql = "select * from student s1" + " left join (select sid as id1, sid as id2, grade from score) s2 on s1.id = s2.id1 and s1.id = 1" + " join (select sid as id1, sid as id2, grade from score) s3 on s1.id = s3.id1 where s1.id = 2"; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java new file mode 100644 index 000000000000000..b1aa25df1b1ac2b --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.SmallIntType; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.junit.jupiter.api.Test; + +import java.util.Set; + +class PredicatePropagationTest { + private final SlotReference a = new SlotReference("a", SmallIntType.INSTANCE); + private final SlotReference b = new SlotReference("b", BigIntType.INSTANCE); + + @Test + void equal() { + Set exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, Literal.of(1))); + Set inferExprs = PredicatePropagation.infer(exprs); + System.out.println(inferExprs); + } + + @Test + void in() { + Set exprs = ImmutableSet.of(new EqualTo(a, b), new InPredicate(a, ImmutableList.of(Literal.of(1)))); + Set inferExprs = PredicatePropagation.infer(exprs); + System.out.println(inferExprs); + } +}