Skip to content

Commit

Permalink
[feature](Nereids): InferPredicates support In (apache#29458)
Browse files Browse the repository at this point in the history
(cherry picked from commit 7a0734d)
  • Loading branch information
jackwener committed Jan 16, 2024
1 parent aacf8f7 commit a3b0c2c
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

Expand All @@ -37,6 +36,7 @@

/**
* infer additional predicates for `LogicalFilter` and `LogicalJoin`.
* <pre>
* The logic is as follows:
* 1. poll up bottom predicate then infer additional predicates
* for example:
Expand All @@ -49,9 +49,9 @@
* select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1
* 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
* round of predicate push-down
* </pre>
*/
public class InferPredicates extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
private final PredicatePropagation propagation = new PredicatePropagation();
private final PullUpPredicates pollUpPredicates = new PullUpPredicates();

@Override
Expand All @@ -62,6 +62,9 @@ public Plan rewriteRoot(Plan plan, JobContext jobContext) {
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, JobContext context) {
join = visitChildren(this, join, context);
if (join.isMarkJoin()) {
return join;
}
Plan left = join.left();
Plan right = join.right();
Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition());
Expand All @@ -86,7 +89,7 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, J
break;
}
if (left != join.left() || right != join.right()) {
return join.withChildren(ImmutableList.of(left, right));
return join.withChildren(left, right);
} else {
return join;
}
Expand All @@ -109,7 +112,7 @@ private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expres
Set<Expression> baseExpressions = pullUpPredicates(left);
baseExpressions.addAll(pullUpPredicates(right));
condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
baseExpressions.addAll(propagation.infer(baseExpressions));
baseExpressions.addAll(PredicatePropagation.infer(baseExpressions));
return baseExpressions;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DataType;
Expand Down Expand Up @@ -55,8 +56,7 @@ private enum InferType {
INTEGRAL(IntegralType.class),
STRING(CharacterType.class),
DATE(DateLikeType.class),
OTHER(DataType.class)
;
OTHER(DataType.class);

private final Class<? extends DataType> superClazz;

Expand All @@ -65,15 +65,15 @@ private enum InferType {
}
}

