swuferhong commented on code in PR #21530:
URL: https://github.com/apache/flink/pull/21530#discussion_r1066492364


##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkBushyJoinReorderRule.java:
##########
@@ -0,0 +1,634 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.table.planner.plan.cost.FlinkCost;
+
+import org.apache.calcite.plan.RelOptCost;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rel.rules.LoptMultiJoin;
+import org.apache.calcite.rel.rules.MultiJoin;
+import org.apache.calcite.rel.rules.TransformationRule;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexShuttle;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.calcite.util.ImmutableBitSet;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * Flink bushy join reorder rule, which will convert {@link MultiJoin} to a 
bushy join tree.
+ *
+ * <p>In this bushy join reorder strategy, the join reorder step is as follows:
+ *
+ * <p>First step, we will reorder all the inner join type inputs in the 
multiJoin. We adopt the
+ * concept of level in dynamic programming, and the latter layer will use the 
results stored in the
+ * previous levels. First, we put all input factor (each input factor in 
{@link MultiJoin}) into
+ * level 0, then we build all two-inputs join at level 1 based on the {@link 
FlinkCost} of level 0,
+ * then we will build three-inputs join based on the previous two levels, then 
four-inputs joins ...
+ * etc, util we reorder all the inner join type input factors in the 
multiJoin. When building
+ * m-inputs join, we only keep the best plan (have the lowest {@link 
FlinkCost}) for the same set of
+ * m inputs. E.g., for three-inputs join, we keep only the best plan for 
inputs {A, B, C} among
+ * plans (A J B) J C, (A J C) J B, (B J C) J A.
+ *
+ * <p>Second step, we will add all outer join factors to the top of reordered 
join tree generated by
+ * the first step. E.g., for the example (((A LJ B) IJ C) IJ D). we will first 
reorder A, C and D
+ * using the first step strategy, get ((A IJ B) IJ C). Then, we will add D to 
the top, get (((A IJ
+ * B) IJ C) LJ D).
+ *
+ * <p>Third step, if there are factors whose join condition is true, we will 
add these factors to
+ * the top in the final step.
+ */
+public class FlinkBushyJoinReorderRule extends 
RelRule<FlinkBushyJoinReorderRule.Config>
+        implements TransformationRule {
+
+    /** Creates a FlinkBushyJoinReorderRule. */
+    protected FlinkBushyJoinReorderRule(Config config) {
+        super(config);
+    }
+
+    @Deprecated // to be removed before 2.0
+    public FlinkBushyJoinReorderRule(RelBuilderFactory relBuilderFactory) {
+        
this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class));
+    }
+
+    @Deprecated // to be removed before 2.0
+    public FlinkBushyJoinReorderRule(
+            RelFactories.JoinFactory joinFactory,
+            RelFactories.ProjectFactory projectFactory,
+            RelFactories.FilterFactory filterFactory) {
+        this(RelBuilder.proto(joinFactory, projectFactory, filterFactory));
+    }
+
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+        final RelBuilder relBuilder = call.builder();
+        final MultiJoin multiJoinRel = call.rel(0);
+        final LoptMultiJoin multiJoin = new LoptMultiJoin(multiJoinRel);
+        RelNode bestOrder = findBestOrder(relBuilder, multiJoin);
+        call.transformTo(bestOrder);
+    }
+
+    /**
+     * Find best join reorder using bushy join reorder strategy. We will first 
try to reorder all
+     * the inner join type input factors in the multiJoin. Then, we will add 
all outer join factors
+     * to the top of reordered join tree generated by the first step. If there 
are factors, which
+     * join condition is true, we will add these factors to the top in the 
final step.
+     */
+    private static RelNode findBestOrder(RelBuilder relBuilder, LoptMultiJoin 
multiJoin) {
+        // Reorder all the inner join type input factors in the multiJoin.
+        List<Map<Set<Integer>, JoinPlan>> foundPlansForInnerJoin =
+                reorderInnerJoin(relBuilder, multiJoin);
+
+        Map<Set<Integer>, JoinPlan> lastLevelOfInnerJoin =
+                foundPlansForInnerJoin.get(foundPlansForInnerJoin.size() - 1);
+
+        JoinPlan containOuterJoinPlan;
+        // Add all outer join factors in the multiJoin (including 
left/right/full) on the
+        // top of tree if outer join condition exists in multiJoin.
+        if (outerJoinConditionExists(multiJoin)) {
+            containOuterJoinPlan =
+                    addToTopForOuterJoin(getBestPlan(lastLevelOfInnerJoin), 
multiJoin, relBuilder);
+        } else {
+            containOuterJoinPlan = getBestPlan(lastLevelOfInnerJoin);
+        }
+
+        JoinPlan finalPlan;
+        // Add these factors whose join condition is true to the top.
+        if (containOuterJoinPlan.factorIds.size() != 
multiJoin.getNumJoinFactors()) {
+            finalPlan = addToTopForTrueCondition(containOuterJoinPlan, 
multiJoin, relBuilder);
+        } else {
+            finalPlan = containOuterJoinPlan;
+        }
+
+        final List<String> fieldNames = 
multiJoin.getMultiJoinRel().getRowType().getFieldNames();
+        return creatTopProject(relBuilder, multiJoin, finalPlan, fieldNames);
+    }
+
+    private static List<Map<Set<Integer>, JoinPlan>> reorderInnerJoin(
+            RelBuilder relBuilder, LoptMultiJoin multiJoin) {
+        int numJoinFactors = multiJoin.getNumJoinFactors();
+        List<Map<Set<Integer>, JoinPlan>> foundPlans = new ArrayList<>();
+
+        // First, we put all join factors in MultiJoin into level 0.
+        Map<Set<Integer>, JoinPlan> firstLevelJoinPlanMap = new 
LinkedHashMap<>();
+        for (int i = 0; i < numJoinFactors; i++) {
+            if (!multiJoin.isNullGenerating(i)) {
+                HashSet<Integer> set1 = new HashSet<>();
+                LinkedHashSet<Integer> set2 = new LinkedHashSet<>();
+                set1.add(i);
+                set2.add(i);
+                RelNode joinFactor = multiJoin.getJoinFactor(i);
+                firstLevelJoinPlanMap.put(set1, new JoinPlan(set2, 
joinFactor));
+            }
+        }
+        foundPlans.add(firstLevelJoinPlanMap);
+
+        // Build plans for next levels until the found plans size equals the 
number of join factors,
+        // or no possible plan exists for next level.
+        while (foundPlans.size() < numJoinFactors) {
+            Map<Set<Integer>, JoinPlan> nextLevelJoinPlanMap =
+                    foundNextLevel(relBuilder, new ArrayList<>(foundPlans), 
multiJoin);
+            if (nextLevelJoinPlanMap.size() == 0) {
+                break;
+            }
+            foundPlans.add(nextLevelJoinPlanMap);
+        }
+
+        return foundPlans;
+    }
+
+    private static boolean outerJoinConditionExists(LoptMultiJoin multiJoin) {
+        for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) {
+            if (multiJoin.getOuterJoinCond(i) != null
+                    && 
RelOptUtil.conjunctions(multiJoin.getOuterJoinCond(i)).size() != 0) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    private static JoinPlan getBestPlan(Map<Set<Integer>, JoinPlan> levelPlan) 
{
+        JoinPlan bestPlan = null;
+        for (Map.Entry<Set<Integer>, JoinPlan> entry : levelPlan.entrySet()) {
+            if (bestPlan == null || entry.getValue().betterThan(bestPlan)) {
+                bestPlan = entry.getValue();
+            }
+        }
+
+        return bestPlan;
+    }
+
+    private static JoinPlan addToTopForOuterJoin(
+            JoinPlan bestPlan, LoptMultiJoin multiJoin, RelBuilder relBuilder) 
{
+        List<Integer> remainIndexes = new ArrayList<>();
+        for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) {
+            if (!bestPlan.factorIds.contains(i)) {
+                remainIndexes.add(i);
+            }
+        }
+
+        RelNode leftNode = bestPlan.relNode;
+        LinkedHashSet<Integer> set = new LinkedHashSet<>(bestPlan.factorIds);
+        for (int index : remainIndexes) {
+            RelNode rightNode = multiJoin.getJoinFactor(index);
+
+            // Make new join condition and get new join type.
+            Optional<Tuple2<Set<RexCall>, JoinRelType>> joinConditions =
+                    getConditionsAndJoinType(
+                            bestPlan.factorIds, Collections.singleton(index), 
multiJoin, true);
+
+            if (!joinConditions.isPresent()) {
+                continue;
+            } else {
+                // Is left/right outer join, but we all given left join type.
+                Set<RexCall> conditions = joinConditions.get().f0;
+                List<RexNode> rexCalls = new ArrayList<>(conditions);
+                Set<RexCall> newCondition =
+                        convertToNewCondition(
+                                new ArrayList<>(set),
+                                Collections.singletonList(index),
+                                rexCalls,
+                                multiJoin);
+
+                leftNode =
+                        relBuilder
+                                .push(leftNode)
+                                .push(rightNode)
+                                .join(JoinRelType.LEFT, newCondition)
+                                .build();
+            }
+            set.add(index);
+        }
+        return new JoinPlan(set, leftNode);
+    }
+
+    private static JoinPlan addToTopForTrueCondition(
+            JoinPlan bestPlan, LoptMultiJoin multiJoin, RelBuilder relBuilder) 
{
+        RexBuilder rexBuilder = 
multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
+        List<Integer> remainIndexes = new ArrayList<>();
+        for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) {
+            if (!bestPlan.factorIds.contains(i)) {
+                remainIndexes.add(i);
+            }
+        }
+
+        RelNode leftNode = bestPlan.relNode;
+        LinkedHashSet<Integer> set = new LinkedHashSet<>(bestPlan.factorIds);
+        for (int index : remainIndexes) {
+            set.add(index);
+            RelNode rightNode = multiJoin.getJoinFactor(index);
+            leftNode =
+                    relBuilder
+                            .push(leftNode)
+                            .push(rightNode)
+                            .join(
+                                    
multiJoin.getMultiJoinRel().getJoinTypes().get(index),
+                                    rexBuilder.makeLiteral(true))
+                            .build();
+        }
+        return new JoinPlan(set, leftNode);
+    }
+
+    /**
+     * Creates the topmost projection that will sit on top of the selected 
join ordering. The
+     * projection needs to match the original join ordering. Also, places any 
post-join filters on
+     * top of the project.
+     */
+    private static RelNode creatTopProject(
+            RelBuilder relBuilder,
+            LoptMultiJoin multiJoin,
+            JoinPlan finalPlan,
+            List<String> fieldNames) {
+        List<RexNode> newProjExprs = new ArrayList<>();
+        RexBuilder rexBuilder = 
multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
+
+        List<Integer> newJoinOrder = new ArrayList<>(finalPlan.factorIds);
+        int nJoinFactors = multiJoin.getNumJoinFactors();
+        List<RelDataTypeField> fields = multiJoin.getMultiJoinFields();
+
+        // create a mapping from each factor to its field offset in the join
+        // ordering
+        final Map<Integer, Integer> factorToOffsetMap = new HashMap<>();
+        for (int pos = 0, fieldStart = 0; pos < nJoinFactors; pos++) {
+            factorToOffsetMap.put(newJoinOrder.get(pos), fieldStart);
+            fieldStart += 
multiJoin.getNumFieldsInJoinFactor(newJoinOrder.get(pos));
+        }
+
+        for (int currFactor = 0; currFactor < nJoinFactors; currFactor++) {
+            // if the factor is the right factor in a removable self-join,
+            // then where possible, remap references to the right factor to
+            // the corresponding reference in the left factor
+            Integer leftFactor = null;
+            if (multiJoin.isRightFactorInRemovableSelfJoin(currFactor)) {
+                leftFactor = multiJoin.getOtherSelfJoinFactor(currFactor);
+            }
+            for (int fieldPos = 0;
+                    fieldPos < multiJoin.getNumFieldsInJoinFactor(currFactor);
+                    fieldPos++) {
+                int newOffset =
+                        requireNonNull(
+                                        factorToOffsetMap.get(currFactor),
+                                        () -> 
"factorToOffsetMap.get(currFactor)")
+                                + fieldPos;
+                if (leftFactor != null) {
+                    Integer leftOffset = 
multiJoin.getRightColumnMapping(currFactor, fieldPos);
+                    if (leftOffset != null) {
+                        newOffset =
+                                requireNonNull(
+                                                
factorToOffsetMap.get(leftFactor),
+                                                
"factorToOffsetMap.get(leftFactor)")
+                                        + leftOffset;
+                    }
+                }
+                newProjExprs.add(
+                        rexBuilder.makeInputRef(
+                                fields.get(newProjExprs.size()).getType(), 
newOffset));
+            }
+        }
+
+        relBuilder.clear();
+        relBuilder.push(finalPlan.relNode);
+        relBuilder.project(newProjExprs, fieldNames);
+
+        // Place the post-join filter (if it exists) on top of the final 
projection.
+        RexNode postJoinFilter = 
multiJoin.getMultiJoinRel().getPostJoinFilter();
+        if (postJoinFilter != null) {
+            relBuilder.filter(postJoinFilter);
+        }
+        return relBuilder.build();
+    }
+
+    /** Found possible join plans for the next level based on the found plans 
in the prev levels. */
+    private static Map<Set<Integer>, JoinPlan> foundNextLevel(
+            RelBuilder relBuilder,
+            List<Map<Set<Integer>, JoinPlan>> foundPlans,
+            LoptMultiJoin multiJoin) {
+        Map<Set<Integer>, JoinPlan> currentLevelJoinPlanMap = new 
LinkedHashMap<>();
+        int foundPlansLevel = foundPlans.size() - 1;
+        int joinLeftSideLevel = 0;
+        int joinRightSideLevel = foundPlansLevel;
+        while (joinLeftSideLevel <= joinRightSideLevel) {
+            List<JoinPlan> joinLeftSidePlans =
+                    new 
ArrayList<>(foundPlans.get(joinLeftSideLevel).values());
+            int planSize = joinLeftSidePlans.size();
+            for (int i = 0; i < planSize; i++) {
+                JoinPlan joinLeftSidePlan = joinLeftSidePlans.get(i);
+                List<JoinPlan> joinRightSidePlans;
+                if (joinLeftSideLevel == joinRightSideLevel) {
+                    // If left side level number equals right side level 
number. We can remove those
+                    // top 'i' plans which already judged in right side plans 
to decrease search
+                    // spaces.
+                    joinRightSidePlans = new ArrayList<>(joinLeftSidePlans);
+                    if (i > 0) {
+                        joinRightSidePlans.subList(0, i).clear();
+                    }
+                } else {
+                    joinRightSidePlans =
+                            new 
ArrayList<>(foundPlans.get(joinRightSideLevel).values());
+                }
+                for (JoinPlan joinRightSidePlan : joinRightSidePlans) {
+                    Optional<JoinPlan> newJoinPlan =
+                            buildInnerJoin(
+                                    relBuilder, joinLeftSidePlan, 
joinRightSidePlan, multiJoin);
+                    if (newJoinPlan.isPresent()) {
+                        JoinPlan existingPlanInCurrentLevel =
+                                
currentLevelJoinPlanMap.get(newJoinPlan.get().factorIds);
+                        // check if it's the first plan for the factor set, or 
it's a better plan
+                        // than the existing one due to lower cost.
+                        if (existingPlanInCurrentLevel == null
+                                || 
newJoinPlan.get().betterThan(existingPlanInCurrentLevel)) {
+                            currentLevelJoinPlanMap.put(
+                                    newJoinPlan.get().factorIds, 
newJoinPlan.get());
+                        }
+                    }
+                }
+            }
+            joinLeftSideLevel += 1;
+            joinRightSideLevel = foundPlansLevel - joinLeftSideLevel;

Review Comment:
   > joinLeftSideLevel++; joinRightSideLevel--;
   
   Done!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to