Github user fhueske commented on a diff in the pull request:
https://github.com/apache/flink/pull/1255#discussion_r46575587
--- Diff:
flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/RangePartitionRewriter.java
---
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.optimizer.traversals;
+
+import org.apache.flink.api.common.distributions.CommonRangeBoundaries;
+import org.apache.flink.api.common.operators.UnaryOperatorInformation;
+import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase;
+import org.apache.flink.api.common.operators.base.MapOperatorBase;
+import org.apache.flink.api.common.operators.base.MapPartitionOperatorBase;
+import org.apache.flink.api.common.operators.util.FieldList;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
+import org.apache.flink.api.java.typeutils.RecordTypeInfo;
+import org.apache.flink.runtime.io.network.DataExchangeMode;
+import org.apache.flink.runtime.operators.udf.AssignRangeIndex;
+import org.apache.flink.runtime.operators.udf.PartitionIDRemoveWrapper;
+import org.apache.flink.runtime.operators.udf.RangeBoundaryBuilder;
+import org.apache.flink.api.java.functions.SampleInCoordinator;
+import org.apache.flink.api.java.functions.SampleInPartition;
+import org.apache.flink.api.java.sampling.IntermediateSampleData;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.optimizer.dag.GroupReduceNode;
+import org.apache.flink.optimizer.dag.MapNode;
+import org.apache.flink.optimizer.dag.MapPartitionNode;
+import org.apache.flink.optimizer.dag.TempMode;
+import org.apache.flink.optimizer.plan.Channel;
+import org.apache.flink.optimizer.plan.NamedChannel;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.PlanNode;
+import org.apache.flink.optimizer.plan.SingleInputPlanNode;
+import org.apache.flink.optimizer.util.Utils;
+import org.apache.flink.runtime.operators.DriverStrategy;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.apache.flink.util.Visitor;
+
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+
+public class RangePartitionRewriter implements Visitor<PlanNode> {
+ final static long SEED = org.apache.flink.api.java.Utils.RNG.nextLong();
+
+ final OptimizedPlan plan;
+
+ public RangePartitionRewriter(OptimizedPlan plan) {
+ this.plan = plan;
+ }
+
+ @Override
+ public boolean preVisit(PlanNode visitable) {
+ return true;
+ }
+
+ @Override
+ public void postVisit(PlanNode visitable) {
+ final List<Channel> outgoingChannels =
visitable.getOutgoingChannels();
+ final List<Channel> newOutGoingChannels = new LinkedList<>();
+ final List<Channel> toBeRemoveChannels = new ArrayList<>();
+ for (Channel channel : outgoingChannels) {
+ ShipStrategyType shipStrategy =
channel.getShipStrategy();
+ if (shipStrategy == ShipStrategyType.PARTITION_RANGE) {
+ TypeInformation<?> outputType =
channel.getSource().getProgramOperator().getOperatorInfo().getOutputType();
+ // Do not optimize for record type, it's a
special case for range partitioner, and should be removed later.
+ if (!(outputType instanceof RecordTypeInfo)) {
+
newOutGoingChannels.addAll(rewriteRangePartitionChannel(channel));
+ toBeRemoveChannels.add(channel);
+ }
+ }
+ }
+
+ for (Channel chan : toBeRemoveChannels) {
+ outgoingChannels.remove(chan);
+ }
+ outgoingChannels.addAll(newOutGoingChannels);
+ }
+
+ private List<Channel> rewriteRangePartitionChannel(Channel channel) {
+ final List<Channel> sourceNewOutputChannels = new ArrayList<>();
+ final PlanNode sourceNode = channel.getSource();
+ final PlanNode targetNode = channel.getTarget();
+ final int sourceParallelism = sourceNode.getParallelism();
+ final int targetParallelism = targetNode.getParallelism();
+ final TypeComparatorFactory<?> comparator =
Utils.getShipComparator(channel,
this.plan.getOriginalPlan().getExecutionConfig());
+ // 1. Fixed size sample in each partitions.
+ final int sampleSize = 20 * targetParallelism;
+ final SampleInPartition sampleInPartition = new
SampleInPartition(false, sampleSize, SEED);
+ final TypeInformation<?> sourceOutputType =
sourceNode.getOptimizerNode().getOperator().getOperatorInfo().getOutputType();
+ final TypeInformation<IntermediateSampleData>
isdTypeInformation = TypeExtractor.getForClass(IntermediateSampleData.class);
+ final UnaryOperatorInformation sipOperatorInformation = new
UnaryOperatorInformation(sourceOutputType, isdTypeInformation);
+ final MapPartitionOperatorBase sipOperatorBase = new
MapPartitionOperatorBase(sampleInPartition, sipOperatorInformation, "Sample in
partitions");
+ final MapPartitionNode sipNode = new
MapPartitionNode(sipOperatorBase);
+ final Channel sipChannel = new Channel(sourceNode,
TempMode.NONE);
+ sipChannel.setShipStrategy(ShipStrategyType.FORWARD,
DataExchangeMode.PIPELINED);
+ final SingleInputPlanNode sipPlanNode = new
SingleInputPlanNode(sipNode, "SampleInPartition PlanNode", sipChannel,
DriverStrategy.MAP_PARTITION);
+ sipPlanNode.setParallelism(sourceParallelism);
+ sipChannel.setTarget(sipPlanNode);
+ this.plan.getAllNodes().add(sipPlanNode);
+ sourceNewOutputChannels.add(sipChannel);
+
+ // 2. Fixed size sample in a single coordinator.
+ final SampleInCoordinator sampleInCoordinator = new
SampleInCoordinator(false, sampleSize, SEED);
+ final UnaryOperatorInformation sicOperatorInformation = new
UnaryOperatorInformation(isdTypeInformation, sourceOutputType);
+ final GroupReduceOperatorBase sicOperatorBase = new
GroupReduceOperatorBase(sampleInCoordinator, sicOperatorInformation, "Sample in
coordinator");
+ final GroupReduceNode sicNode = new
GroupReduceNode(sicOperatorBase);
+ final Channel sicChannel = new Channel(sipPlanNode,
TempMode.NONE);
+ sicChannel.setShipStrategy(ShipStrategyType.FORWARD,
channel.getShipStrategyKeys(), channel.getShipStrategySortOrder(), null,
DataExchangeMode.PIPELINED);
+ final SingleInputPlanNode sicPlanNode = new
SingleInputPlanNode(sicNode, "SampleInCoordinator PlanNode", sicChannel,
DriverStrategy.ALL_GROUP_REDUCE);
+ sicPlanNode.setParallelism(1);
+ sicChannel.setTarget(sicPlanNode);
+ sipPlanNode.addOutgoingChannel(sicChannel);
+ this.plan.getAllNodes().add(sicPlanNode);
+
+ // 3. Use sampled data to build range boundaries.
+ final RangeBoundaryBuilder rangeBoundaryBuilder = new
RangeBoundaryBuilder(comparator, targetParallelism);
+ final TypeInformation<CommonRangeBoundaries> rbTypeInformation
= TypeExtractor.getForClass(CommonRangeBoundaries.class);
+ final UnaryOperatorInformation rbOperatorInformation = new
UnaryOperatorInformation(sourceOutputType, rbTypeInformation);
+ final MapPartitionOperatorBase rbOperatorBase = new
MapPartitionOperatorBase(rangeBoundaryBuilder, rbOperatorInformation,
"RangeBoundaryBuilder");
+ final MapPartitionNode rbNode = new
MapPartitionNode(rbOperatorBase);
+ final Channel rbChannel = new Channel(sicPlanNode,
TempMode.NONE);
+ rbChannel.setShipStrategy(ShipStrategyType.FORWARD,
DataExchangeMode.PIPELINED);
+ final SingleInputPlanNode rbPlanNode = new
SingleInputPlanNode(rbNode, "RangeBoundary PlanNode", rbChannel,
DriverStrategy.MAP_PARTITION);
+ rbPlanNode.setParallelism(1);
+ rbChannel.setTarget(rbPlanNode);
+ sicPlanNode.addOutgoingChannel(rbChannel);
+ this.plan.getAllNodes().add(rbPlanNode);
+
+ // 4. Take range boundaries as broadcast input and take the
tuple of partition id and record as output.
+ final AssignRangeIndex assignRangeIndex = new
AssignRangeIndex(comparator);
+ final TypeInformation<Tuple2> ariOutputTypeInformation = new
TupleTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO, sourceOutputType);
+ final UnaryOperatorInformation ariOperatorInformation = new
UnaryOperatorInformation(sourceOutputType, ariOutputTypeInformation);
+ final MapPartitionOperatorBase ariOperatorBase = new
MapPartitionOperatorBase(assignRangeIndex, ariOperatorInformation, "Assign
Range Index");
+ final MapPartitionNode ariNode = new
MapPartitionNode(ariOperatorBase);
+ final Channel ariChannel = new Channel(sourceNode,
TempMode.NONE);
+ // To avoid deadlock, set the DataExchangeMode of channel
between source node and this to Batch.
+ ariChannel.setShipStrategy(ShipStrategyType.FORWARD,
DataExchangeMode.BATCH);
+ final SingleInputPlanNode ariPlanNode = new
SingleInputPlanNode(ariNode, "AssignRangeIndex PlanNode", ariChannel,
DriverStrategy.MAP_PARTITION);
+ ariPlanNode.setParallelism(sourceParallelism);
+ ariChannel.setTarget(ariPlanNode);
+ this.plan.getAllNodes().add(ariPlanNode);
+ sourceNewOutputChannels.add(ariChannel);
+
+ final NamedChannel broadcastChannel = new
NamedChannel("RangeBoundaries", rbPlanNode);
+ broadcastChannel.setShipStrategy(ShipStrategyType.BROADCAST,
DataExchangeMode.PIPELINED);
+ broadcastChannel.setTarget(ariPlanNode);
+ List<NamedChannel> broadcastChannels = new ArrayList<>(1);
+ broadcastChannels.add(broadcastChannel);
+ ariPlanNode.setBroadcastInputs(broadcastChannels);
+
+ // 5. Remove the partition id.
+ final Channel partChannel = new Channel(ariPlanNode,
TempMode.NONE);
+ partChannel.setDataExchangeMode(DataExchangeMode.PIPELINED);
+ final FieldList keys = new FieldList(0);
+ final boolean[] sortDirection = { true };
+ partChannel.setShipStrategy(ShipStrategyType.PARTITION_RANGE,
keys, sortDirection, null, DataExchangeMode.PIPELINED);
+ ariPlanNode.addOutgoingChannel(channel);
+ partChannel.setLocalStrategy(channel.getLocalStrategy(), keys,
sortDirection);
--- End diff --
We cannot remove the target node and apply its local strategy on the
`PartitionIDRemoveWrapper` because
1. the keys of the local strategy are changed (here keys are fixed to `{0}`
and order to `{true}`.
2. the local strategy is applied before the user functions, i.e.,
originally it is applied on the `SourceOut` type and here it is applied on the
`Tuple2<Int, SourceOut>` type.
I would like to
- keep the `target` node and its local strategy
- connect the `PartitionIDRemoveWrapper` to the `target` node
- change the ship strategy of the channel to the target node from
`PARTITION_RANGE` to `FORWARD` and the data exchange mode to `PIPELINED`.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---