godfreyhe commented on a change in pull request #10174: 
[FLINK-14625][table-planner-blink] Add a rule to eliminate cross join as much 
as possible without statistics
URL: https://github.com/apache/flink/pull/10174#discussion_r346283696
 
 

 ##########
 File path: 
flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/EliminateCrossJoinRule.java
 ##########
 @@ -0,0 +1,477 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.util.Preconditions;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.rules.LoptMultiJoin;
+import org.apache.calcite.rel.rules.MultiJoin;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexShuttle;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.mapping.Mappings;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+
+/**
+ * This rule tries to eliminate cross joins by reordering joins.
+ * The new order of joins are determined with the following steps:
+ *
+ * <p>1. The inputs related with an equi-join filter (= or IS NOT DISTINCT 
FROM) will be joined first.
+ *       Inputs with smaller indices has higher priority.
+ *
+ * <p>2. The inputs related with other join filters will then be joined.
+ *
+ * <p>3. If not all inner join inputs are joined, they will be joined in input 
order.
+ *
+ * <p>4. Outer joins are added.
+ */
+public class EliminateCrossJoinRule extends RelOptRule {
+
+       public static final EliminateCrossJoinRule INSTANCE = new 
EliminateCrossJoinRule();
+
+       private EliminateCrossJoinRule() {
+               super(operand(MultiJoin.class, any()), 
"EliminateCrossJoinRule");
+       }
+
+       @Override
+       public void onMatch(RelOptRuleCall call) {
+               MultiJoin join = call.rel(0);
+               RelBuilder relBuilder = call.builder();
+
+               if (join.isFullOuterJoin()) {
+                       // full outer join, do not reorder joins
+                       Preconditions.checkArgument(
+                               join.getInputs().size() == 2,
+                               "Full outer join must have exactly 2 inputs. 
This is a bug.");
+
+                       relBuilder
+                               .push(join.getInput(0))
+                               .push(join.getInput(1))
+                               .join(JoinRelType.FULL, join.getJoinFilter());
+                       if (join.getPostJoinFilter() != null) {
+                               relBuilder.filter(join.getPostJoinFilter());
+                       }
+               } else {
+                       int outerJoinCount = 0;
+                       for (int i = 0; i < join.getInputs().size(); i++) {
+                               if (join.getJoinTypes().get(i) != 
JoinRelType.INNER) {
+                                       outerJoinCount++;
+                               }
+                       }
+                       Preconditions.checkState(
+                               outerJoinCount <= 1,
+                               "EliminateCrossJoinRule assumes that there is 
at most 1 outer join " +
+                                       "in a layer of multi-join, but " + 
outerJoinCount + " outer joins were found.");
+                       if (outerJoinCount == 1) {
+                               int numInputs = join.getInputs().size();
+                               Preconditions.checkState(
+                                       join.getJoinTypes().get(0) == 
JoinRelType.RIGHT ||
+                                               
join.getJoinTypes().get(numInputs - 1) == JoinRelType.LEFT,
+                                       "EliminateCrossJoinRule assumes that " +
+                                               "the only left outer join input 
must locate at the end, or" +
+                                               "the only right outer join 
input must locate at the beginning");
+                       }
+
+                       LoptMultiJoin loptMultiJoin = new LoptMultiJoin(join);
+
+                       // try to eliminate cross join
+                       Vertex joinVertexTree = 
multiJoinToJoinVertexTree(loptMultiJoin);
+                       Mappings.TargetMapping mapping = 
joinVertexTreeToJoinRelTree(joinVertexTree, loptMultiJoin, relBuilder);
+
+                       // apply post-join filters
+                       if (join.getPostJoinFilter() != null) {
+                               RexBuilder rexBuilder = 
join.getCluster().getRexBuilder();
+                               
relBuilder.filter(mapFilter(join.getPostJoinFilter(), mapping, rexBuilder));
+                       }
+
+                       // use projections to keep the output of the join 
unchanged
+                       List<RexNode> projects = generateProjection(join, 
mapping);
+                       relBuilder.project(projects);
+               }
+
+               RelNode rel = relBuilder.build();
+               call.transformTo(rel);
+       }
+
+       private Vertex multiJoinToJoinVertexTree(LoptMultiJoin multiJoin) {
+               JoinVertexTreeBuilder builder = new JoinVertexTreeBuilder(
+                       multiJoin,
+                       (left, right) -> {
+                               boolean leftIsEqui = isEquiFilter(left.filter);
+                               boolean rightIsEqui = 
isEquiFilter(right.filter);
+                               if (leftIsEqui ^ rightIsEqui) {
+                                       // one of the filter is not an 
equi-filter
+                                       // equi-filter has higher priority
+                                       return leftIsEqui ? -1 : 1;
+                               } else {
+                                       // both or none of the filter is an 
equi-filter
+                                       // the one with the smallest input wins
+                                       int a = -1;
+                                       int b = -1;
+                                       do {
+                                               a = 
left.inputBitSet.nextSetBit(a + 1);
+                                               b = 
right.inputBitSet.nextSetBit(b + 1);
+                                       } while (a == b && a >= 0);
+
+                                       if (a >= 0 && b >= 0) {
+                                               return a - b;
+                                       } else if (a < 0 && b < 0) {
+                                               return 0;
+                                       } else {
+                                               return a;
+                                       }
+                               }
+                       });
+
+               JoinFilter bestFilter;
+               while ((bestFilter = builder.getBestFilter()) != null) {
+                       builder.innerJoin(bestFilter.inputBitSet);
+               }
+
+               return builder.toJoinVertexTree();
+       }
+
+       private boolean isEquiFilter(RexNode filter) {
+               return filter.isA(SqlKind.EQUALS) || 
filter.isA(SqlKind.IS_NOT_DISTINCT_FROM);
+       }
+
+       private static Mappings.TargetMapping joinVertexTreeToJoinRelTree(
+               Vertex joinVertexTree,
+               LoptMultiJoin multiJoin,
+               RelBuilder relBuilder) {
+               if (joinVertexTree instanceof LeafVertex) {
+                       LeafVertex leaf = (LeafVertex) joinVertexTree;
+                       int numFields = 
multiJoin.getNumFieldsInJoinFactor(leaf.smallestInputIdx);
+                       int joinStart = 
multiJoin.getJoinStart(leaf.smallestInputIdx);
+
+                       relBuilder.push(leaf.input);
+                       return Mappings.createShiftMapping(
+                               joinStart + numFields, 0, joinStart, numFields);
+               } else {
+                       JoinVertex joinVertex = (JoinVertex) joinVertexTree;
+                       Mappings.TargetMapping leftMapping =
+                               joinVertexTreeToJoinRelTree(joinVertex.left, 
multiJoin, relBuilder);
+                       Mappings.TargetMapping rightMapping =
+                               joinVertexTreeToJoinRelTree(joinVertex.right, 
multiJoin, relBuilder);
+                       Mappings.TargetMapping mergedMapping = 
mergeMapping(leftMapping, rightMapping);
+                       RexBuilder rexBuilder = 
multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
+
+                       RexNode mappedFilter = mapFilter(
+                               RexUtil.composeConjunction(rexBuilder, 
joinVertex.joinFilters, false),
+                               mergedMapping,
+                               rexBuilder);
+                       relBuilder.join(joinVertex.joinType, mappedFilter);
+                       return mergedMapping;
+               }
+       }
+
+       private static List<RexNode> generateProjection(MultiJoin join, 
Mappings.TargetMapping mapping) {
+               List<RelDataTypeField> fields = 
join.getRowType().getFieldList();
+               List<RexNode> projects = new ArrayList<>();
+               for (int i = 0; i < mapping.getSourceCount(); i++) {
+                       int newIdx = mapping.getTargetOpt(i);
+                       projects.add(new RexInputRef(newIdx, 
fields.get(i).getType()));
+               }
+               return projects;
+       }
+
+       private static RexNode mapFilter(RexNode filter, Mappings.TargetMapping 
mapping, RexBuilder rexBuilder) {
+               return filter.accept(new RexInputConverter(rexBuilder, 
mapping));
+       }
+
+       private static Mappings.TargetMapping 
mergeMapping(Mappings.TargetMapping left, Mappings.TargetMapping right) {
+               return Mappings.merge(left, Mappings.offsetTarget(right, 
left.getTargetCount()));
+       }
+
+       /**
+        * A wrapper class for a join filter.
+        * The bit set and list indicate that which inputs this filter is 
related to.
+        */
+       private static class JoinFilter {
+               final RexNode filter;
+               final ImmutableBitSet inputBitSet;
+
+               JoinFilter(RexNode filter, ImmutableBitSet inputBitSet) {
+                       this.filter = filter;
+                       this.inputBitSet = inputBitSet;
+               }
+       }
+
+       /**
+        * A vertex in the join tree.
+        */
+       private abstract static class Vertex {
+               final int numFields;
+               final ImmutableBitSet inputBitSet;
+               final int smallestInputIdx;
+
+               Vertex(int numFields, ImmutableBitSet inputBitSet, int 
smallestInputIdx) {
+                       this.numFields = numFields;
+                       this.inputBitSet = inputBitSet;
+                       this.smallestInputIdx = smallestInputIdx;
+               }
+       }
+
+       /**
+        * A non-leaf vertex in the join tree.
+        */
+       private static class JoinVertex extends Vertex {
+               final JoinRelType joinType;
+               final Vertex left;
+               final Vertex right;
+               final List<RexNode> joinFilters;
+
+               /**
+                * @param joinType                      Join type of this join 
vertex (INNER, LEFT, RIGHT or FULL)
+                * @param left                          The left input of this 
join vertex
+                * @param right                         The right input of this 
join vertex
+                * @param joinFilters           A list of join filters 
applicable to this join vertex
+                */
+               JoinVertex(
+                       JoinRelType joinType,
+                       Vertex left,
+                       Vertex right,
+                       List<RexNode> joinFilters) {
+                       super(left.numFields + right.numFields,
+                               left.inputBitSet.union(right.inputBitSet),
+                               Math.min(left.smallestInputIdx, 
right.smallestInputIdx));
+                       this.joinType = joinType;
+                       this.left = left;
+                       this.right = right;
+                       this.joinFilters = joinFilters;
+               }
+       }
+
+       /**
+        * A leaf vertex of a join tree, representing an input of the join.
+        */
+       private static class LeafVertex extends Vertex {
+               final RelNode input;
+
+               /**
+                * @param input                         The input
+                * @param inputIdx                      The index of this input 
in the original multi-join
+                */
+               LeafVertex(RelNode input, int inputIdx) {
+                       super(input.getRowType().getFieldCount(), 
ImmutableBitSet.of(inputIdx), inputIdx);
+                       this.input = input;
+               }
+       }
+
+       /**
+        * Build a {@link Vertex} tree from a {@link LoptMultiJoin}.
+        */
+       private static class JoinVertexTreeBuilder {
+               private final LoptMultiJoin multiJoin;
+               private final Vertex[] rootVertex;
+
+               private final Comparator<JoinFilter> comparator;
+               private final List<JoinFilter> filters;
+               private JoinFilter bestFilter;
+
+               private boolean finished = false;
+
+               JoinVertexTreeBuilder(LoptMultiJoin multiJoin, 
Comparator<JoinFilter> comparator) {
+                       this.multiJoin = multiJoin;
+                       this.comparator = comparator;
+
+                       this.filters = new LinkedList<>();
+                       List<RexNode> rexFilters = multiJoin.getJoinFilters();
+                       for (RexNode rex : rexFilters) {
+                               filters.add(new JoinFilter(rex, 
multiJoin.getFactorsRefByJoinFilter(rex)));
+                       }
+                       updateBestFilter();
+
+                       int numInputs = multiJoin.getNumJoinFactors();
+                       this.rootVertex = new Vertex[numInputs];
+                       for (int i = 0; i < numInputs; i++) {
+                               RelNode input = multiJoin.getJoinFactor(i);
+                               rootVertex[i] = new LeafVertex(input, i);
+                       }
+               }
+
+               /**
+                * This method will iterate through a list of filters.
+                * If all the input refs in a filter have appeared in the given 
bit set,
+                * it will be put into the returned list and be removed from 
the original list.
+                *
+                * @param mergedBitSet          All the input refs in a filter 
must also appear in this bit set,
+                *                          so that the filter can be picked out
+                * @param filters                       List of filters to 
check.
+                *                                                      {@link 
java.util.LinkedList} (or other lists whose iterator can remove
+                *                                                      current 
element in O(1) time) is recommended for better performance
+                */
+               static List<RexNode> pickJoinFilters(ImmutableBitSet 
mergedBitSet, List<JoinFilter> filters) {
+                       List<RexNode> ret = new ArrayList<>();
+                       Iterator<JoinFilter> iter = filters.iterator();
+                       while (iter.hasNext()) {
+                               JoinFilter filter = iter.next();
+                               if (mergedBitSet.contains(filter.inputBitSet)) {
+                                       ret.add(filter.filter);
+                                       iter.remove();
+                               }
+                       }
+                       return ret;
+               }
+
+               JoinFilter getBestFilter() {
+                       return bestFilter;
+               }
+
+               /**
+                * Perform inner joins between the inputs specified by the bit 
set.
+                * If the inputs have been joined into other {@link Vertex}, 
those vertices will be joined instead.
+                *
+                * @param joinedInput                   A bit set specifying 
which inputs to join
+                */
+               void innerJoin(ImmutableBitSet joinedInput) {
+                       innerJoin(joinedInput.toList());
+               }
+
+               /**
+                * @param joinedInput                   A list specifying which 
inputs to join
+                */
+               void innerJoin(List<Integer> joinedInput) {
+                       Preconditions.checkArgument(
+                               joinedInput.size() >= 2, "At least 2 inputs are 
needed to perform a join.");
+                       for (int i = 1; i < joinedInput.size(); i++) {
+                               Vertex left = findRootVertex(joinedInput.get(i 
- 1));
+                               Vertex right = 
findRootVertex(joinedInput.get(i));
+                               if (left.smallestInputIdx < 
right.smallestInputIdx) {
+                                       innerJoin(left, right);
+                               } else if (left.smallestInputIdx > 
right.smallestInputIdx) {
+                                       innerJoin(right, left);
+                               }
+                       }
+               }
+
+               /**
+                * NOTE: This method can only be called once for each builder 
instance.
+                *
+                * @return              A {@link Vertex} tree equivalent to the 
given {@link LoptMultiJoin}
+                */
+               Vertex toJoinVertexTree() {
+                       Preconditions.checkState(
+                               !finished,
+                               "`toJoinVertexTree` can only be called once for 
each builder instance");
+                       finished = true;
+
+                       // if not all filters have been applied, join and apply 
them first
+                       while (!filters.isEmpty()) {
+                               innerJoin(filters.get(0).inputBitSet);
+                       }
+
+                       // it is possible that all inner join inputs haven't 
been joined together,
+                       // so join them in order
+                       List<JoinRelType> joinTypes = 
multiJoin.getMultiJoinRel().getJoinTypes();
+                       Vertex lastVertex = null;
+                       for (int i = 0; i < rootVertex.length; i++) {
+                               if (joinTypes.get(i) == JoinRelType.INNER) {
+                                       Vertex currentVertex = 
findRootVertex(i);
+                                       if (lastVertex == null) {
+                                               lastVertex = currentVertex;
+                                       } else if (lastVertex.smallestInputIdx 
!= currentVertex.smallestInputIdx) {
+                                               lastVertex = 
innerJoin(lastVertex, currentVertex);
+                                       }
+                               }
+                       }
+                       Preconditions.checkNotNull(lastVertex, "Inner join 
input not found. This is a bug.");
+
+                       // finally add outer join
+                       int numInputs = joinTypes.size();
+                       if (joinTypes.get(0) == JoinRelType.RIGHT) {
+                               lastVertex = new JoinVertex(
+                                       JoinRelType.RIGHT,
+                                       findRootVertex(0),
+                                       lastVertex,
+                                       
Collections.singletonList(multiJoin.getOuterJoinCond(0)));
+                       } else if (joinTypes.get(numInputs - 1) == 
JoinRelType.LEFT) {
+                               lastVertex = new JoinVertex(
+                                       JoinRelType.LEFT,
+                                       lastVertex,
+                                       findRootVertex(numInputs - 1),
+                                       
Collections.singletonList(multiJoin.getOuterJoinCond(numInputs - 1)));
+                       }
+
+                       return lastVertex;
+               }
+
+               private Vertex innerJoin(Vertex left, Vertex right) {
+                       List<RexNode> pickedFilters = 
pickJoinFilters(left.inputBitSet.union(right.inputBitSet), filters);
+                       updateBestFilter();
+                       Vertex merged = new JoinVertex(JoinRelType.INNER, left, 
right, pickedFilters);
+                       rootVertex[left.smallestInputIdx] = merged;
+                       rootVertex[right.smallestInputIdx] = merged;
+                       return merged;
+               }
+
+               private Vertex findRootVertex(int inputIdx) {
+                       // union-find algorithm to find out
+                       // which join vertex this inputIdx belongs to
+                       if (rootVertex[inputIdx].smallestInputIdx != inputIdx) {
+                               rootVertex[inputIdx] = 
findRootVertex(rootVertex[inputIdx].smallestInputIdx);
+                       }
+                       return rootVertex[inputIdx];
+               }
+
+               private void updateBestFilter() {
+                       bestFilter = null;
+                       for (JoinFilter joinFilter : filters) {
+                               if (bestFilter == null || 
comparator.compare(joinFilter, bestFilter) < 0) {
+                                       bestFilter = joinFilter;
+                               }
+                       }
+               }
+       }
+
+       /**
+        * Simple converter which converts input refs to a new index according 
to a {@link Mappings.TargetMapping}.
+        */
+       private static class RexInputConverter extends RexShuttle {
+               private RexBuilder rexBuilder;
+               private Mappings.TargetMapping mapping;
+
+               RexInputConverter(RexBuilder rexBuilder, Mappings.TargetMapping 
mapping) {
+                       this.rexBuilder = rexBuilder;
+                       this.mapping = mapping;
+               }
+
+               @Override
+               public RexNode visitInputRef(RexInputRef var) {
 
 Review comment:
   ref ?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to