wanglijie95 commented on code in PR #22966:
URL: https://github.com/apache/flink/pull/22966#discussion_r1260517558


##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/program/FlinkRuntimeFilterProgram.java:
##########
@@ -0,0 +1,571 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.optimize.program;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import org.apache.flink.table.planner.plan.nodes.FlinkConventions;
+import 
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
+import 
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase;
+import 
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalHashJoin;
+import 
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalSortMergeJoin;
+import 
org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalGlobalRuntimeFilterBuilder;
+import 
org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalLocalRuntimeFilterBuilder;
+import 
org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalRuntimeFilter;
+import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
+import org.apache.flink.table.planner.plan.utils.DefaultRelShuttle;
+import org.apache.flink.table.planner.plan.utils.FlinkRelMdUtil;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
+
+import org.apache.calcite.plan.RelTraitSet;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Calc;
+import org.apache.calcite.rel.core.Exchange;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.JoinInfo;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.ImmutableIntList;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Planner program that tries to inject runtime filter for suitable join to 
improve join
+ * performance.
+ *
+ * <p>We build the runtime filter in a two-phase manner: First, each subtask 
on the build side
+ * builds a local filter based on its local data, and sends the built filter 
to a global aggregation
+ * node. Then the global aggregation node aggregates the received filters into 
a global filter, and
+ * sends the global filter to all probe side subtasks. Therefore, we will add 
{@link
+ * BatchPhysicalLocalRuntimeFilterBuilder}, {@link 
BatchPhysicalGlobalRuntimeFilterBuilder} and
+ * {@link BatchPhysicalRuntimeFilter} into the physical plan.
+ *
+ * <p>For example, for the following query:
+ *
+ * <pre>{@code SELECT * FROM fact, dim WHERE x = a AND z = 2}</pre>
+ *
+ * <p>The original physical plan:
+ *
+ * <pre>{@code
+ * Calc(select=[a, b, c, x, y, CAST(2 AS BIGINT) AS z])
+ * +- HashJoin(joinType=[InnerJoin], where=[=(x, a)], select=[a, b, c, x, y], 
build=[right])
+ *    :- Exchange(distribution=[hash[a]])
+ *    :  +- TableSourceScan(table=[[fact]], fields=[a, b, c])
+ *    +- Exchange(distribution=[hash[x]])
+ *       +- Calc(select=[x, y], where=[=(z, 2)])
+ *          +- TableSourceScan(table=[[dim, filter=[]]], fields=[x, y, z])
+ * }</pre>
+ *
+ * <p>This optimized physical plan:
+ *
+ * <pre>{@code
+ * Calc(select=[a, b, c, x, y, CAST(2 AS BIGINT) AS z])
+ * +- HashJoin(joinType=[InnerJoin], where=[=(x, a)], select=[a, b, c, x, y], 
build=[right])
+ *    :- Exchange(distribution=[hash[a]])
+ *    :  +- RuntimeFilter(select=[a])
+ *    :     :- Exchange(distribution=[broadcast])
+ *    :     :  +- GlobalRuntimeFilterBuilder
+ *    :     :     +- Exchange(distribution=[single])
+ *    :     :        +- LocalRuntimeFilterBuilder(select=[x])
+ *    :     :           +- Calc(select=[x, y], where=[=(z, 2)])
+ *    :     :              +- TableSourceScan(table=[[dim, filter=[]]], 
fields=[x, y, z])
+ *    :     +- TableSourceScan(table=[[fact]], fields=[a, b, c])
+ *    +- Exchange(distribution=[hash[x]])
+ *       +- Calc(select=[x, y], where=[=(z, 2)])
+ *          +- TableSourceScan(table=[[dim, filter=[]]], fields=[x, y, z])
+ *
+ * }</pre>
+ */
+public class FlinkRuntimeFilterProgram implements 
FlinkOptimizeProgram<BatchOptimizeContext> {
+
+    @Override
+    public RelNode optimize(RelNode root, BatchOptimizeContext context) {
+        final RuntimeFilterConfig config =
+                new 
RuntimeFilterConfig(ShortcutUtils.unwrapContext(root).getTableConfig());
+        if (!config.runtimeFilterEnabled) {
+            return root;
+        }
+
+        DefaultRelShuttle shuttle =
+                new DefaultRelShuttle() {
+                    @Override
+                    public RelNode visit(RelNode rel) {
+                        if (!(rel instanceof Join)) {
+                            List<RelNode> newInputs = new ArrayList<>();
+                            for (RelNode input : rel.getInputs()) {
+                                RelNode newInput = input.accept(this);
+                                newInputs.add(newInput);
+                            }
+                            return rel.copy(rel.getTraitSet(), newInputs);
+                        }
+
+                        Join join = (Join) rel;
+                        RelNode newLeft = join.getLeft().accept(this);
+                        RelNode newRight = join.getRight().accept(this);
+
+                        return tryInjectRuntimeFilter(
+                                join.copy(join.getTraitSet(), 
Arrays.asList(newLeft, newRight)),
+                                config);
+                    }
+                };
+        return shuttle.visit(root);
+    }
+
+    /**
+     * Judge whether the join is suitable, and try to inject runtime filter 
for it.
+     *
+     * @param join the join node
+     * @param config the runtime filter configuration
+     * @return the new join node with runtime filter.
+     */
+    private static Join tryInjectRuntimeFilter(Join join, RuntimeFilterConfig 
config) {
+
+        // check supported join type
+        if (join.getJoinType() != JoinRelType.INNER
+                && join.getJoinType() != JoinRelType.SEMI
+                && join.getJoinType() != JoinRelType.LEFT
+                && join.getJoinType() != JoinRelType.RIGHT) {
+            return join;
+        }
+
+        // check supported join implementation
+        if (!(join instanceof BatchPhysicalHashJoin)
+                && !(join instanceof BatchPhysicalSortMergeJoin)) {
+            return join;
+        }
+
+        boolean leftIsBuild;
+        if (canBeProbeSide(join.getLeft(), config)) {
+            leftIsBuild = false;
+        } else if (canBeProbeSide(join.getRight(), config)) {
+            leftIsBuild = true;
+        } else {
+            return join;
+        }
+
+        // check left join + left build
+        if (join.getJoinType() == JoinRelType.LEFT && !leftIsBuild) {
+            return join;
+        }
+
+        // check right join + right build
+        if (join.getJoinType() == JoinRelType.RIGHT && leftIsBuild) {
+            return join;
+        }
+
+        JoinInfo joinInfo = join.analyzeCondition();
+        RelNode buildSide;
+        RelNode probeSide;
+        ImmutableIntList buildIndices;
+        ImmutableIntList probeIndices;
+        if (leftIsBuild) {
+            buildSide = join.getLeft();
+            probeSide = join.getRight();
+            buildIndices = joinInfo.leftKeys;
+            probeIndices = joinInfo.rightKeys;
+        } else {
+            buildSide = join.getRight();
+            probeSide = join.getLeft();
+            buildIndices = joinInfo.rightKeys;
+            probeIndices = joinInfo.leftKeys;
+        }
+
+        Optional<BuildSideInfo> suitableBuildOpt =
+                findSuitableBuildSide(
+                        buildSide,
+                        buildIndices,
+                        (build, indices) ->
+                                isSuitableDataSize(
+                                        build, probeSide, indices, 
probeIndices, config));
+
+        if (suitableBuildOpt.isPresent()) {
+            BuildSideInfo suitableBuildInfo = suitableBuildOpt.get();
+            RelNode newProbe =
+                    createNewProbeWithRuntimeFilter(
+                            ignoreExchange(suitableBuildInfo.buildSide),
+                            ignoreExchange(probeSide),
+                            suitableBuildInfo.buildIndices,
+                            probeIndices,
+                            config);
+            if (probeSide instanceof Exchange) {
+                newProbe =
+                        ((Exchange) probeSide)
+                                .copy(probeSide.getTraitSet(), 
Collections.singletonList(newProbe));
+            }
+            if (leftIsBuild) {
+                return join.copy(join.getTraitSet(), Arrays.asList(buildSide, 
newProbe));
+            } else {
+                return join.copy(join.getTraitSet(), Arrays.asList(newProbe, 
buildSide));
+            }
+        }
+
+        return join;
+    }
+
+    /**
+     * Inject runtime filter and return the new probe side (without exchange).
+     *
+     * @param buildSide the build side
+     * @param probeSide the probe side
+     * @param buildIndices the build projection
+     * @param probeIndices the probe projection
+     * @return the new probe side
+     */
+    private static RelNode createNewProbeWithRuntimeFilter(
+            RelNode buildSide,
+            RelNode probeSide,
+            ImmutableIntList buildIndices,
+            ImmutableIntList probeIndices,
+            RuntimeFilterConfig config) {
+        Optional<Double> buildRowCountOpt = getEstimatedRowCount(buildSide);
+        checkState(buildRowCountOpt.isPresent());
+        int buildRowCount = buildRowCountOpt.get().intValue();
+        int maxRowCount =
+                (int)
+                        Math.ceil(
+                                config.maxBuildSize
+                                        / 
FlinkRelMdUtil.binaryRowAverageSize(buildSide));
+        double filterRatio = computeFilterRatio(buildSide, probeSide, 
buildIndices, probeIndices);
+
+        RelNode localBuilder =
+                new BatchPhysicalLocalRuntimeFilterBuilder(
+                        buildSide.getCluster(),
+                        buildSide.getTraitSet(),
+                        buildSide,
+                        buildIndices.toIntArray(),
+                        buildRowCount,
+                        maxRowCount);
+        RelNode globalBuilder =
+                new BatchPhysicalGlobalRuntimeFilterBuilder(
+                        localBuilder.getCluster(),
+                        localBuilder.getTraitSet(),
+                        createExchange(localBuilder, 
FlinkRelDistribution.SINGLETON()),
+                        maxRowCount);
+        RelNode runtimeFilter =
+                new BatchPhysicalRuntimeFilter(
+                        probeSide.getCluster(),
+                        probeSide.getTraitSet(),
+                        createExchange(globalBuilder, 
FlinkRelDistribution.BROADCAST_DISTRIBUTED()),
+                        probeSide,
+                        probeIndices.toIntArray(),
+                        filterRatio);
+
+        return runtimeFilter;
+    }
+
+    /**
+     * Find a suitable build side. In order not to affect MultiInput, when the 
original build side
+     * of runtime filter is not an {@link Exchange}, we need to push down the 
builder, until we find
+     * an exchange and inject the builder there.
+     *
+     * @param rel the original build side
+     * @param buildIndices build indices
+     * @param buildSideChecker check whether current build side is suitable
+     * @return An optional info of the suitable build side.It will be empty if 
we cannot find the
+     *     suitable build side.
+     */
+    private static Optional<BuildSideInfo> findSuitableBuildSide(
+            RelNode rel,
+            ImmutableIntList buildIndices,
+            BiFunction<RelNode, ImmutableIntList, Boolean> buildSideChecker) {
+        if (rel instanceof Exchange) {
+            Exchange exchange = (Exchange) rel;
+            if (!(exchange.getInput() instanceof BatchPhysicalRuntimeFilter)
+                    && buildSideChecker.apply(exchange.getInput(), 
buildIndices)) {
+                return Optional.of(new BuildSideInfo(exchange.getInput(), 
buildIndices));
+            }
+        } else if (rel instanceof BatchPhysicalRuntimeFilter) {
+            // runtime filter should not as build side
+            return Optional.empty();
+        } else if (rel instanceof Calc) {
+            Calc calc = ((Calc) rel);
+            RexProgram program = calc.getProgram();
+            List<RexNode> projects =
+                    program.getProjectList().stream()
+                            .map(program::expandLocalRef)
+                            .collect(Collectors.toList());
+            ImmutableIntList inputIndices = getInputIndices(projects, 
buildIndices);
+            if (inputIndices.isEmpty()) {
+                return Optional.empty();
+            }
+            return findSuitableBuildSide(calc.getInput(), inputIndices, 
buildSideChecker);
+
+        } else if (rel instanceof Join) {
+            Join join = (Join) rel;
+            if (!(join.getLeft() instanceof Exchange) && !(join.getRight() 
instanceof Exchange)) {
+                // TODO: Is there such case?
+                return Optional.empty();
+            }
+
+            Tuple2<ImmutableIntList, ImmutableIntList> tuple2 = 
getInputIndices(join, buildIndices);
+            ImmutableIntList leftIndices = tuple2.f0;
+            ImmutableIntList rightIndices = tuple2.f1;
+
+            if (leftIndices.isEmpty() && rightIndices.isEmpty()) {
+                return Optional.empty();
+            }
+
+            boolean firstCheckLeft = !leftIndices.isEmpty() && join.getLeft() 
instanceof Exchange;
+            Optional<BuildSideInfo> buildSideInfoOpt = Optional.empty();
+            if (firstCheckLeft) {
+                buildSideInfoOpt =
+                        findSuitableBuildSide(join.getLeft(), leftIndices, 
buildSideChecker);
+                if (!buildSideInfoOpt.isPresent() && !rightIndices.isEmpty()) {
+                    buildSideInfoOpt =
+                            findSuitableBuildSide(join.getRight(), 
rightIndices, buildSideChecker);
+                }
+                return buildSideInfoOpt;
+            } else {
+                if (!rightIndices.isEmpty()) {
+                    buildSideInfoOpt =
+                            findSuitableBuildSide(join.getRight(), 
rightIndices, buildSideChecker);
+                    if (!buildSideInfoOpt.isPresent() && 
!leftIndices.isEmpty()) {
+                        buildSideInfoOpt =
+                                findSuitableBuildSide(
+                                        join.getLeft(), leftIndices, 
buildSideChecker);
+                    }
+                }
+            }
+            return buildSideInfoOpt;
+        } else if (rel instanceof BatchPhysicalGroupAggregateBase) {
+            BatchPhysicalGroupAggregateBase agg = 
(BatchPhysicalGroupAggregateBase) rel;
+            int[] grouping = agg.grouping();
+
+            // If one of keys are aggregate function field, return directly.
+            for (int k : buildIndices) {
+                if (k >= grouping.length) {
+                    return Optional.empty();
+                }
+            }
+
+            return findSuitableBuildSide(
+                    agg.getInput(),
+                    ImmutableIntList.copyOf(
+                            buildIndices.stream()
+                                    .map(index -> agg.grouping()[index])
+                                    .collect(Collectors.toList())),
+                    buildSideChecker);
+
+        } else {
+            // more cases

Review Comment:
   TPCDS-related tests are still in progress



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to