[FLINK-2662] [optimizer] Fix computation of global properties of union operator.
- Fixes invalid shipping strategy between consecutive unions. This closes #2848. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/7d91b9ec Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/7d91b9ec Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/7d91b9ec Branch: refs/heads/master Commit: 7d91b9ec71c9b711e04a91f847f5c85d3f561da6 Parents: e8318d6 Author: Fabian Hueske <[email protected]> Authored: Mon Nov 21 19:06:42 2016 +0100 Committer: Fabian Hueske <[email protected]> Committed: Wed Nov 23 18:35:44 2016 +0100 ---------------------------------------------------------------------- .../operators/BinaryUnionOpDescriptor.java | 30 ++- .../flink/optimizer/UnionReplacementTest.java | 240 ++++++++++++++++++- 2 files changed, 258 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/7d91b9ec/flink-optimizer/src/main/java/org/apache/flink/optimizer/operators/BinaryUnionOpDescriptor.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/operators/BinaryUnionOpDescriptor.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/operators/BinaryUnionOpDescriptor.java index 8cc517e..78ac3d6 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/operators/BinaryUnionOpDescriptor.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/operators/BinaryUnionOpDescriptor.java @@ -69,11 +69,35 @@ public class BinaryUnionOpDescriptor extends OperatorDescriptorDual { if (in1.getPartitioning() == PartitioningProperty.HASH_PARTITIONED && in2.getPartitioning() == PartitioningProperty.HASH_PARTITIONED && - in1.getPartitioningFields().equals(in2.getPartitioningFields())) - { + in1.getPartitioningFields().equals(in2.getPartitioningFields())) { newProps.setHashPartitioned(in1.getPartitioningFields()); } - + else if (in1.getPartitioning() == PartitioningProperty.RANGE_PARTITIONED && + in2.getPartitioning() == PartitioningProperty.RANGE_PARTITIONED && + in1.getPartitioningOrdering().equals(in2.getPartitioningOrdering()) && + ( + in1.getDataDistribution() == null && in2.getDataDistribution() == null || + in1.getDataDistribution() != null && in1.getDataDistribution().equals(in2.getDataDistribution()) + ) + ) { + if (in1.getDataDistribution() == null) { + newProps.setRangePartitioned(in1.getPartitioningOrdering()); + } + else { + newProps.setRangePartitioned(in1.getPartitioningOrdering(), in1.getDataDistribution()); + } + } + else if (in1.getPartitioning() == PartitioningProperty.CUSTOM_PARTITIONING && + in2.getPartitioning() == PartitioningProperty.CUSTOM_PARTITIONING && + in1.getPartitioningFields().equals(in2.getPartitioningFields()) && + in1.getCustomPartitioner().equals(in2.getCustomPartitioner())) { + newProps.setCustomPartitioned(in1.getPartitioningFields(), in1.getCustomPartitioner()); + } + else if (in1.getPartitioning() == PartitioningProperty.FORCED_REBALANCED && + in2.getPartitioning() == PartitioningProperty.FORCED_REBALANCED) { + newProps.setForcedRebalanced(); + } + return newProps; } http://git-wip-us.apache.org/repos/asf/flink/blob/7d91b9ec/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java index 3be7657..d0bb376 100644 --- a/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java @@ -19,18 +19,25 @@ package org.apache.flink.optimizer; import junit.framework.Assert; +import org.apache.flink.api.common.operators.Order; +import org.apache.flink.api.common.operators.Ordering; import org.apache.flink.api.common.operators.util.FieldList; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.common.Plan; import org.apache.flink.api.java.io.DiscardingOutputFormat; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.optimizer.dataproperties.PartitioningProperty; import org.apache.flink.optimizer.plan.Channel; +import org.apache.flink.optimizer.plan.DualInputPlanNode; import org.apache.flink.optimizer.plan.NAryUnionPlanNode; import org.apache.flink.optimizer.plan.OptimizedPlan; import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SourcePlanNode; import org.apache.flink.optimizer.plantranslate.JobGraphGenerator; import org.apache.flink.optimizer.util.CompilerTestBase; +import org.apache.flink.runtime.operators.Driver; +import org.apache.flink.runtime.operators.DriverStrategy; import org.apache.flink.runtime.operators.shipping.ShipStrategyType; import org.junit.Test; @@ -87,7 +94,7 @@ public class UnionReplacementTest extends CompilerTestBase { * */ @Test - public void testUnionWithTwoOutputsTest() throws Exception { + public void testUnionWithTwoOutputs() throws Exception { // ----------------------------------------------------------------------------------------- // Build test program @@ -120,38 +127,253 @@ public class UnionReplacementTest extends CompilerTestBase { SingleInputPlanNode groupRed2 = resolver.getNode("2"); // check partitioning is correct - Assert.assertTrue("Reduce input should be partitioned on 0.", + assertTrue("Reduce input should be partitioned on 0.", groupRed1.getInput().getGlobalProperties().getPartitioningFields().isExactMatch(new FieldList(0))); - Assert.assertTrue("Reduce input should be partitioned on 1.", + assertTrue("Reduce input should be partitioned on 1.", groupRed2.getInput().getGlobalProperties().getPartitioningFields().isExactMatch(new FieldList(1))); // check group reduce inputs are n-ary unions with three inputs - Assert.assertTrue("Reduce input should be n-ary union with three inputs.", + assertTrue("Reduce input should be n-ary union with three inputs.", groupRed1.getInput().getSource() instanceof NAryUnionPlanNode && ((NAryUnionPlanNode) groupRed1.getInput().getSource()).getListOfInputs().size() == 3); - Assert.assertTrue("Reduce input should be n-ary union with three inputs.", + assertTrue("Reduce input should be n-ary union with three inputs.", groupRed2.getInput().getSource() instanceof NAryUnionPlanNode && ((NAryUnionPlanNode) groupRed2.getInput().getSource()).getListOfInputs().size() == 3); // check channel from union to group reduce is forwarding - Assert.assertTrue("Channel between union and group reduce should be forwarding", + assertTrue("Channel between union and group reduce should be forwarding", groupRed1.getInput().getShipStrategy().equals(ShipStrategyType.FORWARD)); - Assert.assertTrue("Channel between union and group reduce should be forwarding", + assertTrue("Channel between union and group reduce should be forwarding", groupRed2.getInput().getShipStrategy().equals(ShipStrategyType.FORWARD)); // check that all inputs of unions are hash partitioned List<Channel> union123In = ((NAryUnionPlanNode) groupRed1.getInput().getSource()).getListOfInputs(); for(Channel i : union123In) { - Assert.assertTrue("Union input channel should hash partition on 0", + assertTrue("Union input channel should hash partition on 0", i.getShipStrategy().equals(ShipStrategyType.PARTITION_HASH) && i.getShipStrategyKeys().isExactMatch(new FieldList(0))); } List<Channel> union234In = ((NAryUnionPlanNode) groupRed2.getInput().getSource()).getListOfInputs(); for(Channel i : union234In) { - Assert.assertTrue("Union input channel should hash partition on 0", + assertTrue("Union input channel should hash partition on 0", i.getShipStrategy().equals(ShipStrategyType.PARTITION_HASH) && i.getShipStrategyKeys().isExactMatch(new FieldList(1))); } } + + /** + * + * Checks that a plan with consecutive UNIONs followed by PartitionByHash is correctly translated. + * + * The program can be illustrated as follows: + * + * Src1 -\ + * >-> Union12--< + * Src2 -/ \ + * >-> Union123 -> PartitionByHash -> Output + * Src3 ----------------/ + * + * In the resulting plan, the hash partitioning (ShippingStrategy.PARTITION_HASH) must be + * pushed to the inputs of the unions (Src1, Src2, Src3). + * + */ + @Test + public void testConsecutiveUnionsWithHashPartitioning() throws Exception { + + // ----------------------------------------------------------------------------------------- + // Build test program + // ----------------------------------------------------------------------------------------- + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(DEFAULT_PARALLELISM); + + DataSet<Tuple2<Long, Long>> src1 = env.fromElements(new Tuple2<>(0L, 0L)); + DataSet<Tuple2<Long, Long>> src2 = env.fromElements(new Tuple2<>(0L, 0L)); + DataSet<Tuple2<Long, Long>> src3 = env.fromElements(new Tuple2<>(0L, 0L)); + + DataSet<Tuple2<Long, Long>> union12 = src1.union(src2); + DataSet<Tuple2<Long, Long>> union123 = union12.union(src3); + + union123.partitionByHash(1).output(new DiscardingOutputFormat<Tuple2<Long, Long>>()).name("out"); + + // ----------------------------------------------------------------------------------------- + // Verify optimized plan + // ----------------------------------------------------------------------------------------- + + OptimizedPlan optimizedPlan = compileNoStats(env.createProgramPlan()); + + OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(optimizedPlan); + + SingleInputPlanNode sink = resolver.getNode("out"); + + // check partitioning is correct + assertEquals("Sink input should be hash partitioned.", + PartitioningProperty.HASH_PARTITIONED, sink.getInput().getGlobalProperties().getPartitioning()); + assertEquals("Sink input should be hash partitioned on 1.", + new FieldList(1), sink.getInput().getGlobalProperties().getPartitioningFields()); + + SingleInputPlanNode partitioner = (SingleInputPlanNode)sink.getInput().getSource(); + assertTrue(partitioner.getDriverStrategy() == DriverStrategy.UNARY_NO_OP); + assertEquals("Partitioner input should be hash partitioned.", + PartitioningProperty.HASH_PARTITIONED, partitioner.getInput().getGlobalProperties().getPartitioning()); + assertEquals("Partitioner input should be hash partitioned on 1.", + new FieldList(1), partitioner.getInput().getGlobalProperties().getPartitioningFields()); + assertEquals("Partitioner input channel should be forwarding", + ShipStrategyType.FORWARD, partitioner.getInput().getShipStrategy()); + + NAryUnionPlanNode union = (NAryUnionPlanNode)partitioner.getInput().getSource(); + // all union inputs should be hash partitioned + for (Channel c : union.getInputs()) { + assertEquals("Union input should be hash partitioned", + PartitioningProperty.HASH_PARTITIONED, c.getGlobalProperties().getPartitioning()); + assertEquals("Union input channel should be hash partitioning", + ShipStrategyType.PARTITION_HASH, c.getShipStrategy()); + assertTrue("Union input should be data source", + c.getSource() instanceof SourcePlanNode); + } + } + + /** + * + * Checks that a plan with consecutive UNIONs followed by REBALANCE is correctly translated. + * + * The program can be illustrated as follows: + * + * Src1 -\ + * >-> Union12--< + * Src2 -/ \ + * >-> Union123 -> Rebalance -> Output + * Src3 ----------------/ + * + * In the resulting plan, the Rebalance (ShippingStrategy.PARTITION_FORCED_REBALANCE) must be + * pushed to the inputs of the unions (Src1, Src2, Src3). + * + */ + @Test + public void testConsecutiveUnionsWithRebalance() throws Exception { + + // ----------------------------------------------------------------------------------------- + // Build test program + // ----------------------------------------------------------------------------------------- + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(DEFAULT_PARALLELISM); + + DataSet<Tuple2<Long, Long>> src1 = env.fromElements(new Tuple2<>(0L, 0L)); + DataSet<Tuple2<Long, Long>> src2 = env.fromElements(new Tuple2<>(0L, 0L)); + DataSet<Tuple2<Long, Long>> src3 = env.fromElements(new Tuple2<>(0L, 0L)); + + DataSet<Tuple2<Long, Long>> union12 = src1.union(src2); + DataSet<Tuple2<Long, Long>> union123 = union12.union(src3); + + union123.rebalance().output(new DiscardingOutputFormat<Tuple2<Long, Long>>()).name("out"); + + // ----------------------------------------------------------------------------------------- + // Verify optimized plan + // ----------------------------------------------------------------------------------------- + + OptimizedPlan optimizedPlan = compileNoStats(env.createProgramPlan()); + + OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(optimizedPlan); + + SingleInputPlanNode sink = resolver.getNode("out"); + + // check partitioning is correct + assertEquals("Sink input should be force rebalanced.", + PartitioningProperty.FORCED_REBALANCED, sink.getInput().getGlobalProperties().getPartitioning()); + + SingleInputPlanNode partitioner = (SingleInputPlanNode)sink.getInput().getSource(); + assertTrue(partitioner.getDriverStrategy() == DriverStrategy.UNARY_NO_OP); + assertEquals("Partitioner input should be force rebalanced.", + PartitioningProperty.FORCED_REBALANCED, partitioner.getInput().getGlobalProperties().getPartitioning()); + assertEquals("Partitioner input channel should be forwarding", + ShipStrategyType.FORWARD, partitioner.getInput().getShipStrategy()); + + NAryUnionPlanNode union = (NAryUnionPlanNode)partitioner.getInput().getSource(); + // all union inputs should be force rebalanced + for (Channel c : union.getInputs()) { + assertEquals("Union input should be force rebalanced", + PartitioningProperty.FORCED_REBALANCED, c.getGlobalProperties().getPartitioning()); + assertEquals("Union input channel should be rebalancing", + ShipStrategyType.PARTITION_FORCED_REBALANCE, c.getShipStrategy()); + assertTrue("Union input should be data source", + c.getSource() instanceof SourcePlanNode); + } + } + + /** + * + * Checks that a plan with consecutive UNIONs followed by PARTITION_RANGE is correctly translated. + * + * The program can be illustrated as follows: + * + * Src1 -\ + * >-> Union12--< + * Src2 -/ \ + * >-> Union123 -> PartitionByRange -> Output + * Src3 ----------------/ + * + * In the resulting plan, the range partitioning must be + * pushed to the inputs of the unions (Src1, Src2, Src3). + * + */ + @Test + public void testConsecutiveUnionsWithRangePartitioning() throws Exception { + + // ----------------------------------------------------------------------------------------- + // Build test program + // ----------------------------------------------------------------------------------------- + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(DEFAULT_PARALLELISM); + + DataSet<Tuple2<Long, Long>> src1 = env.fromElements(new Tuple2<>(0L, 0L)); + DataSet<Tuple2<Long, Long>> src2 = env.fromElements(new Tuple2<>(0L, 0L)); + DataSet<Tuple2<Long, Long>> src3 = env.fromElements(new Tuple2<>(0L, 0L)); + + DataSet<Tuple2<Long, Long>> union12 = src1.union(src2); + DataSet<Tuple2<Long, Long>> union123 = union12.union(src3); + + union123.partitionByRange(1).output(new DiscardingOutputFormat<Tuple2<Long, Long>>()).name("out"); + + // ----------------------------------------------------------------------------------------- + // Verify optimized plan + // ----------------------------------------------------------------------------------------- + + OptimizedPlan optimizedPlan = compileNoStats(env.createProgramPlan()); + + OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(optimizedPlan); + + SingleInputPlanNode sink = resolver.getNode("out"); + + // check partitioning is correct + assertEquals("Sink input should be range partitioned.", + PartitioningProperty.RANGE_PARTITIONED, sink.getInput().getGlobalProperties().getPartitioning()); + assertEquals("Sink input should be range partitioned on 1", + new Ordering(1, null, Order.ASCENDING), sink.getInput().getGlobalProperties().getPartitioningOrdering()); + + SingleInputPlanNode partitioner = (SingleInputPlanNode)sink.getInput().getSource(); + assertTrue(partitioner.getDriverStrategy() == DriverStrategy.UNARY_NO_OP); + assertEquals("Partitioner input should be range partitioned.", + PartitioningProperty.RANGE_PARTITIONED, partitioner.getInput().getGlobalProperties().getPartitioning()); + assertEquals("Partitioner input should be range partitioned on 1", + new Ordering(1, null, Order.ASCENDING), partitioner.getInput().getGlobalProperties().getPartitioningOrdering()); + assertEquals("Partitioner input channel should be forwarding", + ShipStrategyType.FORWARD, partitioner.getInput().getShipStrategy()); + + NAryUnionPlanNode union = (NAryUnionPlanNode)partitioner.getInput().getSource(); + // all union inputs should be force rebalanced + for (Channel c : union.getInputs()) { + assertEquals("Union input should be force rebalanced", + PartitioningProperty.RANGE_PARTITIONED, c.getGlobalProperties().getPartitioning()); + assertEquals("Union input channel should be rebalancing", + ShipStrategyType.FORWARD, c.getShipStrategy()); + // range partitioning is executed as custom partitioning with prior sampling + SingleInputPlanNode partitionMap = (SingleInputPlanNode)c.getSource(); + assertEquals(DriverStrategy.MAP, partitionMap.getDriverStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitionMap.getInput().getShipStrategy()); + } + } + }
