Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature](Nereids): inferPredicate support InPredicate into Branch 2.0 #30007

Merged
merged 1 commit into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

Expand All @@ -37,6 +36,7 @@

/**
* infer additional predicates for `LogicalFilter` and `LogicalJoin`.
* <pre>
* The logic is as follows:
* 1. poll up bottom predicate then infer additional predicates
* for example:
Expand All @@ -49,9 +49,9 @@
* select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1
* 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
* round of predicate push-down
* </pre>
*/
public class InferPredicates extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
private final PredicatePropagation propagation = new PredicatePropagation();
private final PullUpPredicates pollUpPredicates = new PullUpPredicates();

@Override
Expand All @@ -62,6 +62,9 @@ public Plan rewriteRoot(Plan plan, JobContext jobContext) {
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, JobContext context) {
join = visitChildren(this, join, context);
if (join.isMarkJoin()) {
return join;
}
Plan left = join.left();
Plan right = join.right();
Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition());
Expand All @@ -86,7 +89,7 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, J
break;
}
if (left != join.left() || right != join.right()) {
return join.withChildren(ImmutableList.of(left, right));
return join.withChildren(left, right);
} else {
return join;
}
Expand All @@ -109,7 +112,7 @@ private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expres
Set<Expression> baseExpressions = pullUpPredicates(left);
baseExpressions.addAll(pullUpPredicates(right));
condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
baseExpressions.addAll(propagation.infer(baseExpressions));
baseExpressions.addAll(PredicatePropagation.infer(baseExpressions));
return baseExpressions;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DataType;
Expand Down Expand Up @@ -55,8 +56,7 @@ private enum InferType {
INTEGRAL(IntegralType.class),
STRING(CharacterType.class),
DATE(DateLikeType.class),
OTHER(DataType.class)
;
OTHER(DataType.class);

private final Class<? extends DataType> superClazz;

Expand All @@ -65,15 +65,15 @@ private enum InferType {
}
}

