wanglijie95 commented on code in PR #22966: URL: https://github.com/apache/flink/pull/22966#discussion_r1260503527
########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/program/FlinkRuntimeFilterProgram.java: ########## @@ -0,0 +1,571 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.optimize.program; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.planner.plan.nodes.FlinkConventions; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalHashJoin; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalSortMergeJoin; +import org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalGlobalRuntimeFilterBuilder; +import org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalLocalRuntimeFilterBuilder; +import org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalRuntimeFilter; +import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution; +import org.apache.flink.table.planner.plan.utils.DefaultRelShuttle; +import org.apache.flink.table.planner.plan.utils.FlinkRelMdUtil; +import org.apache.flink.table.planner.utils.ShortcutUtils; + +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Calc; +import org.apache.calcite.rel.core.Exchange; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinInfo; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexProgram; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableIntList; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Planner program that tries to inject runtime filter for suitable join to improve join + * performance. + * + * <p>We build the runtime filter in a two-phase manner: First, each subtask on the build side + * builds a local filter based on its local data, and sends the built filter to a global aggregation + * node. Then the global aggregation node aggregates the received filters into a global filter, and + * sends the global filter to all probe side subtasks. Therefore, we will add {@link + * BatchPhysicalLocalRuntimeFilterBuilder}, {@link BatchPhysicalGlobalRuntimeFilterBuilder} and + * {@link BatchPhysicalRuntimeFilter} into the physical plan. + * + * <p>For example, for the following query: + * + * <pre>{@code SELECT * FROM fact, dim WHERE x = a AND z = 2}</pre> + * + * <p>The original physical plan: + * + * <pre>{@code + * Calc(select=[a, b, c, x, y, CAST(2 AS BIGINT) AS z]) + * +- HashJoin(joinType=[InnerJoin], where=[=(x, a)], select=[a, b, c, x, y], build=[right]) + * :- Exchange(distribution=[hash[a]]) + * : +- TableSourceScan(table=[[fact]], fields=[a, b, c]) + * +- Exchange(distribution=[hash[x]]) + * +- Calc(select=[x, y], where=[=(z, 2)]) + * +- TableSourceScan(table=[[dim, filter=[]]], fields=[x, y, z]) + * }</pre> + * + * <p>This optimized physical plan: + * + * <pre>{@code + * Calc(select=[a, b, c, x, y, CAST(2 AS BIGINT) AS z]) + * +- HashJoin(joinType=[InnerJoin], where=[=(x, a)], select=[a, b, c, x, y], build=[right]) + * :- Exchange(distribution=[hash[a]]) + * : +- RuntimeFilter(select=[a]) + * : :- Exchange(distribution=[broadcast]) + * : : +- GlobalRuntimeFilterBuilder + * : : +- Exchange(distribution=[single]) + * : : +- LocalRuntimeFilterBuilder(select=[x]) + * : : +- Calc(select=[x, y], where=[=(z, 2)]) + * : : +- TableSourceScan(table=[[dim, filter=[]]], fields=[x, y, z]) + * : +- TableSourceScan(table=[[fact]], fields=[a, b, c]) + * +- Exchange(distribution=[hash[x]]) + * +- Calc(select=[x, y], where=[=(z, 2)]) + * +- TableSourceScan(table=[[dim, filter=[]]], fields=[x, y, z]) + * + * }</pre> + */ +public class FlinkRuntimeFilterProgram implements FlinkOptimizeProgram<BatchOptimizeContext> { + + @Override + public RelNode optimize(RelNode root, BatchOptimizeContext context) { + final RuntimeFilterConfig config = + new RuntimeFilterConfig(ShortcutUtils.unwrapContext(root).getTableConfig()); + if (!config.runtimeFilterEnabled) { + return root; + } + + DefaultRelShuttle shuttle = + new DefaultRelShuttle() { + @Override + public RelNode visit(RelNode rel) { + if (!(rel instanceof Join)) { + List<RelNode> newInputs = new ArrayList<>(); + for (RelNode input : rel.getInputs()) { + RelNode newInput = input.accept(this); + newInputs.add(newInput); + } + return rel.copy(rel.getTraitSet(), newInputs); + } + + Join join = (Join) rel; + RelNode newLeft = join.getLeft().accept(this); + RelNode newRight = join.getRight().accept(this); + + return tryInjectRuntimeFilter( + join.copy(join.getTraitSet(), Arrays.asList(newLeft, newRight)), + config); + } + }; + return shuttle.visit(root); + } + + /** + * Judge whether the join is suitable, and try to inject runtime filter for it. + * + * @param join the join node + * @param config the runtime filter configuration + * @return the new join node with runtime filter. + */ + private static Join tryInjectRuntimeFilter(Join join, RuntimeFilterConfig config) { + + // check supported join type + if (join.getJoinType() != JoinRelType.INNER + && join.getJoinType() != JoinRelType.SEMI + && join.getJoinType() != JoinRelType.LEFT + && join.getJoinType() != JoinRelType.RIGHT) { + return join; + } + + // check supported join implementation + if (!(join instanceof BatchPhysicalHashJoin) + && !(join instanceof BatchPhysicalSortMergeJoin)) { + return join; + } + + boolean leftIsBuild; + if (canBeProbeSide(join.getLeft(), config)) { + leftIsBuild = false; + } else if (canBeProbeSide(join.getRight(), config)) { + leftIsBuild = true; + } else { + return join; + } + + // check left join + left build + if (join.getJoinType() == JoinRelType.LEFT && !leftIsBuild) { + return join; + } + + // check right join + right build + if (join.getJoinType() == JoinRelType.RIGHT && leftIsBuild) { + return join; + } + + JoinInfo joinInfo = join.analyzeCondition(); + RelNode buildSide; + RelNode probeSide; + ImmutableIntList buildIndices; + ImmutableIntList probeIndices; + if (leftIsBuild) { + buildSide = join.getLeft(); + probeSide = join.getRight(); + buildIndices = joinInfo.leftKeys; + probeIndices = joinInfo.rightKeys; + } else { + buildSide = join.getRight(); + probeSide = join.getLeft(); + buildIndices = joinInfo.rightKeys; + probeIndices = joinInfo.leftKeys; + } + + Optional<BuildSideInfo> suitableBuildOpt = + findSuitableBuildSide( + buildSide, + buildIndices, + (build, indices) -> + isSuitableDataSize( + build, probeSide, indices, probeIndices, config)); + + if (suitableBuildOpt.isPresent()) { + BuildSideInfo suitableBuildInfo = suitableBuildOpt.get(); + RelNode newProbe = + createNewProbeWithRuntimeFilter( + ignoreExchange(suitableBuildInfo.buildSide), + ignoreExchange(probeSide), + suitableBuildInfo.buildIndices, + probeIndices, + config); + if (probeSide instanceof Exchange) { + newProbe = + ((Exchange) probeSide) + .copy(probeSide.getTraitSet(), Collections.singletonList(newProbe)); + } + if (leftIsBuild) { + return join.copy(join.getTraitSet(), Arrays.asList(buildSide, newProbe)); + } else { + return join.copy(join.getTraitSet(), Arrays.asList(newProbe, buildSide)); + } + } + + return join; + } + + /** + * Inject runtime filter and return the new probe side (without exchange). + * + * @param buildSide the build side + * @param probeSide the probe side + * @param buildIndices the build projection + * @param probeIndices the probe projection + * @return the new probe side + */ + private static RelNode createNewProbeWithRuntimeFilter( + RelNode buildSide, + RelNode probeSide, + ImmutableIntList buildIndices, + ImmutableIntList probeIndices, + RuntimeFilterConfig config) { + Optional<Double> buildRowCountOpt = getEstimatedRowCount(buildSide); + checkState(buildRowCountOpt.isPresent()); + int buildRowCount = buildRowCountOpt.get().intValue(); + int maxRowCount = + (int) + Math.ceil( + config.maxBuildSize + / FlinkRelMdUtil.binaryRowAverageSize(buildSide)); + double filterRatio = computeFilterRatio(buildSide, probeSide, buildIndices, probeIndices); + + RelNode localBuilder = + new BatchPhysicalLocalRuntimeFilterBuilder( + buildSide.getCluster(), + buildSide.getTraitSet(), + buildSide, + buildIndices.toIntArray(), + buildRowCount, + maxRowCount); + RelNode globalBuilder = + new BatchPhysicalGlobalRuntimeFilterBuilder( + localBuilder.getCluster(), + localBuilder.getTraitSet(), + createExchange(localBuilder, FlinkRelDistribution.SINGLETON()), + maxRowCount); + RelNode runtimeFilter = + new BatchPhysicalRuntimeFilter( + probeSide.getCluster(), + probeSide.getTraitSet(), + createExchange(globalBuilder, FlinkRelDistribution.BROADCAST_DISTRIBUTED()), + probeSide, + probeIndices.toIntArray(), + filterRatio); + + return runtimeFilter; + } + + /** + * Find a suitable build side. In order not to affect MultiInput, when the original build side + * of runtime filter is not an {@link Exchange}, we need to push down the builder, until we find + * an exchange and inject the builder there. + * + * @param rel the original build side + * @param buildIndices build indices + * @param buildSideChecker check whether current build side is suitable + * @return An optional info of the suitable build side.It will be empty if we cannot find the + * suitable build side. + */ + private static Optional<BuildSideInfo> findSuitableBuildSide( + RelNode rel, + ImmutableIntList buildIndices, + BiFunction<RelNode, ImmutableIntList, Boolean> buildSideChecker) { + if (rel instanceof Exchange) { + Exchange exchange = (Exchange) rel; + if (!(exchange.getInput() instanceof BatchPhysicalRuntimeFilter) + && buildSideChecker.apply(exchange.getInput(), buildIndices)) { + return Optional.of(new BuildSideInfo(exchange.getInput(), buildIndices)); + } + } else if (rel instanceof BatchPhysicalRuntimeFilter) { + // runtime filter should not as build side + return Optional.empty(); + } else if (rel instanceof Calc) { + Calc calc = ((Calc) rel); + RexProgram program = calc.getProgram(); + List<RexNode> projects = + program.getProjectList().stream() + .map(program::expandLocalRef) + .collect(Collectors.toList()); + ImmutableIntList inputIndices = getInputIndices(projects, buildIndices); + if (inputIndices.isEmpty()) { + return Optional.empty(); + } + return findSuitableBuildSide(calc.getInput(), inputIndices, buildSideChecker); + + } else if (rel instanceof Join) { + Join join = (Join) rel; + if (!(join.getLeft() instanceof Exchange) && !(join.getRight() instanceof Exchange)) { + // TODO: Is there such case? + return Optional.empty(); + } + + Tuple2<ImmutableIntList, ImmutableIntList> tuple2 = getInputIndices(join, buildIndices); + ImmutableIntList leftIndices = tuple2.f0; + ImmutableIntList rightIndices = tuple2.f1; + + if (leftIndices.isEmpty() && rightIndices.isEmpty()) { + return Optional.empty(); + } + + boolean firstCheckLeft = !leftIndices.isEmpty() && join.getLeft() instanceof Exchange; + Optional<BuildSideInfo> buildSideInfoOpt = Optional.empty(); + if (firstCheckLeft) { + buildSideInfoOpt = + findSuitableBuildSide(join.getLeft(), leftIndices, buildSideChecker); + if (!buildSideInfoOpt.isPresent() && !rightIndices.isEmpty()) { + buildSideInfoOpt = + findSuitableBuildSide(join.getRight(), rightIndices, buildSideChecker); + } + return buildSideInfoOpt; + } else { + if (!rightIndices.isEmpty()) { + buildSideInfoOpt = + findSuitableBuildSide(join.getRight(), rightIndices, buildSideChecker); + if (!buildSideInfoOpt.isPresent() && !leftIndices.isEmpty()) { + buildSideInfoOpt = + findSuitableBuildSide( + join.getLeft(), leftIndices, buildSideChecker); + } + } + } + return buildSideInfoOpt; + } else if (rel instanceof BatchPhysicalGroupAggregateBase) { + BatchPhysicalGroupAggregateBase agg = (BatchPhysicalGroupAggregateBase) rel; + int[] grouping = agg.grouping(); + + // If one of keys are aggregate function field, return directly. + for (int k : buildIndices) { + if (k >= grouping.length) { + return Optional.empty(); + } + } + + return findSuitableBuildSide( + agg.getInput(), + ImmutableIntList.copyOf( + buildIndices.stream() + .map(index -> agg.grouping()[index]) + .collect(Collectors.toList())), + buildSideChecker); + + } else { + // more cases + } + + return Optional.empty(); + } + + private static BatchPhysicalExchange createExchange( + RelNode input, FlinkRelDistribution newDistribution) { + RelTraitSet newTraitSet = + input.getCluster() + .getPlanner() + .emptyTraitSet() + .replace(FlinkConventions.BATCH_PHYSICAL()) + .replace(newDistribution); + + return new BatchPhysicalExchange(input.getCluster(), newTraitSet, input, newDistribution); + } + + private static ImmutableIntList getInputIndices( + List<RexNode> projects, ImmutableIntList outputIndices) { + List<Integer> inputIndices = new ArrayList<>(); + for (int k : outputIndices) { + RexNode rexNode = projects.get(k); + if (!(rexNode instanceof RexInputRef)) { + return ImmutableIntList.of(); + } + inputIndices.add(((RexInputRef) rexNode).getIndex()); + } + return ImmutableIntList.copyOf(inputIndices); + } + + private static Tuple2<ImmutableIntList, ImmutableIntList> getInputIndices( + Join join, ImmutableIntList outputIndices) { + JoinInfo joinInfo = join.analyzeCondition(); + Map<Integer, Integer> leftToRightJoinKeysMapping = + createKeysMapping(joinInfo.leftKeys, joinInfo.rightKeys); + Map<Integer, Integer> rightToLeftJoinKeysMapping = + createKeysMapping(joinInfo.rightKeys, joinInfo.leftKeys); + + List<Integer> leftIndices = new ArrayList<>(); + List<Integer> rightIndices = new ArrayList<>(); + + int leftFieldCnt = join.getLeft().getRowType().getFieldCount(); + for (int index : outputIndices) { + if (index < leftFieldCnt) { + leftIndices.add(index); + // if it's join key, map to right + if (leftToRightJoinKeysMapping.containsKey(index)) { + rightIndices.add(leftToRightJoinKeysMapping.get(index)); + } + } else { + int rightIndex = index - leftFieldCnt; + rightIndices.add(rightIndex); + // if it's join key, map to left + if (rightToLeftJoinKeysMapping.containsKey(rightIndex)) { + leftIndices.add(rightToLeftJoinKeysMapping.get(rightIndex)); + } + } + } + + ImmutableIntList left = + leftIndices.size() == outputIndices.size() + ? ImmutableIntList.copyOf(leftIndices) + : ImmutableIntList.of(); + ImmutableIntList right = + rightIndices.size() == outputIndices.size() + ? ImmutableIntList.copyOf(rightIndices) + : ImmutableIntList.of(); + + return Tuple2.of(left, right); + } + + private static Map<Integer, Integer> createKeysMapping( + ImmutableIntList keyList1, ImmutableIntList keyList2) { + checkState(keyList1.size() == keyList2.size()); + Map<Integer, Integer> mapping = new HashMap<>(); + for (int i = 0; i < keyList1.size(); ++i) { + mapping.put(keyList1.get(i), keyList2.get(i)); + } + return mapping; + } + + private static boolean canBeProbeSide(RelNode rel, RuntimeFilterConfig config) { + Optional<Double> size = getEstimatedDataSize(rel); + return size.isPresent() && size.get() >= config.minProbeSize; + } + + private static boolean isSuitableDataSize( + RelNode buildSide, + RelNode probeSide, + ImmutableIntList buildIndices, + ImmutableIntList probeIndices, + RuntimeFilterConfig config) { + Optional<Double> buildSize = getEstimatedDataSize(buildSide); + Optional<Double> probeSize = getEstimatedDataSize(probeSide); + + if (!buildSize.isPresent() || !probeSize.isPresent()) { + return false; + } + + if (buildSize.get() > config.maxBuildSize || probeSize.get() < config.minProbeSize) { + return false; + } + + return computeFilterRatio(buildSide, probeSide, buildIndices, probeIndices) + >= config.minFilterRatio; + } + + private static double computeFilterRatio( + RelNode buildSide, + RelNode probeSide, + ImmutableIntList buildIndices, + ImmutableIntList probeIndices) { + + Optional<Double> buildNdv = getEstimatedNdv(buildSide, ImmutableBitSet.of(buildIndices)); + Optional<Double> probeNdv = getEstimatedNdv(probeSide, ImmutableBitSet.of(probeIndices)); + + if (buildNdv.isPresent() && probeNdv.isPresent()) { + return Math.max(0, 1 - buildNdv.get() / probeNdv.get()); + } else { + Optional<Double> buildSize = getEstimatedDataSize(buildSide); + Optional<Double> probeSize = getEstimatedDataSize(probeSide); + checkState(buildSize.isPresent() && probeSize.isPresent()); + return Math.max(0, 1 - buildSize.get() / probeSize.get()); + } + } + + private static RelNode ignoreExchange(RelNode relNode) { + if (relNode instanceof Exchange) { + return relNode.getInput(0); + } else { + return relNode; + } + } + + private static Optional<Double> getEstimatedDataSize(RelNode relNode) { Review Comment: Done -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
