Skip to content

Commit

Permalink
[CALCITE-6846] Support basic dphyp join reorder algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
silundong committed Mar 6, 2025
1 parent e8117ce commit 3c053d2
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 60 deletions.
55 changes: 50 additions & 5 deletions core/src/main/java/org/apache/calcite/rel/rules/DpHyp.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.calcite.rel.rules;

import org.apache.calcite.linq4j.function.Experimental;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
Expand All @@ -31,23 +32,37 @@
/**
* The core process of dphyp enumeration algorithm.
*/
@Experimental
public class DpHyp {

private HyperGraph hyperGraph;
private final HyperGraph hyperGraph;

private HashMap<Long, RelNode> dpTable;
private final HashMap<Long, RelNode> dpTable;

private RelBuilder builder;
private final RelBuilder builder;

private RelMetadataQuery mq;
private final RelMetadataQuery mq;

public DpHyp(HyperGraph hyperGraph, RelBuilder builder, RelMetadataQuery relMetadataQuery) {
this.hyperGraph = hyperGraph;
this.hyperGraph =
hyperGraph.copy(
hyperGraph.getTraitSet(),
hyperGraph.getInputs());
this.dpTable = new HashMap<>();
this.builder = builder;
this.mq = relMetadataQuery;
// make all field name unique and convert the
// HyperEdge condition from RexInputRef to RexInputFieldName
this.hyperGraph.convertHyperEdgeCond(builder);
}

/**
* The entry function of the algorithm. We use a bitmap to represent a leaf node,
* which indicates the position of the corresponding leaf node in {@link HyperGraph}.
*
* <p>After the enumeration is completed, the best join order will be stored
* in the {@link DpHyp#dpTable}.
*/
public void startEnumerateJoin() {
int size = hyperGraph.getInputs().size();
for (int i = 0; i < size; i++) {
Expand All @@ -65,6 +80,12 @@ public void startEnumerateJoin() {
}
}

/**
* Given a connected subgraph (csg), enumerate all possible complements subgraph (cmp)
* that do not include anything from the exclusion subset.
*
* <p>Corresponding to EmitCsg in origin paper.
*/
private void emitCsg(long csg) {
long forbidden = csg | LongBitmap.getBvBitmap(csg);
long neighbors = hyperGraph.getNeighborBitmap(csg, forbidden);
Expand All @@ -86,6 +107,16 @@ private void emitCsg(long csg) {
}
}

/**
* Given a connected subgraph (csg), expands it recursively by its neighbors.
* If the expanded csg is connected, try to enumerate its cmp (note that for complex hyperedge,
* we only select a single representative node to add to the neighbors, so csg and subNeighbor
* are not necessarily connected. However, it still needs to be expanded to prevent missing
* complex hyperedge). This method is called after the enumeration of csg is completed,
* that is, after {@link DpHyp#emitCsg(long csg)}.
*
* <p>Corresponding to EnumerateCsgRec in origin paper.
*/
private void enumerateCsgRec(long csg, long forbidden) {
long neighbors = hyperGraph.getNeighborBitmap(csg, forbidden);
LongBitmap.SubsetIterator subsetIterator = new LongBitmap.SubsetIterator(neighbors);
Expand All @@ -104,6 +135,13 @@ private void enumerateCsgRec(long csg, long forbidden) {
}
}

/**
* Given a connected subgraph (csg) and its complement subgraph (cmp), expands the cmp
* recursively by neighbors of cmp (cmp and subNeighbor are not necessarily connected,
* which is the same logic as in {@link DpHyp#enumerateCsgRec}).
*
* <p>Corresponding to EnumerateCmpRec in origin paper.
*/
private void enumerateCmpRec(long csg, long cmp, long forbidden) {
long neighbors = hyperGraph.getNeighborBitmap(cmp, forbidden);
LongBitmap.SubsetIterator subsetIterator = new LongBitmap.SubsetIterator(neighbors);
Expand All @@ -125,6 +163,13 @@ private void enumerateCmpRec(long csg, long cmp, long forbidden) {
}
}

/**
* Given a connected csg-cmp pair and the hyperedges that connect them, build the
* corresponding Join plan. If the new Join plan is better than the existing plan,
* update the {@link DpHyp#dpTable}.
*
* <p>Corresponding to EmitCsgCmp in origin paper.
*/
private void emitCsgCmp(long csg, long cmp, List<HyperEdge> edges) {
RelNode child1 = dpTable.get(csg);
RelNode child2 = dpTable.get(cmp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.calcite.rel.rules;

import org.apache.calcite.linq4j.function.Experimental;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
Expand All @@ -33,6 +34,7 @@
*
* @see CoreRules#HYPER_GRAPH_OPTIMIZE */
@Value.Enclosing
@Experimental
public class DphypJoinReorderRule
extends RelRule<DphypJoinReorderRule.Config>
implements TransformationRule {
Expand All @@ -44,9 +46,6 @@ protected DphypJoinReorderRule(Config config) {
@Override public void onMatch(RelOptRuleCall call) {
HyperGraph hyperGraph = call.rel(0);
RelBuilder relBuilder = call.builder();
// make all field name unique and convert the
// HyperEdge condition from RexInputRef to RexInputFieldName
hyperGraph.convertHyperEdgeCond();

// enumerate by Dphyp
DpHyp dpHyp = new DpHyp(hyperGraph, relBuilder, call.getMetadataQuery());
Expand Down
26 changes: 13 additions & 13 deletions core/src/main/java/org/apache/calcite/rel/rules/HyperEdge.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,34 @@
*/
package org.apache.calcite.rel.rules;

import org.apache.calcite.linq4j.function.Experimental;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rex.RexNode;

/**
* Edge in HyperGraph, that represents a join predicate.
*/
@Experimental
public class HyperEdge {

private long leftNodeBits;
private final long leftNodeBits;

private long rightNodeBits;
private final long rightNodeBits;

private JoinRelType joinType;
private final JoinRelType joinType;

private RexNode condition;
private final boolean isSimple;

private final RexNode condition;

public HyperEdge(long leftNodeBits, long rightNodeBits, JoinRelType joinType, RexNode condition) {
this.leftNodeBits = leftNodeBits;
this.rightNodeBits = rightNodeBits;
this.joinType = joinType;
this.condition = condition;
boolean leftSimple = (leftNodeBits & (leftNodeBits - 1)) == 0;
boolean rightSimple = (rightNodeBits & (rightNodeBits - 1)) == 0;
this.isSimple = leftSimple && rightSimple;
}

public long getNodeBitmap() {
Expand All @@ -53,9 +60,7 @@ public long getRightNodeBitmap() {

// hyperedge (u, v) is simple if |u| = |v| = 1
public boolean isSimple() {
boolean leftSimple = (leftNodeBits & (leftNodeBits - 1)) == 0;
boolean rightSimple = (rightNodeBits & (rightNodeBits - 1)) == 0;
return leftSimple && rightSimple;
return isSimple;
}

public JoinRelType getJoinType() {
Expand All @@ -69,14 +74,9 @@ public RexNode getCondition() {
@Override public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(LongBitmap.printBitmap(leftNodeBits))
.append("——").append(joinType).append("——")
.append("——[").append(joinType).append(", ").append(condition).append("]——")
.append(LongBitmap.printBitmap(rightNodeBits));
return sb.toString();
}

// before starting dphyp, replace RexInputRef to RexInputFieldName
public void replaceCondition(RexNode fieldNameCond) {
this.condition = fieldNameCond;
}

}
87 changes: 60 additions & 27 deletions core/src/main/java/org/apache/calcite/rel/rules/HyperGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
package org.apache.calcite.rel.rules;

import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.linq4j.function.Experimental;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.AbstractRelNode;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBiVisitor;
Expand All @@ -33,8 +33,11 @@
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVariable;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import org.checkerframework.checker.nullness.qual.Nullable;

Expand All @@ -49,6 +52,7 @@
/**
* HyperGraph represents a join graph.
*/
@Experimental
public class HyperGraph extends AbstractRelNode {

private final List<RelNode> inputs;
Expand All @@ -59,7 +63,7 @@ public class HyperGraph extends AbstractRelNode {
private final List<HyperEdge> edges;

// record the indices of complex hyper edges in the 'edges'
private final BitSet complexEdgesBitmap;
private final ImmutableBitSet complexEdgesBitmap;

/**
* For the HashMap fields, key is the bitmap for inputs,
Expand All @@ -81,15 +85,16 @@ protected HyperGraph(RelOptCluster cluster,
List<HyperEdge> edges,
RelDataType rowType) {
super(cluster, traitSet);
this.inputs = inputs;
this.edges = edges;
this.inputs = Lists.newArrayList(inputs);
this.edges = Lists.newArrayList(edges);
this.rowType = rowType;
this.complexEdgesBitmap = new BitSet();
ImmutableBitSet.Builder bitSetBuilder = ImmutableBitSet.builder();
for (int i = 0; i < edges.size(); i++) {
if (!edges.get(i).isSimple()) {
complexEdgesBitmap.set(i);
bitSetBuilder.set(i);
}
}
this.complexEdgesBitmap = bitSetBuilder.build();
this.ccpUsedEdgesMap = new HashMap<>();
this.simpleEdgesMap = new HashMap<>();
this.complexEdgesMap = new HashMap<>();
Expand All @@ -101,23 +106,23 @@ protected HyperGraph(RelOptCluster cluster,
List<RelNode> inputs,
List<HyperEdge> edges,
RelDataType rowType,
BitSet complexEdgesBitmap,
ImmutableBitSet complexEdgesBitmap,
HashMap<Long, BitSet> ccpUsedEdgesMap,
HashMap<Long, BitSet> simpleEdgesMap,
HashMap<Long, BitSet> complexEdgesMap,
HashMap<Long, BitSet> overlapEdgesMap) {
super(cluster, traitSet);
this.inputs = inputs;
this.edges = edges;
this.inputs = Lists.newArrayList(inputs);
this.edges = Lists.newArrayList(edges);
this.rowType = rowType;
this.complexEdgesBitmap = complexEdgesBitmap;
this.ccpUsedEdgesMap = ccpUsedEdgesMap;
this.simpleEdgesMap = simpleEdgesMap;
this.complexEdgesMap = complexEdgesMap;
this.overlapEdgesMap = overlapEdgesMap;
this.ccpUsedEdgesMap = new HashMap<>(ccpUsedEdgesMap);
this.simpleEdgesMap = new HashMap<>(simpleEdgesMap);
this.complexEdgesMap = new HashMap<>(complexEdgesMap);
this.overlapEdgesMap = new HashMap<>(overlapEdgesMap);
}

@Override public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
@Override public HyperGraph copy(RelTraitSet traitSet, List<RelNode> inputs) {
return new HyperGraph(
getCluster(),
traitSet,
Expand Down Expand Up @@ -156,6 +161,31 @@ protected HyperGraph(RelOptCluster cluster,
return rowType;
}

@Override public RelNode accept(RexShuttle shuttle) {
List<HyperEdge> shuttleEdges = new ArrayList<>();
for (HyperEdge edge : edges) {
HyperEdge shuttleEdge =
new HyperEdge(
edge.getLeftNodeBitmap(),
edge.getRightNodeBitmap(),
edge.getJoinType(),
shuttle.apply(edge.getCondition()));
shuttleEdges.add(shuttleEdge);
}

return new HyperGraph(
getCluster(),
traitSet,
inputs,
shuttleEdges,
rowType,
complexEdgesBitmap,
ccpUsedEdgesMap,
simpleEdgesMap,
complexEdgesMap,
overlapEdgesMap);
}

//~ hyper graph method ----------------------------------------------------------

public List<HyperEdge> getEdges() {
Expand Down Expand Up @@ -211,7 +241,7 @@ public List<HyperEdge> connectCsgCmp(long csg, long cmp) {
// may omit some complex hyper edges. e.g.
// csg = {t1, t3}, cmp = {t2}, will omit the edge (t1, t2)——(t3)
BitSet mayMissedEdges = new BitSet();
mayMissedEdges.or(complexEdgesBitmap);
mayMissedEdges.or(complexEdgesBitmap.toBitSet());
mayMissedEdges.andNot(ccpUsedEdgesMap.getOrDefault(csg, new BitSet()));
mayMissedEdges.andNot(ccpUsedEdgesMap.getOrDefault(cmp, new BitSet()));
mayMissedEdges.andNot(connectedEdgesBitmap);
Expand Down Expand Up @@ -381,7 +411,7 @@ public RexNode extractJoinCond(RelNode left, RelNode right, List<HyperEdge> edge
* Before starting enumeration, add Project on every input, make all field name unique.
* Convert the HyperEdge condition from RexInputRef to RexInputFieldName
*/
public void convertHyperEdgeCond() {
public void convertHyperEdgeCond(RelBuilder builder) {
int fieldIndex = 0;
List<RelDataTypeField> fieldList = rowType.getFieldList();
for (int nodeIndex = 0; nodeIndex < inputs.size(); nodeIndex++) {
Expand All @@ -396,14 +426,10 @@ public void convertHyperEdgeCond() {
names.add(fieldList.get(fieldIndex).getName());
fieldIndex++;
}
RelNode renameProject =
LogicalProject.create(
input,
ImmutableList.of(),
projects,
names,
input.getVariablesSet());
replaceInput(nodeIndex, renameProject);

builder.push(input)
.project(projects, names, true);
replaceInput(nodeIndex, builder.build());
}

RexShuttle inputRef2inputNameShuttle = new RexShuttle() {
Expand All @@ -415,9 +441,16 @@ public void convertHyperEdgeCond() {
}
};

for (HyperEdge hyperEdge : edges) {
RexNode convertCond = hyperEdge.getCondition().accept(inputRef2inputNameShuttle);
hyperEdge.replaceCondition(convertCond);
for (int i = 0; i < edges.size(); i++) {
HyperEdge edge = edges.get(i);
RexNode convertCond = edge.getCondition().accept(inputRef2inputNameShuttle);
HyperEdge convertEdge =
new HyperEdge(
edge.getLeftNodeBitmap(),
edge.getRightNodeBitmap(),
edge.getJoinType(),
convertCond);
edges.set(i, convertEdge);
}
}

Expand Down
Loading

0 comments on commit 3c053d2

Please sign in to comment.