validForInfer(Expression expression, InferType inferType) {
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
return Optional.empty();
}
if (expression instanceof SlotReference || expression.isConstant()) {
return Optional.of(expression);
}
+ if (!(expression instanceof Cast)) {
+ return Optional.empty();
+ }
+ Cast cast = (Cast) expression;
+ Expression child = cast.child();
+ DataType dataType = cast.getDataType();
+ DataType childType = child.getDataType();
if (inferType == InferType.INTEGRAL) {
- if (expression instanceof Cast) {
- // avoid cast from wider type to narrower type, such as cast(int as smallint)
- // IntegralType dataType = (IntegralType) expression.getDataType();
- // DataType childType = ((Cast) expression).child().getDataType();
- // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
- // return validForInfer(((Cast) expression).child(), inferType);
- // }
- return validForInfer(((Cast) expression).child(), inferType);
- }
+ // avoid cast from wider type to narrower type, such as cast(int as smallint)
+ // IntegralType dataType = (IntegralType) expression.getDataType();
+ // DataType childType = ((Cast) expression).child().getDataType();
+ // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
+ // return validForInfer(((Cast) expression).child(), inferType);
+ // }
+ return validForInfer(child, inferType);
} else if (inferType == InferType.DATE) {
- if (expression instanceof Cast) {
- DataType dataType = expression.getDataType();
- DataType childType = ((Cast) expression).child().getDataType();
- // avoid lost precision
- if (dataType instanceof DateType) {
- if (childType instanceof DateV2Type || childType instanceof DateType) {
- return validForInfer(((Cast) expression).child(), inferType);
- }
- } else if (dataType instanceof DateV2Type) {
- if (childType instanceof DateType || childType instanceof DateV2Type) {
- return validForInfer(((Cast) expression).child(), inferType);
- }
- } else if (dataType instanceof DateTimeType) {
- if (!(childType instanceof DateTimeV2Type)) {
- return validForInfer(((Cast) expression).child(), inferType);
- }
- } else if (dataType instanceof DateTimeV2Type) {
- return validForInfer(((Cast) expression).child(), inferType);
+ // avoid lost precision
+ if (dataType instanceof DateType) {
+ if (childType instanceof DateV2Type || childType instanceof DateType) {
+ return validForInfer(child, inferType);
+ }
+ } else if (dataType instanceof DateV2Type) {
+ if (childType instanceof DateType || childType instanceof DateV2Type) {
+ return validForInfer(child, inferType);
}
+ } else if (dataType instanceof DateTimeType) {
+ if (!(childType instanceof DateTimeV2Type)) {
+ return validForInfer(child, inferType);
+ }
+ } else if (dataType instanceof DateTimeV2Type) {
+ return validForInfer(child, inferType);
}
} else if (inferType == InferType.STRING) {
- if (expression instanceof Cast) {
- DataType dataType = expression.getDataType();
- DataType childType = ((Cast) expression).child().getDataType();
- // avoid substring cast such as cast(char(3) as char(2))
- if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
- return validForInfer(((Cast) expression).child(), inferType);
- }
+ // avoid substring cast such as cast(char(3) as char(2))
+ if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
+ return validForInfer(child, inferType);
}
- } else {
- return Optional.empty();
}
return Optional.empty();
}
- private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
+ private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
DataType leftType = comparisonPredicate.left().getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
@@ -220,25 +265,27 @@ private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredica
if (!left.isPresent() || !right.isPresent()) {
inferType = InferType.NONE;
}
- return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
+ return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()),
+ right.orElse(comparisonPredicate.right()), comparisonPredicate);
}
/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
+ *
+ * TODO: NullSafeEqual
*/
- private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
+ private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) {
if (!(predicate instanceof EqualTo)) {
- return new ComparisonInferInfo(InferType.NONE,
- Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
+ return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate);
}
- ComparisonInferInfo info = inferInferInfo(predicate);
+ EqualInferInfo info = inferInferInfo(predicate);
if (info.inferType == InferType.NONE) {
return info;
}
- if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
+ if (info.left instanceof SlotReference && info.right instanceof SlotReference) {
return info;
}
- return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
+ return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
index 1a198c76ea5d21..26e1358c2e5e11 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
@@ -47,7 +47,6 @@
*/
public class PullUpPredicates extends PlanVisitor, Void> {
- PredicatePropagation propagation = new PredicatePropagation();
Map> cache = new IdentityHashMap<>();
@Override
@@ -99,6 +98,7 @@ public ImmutableSet visitLogicalProject(LogicalProject extends Pla
public ImmutableSet visitLogicalAggregate(LogicalAggregate extends Plan> aggregate, Void context) {
return cacheOrElse(aggregate, () -> {
ImmutableSet childPredicates = aggregate.child().accept(this, context);
+ // TODO
Map expressionSlotMap = aggregate.getOutputExpressions()
.stream()
.filter(this::hasAgg)
@@ -130,7 +130,7 @@ private ImmutableSet cacheOrElse(Plan plan, Supplier getAvailableExpressions(Collection predicates, Plan plan) {
Set expressions = Sets.newHashSet(predicates);
- expressions.addAll(propagation.infer(expressions));
+ expressions.addAll(PredicatePropagation.infer(expressions));
return expressions.stream()
.filter(p -> plan.getOutputSet().containsAll(p.getInputSlots()))
.collect(ImmutableSet.toImmutableSet());
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
index d08d8abff73429..0bffb1c73abb97 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
@@ -68,18 +68,30 @@ public boolean nullable() throws UnboundException {
return children().stream().anyMatch(Expression::nullable);
}
+ @Override
+ public void checkLegalityBeforeTypeCoercion() {
+ children().forEach(c -> {
+ if (c.getDataType().isObjectType()) {
+ throw new AnalysisException("in predicate could not contains object type: " + this.toSql());
+ }
+ if (c.getDataType().isComplexType()) {
+ throw new AnalysisException("in predicate could not contains complex type: " + this.toSql());
+ }
+ });
+ }
+
@Override
public String toString() {
return compareExpr + " IN " + options.stream()
- .map(Expression::toString)
- .collect(Collectors.joining(", ", "(", ")"));
+ .map(Expression::toString)
+ .collect(Collectors.joining(", ", "(", ")"));
}
@Override
public String toSql() {
return compareExpr.toSql() + " IN " + options.stream()
- .map(Expression::toSql)
- .collect(Collectors.joining(", ", "(", ")"));
+ .map(Expression::toSql)
+ .collect(Collectors.joining(", ", "(", ")"));
}
@Override
@@ -92,7 +104,7 @@ public boolean equals(Object o) {
}
InPredicate that = (InPredicate) o;
return Objects.equals(compareExpr, that.getCompareExpr())
- && Objects.equals(options, that.getOptions());
+ && Objects.equals(options, that.getOptions());
}
@Override
@@ -119,16 +131,4 @@ public boolean isLiteralChildren() {
}
return true;
}
-
- @Override
- public void checkLegalityBeforeTypeCoercion() {
- children().forEach(c -> {
- if (c.getDataType().isObjectType()) {
- throw new AnalysisException("in predicate could not contains object type: " + this.toSql());
- }
- if (c.getDataType().isComplexType()) {
- throw new AnalysisException("in predicate could not contains complex type: " + this.toSql());
- }
- });
- }
}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
index 5aff44f9411aab..243466b13c0352 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
@@ -24,7 +24,7 @@
import org.junit.jupiter.api.Test;
-public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {
+class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
@@ -76,7 +76,7 @@ protected void runBeforeAll() throws Exception {
}
@Test
- public void inferPredicatesTest01() {
+ void inferPredicatesTest01() {
String sql = "select * from student join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -97,7 +97,7 @@ public void inferPredicatesTest01() {
}
@Test
- public void inferPredicatesTest02() {
+ void inferPredicatesTest02() {
String sql = "select * from student join score on student.id = score.sid";
PlanChecker.from(connectContext)
@@ -114,7 +114,7 @@ public void inferPredicatesTest02() {
}
@Test
- public void inferPredicatesTest03() {
+ void inferPredicatesTest03() {
String sql = "select * from student join score on student.id = score.sid where student.id in (1,2,3)";
PlanChecker.from(connectContext)
@@ -123,17 +123,15 @@ public void inferPredicatesTest03() {
.matches(
logicalProject(
logicalJoin(
- logicalFilter(
- logicalOlapScan()
- ).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
- logicalOlapScan()
+ logicalFilter(logicalOlapScan()).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
+ logicalFilter(logicalOlapScan()).when(filter -> filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
)
)
);
}
@Test
- public void inferPredicatesTest04() {
+ void inferPredicatesTest04() {
String sql = "select * from student join score on student.id = score.sid and student.id in (1,2,3)";
PlanChecker.from(connectContext)
@@ -142,17 +140,15 @@ public void inferPredicatesTest04() {
.matches(
logicalProject(
logicalJoin(
- logicalFilter(
- logicalOlapScan()
- ).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
- logicalOlapScan()
+ logicalFilter(logicalOlapScan()).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
+ logicalFilter(logicalOlapScan()).when(filter -> filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
)
)
);
}
@Test
- public void inferPredicatesTest05() {
+ void inferPredicatesTest05() {
String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id where student.id > 1";
PlanChecker.from(connectContext)
@@ -178,7 +174,7 @@ public void inferPredicatesTest05() {
}
@Test
- public void inferPredicatesTest06() {
+ void inferPredicatesTest06() {
String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id and score.sid > 1";
PlanChecker.from(connectContext)
@@ -204,7 +200,7 @@ public void inferPredicatesTest06() {
}
@Test
- public void inferPredicatesTest07() {
+ void inferPredicatesTest07() {
String sql = "select * from student left join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -225,7 +221,7 @@ public void inferPredicatesTest07() {
}
@Test
- public void inferPredicatesTest08() {
+ void inferPredicatesTest08() {
String sql = "select * from student left join score on student.id = score.sid and student.id > 1";
PlanChecker.from(connectContext)
@@ -244,7 +240,7 @@ public void inferPredicatesTest08() {
}
@Test
- public void inferPredicatesTest09() {
+ void inferPredicatesTest09() {
// convert left join to inner join
String sql = "select * from student left join score on student.id = score.sid where score.sid > 1";
@@ -266,7 +262,7 @@ public void inferPredicatesTest09() {
}
@Test
- public void inferPredicatesTest10() {
+ void inferPredicatesTest10() {
String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid where t.nid > 1";
PlanChecker.from(connectContext)
@@ -289,7 +285,7 @@ public void inferPredicatesTest10() {
}
@Test
- public void inferPredicatesTest11() {
+ void inferPredicatesTest11() {
String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid and t.nid > 1";
PlanChecker.from(connectContext)
@@ -310,7 +306,7 @@ public void inferPredicatesTest11() {
}
@Test
- public void inferPredicatesTest12() {
+ void inferPredicatesTest12() {
String sql = "select * from student left join (select sid as nid, sum(grade) from score group by sid) s on s.nid = student.id where student.id > 1";
PlanChecker.from(connectContext)
@@ -337,7 +333,7 @@ public void inferPredicatesTest12() {
}
@Test
- public void inferPredicatesTest13() {
+ void inferPredicatesTest13() {
String sql = "select * from (select id, name from student where id = 1) t left join score on t.id = score.sid";
PlanChecker.from(connectContext)
@@ -360,7 +356,7 @@ public void inferPredicatesTest13() {
}
@Test
- public void inferPredicatesTest14() {
+ void inferPredicatesTest14() {
String sql = "select * from student left semi join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -383,7 +379,7 @@ public void inferPredicatesTest14() {
}
@Test
- public void inferPredicatesTest15() {
+ void inferPredicatesTest15() {
String sql = "select * from student left semi join score on student.id = score.sid and student.id > 1";
PlanChecker.from(connectContext)
@@ -406,7 +402,7 @@ public void inferPredicatesTest15() {
}
@Test
- public void inferPredicatesTest16() {
+ void inferPredicatesTest16() {
String sql = "select * from student left anti join score on student.id = score.sid and student.id > 1";
PlanChecker.from(connectContext)
@@ -427,7 +423,7 @@ public void inferPredicatesTest16() {
}
@Test
- public void inferPredicatesTest17() {
+ void inferPredicatesTest17() {
String sql = "select * from student left anti join score on student.id = score.sid and score.sid > 1";
PlanChecker.from(connectContext)
@@ -448,7 +444,7 @@ public void inferPredicatesTest17() {
}
@Test
- public void inferPredicatesTest18() {
+ void inferPredicatesTest18() {
String sql = "select * from student left anti join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -471,7 +467,7 @@ public void inferPredicatesTest18() {
}
@Test
- public void inferPredicatesTest19() {
+ void inferPredicatesTest19() {
String sql = "select * from subquery1\n"
+ "left semi join (\n"
+ " select t1.k3\n"
@@ -532,7 +528,7 @@ public void inferPredicatesTest19() {
}
@Test
- public void inferPredicatesTest20() {
+ void inferPredicatesTest20() {
String sql = "select * from student left join score on student.id = score.sid and score.sid > 1 inner join course on course.id = score.sid";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@@ -558,7 +554,7 @@ public void inferPredicatesTest20() {
}
@Test
- public void inferPredicatesTest21() {
+ void inferPredicatesTest21() {
String sql = "select * from student,score,course where student.id = score.sid and score.sid = course.id and score.sid > 1";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@@ -587,7 +583,7 @@ public void inferPredicatesTest21() {
* test for #15310
*/
@Test
- public void inferPredicatesTest22() {
+ void inferPredicatesTest22() {
String sql = "select * from student join (select sid as id1, sid as id2, grade from score) s on student.id = s.id1 where s.id1 > 1";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@@ -613,7 +609,7 @@ public void inferPredicatesTest22() {
* in this case, filter on relation s1 should not contain s1.id = 1.
*/
@Test
- public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
+ void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
String sql = "select * from student s1"
+ " left join (select sid as id1, sid as id2, grade from score) s2 on s1.id = s2.id1 and s1.id = 1"
+ " join (select sid as id1, sid as id2, grade from score) s3 on s1.id = s3.id1 where s1.id = 2";
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
new file mode 100644
index 00000000000000..b1aa25df1b1ac2
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.SmallIntType;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.junit.jupiter.api.Test;
+
+import java.util.Set;
+
+class PredicatePropagationTest {
+ private final SlotReference a = new SlotReference("a", SmallIntType.INSTANCE);
+ private final SlotReference b = new SlotReference("b", BigIntType.INSTANCE);
+
+ @Test
+ void equal() {
+ Set exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, Literal.of(1)));
+ Set inferExprs = PredicatePropagation.infer(exprs);
+ System.out.println(inferExprs);
+ }
+
+ @Test
+ void in() {
+ Set exprs = ImmutableSet.of(new EqualTo(a, b), new InPredicate(a, ImmutableList.of(Literal.of(1))));
+ Set inferExprs = PredicatePropagation.infer(exprs);
+ System.out.println(inferExprs);
+ }
+}
diff --git a/regression-test/data/nereids_p0/hint/fix_leading.out b/regression-test/data/nereids_p0/hint/fix_leading.out
index 54da890cced6c1..a71ca311e758ef 100644
--- a/regression-test/data/nereids_p0/hint/fix_leading.out
+++ b/regression-test/data/nereids_p0/hint/fix_leading.out
@@ -9,7 +9,7 @@ PhysicalResultSink
----------PhysicalDistribute
------------PhysicalOlapScan[t2]
--------PhysicalDistribute
-----------NestedLoopJoin[CROSS_JOIN]
+----------NestedLoopJoin[CROSS_JOIN](t4.c4 = t3.c3)(t3.c3 = t4.c4)
------------PhysicalOlapScan[t3]
------------PhysicalDistribute
--------------PhysicalOlapScan[t4]
diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
index c5942680ea7d09..55645ed8ea0950 100644
--- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
+++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
@@ -41,7 +41,7 @@ suite("test_infer_predicate") {
explain {
sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;"
- contains "PREDICATES: k2"
+ contains "PREDICATES: CAST(k2"
}
explain {