Github user fhueske commented on a diff in the pull request:

    https://github.com/apache/flink/pull/1255#discussion_r46575587
  
    --- Diff: 
flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/RangePartitionRewriter.java
 ---
    @@ -0,0 +1,194 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.flink.optimizer.traversals;
    +
    +import org.apache.flink.api.common.distributions.CommonRangeBoundaries;
    +import org.apache.flink.api.common.operators.UnaryOperatorInformation;
    +import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase;
    +import org.apache.flink.api.common.operators.base.MapOperatorBase;
    +import org.apache.flink.api.common.operators.base.MapPartitionOperatorBase;
    +import org.apache.flink.api.common.operators.util.FieldList;
    +import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
    +import org.apache.flink.api.common.typeinfo.TypeInformation;
    +import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
    +import org.apache.flink.api.java.typeutils.RecordTypeInfo;
    +import org.apache.flink.runtime.io.network.DataExchangeMode;
    +import org.apache.flink.runtime.operators.udf.AssignRangeIndex;
    +import org.apache.flink.runtime.operators.udf.PartitionIDRemoveWrapper;
    +import org.apache.flink.runtime.operators.udf.RangeBoundaryBuilder;
    +import org.apache.flink.api.java.functions.SampleInCoordinator;
    +import org.apache.flink.api.java.functions.SampleInPartition;
    +import org.apache.flink.api.java.sampling.IntermediateSampleData;
    +import org.apache.flink.api.java.tuple.Tuple2;
    +import org.apache.flink.api.java.typeutils.TupleTypeInfo;
    +import org.apache.flink.api.java.typeutils.TypeExtractor;
    +import org.apache.flink.optimizer.dag.GroupReduceNode;
    +import org.apache.flink.optimizer.dag.MapNode;
    +import org.apache.flink.optimizer.dag.MapPartitionNode;
    +import org.apache.flink.optimizer.dag.TempMode;
    +import org.apache.flink.optimizer.plan.Channel;
    +import org.apache.flink.optimizer.plan.NamedChannel;
    +import org.apache.flink.optimizer.plan.OptimizedPlan;
    +import org.apache.flink.optimizer.plan.PlanNode;
    +import org.apache.flink.optimizer.plan.SingleInputPlanNode;
    +import org.apache.flink.optimizer.util.Utils;
    +import org.apache.flink.runtime.operators.DriverStrategy;
    +import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
    +import org.apache.flink.util.Visitor;
    +
    +import java.util.ArrayList;
    +import java.util.LinkedList;
    +import java.util.List;
    +
    +public class RangePartitionRewriter implements Visitor<PlanNode> {
    +   final static long SEED = org.apache.flink.api.java.Utils.RNG.nextLong();
    +
    +   final OptimizedPlan plan;
    +
    +   public RangePartitionRewriter(OptimizedPlan plan) {
    +           this.plan = plan;
    +   }
    +
    +   @Override
    +   public boolean preVisit(PlanNode visitable) {
    +           return true;
    +   }
    +
    +   @Override
    +   public void postVisit(PlanNode visitable) {
    +           final List<Channel> outgoingChannels = 
visitable.getOutgoingChannels();
    +           final List<Channel> newOutGoingChannels = new LinkedList<>();
    +           final List<Channel> toBeRemoveChannels = new ArrayList<>();
    +           for (Channel channel : outgoingChannels) {
    +                   ShipStrategyType shipStrategy = 
channel.getShipStrategy();
    +                   if (shipStrategy == ShipStrategyType.PARTITION_RANGE) {
    +                           TypeInformation<?> outputType = 
channel.getSource().getProgramOperator().getOperatorInfo().getOutputType();
    +                           // Do not optimize for record type, it's a 
special case for range partitioner, and should be removed later.
    +                           if (!(outputType instanceof RecordTypeInfo)) {
    +                                   
newOutGoingChannels.addAll(rewriteRangePartitionChannel(channel));
    +                                   toBeRemoveChannels.add(channel);
    +                           }
    +                   }
    +           }
    +
    +           for (Channel chan : toBeRemoveChannels) {
    +                   outgoingChannels.remove(chan);
    +           }
    +           outgoingChannels.addAll(newOutGoingChannels);
    +   }
    +
    +   private List<Channel> rewriteRangePartitionChannel(Channel channel) {
    +           final List<Channel> sourceNewOutputChannels = new ArrayList<>();
    +           final PlanNode sourceNode = channel.getSource();
    +           final PlanNode targetNode = channel.getTarget();
    +           final int sourceParallelism = sourceNode.getParallelism();
    +           final int targetParallelism = targetNode.getParallelism();
    +           final TypeComparatorFactory<?> comparator = 
Utils.getShipComparator(channel, 
this.plan.getOriginalPlan().getExecutionConfig());
    +           // 1. Fixed size sample in each partitions.
    +           final int sampleSize = 20 * targetParallelism;
    +           final SampleInPartition sampleInPartition = new 
SampleInPartition(false, sampleSize, SEED);
    +           final TypeInformation<?> sourceOutputType = 
sourceNode.getOptimizerNode().getOperator().getOperatorInfo().getOutputType();
    +           final TypeInformation<IntermediateSampleData> 
isdTypeInformation = TypeExtractor.getForClass(IntermediateSampleData.class);
    +           final UnaryOperatorInformation sipOperatorInformation = new 
UnaryOperatorInformation(sourceOutputType, isdTypeInformation);
    +           final MapPartitionOperatorBase sipOperatorBase = new 
MapPartitionOperatorBase(sampleInPartition, sipOperatorInformation, "Sample in 
partitions");
    +           final MapPartitionNode sipNode = new 
MapPartitionNode(sipOperatorBase);
    +           final Channel sipChannel = new Channel(sourceNode, 
TempMode.NONE);
    +           sipChannel.setShipStrategy(ShipStrategyType.FORWARD, 
DataExchangeMode.PIPELINED);
    +           final SingleInputPlanNode sipPlanNode = new 
SingleInputPlanNode(sipNode, "SampleInPartition PlanNode", sipChannel, 
DriverStrategy.MAP_PARTITION);
    +           sipPlanNode.setParallelism(sourceParallelism);
    +           sipChannel.setTarget(sipPlanNode);
    +           this.plan.getAllNodes().add(sipPlanNode);
    +           sourceNewOutputChannels.add(sipChannel);
    +
    +           // 2. Fixed size sample in a single coordinator.
    +           final SampleInCoordinator sampleInCoordinator = new 
SampleInCoordinator(false, sampleSize, SEED);
    +           final UnaryOperatorInformation sicOperatorInformation = new 
UnaryOperatorInformation(isdTypeInformation, sourceOutputType);
    +           final GroupReduceOperatorBase sicOperatorBase = new 
GroupReduceOperatorBase(sampleInCoordinator, sicOperatorInformation, "Sample in 
coordinator");
    +           final GroupReduceNode sicNode = new 
GroupReduceNode(sicOperatorBase);
    +           final Channel sicChannel = new Channel(sipPlanNode, 
TempMode.NONE);
    +           sicChannel.setShipStrategy(ShipStrategyType.FORWARD, 
channel.getShipStrategyKeys(), channel.getShipStrategySortOrder(), null, 
DataExchangeMode.PIPELINED);
    +           final SingleInputPlanNode sicPlanNode = new 
SingleInputPlanNode(sicNode, "SampleInCoordinator PlanNode", sicChannel, 
DriverStrategy.ALL_GROUP_REDUCE);
    +           sicPlanNode.setParallelism(1);
    +           sicChannel.setTarget(sicPlanNode);
    +           sipPlanNode.addOutgoingChannel(sicChannel);
    +           this.plan.getAllNodes().add(sicPlanNode);
    +
    +           // 3. Use sampled data to build range boundaries.
    +           final RangeBoundaryBuilder rangeBoundaryBuilder = new 
RangeBoundaryBuilder(comparator, targetParallelism);
    +           final TypeInformation<CommonRangeBoundaries> rbTypeInformation 
= TypeExtractor.getForClass(CommonRangeBoundaries.class);
    +           final UnaryOperatorInformation rbOperatorInformation = new 
UnaryOperatorInformation(sourceOutputType, rbTypeInformation);
    +           final MapPartitionOperatorBase rbOperatorBase = new 
MapPartitionOperatorBase(rangeBoundaryBuilder, rbOperatorInformation, 
"RangeBoundaryBuilder");
    +           final MapPartitionNode rbNode = new 
MapPartitionNode(rbOperatorBase);
    +           final Channel rbChannel = new Channel(sicPlanNode, 
TempMode.NONE);
    +           rbChannel.setShipStrategy(ShipStrategyType.FORWARD, 
DataExchangeMode.PIPELINED);
    +           final SingleInputPlanNode rbPlanNode = new 
SingleInputPlanNode(rbNode, "RangeBoundary PlanNode", rbChannel, 
DriverStrategy.MAP_PARTITION);
    +           rbPlanNode.setParallelism(1);
    +           rbChannel.setTarget(rbPlanNode);
    +           sicPlanNode.addOutgoingChannel(rbChannel);
    +           this.plan.getAllNodes().add(rbPlanNode);
    +
    +           // 4. Take range boundaries as broadcast input and take the 
tuple of partition id and record as output.
    +           final AssignRangeIndex assignRangeIndex = new 
AssignRangeIndex(comparator);
    +           final TypeInformation<Tuple2> ariOutputTypeInformation = new 
TupleTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO, sourceOutputType);
    +           final UnaryOperatorInformation ariOperatorInformation = new 
UnaryOperatorInformation(sourceOutputType, ariOutputTypeInformation);
    +           final MapPartitionOperatorBase ariOperatorBase = new 
MapPartitionOperatorBase(assignRangeIndex, ariOperatorInformation, "Assign 
Range Index");
    +           final MapPartitionNode ariNode = new 
MapPartitionNode(ariOperatorBase);
    +           final Channel ariChannel = new Channel(sourceNode, 
TempMode.NONE);
    +           // To avoid deadlock, set the DataExchangeMode of channel 
between source node and this to Batch.
    +           ariChannel.setShipStrategy(ShipStrategyType.FORWARD, 
DataExchangeMode.BATCH);
    +           final SingleInputPlanNode ariPlanNode = new 
SingleInputPlanNode(ariNode, "AssignRangeIndex PlanNode", ariChannel, 
DriverStrategy.MAP_PARTITION);
    +           ariPlanNode.setParallelism(sourceParallelism);
    +           ariChannel.setTarget(ariPlanNode);
    +           this.plan.getAllNodes().add(ariPlanNode);
    +           sourceNewOutputChannels.add(ariChannel);
    +
    +           final NamedChannel broadcastChannel = new 
NamedChannel("RangeBoundaries", rbPlanNode);
    +           broadcastChannel.setShipStrategy(ShipStrategyType.BROADCAST, 
DataExchangeMode.PIPELINED);
    +           broadcastChannel.setTarget(ariPlanNode);
    +           List<NamedChannel> broadcastChannels = new ArrayList<>(1);
    +           broadcastChannels.add(broadcastChannel);
    +           ariPlanNode.setBroadcastInputs(broadcastChannels);
    +
    +           // 5. Remove the partition id.
    +           final Channel partChannel = new Channel(ariPlanNode, 
TempMode.NONE);
    +           partChannel.setDataExchangeMode(DataExchangeMode.PIPELINED);
    +           final FieldList keys = new FieldList(0);
    +           final boolean[] sortDirection = { true };
    +           partChannel.setShipStrategy(ShipStrategyType.PARTITION_RANGE, 
keys, sortDirection, null, DataExchangeMode.PIPELINED);
    +           ariPlanNode.addOutgoingChannel(channel);
    +           partChannel.setLocalStrategy(channel.getLocalStrategy(), keys, 
sortDirection);
    --- End diff --
    
    We cannot remove the target node and apply its local strategy on the 
`PartitionIDRemoveWrapper` because
    1. the keys of the local strategy are changed (here keys are fixed to `{0}` 
and order to `{true}`.
    2. the local strategy is applied before the user functions, i.e., 
originally it is applied on the `SourceOut` type and here it is applied on the 
`Tuple2<Int, SourceOut>` type.
    
    I would like to 
    - keep the `target` node and its local strategy
    - connect the `PartitionIDRemoveWrapper` to the `target` node
    - change the ship strategy of the channel to the target node from 
`PARTITION_RANGE` to `FORWARD` and the data exchange mode to `PIPELINED`.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

Reply via email to