private class ComparisonInferInfo {
private static class EqualInferInfo {

public final InferType inferType;
public final Optional<Expression> left;
public final Optional<Expression> right;
public final Expression left;
public final Expression right;
public final ComparisonPredicate comparisonPredicate;

public ComparisonInferInfo(InferType inferType,
Optional<Expression> left, Optional<Expression> right,
public EqualInferInfo(InferType inferType,
Expression left, Expression right,
ComparisonPredicate comparisonPredicate) {
this.inferType = inferType;
this.left = left;
Expand All @@ -85,23 +85,24 @@ public ComparisonInferInfo(InferType inferType,
/**
* infer additional predicates.
*/
public Set<Expression> infer(Set<Expression> predicates) {
public static Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
if (!(predicate instanceof ComparisonPredicate)) {
if (!(predicate instanceof ComparisonPredicate
|| (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) {
continue;
}
ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate);
if (predicate instanceof InPredicate) {
continue;
}
EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate);
if (equalInfo.inferType == InferType.NONE) {
continue;
}
Set<Expression> newInferred = predicates.stream()
.filter(ComparisonPredicate.class::isInstance)
.filter(p -> !p.equals(predicate))
.map(ComparisonPredicate.class::cast)
.map(this::inferInferInfo)
.filter(predicateInfo -> predicateInfo.inferType != InferType.NONE)
.map(predicateInfo -> doInfer(equalInfo, predicateInfo))
.filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate)
.map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo))
.filter(Objects::nonNull)
.collect(Collectors.toSet());
inferred.addAll(newInferred);
Expand All @@ -110,17 +111,66 @@ public Set<Expression> infer(Set<Expression> predicates) {
return inferred;
}

private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;

DataType leftType = predicate.child(0).getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
inferType = InferType.STRING;
} else if (leftType instanceof IntegralType) {
inferType = InferType.INTEGRAL;
} else if (leftType instanceof DateLikeType) {
inferType = InferType.DATE;
} else {
inferType = InferType.OTHER;
}
if (predicate instanceof ComparisonPredicate) {
ComparisonPredicate comparisonPredicate = (ComparisonPredicate) predicate;
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
if (!left.isPresent() || !right.isPresent()) {
return null;
}
} else if (predicate instanceof InPredicate) {
InPredicate inPredicate = (InPredicate) predicate;
Optional<Expression> left = validForInfer(inPredicate.getCompareExpr(), inferType);
if (!left.isPresent()) {
return null;
}
}

Expression newPredicate = predicate.rewriteUp(e -> {
if (e.equals(equalLeft)) {
return equalRight;
} else if (e.equals(equalRight)) {
return equalLeft;
} else {
return e;
}
});
if (predicate instanceof ComparisonPredicate) {
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate,
newPredicate.child(0),
newPredicate.child(1));
} else {
return TypeCoercionUtils.processInPredicate((InPredicate) newPredicate);
}
}

/**
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
* Now only support infer `ComparisonPredicate`.
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) {
Expression predicateLeft = predicateInfo.left.get();
Expression predicateRight = predicateInfo.right.get();
Expression equalLeft = equalInfo.left.get();
Expression equalRight = equalInfo.right.get();
private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;

Expression predicateLeft = predicateInfo.left;
Expression predicateRight = predicateInfo.right;
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
if (newLeft == null || newRight == null) {
Expand All @@ -133,7 +183,7 @@ private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo pr
return DateFunctionRewrite.INSTANCE.rewrite(expr, null);
}

private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
if (predicateOneSide instanceof SlotReference) {
if (predicateOneSide.equals(equalLeft)) {
return equalRight;
Expand All @@ -150,60 +200,55 @@ private Expression inferOneSide(Expression predicateOneSide, Expression equalLef
return null;
}

private Optional<Expression> validForInfer(Expression expression, InferType inferType) {
private static Optional<Expression> validForInfer(Expression expression, InferType inferType) {
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
return Optional.empty();
}
if (expression instanceof SlotReference || expression.isConstant()) {
return Optional.of(expression);
}
if (!(expression instanceof Cast)) {
return Optional.empty();
}
Cast cast = (Cast) expression;
Expression child = cast.child();
DataType dataType = cast.getDataType();
DataType childType = child.getDataType();
if (inferType == InferType.INTEGRAL) {
if (expression instanceof Cast) {
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(((Cast) expression).child(), inferType);
}
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(child, inferType);
} else if (inferType == InferType.DATE) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(child, inferType);
}
} else if (inferType == InferType.STRING) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(((Cast) expression).child(), inferType);
}
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(child, inferType);
}
} else {
return Optional.empty();
}
return Optional.empty();
}

private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
DataType leftType = comparisonPredicate.left().getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
Expand All @@ -220,25 +265,27 @@ private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredica
if (!left.isPresent() || !right.isPresent()) {
inferType = InferType.NONE;
}
return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()),
right.orElse(comparisonPredicate.right()), comparisonPredicate);
}

/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
* <p>
* TODO: NullSafeEqual
*/
private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) {
if (!(predicate instanceof EqualTo)) {
return new ComparisonInferInfo(InferType.NONE,
Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate);
}
ComparisonInferInfo info = inferInferInfo(predicate);
EqualInferInfo info = inferInferInfo(predicate);
if (info.inferType == InferType.NONE) {
return info;
}
if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
if (info.left instanceof SlotReference && info.right instanceof SlotReference) {
return info;
}
return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
*/
public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void> {

PredicatePropagation propagation = new PredicatePropagation();
Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();

@Override
Expand Down Expand Up @@ -99,6 +98,7 @@ public ImmutableSet<Expression> visitLogicalProject(LogicalProject<? extends Pla
public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
return cacheOrElse(aggregate, () -> {
ImmutableSet<Expression> childPredicates = aggregate.child().accept(this, context);
// TODO
Map<Expression, Slot> expressionSlotMap = aggregate.getOutputExpressions()
.stream()
.filter(this::hasAgg)
Expand Down Expand Up @@ -130,7 +130,7 @@ private ImmutableSet<Expression> cacheOrElse(Plan plan, Supplier<ImmutableSet<Ex

private ImmutableSet<Expression> getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
Set<Expression> expressions = Sets.newHashSet(predicates);
expressions.addAll(propagation.infer(expressions));
expressions.addAll(PredicatePropagation.infer(expressions));
return expressions.stream()
.filter(p -> plan.getOutputSet().containsAll(p.getInputSlots()))
.collect(ImmutableSet.toImmutableSet());
Expand Down
Loading
Loading