diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java index 3c4593df54c81d..36236c3db8dcc7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; @@ -37,6 +36,7 @@ /** * infer additional predicates for `LogicalFilter` and `LogicalJoin`. + *
  * The logic is as follows:
  * 1. poll up bottom predicate then infer additional predicates
  *   for example:
@@ -49,9 +49,9 @@
  *      select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1
  * 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
  *   round of predicate push-down
+ * 
*/ public class InferPredicates extends DefaultPlanRewriter implements CustomRewriter { - private final PredicatePropagation propagation = new PredicatePropagation(); private final PullUpPredicates pollUpPredicates = new PullUpPredicates(); @Override @@ -62,6 +62,9 @@ public Plan rewriteRoot(Plan plan, JobContext jobContext) { @Override public Plan visitLogicalJoin(LogicalJoin join, JobContext context) { join = visitChildren(this, join, context); + if (join.isMarkJoin()) { + return join; + } Plan left = join.left(); Plan right = join.right(); Set expressions = getAllExpressions(left, right, join.getOnClauseCondition()); @@ -86,7 +89,7 @@ public Plan visitLogicalJoin(LogicalJoin join, J break; } if (left != join.left() || right != join.right()) { - return join.withChildren(ImmutableList.of(left, right)); + return join.withChildren(left, right); } else { return join; } @@ -109,7 +112,7 @@ private Set getAllExpressions(Plan left, Plan right, Optional baseExpressions = pullUpPredicates(left); baseExpressions.addAll(pullUpPredicates(right)); condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on))); - baseExpressions.addAll(propagation.infer(baseExpressions)); + baseExpressions.addAll(PredicatePropagation.infer(baseExpressions)); return baseExpressions; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index 2317da427ea6a9..aa520362a7774d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.types.DataType; @@ -55,8 +56,7 @@ private enum InferType { INTEGRAL(IntegralType.class), STRING(CharacterType.class), DATE(DateLikeType.class), - OTHER(DataType.class) - ; + OTHER(DataType.class); private final Class superClazz; @@ -65,15 +65,15 @@ private enum InferType { } } - private class ComparisonInferInfo { + private static class EqualInferInfo { public final InferType inferType; - public final Optional left; - public final Optional right; + public final Expression left; + public final Expression right; public final ComparisonPredicate comparisonPredicate; - public ComparisonInferInfo(InferType inferType, - Optional left, Optional right, + public EqualInferInfo(InferType inferType, + Expression left, Expression right, ComparisonPredicate comparisonPredicate) { this.inferType = inferType; this.left = left; @@ -85,23 +85,24 @@ public ComparisonInferInfo(InferType inferType, /** * infer additional predicates. */ - public Set infer(Set predicates) { + public static Set infer(Set predicates) { Set inferred = Sets.newHashSet(); for (Expression predicate : predicates) { - if (!(predicate instanceof ComparisonPredicate)) { + if (!(predicate instanceof ComparisonPredicate + || (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) { continue; } - ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate); + if (predicate instanceof InPredicate) { + continue; + } + EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate); if (equalInfo.inferType == InferType.NONE) { continue; } Set newInferred = predicates.stream() - .filter(ComparisonPredicate.class::isInstance) .filter(p -> !p.equals(predicate)) - .map(ComparisonPredicate.class::cast) - .map(this::inferInferInfo) - .filter(predicateInfo -> predicateInfo.inferType != InferType.NONE) - .map(predicateInfo -> doInfer(equalInfo, predicateInfo)) + .filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate) + .map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo)) .filter(Objects::nonNull) .collect(Collectors.toSet()); inferred.addAll(newInferred); @@ -110,17 +111,66 @@ public Set infer(Set predicates) { return inferred; } + private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) { + Expression equalLeft = equalInfo.left; + Expression equalRight = equalInfo.right; + + DataType leftType = predicate.child(0).getDataType(); + InferType inferType; + if (leftType instanceof CharacterType) { + inferType = InferType.STRING; + } else if (leftType instanceof IntegralType) { + inferType = InferType.INTEGRAL; + } else if (leftType instanceof DateLikeType) { + inferType = InferType.DATE; + } else { + inferType = InferType.OTHER; + } + if (predicate instanceof ComparisonPredicate) { + ComparisonPredicate comparisonPredicate = (ComparisonPredicate) predicate; + Optional left = validForInfer(comparisonPredicate.left(), inferType); + Optional right = validForInfer(comparisonPredicate.right(), inferType); + if (!left.isPresent() || !right.isPresent()) { + return null; + } + } else if (predicate instanceof InPredicate) { + InPredicate inPredicate = (InPredicate) predicate; + Optional left = validForInfer(inPredicate.getCompareExpr(), inferType); + if (!left.isPresent()) { + return null; + } + } + + Expression newPredicate = predicate.rewriteUp(e -> { + if (e.equals(equalLeft)) { + return equalRight; + } else if (e.equals(equalRight)) { + return equalLeft; + } else { + return e; + } + }); + if (predicate instanceof ComparisonPredicate) { + return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate, + newPredicate.child(0), + newPredicate.child(1)); + } else { + return TypeCoercionUtils.processInPredicate((InPredicate) newPredicate); + } + } + /** * Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression` * Now only support infer `ComparisonPredicate`. * TODO: We should determine whether `expression` satisfies the condition for replacement * eg: Satisfy `expression` is non-deterministic */ - private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) { - Expression predicateLeft = predicateInfo.left.get(); - Expression predicateRight = predicateInfo.right.get(); - Expression equalLeft = equalInfo.left.get(); - Expression equalRight = equalInfo.right.get(); + private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) { + Expression equalLeft = equalInfo.left; + Expression equalRight = equalInfo.right; + + Expression predicateLeft = predicateInfo.left; + Expression predicateRight = predicateInfo.right; Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight); Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight); if (newLeft == null || newRight == null) { @@ -133,7 +183,7 @@ private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo pr return DateFunctionRewrite.INSTANCE.rewrite(expr, null); } - private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) { + private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) { if (predicateOneSide instanceof SlotReference) { if (predicateOneSide.equals(equalLeft)) { return equalRight; @@ -150,60 +200,55 @@ private Expression inferOneSide(Expression predicateOneSide, Expression equalLef return null; } - private Optional validForInfer(Expression expression, InferType inferType) { + private static Optional validForInfer(Expression expression, InferType inferType) { if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) { return Optional.empty(); } if (expression instanceof SlotReference || expression.isConstant()) { return Optional.of(expression); } + if (!(expression instanceof Cast)) { + return Optional.empty(); + } + Cast cast = (Cast) expression; + Expression child = cast.child(); + DataType dataType = cast.getDataType(); + DataType childType = child.getDataType(); if (inferType == InferType.INTEGRAL) { - if (expression instanceof Cast) { - // avoid cast from wider type to narrower type, such as cast(int as smallint) - // IntegralType dataType = (IntegralType) expression.getDataType(); - // DataType childType = ((Cast) expression).child().getDataType(); - // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) { - // return validForInfer(((Cast) expression).child(), inferType); - // } - return validForInfer(((Cast) expression).child(), inferType); - } + // avoid cast from wider type to narrower type, such as cast(int as smallint) + // IntegralType dataType = (IntegralType) expression.getDataType(); + // DataType childType = ((Cast) expression).child().getDataType(); + // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) { + // return validForInfer(((Cast) expression).child(), inferType); + // } + return validForInfer(child, inferType); } else if (inferType == InferType.DATE) { - if (expression instanceof Cast) { - DataType dataType = expression.getDataType(); - DataType childType = ((Cast) expression).child().getDataType(); - // avoid lost precision - if (dataType instanceof DateType) { - if (childType instanceof DateV2Type || childType instanceof DateType) { - return validForInfer(((Cast) expression).child(), inferType); - } - } else if (dataType instanceof DateV2Type) { - if (childType instanceof DateType || childType instanceof DateV2Type) { - return validForInfer(((Cast) expression).child(), inferType); - } - } else if (dataType instanceof DateTimeType) { - if (!(childType instanceof DateTimeV2Type)) { - return validForInfer(((Cast) expression).child(), inferType); - } - } else if (dataType instanceof DateTimeV2Type) { - return validForInfer(((Cast) expression).child(), inferType); + // avoid lost precision + if (dataType instanceof DateType) { + if (childType instanceof DateV2Type || childType instanceof DateType) { + return validForInfer(child, inferType); + } + } else if (dataType instanceof DateV2Type) { + if (childType instanceof DateType || childType instanceof DateV2Type) { + return validForInfer(child, inferType); } + } else if (dataType instanceof DateTimeType) { + if (!(childType instanceof DateTimeV2Type)) { + return validForInfer(child, inferType); + } + } else if (dataType instanceof DateTimeV2Type) { + return validForInfer(child, inferType); } } else if (inferType == InferType.STRING) { - if (expression instanceof Cast) { - DataType dataType = expression.getDataType(); - DataType childType = ((Cast) expression).child().getDataType(); - // avoid substring cast such as cast(char(3) as char(2)) - if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) { - return validForInfer(((Cast) expression).child(), inferType); - } + // avoid substring cast such as cast(char(3) as char(2)) + if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) { + return validForInfer(child, inferType); } - } else { - return Optional.empty(); } return Optional.empty(); } - private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) { + private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) { DataType leftType = comparisonPredicate.left().getDataType(); InferType inferType; if (leftType instanceof CharacterType) { @@ -220,25 +265,27 @@ private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredica if (!left.isPresent() || !right.isPresent()) { inferType = InferType.NONE; } - return new ComparisonInferInfo(inferType, left, right, comparisonPredicate); + return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()), + right.orElse(comparisonPredicate.right()), comparisonPredicate); } /** * Currently only equivalence derivation is supported * and requires that the left and right sides of an expression must be slot + *

+ * TODO: NullSafeEqual */ - private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) { + private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) { if (!(predicate instanceof EqualTo)) { - return new ComparisonInferInfo(InferType.NONE, - Optional.of(predicate.left()), Optional.of(predicate.right()), predicate); + return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate); } - ComparisonInferInfo info = inferInferInfo(predicate); + EqualInferInfo info = inferInferInfo(predicate); if (info.inferType == InferType.NONE) { return info; } - if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) { + if (info.left instanceof SlotReference && info.right instanceof SlotReference) { return info; } - return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate); + return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 1a198c76ea5d21..26e1358c2e5e11 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -47,7 +47,6 @@ */ public class PullUpPredicates extends PlanVisitor, Void> { - PredicatePropagation propagation = new PredicatePropagation(); Map> cache = new IdentityHashMap<>(); @Override @@ -99,6 +98,7 @@ public ImmutableSet visitLogicalProject(LogicalProject visitLogicalAggregate(LogicalAggregate aggregate, Void context) { return cacheOrElse(aggregate, () -> { ImmutableSet childPredicates = aggregate.child().accept(this, context); + // TODO Map expressionSlotMap = aggregate.getOutputExpressions() .stream() .filter(this::hasAgg) @@ -130,7 +130,7 @@ private ImmutableSet cacheOrElse(Plan plan, Supplier getAvailableExpressions(Collection predicates, Plan plan) { Set expressions = Sets.newHashSet(predicates); - expressions.addAll(propagation.infer(expressions)); + expressions.addAll(PredicatePropagation.infer(expressions)); return expressions.stream() .filter(p -> plan.getOutputSet().containsAll(p.getInputSlots())) .collect(ImmutableSet.toImmutableSet()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java index d08d8abff73429..0bffb1c73abb97 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java @@ -68,18 +68,30 @@ public boolean nullable() throws UnboundException { return children().stream().anyMatch(Expression::nullable); } + @Override + public void checkLegalityBeforeTypeCoercion() { + children().forEach(c -> { + if (c.getDataType().isObjectType()) { + throw new AnalysisException("in predicate could not contains object type: " + this.toSql()); + } + if (c.getDataType().isComplexType()) { + throw new AnalysisException("in predicate could not contains complex type: " + this.toSql()); + } + }); + } + @Override public String toString() { return compareExpr + " IN " + options.stream() - .map(Expression::toString) - .collect(Collectors.joining(", ", "(", ")")); + .map(Expression::toString) + .collect(Collectors.joining(", ", "(", ")")); } @Override public String toSql() { return compareExpr.toSql() + " IN " + options.stream() - .map(Expression::toSql) - .collect(Collectors.joining(", ", "(", ")")); + .map(Expression::toSql) + .collect(Collectors.joining(", ", "(", ")")); } @Override @@ -92,7 +104,7 @@ public boolean equals(Object o) { } InPredicate that = (InPredicate) o; return Objects.equals(compareExpr, that.getCompareExpr()) - && Objects.equals(options, that.getOptions()); + && Objects.equals(options, that.getOptions()); } @Override @@ -119,16 +131,4 @@ public boolean isLiteralChildren() { } return true; } - - @Override - public void checkLegalityBeforeTypeCoercion() { - children().forEach(c -> { - if (c.getDataType().isObjectType()) { - throw new AnalysisException("in predicate could not contains object type: " + this.toSql()); - } - if (c.getDataType().isComplexType()) { - throw new AnalysisException("in predicate could not contains complex type: " + this.toSql()); - } - }); - } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java index 5aff44f9411aab..243466b13c0352 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java @@ -24,7 +24,7 @@ import org.junit.jupiter.api.Test; -public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported { +class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported { @Override protected void runBeforeAll() throws Exception { @@ -76,7 +76,7 @@ protected void runBeforeAll() throws Exception { } @Test - public void inferPredicatesTest01() { + void inferPredicatesTest01() { String sql = "select * from student join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -97,7 +97,7 @@ public void inferPredicatesTest01() { } @Test - public void inferPredicatesTest02() { + void inferPredicatesTest02() { String sql = "select * from student join score on student.id = score.sid"; PlanChecker.from(connectContext) @@ -114,7 +114,7 @@ public void inferPredicatesTest02() { } @Test - public void inferPredicatesTest03() { + void inferPredicatesTest03() { String sql = "select * from student join score on student.id = score.sid where student.id in (1,2,3)"; PlanChecker.from(connectContext) @@ -123,17 +123,15 @@ public void inferPredicatesTest03() { .matches( logicalProject( logicalJoin( - logicalFilter( - logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), - logicalOlapScan() + logicalFilter(logicalOlapScan()).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), + logicalFilter(logicalOlapScan()).when(filter -> filter.getPredicate().toSql().contains("sid IN (1, 2, 3)")) ) ) ); } @Test - public void inferPredicatesTest04() { + void inferPredicatesTest04() { String sql = "select * from student join score on student.id = score.sid and student.id in (1,2,3)"; PlanChecker.from(connectContext) @@ -142,17 +140,15 @@ public void inferPredicatesTest04() { .matches( logicalProject( logicalJoin( - logicalFilter( - logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), - logicalOlapScan() + logicalFilter(logicalOlapScan()).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), + logicalFilter(logicalOlapScan()).when(filter -> filter.getPredicate().toSql().contains("sid IN (1, 2, 3)")) ) ) ); } @Test - public void inferPredicatesTest05() { + void inferPredicatesTest05() { String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id where student.id > 1"; PlanChecker.from(connectContext) @@ -178,7 +174,7 @@ public void inferPredicatesTest05() { } @Test - public void inferPredicatesTest06() { + void inferPredicatesTest06() { String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id and score.sid > 1"; PlanChecker.from(connectContext) @@ -204,7 +200,7 @@ public void inferPredicatesTest06() { } @Test - public void inferPredicatesTest07() { + void inferPredicatesTest07() { String sql = "select * from student left join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -225,7 +221,7 @@ public void inferPredicatesTest07() { } @Test - public void inferPredicatesTest08() { + void inferPredicatesTest08() { String sql = "select * from student left join score on student.id = score.sid and student.id > 1"; PlanChecker.from(connectContext) @@ -244,7 +240,7 @@ public void inferPredicatesTest08() { } @Test - public void inferPredicatesTest09() { + void inferPredicatesTest09() { // convert left join to inner join String sql = "select * from student left join score on student.id = score.sid where score.sid > 1"; @@ -266,7 +262,7 @@ public void inferPredicatesTest09() { } @Test - public void inferPredicatesTest10() { + void inferPredicatesTest10() { String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid where t.nid > 1"; PlanChecker.from(connectContext) @@ -289,7 +285,7 @@ public void inferPredicatesTest10() { } @Test - public void inferPredicatesTest11() { + void inferPredicatesTest11() { String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid and t.nid > 1"; PlanChecker.from(connectContext) @@ -310,7 +306,7 @@ public void inferPredicatesTest11() { } @Test - public void inferPredicatesTest12() { + void inferPredicatesTest12() { String sql = "select * from student left join (select sid as nid, sum(grade) from score group by sid) s on s.nid = student.id where student.id > 1"; PlanChecker.from(connectContext) @@ -337,7 +333,7 @@ public void inferPredicatesTest12() { } @Test - public void inferPredicatesTest13() { + void inferPredicatesTest13() { String sql = "select * from (select id, name from student where id = 1) t left join score on t.id = score.sid"; PlanChecker.from(connectContext) @@ -360,7 +356,7 @@ public void inferPredicatesTest13() { } @Test - public void inferPredicatesTest14() { + void inferPredicatesTest14() { String sql = "select * from student left semi join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -383,7 +379,7 @@ public void inferPredicatesTest14() { } @Test - public void inferPredicatesTest15() { + void inferPredicatesTest15() { String sql = "select * from student left semi join score on student.id = score.sid and student.id > 1"; PlanChecker.from(connectContext) @@ -406,7 +402,7 @@ public void inferPredicatesTest15() { } @Test - public void inferPredicatesTest16() { + void inferPredicatesTest16() { String sql = "select * from student left anti join score on student.id = score.sid and student.id > 1"; PlanChecker.from(connectContext) @@ -427,7 +423,7 @@ public void inferPredicatesTest16() { } @Test - public void inferPredicatesTest17() { + void inferPredicatesTest17() { String sql = "select * from student left anti join score on student.id = score.sid and score.sid > 1"; PlanChecker.from(connectContext) @@ -448,7 +444,7 @@ public void inferPredicatesTest17() { } @Test - public void inferPredicatesTest18() { + void inferPredicatesTest18() { String sql = "select * from student left anti join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -471,7 +467,7 @@ public void inferPredicatesTest18() { } @Test - public void inferPredicatesTest19() { + void inferPredicatesTest19() { String sql = "select * from subquery1\n" + "left semi join (\n" + " select t1.k3\n" @@ -532,7 +528,7 @@ public void inferPredicatesTest19() { } @Test - public void inferPredicatesTest20() { + void inferPredicatesTest20() { String sql = "select * from student left join score on student.id = score.sid and score.sid > 1 inner join course on course.id = score.sid"; PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); PlanChecker.from(connectContext) @@ -558,7 +554,7 @@ public void inferPredicatesTest20() { } @Test - public void inferPredicatesTest21() { + void inferPredicatesTest21() { String sql = "select * from student,score,course where student.id = score.sid and score.sid = course.id and score.sid > 1"; PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); PlanChecker.from(connectContext) @@ -587,7 +583,7 @@ public void inferPredicatesTest21() { * test for #15310 */ @Test - public void inferPredicatesTest22() { + void inferPredicatesTest22() { String sql = "select * from student join (select sid as id1, sid as id2, grade from score) s on student.id = s.id1 where s.id1 > 1"; PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); PlanChecker.from(connectContext) @@ -613,7 +609,7 @@ public void inferPredicatesTest22() { * in this case, filter on relation s1 should not contain s1.id = 1. */ @Test - public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() { + void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() { String sql = "select * from student s1" + " left join (select sid as id1, sid as id2, grade from score) s2 on s1.id = s2.id1 and s1.id = 1" + " join (select sid as id1, sid as id2, grade from score) s3 on s1.id = s3.id1 where s1.id = 2"; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java new file mode 100644 index 00000000000000..b1aa25df1b1ac2 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.SmallIntType; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.junit.jupiter.api.Test; + +import java.util.Set; + +class PredicatePropagationTest { + private final SlotReference a = new SlotReference("a", SmallIntType.INSTANCE); + private final SlotReference b = new SlotReference("b", BigIntType.INSTANCE); + + @Test + void equal() { + Set exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, Literal.of(1))); + Set inferExprs = PredicatePropagation.infer(exprs); + System.out.println(inferExprs); + } + + @Test + void in() { + Set exprs = ImmutableSet.of(new EqualTo(a, b), new InPredicate(a, ImmutableList.of(Literal.of(1)))); + Set inferExprs = PredicatePropagation.infer(exprs); + System.out.println(inferExprs); + } +} diff --git a/regression-test/data/nereids_p0/hint/fix_leading.out b/regression-test/data/nereids_p0/hint/fix_leading.out index 54da890cced6c1..a71ca311e758ef 100644 --- a/regression-test/data/nereids_p0/hint/fix_leading.out +++ b/regression-test/data/nereids_p0/hint/fix_leading.out @@ -9,7 +9,7 @@ PhysicalResultSink ----------PhysicalDistribute ------------PhysicalOlapScan[t2] --------PhysicalDistribute -----------NestedLoopJoin[CROSS_JOIN] +----------NestedLoopJoin[CROSS_JOIN](t4.c4 = t3.c3)(t3.c3 = t4.c4) ------------PhysicalOlapScan[t3] ------------PhysicalDistribute --------------PhysicalOlapScan[t4] diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy index c5942680ea7d09..55645ed8ea0950 100644 --- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy +++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy @@ -41,7 +41,7 @@ suite("test_infer_predicate") { explain { sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;" - contains "PREDICATES: k2" + contains "PREDICATES: CAST(k2" } explain {