Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/1255#discussion_r46575587 --- Diff: flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/RangePartitionRewriter.java --- @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.optimizer.traversals; + +import org.apache.flink.api.common.distributions.CommonRangeBoundaries; +import org.apache.flink.api.common.operators.UnaryOperatorInformation; +import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase; +import org.apache.flink.api.common.operators.base.MapOperatorBase; +import org.apache.flink.api.common.operators.base.MapPartitionOperatorBase; +import org.apache.flink.api.common.operators.util.FieldList; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparatorFactory; +import org.apache.flink.api.java.typeutils.RecordTypeInfo; +import org.apache.flink.runtime.io.network.DataExchangeMode; +import org.apache.flink.runtime.operators.udf.AssignRangeIndex; +import org.apache.flink.runtime.operators.udf.PartitionIDRemoveWrapper; +import org.apache.flink.runtime.operators.udf.RangeBoundaryBuilder; +import org.apache.flink.api.java.functions.SampleInCoordinator; +import org.apache.flink.api.java.functions.SampleInPartition; +import org.apache.flink.api.java.sampling.IntermediateSampleData; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.optimizer.dag.GroupReduceNode; +import org.apache.flink.optimizer.dag.MapNode; +import org.apache.flink.optimizer.dag.MapPartitionNode; +import org.apache.flink.optimizer.dag.TempMode; +import org.apache.flink.optimizer.plan.Channel; +import org.apache.flink.optimizer.plan.NamedChannel; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.PlanNode; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.util.Utils; +import org.apache.flink.runtime.operators.DriverStrategy; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.apache.flink.util.Visitor; + +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; + +public class RangePartitionRewriter implements Visitor<PlanNode> { + final static long SEED = org.apache.flink.api.java.Utils.RNG.nextLong(); + + final OptimizedPlan plan; + + public RangePartitionRewriter(OptimizedPlan plan) { + this.plan = plan; + } + + @Override + public boolean preVisit(PlanNode visitable) { + return true; + } + + @Override + public void postVisit(PlanNode visitable) { + final List<Channel> outgoingChannels = visitable.getOutgoingChannels(); + final List<Channel> newOutGoingChannels = new LinkedList<>(); + final List<Channel> toBeRemoveChannels = new ArrayList<>(); + for (Channel channel : outgoingChannels) { + ShipStrategyType shipStrategy = channel.getShipStrategy(); + if (shipStrategy == ShipStrategyType.PARTITION_RANGE) { + TypeInformation<?> outputType = channel.getSource().getProgramOperator().getOperatorInfo().getOutputType(); + // Do not optimize for record type, it's a special case for range partitioner, and should be removed later. + if (!(outputType instanceof RecordTypeInfo)) { + newOutGoingChannels.addAll(rewriteRangePartitionChannel(channel)); + toBeRemoveChannels.add(channel); + } + } + } + + for (Channel chan : toBeRemoveChannels) { + outgoingChannels.remove(chan); + } + outgoingChannels.addAll(newOutGoingChannels); + } + + private List<Channel> rewriteRangePartitionChannel(Channel channel) { + final List<Channel> sourceNewOutputChannels = new ArrayList<>(); + final PlanNode sourceNode = channel.getSource(); + final PlanNode targetNode = channel.getTarget(); + final int sourceParallelism = sourceNode.getParallelism(); + final int targetParallelism = targetNode.getParallelism(); + final TypeComparatorFactory<?> comparator = Utils.getShipComparator(channel, this.plan.getOriginalPlan().getExecutionConfig()); + // 1. Fixed size sample in each partitions. + final int sampleSize = 20 * targetParallelism; + final SampleInPartition sampleInPartition = new SampleInPartition(false, sampleSize, SEED); + final TypeInformation<?> sourceOutputType = sourceNode.getOptimizerNode().getOperator().getOperatorInfo().getOutputType(); + final TypeInformation<IntermediateSampleData> isdTypeInformation = TypeExtractor.getForClass(IntermediateSampleData.class); + final UnaryOperatorInformation sipOperatorInformation = new UnaryOperatorInformation(sourceOutputType, isdTypeInformation); + final MapPartitionOperatorBase sipOperatorBase = new MapPartitionOperatorBase(sampleInPartition, sipOperatorInformation, "Sample in partitions"); + final MapPartitionNode sipNode = new MapPartitionNode(sipOperatorBase); + final Channel sipChannel = new Channel(sourceNode, TempMode.NONE); + sipChannel.setShipStrategy(ShipStrategyType.FORWARD, DataExchangeMode.PIPELINED); + final SingleInputPlanNode sipPlanNode = new SingleInputPlanNode(sipNode, "SampleInPartition PlanNode", sipChannel, DriverStrategy.MAP_PARTITION); + sipPlanNode.setParallelism(sourceParallelism); + sipChannel.setTarget(sipPlanNode); + this.plan.getAllNodes().add(sipPlanNode); + sourceNewOutputChannels.add(sipChannel); + + // 2. Fixed size sample in a single coordinator. + final SampleInCoordinator sampleInCoordinator = new SampleInCoordinator(false, sampleSize, SEED); + final UnaryOperatorInformation sicOperatorInformation = new UnaryOperatorInformation(isdTypeInformation, sourceOutputType); + final GroupReduceOperatorBase sicOperatorBase = new GroupReduceOperatorBase(sampleInCoordinator, sicOperatorInformation, "Sample in coordinator"); + final GroupReduceNode sicNode = new GroupReduceNode(sicOperatorBase); + final Channel sicChannel = new Channel(sipPlanNode, TempMode.NONE); + sicChannel.setShipStrategy(ShipStrategyType.FORWARD, channel.getShipStrategyKeys(), channel.getShipStrategySortOrder(), null, DataExchangeMode.PIPELINED); + final SingleInputPlanNode sicPlanNode = new SingleInputPlanNode(sicNode, "SampleInCoordinator PlanNode", sicChannel, DriverStrategy.ALL_GROUP_REDUCE); + sicPlanNode.setParallelism(1); + sicChannel.setTarget(sicPlanNode); + sipPlanNode.addOutgoingChannel(sicChannel); + this.plan.getAllNodes().add(sicPlanNode); + + // 3. Use sampled data to build range boundaries. + final RangeBoundaryBuilder rangeBoundaryBuilder = new RangeBoundaryBuilder(comparator, targetParallelism); + final TypeInformation<CommonRangeBoundaries> rbTypeInformation = TypeExtractor.getForClass(CommonRangeBoundaries.class); + final UnaryOperatorInformation rbOperatorInformation = new UnaryOperatorInformation(sourceOutputType, rbTypeInformation); + final MapPartitionOperatorBase rbOperatorBase = new MapPartitionOperatorBase(rangeBoundaryBuilder, rbOperatorInformation, "RangeBoundaryBuilder"); + final MapPartitionNode rbNode = new MapPartitionNode(rbOperatorBase); + final Channel rbChannel = new Channel(sicPlanNode, TempMode.NONE); + rbChannel.setShipStrategy(ShipStrategyType.FORWARD, DataExchangeMode.PIPELINED); + final SingleInputPlanNode rbPlanNode = new SingleInputPlanNode(rbNode, "RangeBoundary PlanNode", rbChannel, DriverStrategy.MAP_PARTITION); + rbPlanNode.setParallelism(1); + rbChannel.setTarget(rbPlanNode); + sicPlanNode.addOutgoingChannel(rbChannel); + this.plan.getAllNodes().add(rbPlanNode); + + // 4. Take range boundaries as broadcast input and take the tuple of partition id and record as output. + final AssignRangeIndex assignRangeIndex = new AssignRangeIndex(comparator); + final TypeInformation<Tuple2> ariOutputTypeInformation = new TupleTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO, sourceOutputType); + final UnaryOperatorInformation ariOperatorInformation = new UnaryOperatorInformation(sourceOutputType, ariOutputTypeInformation); + final MapPartitionOperatorBase ariOperatorBase = new MapPartitionOperatorBase(assignRangeIndex, ariOperatorInformation, "Assign Range Index"); + final MapPartitionNode ariNode = new MapPartitionNode(ariOperatorBase); + final Channel ariChannel = new Channel(sourceNode, TempMode.NONE); + // To avoid deadlock, set the DataExchangeMode of channel between source node and this to Batch. + ariChannel.setShipStrategy(ShipStrategyType.FORWARD, DataExchangeMode.BATCH); + final SingleInputPlanNode ariPlanNode = new SingleInputPlanNode(ariNode, "AssignRangeIndex PlanNode", ariChannel, DriverStrategy.MAP_PARTITION); + ariPlanNode.setParallelism(sourceParallelism); + ariChannel.setTarget(ariPlanNode); + this.plan.getAllNodes().add(ariPlanNode); + sourceNewOutputChannels.add(ariChannel); + + final NamedChannel broadcastChannel = new NamedChannel("RangeBoundaries", rbPlanNode); + broadcastChannel.setShipStrategy(ShipStrategyType.BROADCAST, DataExchangeMode.PIPELINED); + broadcastChannel.setTarget(ariPlanNode); + List<NamedChannel> broadcastChannels = new ArrayList<>(1); + broadcastChannels.add(broadcastChannel); + ariPlanNode.setBroadcastInputs(broadcastChannels); + + // 5. Remove the partition id. + final Channel partChannel = new Channel(ariPlanNode, TempMode.NONE); + partChannel.setDataExchangeMode(DataExchangeMode.PIPELINED); + final FieldList keys = new FieldList(0); + final boolean[] sortDirection = { true }; + partChannel.setShipStrategy(ShipStrategyType.PARTITION_RANGE, keys, sortDirection, null, DataExchangeMode.PIPELINED); + ariPlanNode.addOutgoingChannel(channel); + partChannel.setLocalStrategy(channel.getLocalStrategy(), keys, sortDirection); --- End diff -- We cannot remove the target node and apply its local strategy on the `PartitionIDRemoveWrapper` because 1. the keys of the local strategy are changed (here keys are fixed to `{0}` and order to `{true}`. 2. the local strategy is applied before the user functions, i.e., originally it is applied on the `SourceOut` type and here it is applied on the `Tuple2<Int, SourceOut>` type. I would like to - keep the `target` node and its local strategy - connect the `PartitionIDRemoveWrapper` to the `target` node - change the ship strategy of the channel to the target node from `PARTITION_RANGE` to `FORWARD` and the data exchange mode to `PIPELINED`.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. ---