private class ComparisonInferInfo {
private static class EqualInferInfo {

public final InferType inferType;
public final Optional<Expression> left;
public final Optional<Expression> right;
public final Expression left;
public final Expression right;
public final ComparisonPredicate comparisonPredicate;

public ComparisonInferInfo(InferType inferType,
Optional<Expression> left, Optional<Expression> right,
public EqualInferInfo(InferType inferType,
Expression left, Expression right,
ComparisonPredicate comparisonPredicate) {
this.inferType = inferType;
this.left = left;
Expand All @@ -85,23 +85,24 @@ public ComparisonInferInfo(InferType inferType,
/**
* infer additional predicates.
*/
public Set<Expression> infer(Set<Expression> predicates) {
public static Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
if (!(predicate instanceof ComparisonPredicate)) {
if (!(predicate instanceof ComparisonPredicate
|| (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) {
continue;
}
ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate);
if (predicate instanceof InPredicate) {
continue;
}
EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate);
if (equalInfo.inferType == InferType.NONE) {
continue;
}
Set<Expression> newInferred = predicates.stream()
.filter(ComparisonPredicate.class::isInstance)
.filter(p -> !p.equals(predicate))
.map(ComparisonPredicate.class::cast)
.map(this::inferInferInfo)
.filter(predicateInfo -> predicateInfo.inferType != InferType.NONE)
.map(predicateInfo -> doInfer(equalInfo, predicateInfo))
.filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate)
.map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo))
.filter(Objects::nonNull)
.collect(Collectors.toSet());
inferred.addAll(newInferred);
Expand All @@ -110,17 +111,65 @@ public Set<Expression> infer(Set<Expression> predicates) {
return inferred;
}

private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;

DataType leftType = predicate.child(0).getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
inferType = InferType.STRING;
} else if (leftType instanceof IntegralType) {
inferType = InferType.INTEGRAL;
} else if (leftType instanceof DateLikeType) {
inferType = InferType.DATE;
} else {
inferType = InferType.OTHER;
}
if (predicate instanceof ComparisonPredicate) {
ComparisonPredicate comparisonPredicate = (ComparisonPredicate) predicate;
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
if (!left.isPresent() || !right.isPresent()) {
return null;
}
} else if (predicate instanceof InPredicate) {
InPredicate inPredicate = (InPredicate) predicate;
Optional<Expression> left = validForInfer(inPredicate.getCompareExpr(), inferType);
if (!left.isPresent()) {
return null;
}
}

Expression newPredicate = predicate.rewriteUp(e -> {
if (e.equals(equalLeft)) {
return equalRight;
} else if (e.equals(equalRight)) {
return equalLeft;
} else {
return e;
}
});
if (predicate instanceof ComparisonPredicate) {
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate, newPredicate.child(0),
newPredicate.child(1));
} else {
return TypeCoercionUtils.processInPredicate((InPredicate) newPredicate);
}
}

/**
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
* Now only support infer `ComparisonPredicate`.
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) {
Expression predicateLeft = predicateInfo.left.get();
Expression predicateRight = predicateInfo.right.get();
Expression equalLeft = equalInfo.left.get();
Expression equalRight = equalInfo.right.get();
private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;

Expression predicateLeft = predicateInfo.left;
Expression predicateRight = predicateInfo.right;
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
if (newLeft == null || newRight == null) {
Expand All @@ -133,7 +182,7 @@ private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo pr
return DateFunctionRewrite.INSTANCE.rewrite(expr, null);
}

private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
if (predicateOneSide instanceof SlotReference) {
if (predicateOneSide.equals(equalLeft)) {
return equalRight;
Expand All @@ -150,60 +199,55 @@ private Expression inferOneSide(Expression predicateOneSide, Expression equalLef
return null;
}

private Optional<Expression> validForInfer(Expression expression, InferType inferType) {
private static Optional<Expression> validForInfer(Expression expression, InferType inferType) {
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
return Optional.empty();
}
if (expression instanceof SlotReference || expression.isConstant()) {
return Optional.of(expression);
}
if (!(expression instanceof Cast)) {
return Optional.empty();
}
Cast cast = (Cast) expression;
Expression child = cast.child();
DataType dataType = cast.getDataType();
DataType childType = child.getDataType();
if (inferType == InferType.INTEGRAL) {
if (expression instanceof Cast) {
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(((Cast) expression).child(), inferType);
}
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(child, inferType);
} else if (inferType == InferType.DATE) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(child, inferType);
}
} else if (inferType == InferType.STRING) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(((Cast) expression).child(), inferType);
}
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(child, inferType);
}
} else {
return Optional.empty();
}
return Optional.empty();
}

private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
DataType leftType = comparisonPredicate.left().getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
Expand All @@ -220,25 +264,27 @@ private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredica
if (!left.isPresent() || !right.isPresent()) {
inferType = InferType.NONE;
}
return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()),
right.orElse(comparisonPredicate.right()), comparisonPredicate);
}

/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
* <p>
* TODO: NullSafeEqual
*/
private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) {
if (!(predicate instanceof EqualTo)) {
return new ComparisonInferInfo(InferType.NONE,
Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate);
}
ComparisonInferInfo info = inferInferInfo(predicate);
EqualInferInfo info = inferInferInfo(predicate);
if (info.inferType == InferType.NONE) {
return info;
}
if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
if (info.left instanceof SlotReference && info.right instanceof SlotReference) {
return info;
}
return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
*/
public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void> {

PredicatePropagation propagation = new PredicatePropagation();
Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();

@Override
Expand Down Expand Up @@ -99,6 +98,7 @@ public ImmutableSet<Expression> visitLogicalProject(LogicalProject<? extends Pla
public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
return cacheOrElse(aggregate, () -> {
ImmutableSet<Expression> childPredicates = aggregate.child().accept(this, context);
// TODO
Map<Expression, Slot> expressionSlotMap = aggregate.getOutputExpressions()
.stream()
.filter(this::hasAgg)
Expand Down Expand Up @@ -130,7 +130,7 @@ private ImmutableSet<Expression> cacheOrElse(Plan plan, Supplier<ImmutableSet<Ex

private ImmutableSet<Expression> getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
Set<Expression> expressions = Sets.newHashSet(predicates);
expressions.addAll(propagation.infer(expressions));
expressions.addAll(PredicatePropagation.infer(expressions));
return expressions.stream()
.filter(p -> plan.getOutputSet().containsAll(p.getInputSlots()))
.collect(ImmutableSet.toImmutableSet());
Expand Down
Loading

0 comments on commit a3b0c2c

Please sign in to comment.