Skip to content

Commit

Permalink
[feature](Nereids): InferPredicates support In
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener committed Jan 3, 2024
1 parent 28ff349 commit bcba028
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
* round of predicate push-down
*/
public class InferPredicates extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
private final PredicatePropagation propagation = new PredicatePropagation();
private final PullUpPredicates pollUpPredicates = new PullUpPredicates();

@Override
Expand Down Expand Up @@ -109,7 +108,7 @@ private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expres
Set<Expression> baseExpressions = pullUpPredicates(left);
baseExpressions.addAll(pullUpPredicates(right));
condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
baseExpressions.addAll(propagation.infer(baseExpressions));
baseExpressions.addAll(PredicatePropagation.infer(baseExpressions));
return baseExpressions;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DataType;
Expand Down Expand Up @@ -55,8 +56,7 @@ private enum InferType {
INTEGRAL(IntegralType.class),
STRING(CharacterType.class),
DATE(DateLikeType.class),
OTHER(DataType.class)
;
OTHER(DataType.class);

private final Class<? extends DataType> superClazz;

Expand All @@ -65,15 +65,15 @@ private enum InferType {
}
}

private class ComparisonInferInfo {
private static class EqualInferInfo {

public final InferType inferType;
public final Optional<Expression> left;
public final Optional<Expression> right;
public final Expression left;
public final Expression right;
public final ComparisonPredicate comparisonPredicate;

public ComparisonInferInfo(InferType inferType,
Optional<Expression> left, Optional<Expression> right,
public EqualInferInfo(InferType inferType,
Expression left, Expression right,
ComparisonPredicate comparisonPredicate) {
this.inferType = inferType;
this.left = left;
Expand All @@ -85,26 +85,27 @@ public ComparisonInferInfo(InferType inferType,
/**
* infer additional predicates.
*/
public Set<Expression> infer(Set<Expression> predicates) {
public static Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
// if we support more infer predicate expression type, we should impl withInferred() method.
// And should add inferred props in withChildren() method just like ComparisonPredicate,
// and it's subclass, to mark the predicate is from infer.
if (!(predicate instanceof ComparisonPredicate)) {
if (!(predicate instanceof ComparisonPredicate
|| (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) {
continue;
}
ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate);
if (predicate instanceof InPredicate) {
continue;
}
EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate);
if (equalInfo.inferType == InferType.NONE) {
continue;
}
Set<Expression> newInferred = predicates.stream()
.filter(ComparisonPredicate.class::isInstance)
.filter(p -> !p.equals(predicate))
.map(ComparisonPredicate.class::cast)
.map(this::inferInferInfo)
.filter(predicateInfo -> predicateInfo.inferType != InferType.NONE)
.map(predicateInfo -> doInfer(equalInfo, predicateInfo))
.filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate)
.map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo))
.filter(Objects::nonNull)
.collect(Collectors.toSet());
inferred.addAll(newInferred);
Expand All @@ -113,17 +114,64 @@ public Set<Expression> infer(Set<Expression> predicates) {
return inferred;
}

private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;

DataType leftType = predicate.child(0).getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
inferType = InferType.STRING;
} else if (leftType instanceof IntegralType) {
inferType = InferType.INTEGRAL;
} else if (leftType instanceof DateLikeType) {
inferType = InferType.DATE;
} else {
inferType = InferType.OTHER;
}
if (predicate instanceof ComparisonPredicate) {
ComparisonPredicate comparisonPredicate = (ComparisonPredicate) predicate;
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
if (!left.isPresent() || !right.isPresent()) {
return null;
}
} else if (predicate instanceof InPredicate) {
InPredicate inPredicate = (InPredicate) predicate;
Optional<Expression> left = validForInfer(inPredicate.getCompareExpr(), inferType);
if (!left.isPresent()) {
return null;
}
}

Expression newPredicate = predicate.rewriteUp(e -> {
if (e.equals(equalLeft)) {
return equalRight;
} else if (e.equals(equalRight)) {
return equalLeft;
} else {
return e;
}
});
if (predicate instanceof ComparisonPredicate) {
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate).withInferred(true);
} else {
return TypeCoercionUtils.processInPredicate((InPredicate) newPredicate).withInferred(true);
}
}

/**
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
* Now only support infer `ComparisonPredicate`.
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) {
Expression predicateLeft = predicateInfo.left.get();
Expression predicateRight = predicateInfo.right.get();
Expression equalLeft = equalInfo.left.get();
Expression equalRight = equalInfo.right.get();
private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;

Expression predicateLeft = predicateInfo.left;
Expression predicateRight = predicateInfo.right;
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
if (newLeft == null || newRight == null) {
Expand All @@ -136,7 +184,7 @@ private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo pr
return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true);
}

private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
if (predicateOneSide instanceof SlotReference) {
if (predicateOneSide.equals(equalLeft)) {
return equalRight;
Expand All @@ -153,60 +201,55 @@ private Expression inferOneSide(Expression predicateOneSide, Expression equalLef
return null;
}

private Optional<Expression> validForInfer(Expression expression, InferType inferType) {
private static Optional<Expression> validForInfer(Expression expression, InferType inferType) {
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
return Optional.empty();
}
if (expression instanceof SlotReference || expression.isConstant()) {
return Optional.of(expression);
}
if (!(expression instanceof Cast)) {
return Optional.empty();
}
Cast cast = (Cast) expression;
Expression child = cast.child();
DataType dataType = cast.getDataType();
DataType childType = child.getDataType();
if (inferType == InferType.INTEGRAL) {
if (expression instanceof Cast) {
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(((Cast) expression).child(), inferType);
}
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(child, inferType);
} else if (inferType == InferType.DATE) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(child, inferType);
}
} else if (inferType == InferType.STRING) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(((Cast) expression).child(), inferType);
}
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(child, inferType);
}
} else {
return Optional.empty();
}
return Optional.empty();
}

private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
DataType leftType = comparisonPredicate.left().getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
Expand All @@ -223,25 +266,27 @@ private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredica
if (!left.isPresent() || !right.isPresent()) {
inferType = InferType.NONE;
}
return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()),
right.orElse(comparisonPredicate.right()), comparisonPredicate);
}

/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
* <p>
* TODO: NullSafeEqual
*/
private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) {
if (!(predicate instanceof EqualTo)) {
return new ComparisonInferInfo(InferType.NONE,
Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate);
}
ComparisonInferInfo info = inferInferInfo(predicate);
EqualInferInfo info = inferInferInfo(predicate);
if (info.inferType == InferType.NONE) {
return info;
}
if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
if (info.left instanceof SlotReference && info.right instanceof SlotReference) {
return info;
}
return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
*/
public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void> {

PredicatePropagation propagation = new PredicatePropagation();
Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();

@Override
Expand Down Expand Up @@ -99,6 +98,7 @@ public ImmutableSet<Expression> visitLogicalProject(LogicalProject<? extends Pla
public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
return cacheOrElse(aggregate, () -> {
ImmutableSet<Expression> childPredicates = aggregate.child().accept(this, context);
// TODO
Map<Expression, Slot> expressionSlotMap = aggregate.getOutputExpressions()
.stream()
.filter(this::hasAgg)
Expand Down Expand Up @@ -130,7 +130,7 @@ private ImmutableSet<Expression> cacheOrElse(Plan plan, Supplier<ImmutableSet<Ex

private ImmutableSet<Expression> getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
Set<Expression> expressions = Sets.newHashSet(predicates);
expressions.addAll(propagation.infer(expressions));
expressions.addAll(PredicatePropagation.infer(expressions));
return expressions.stream()
.filter(p -> plan.getOutputSet().containsAll(p.getInputSlots()))
.collect(ImmutableSet.toImmutableSet());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ public EqualTo(Expression left, Expression right, boolean inferred) {
super(ImmutableList.of(left, right), "=", inferred);
}

private EqualTo(List<Expression> children) {
this(children, false);
}

private EqualTo(List<Expression> children, boolean inferred) {
super(children, "=", inferred);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ public InPredicate(Expression compareExpr, List<Expression> options) {
this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
}

public InPredicate(Expression compareExpr, List<Expression> options, boolean inferred) {
super(new Builder<Expression>().add(compareExpr).addAll(options).build(), inferred);
this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null");
this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
}

public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitInPredicate(this, context);
}
Expand Down Expand Up @@ -80,6 +86,11 @@ public void checkLegalityBeforeTypeCoercion() {
});
}

@Override
public Expression withInferred(boolean inferred) {
return new InPredicate(children.get(0), ImmutableList.copyOf(children).subList(1, children.size()), true);
}

@Override
public String toString() {
return compareExpr + " IN " + options.stream()
Expand Down
Loading

0 comments on commit bcba028

Please sign in to comment.