[FLINK-2662] [dataSet] [optimizer] Fix merging of unions with multiple outputs.
Translate union with N outputs into N unions with single output. This closes #2508. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/303f6fee Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/303f6fee Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/303f6fee Branch: refs/heads/flip-6 Commit: 303f6fee99b731dd138e37513705271f97f76d72 Parents: 5c02988 Author: Fabian Hueske <fhue...@apache.org> Authored: Fri Sep 16 18:40:32 2016 +0200 Committer: Fabian Hueske <fhue...@apache.org> Committed: Tue Sep 20 21:52:08 2016 +0200 ---------------------------------------------------------------------- .../api/java/operators/OperatorTranslation.java | 23 +++-- .../flink/optimizer/dag/BinaryUnionNode.java | 8 +- .../flink/optimizer/UnionReplacementTest.java | 102 ++++++++++++++++++- .../dataexchange/UnionClosedBranchingTest.java | 26 +++-- 4 files changed, 141 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/303f6fee/flink-java/src/main/java/org/apache/flink/api/java/operators/OperatorTranslation.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/OperatorTranslation.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/OperatorTranslation.java index 3f44d58..88c9c37 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/OperatorTranslation.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/OperatorTranslation.java @@ -40,11 +40,11 @@ import java.util.Map; public class OperatorTranslation { /** The already translated operations */ - private Map<DataSet<?>, Operator<?>> translated = new HashMap<DataSet<?>, Operator<?>>(); + private Map<DataSet<?>, Operator<?>> translated = new HashMap<>(); public Plan translateToPlan(List<DataSink<?>> sinks, String jobName) { - List<GenericDataSinkBase<?>> planSinks = new ArrayList<GenericDataSinkBase<?>>(); + List<GenericDataSinkBase<?>> planSinks = new ArrayList<>(); for (DataSink<?> sink : sinks) { planSinks.add(translate(sink)); @@ -74,11 +74,18 @@ public class OperatorTranslation { } // check if we have already translated that data set (operation or source) - Operator<?> previous = (Operator<?>) this.translated.get(dataSet); + Operator<?> previous = this.translated.get(dataSet); if (previous != null) { - @SuppressWarnings("unchecked") - Operator<T> typedPrevious = (Operator<T>) previous; - return typedPrevious; + + // Union operators may only have a single output. + // We ensure this by not reusing previously created union operators. + // The optimizer will merge subsequent binary unions into one n-ary union. + if (!(dataSet instanceof UnionOperator)) { + // all other operators are reused. + @SuppressWarnings("unchecked") + Operator<T> typedPrevious = (Operator<T>) previous; + return typedPrevious; + } } Operator<T> dataFlowOp; @@ -190,7 +197,7 @@ public class OperatorTranslation { BulkIterationResultSet<T> iterationEnd = (BulkIterationResultSet<T>) untypedIterationEnd; BulkIterationBase<T> iterationOperator = - new BulkIterationBase<T>(new UnaryOperatorInformation<T, T>(iterationEnd.getType(), iterationEnd.getType()), "Bulk Iteration"); + new BulkIterationBase<>(new UnaryOperatorInformation<>(iterationEnd.getType(), iterationEnd.getType()), "Bulk Iteration"); IterativeDataSet<T> iterationHead = iterationEnd.getIterationHead(); translated.put(iterationHead, iterationOperator.getPartialSolution()); @@ -216,7 +223,7 @@ public class OperatorTranslation { String name = iterationHead.getName() == null ? "Unnamed Delta Iteration" : iterationHead.getName(); - DeltaIterationBase<D, W> iterationOperator = new DeltaIterationBase<D, W>(new BinaryOperatorInformation<D, W, D>(iterationEnd.getType(), iterationEnd.getWorksetType(), iterationEnd.getType()), + DeltaIterationBase<D, W> iterationOperator = new DeltaIterationBase<>(new BinaryOperatorInformation<>(iterationEnd.getType(), iterationEnd.getWorksetType(), iterationEnd.getType()), iterationEnd.getKeyPositions(), name); iterationOperator.setMaximumNumberOfIterations(iterationEnd.getMaxIterations()); http://git-wip-us.apache.org/repos/asf/flink/blob/303f6fee/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/BinaryUnionNode.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/BinaryUnionNode.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/BinaryUnionNode.java index fdd76a8..d262cf6 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/BinaryUnionNode.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/BinaryUnionNode.java @@ -98,6 +98,12 @@ public class BinaryUnionNode extends TwoInputNode { @Override public List<PlanNode> getAlternativePlans(CostEstimator estimator) { + + // check that union has only a single successor + if (this.getOutgoingConnections().size() > 1) { + throw new CompilerException("BinaryUnionNode has more than one successor."); + } + // check if we have a cached version if (this.cachedPlans != null) { return this.cachedPlans; @@ -173,7 +179,7 @@ public class BinaryUnionNode extends TwoInputNode { } } - // create a candidate channel for the first input. mark it cached, if the connection says so + // create a candidate channel for the second input. mark it cached, if the connection says so Channel c2 = new Channel(child2, this.input2.getMaterializationMode()); if (this.input2.getShipStrategy() == null) { // free to choose the ship strategy http://git-wip-us.apache.org/repos/asf/flink/blob/303f6fee/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java index 65dd2b3..3be7657 100644 --- a/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java @@ -18,16 +18,25 @@ package org.apache.flink.optimizer; +import junit.framework.Assert; +import org.apache.flink.api.common.operators.util.FieldList; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.common.Plan; import org.apache.flink.api.java.io.DiscardingOutputFormat; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.optimizer.plan.Channel; +import org.apache.flink.optimizer.plan.NAryUnionPlanNode; import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; import org.apache.flink.optimizer.plantranslate.JobGraphGenerator; import org.apache.flink.optimizer.util.CompilerTestBase; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; import org.junit.Test; -import static org.junit.Assert.fail; +import java.util.List; + +import static org.junit.Assert.*; @SuppressWarnings("serial") public class UnionReplacementTest extends CompilerTestBase { @@ -54,4 +63,95 @@ public class UnionReplacementTest extends CompilerTestBase { fail(e.getMessage()); } } + + /** + * + * Test for FLINK-2662. + * + * Checks that a plan with an union with two outputs is correctly translated. + * The program can be illustrated as follows: + * + * Src1 ----------------\ + * >-> Union123 -> GroupBy(0) -> Sum -> Output + * Src2 -\ / + * >-> Union23--< + * Src3 -/ \ + * >-> Union234 -> GroupBy(1) -> Sum -> Output + * Src4 ----------------/ + * + * The fix for FLINK-2662 translates the union with two output (Union-23) into two separate + * unions (Union-23_1 and Union-23_2) with one output each. Due to this change, the interesting + * partitioning properties for GroupBy(0) and GroupBy(1) are pushed through Union-23_1 and + * Union-23_2 and do not interfere with each other (which would be the case if Union-23 would + * be a single operator with two outputs). + * + */ + @Test + public void testUnionWithTwoOutputsTest() throws Exception { + + // ----------------------------------------------------------------------------------------- + // Build test program + // ----------------------------------------------------------------------------------------- + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(DEFAULT_PARALLELISM); + + DataSet<Tuple2<Long, Long>> src1 = env.fromElements(new Tuple2<>(0L, 0L)); + DataSet<Tuple2<Long, Long>> src2 = env.fromElements(new Tuple2<>(0L, 0L)); + DataSet<Tuple2<Long, Long>> src3 = env.fromElements(new Tuple2<>(0L, 0L)); + DataSet<Tuple2<Long, Long>> src4 = env.fromElements(new Tuple2<>(0L, 0L)); + + DataSet<Tuple2<Long, Long>> union23 = src2.union(src3); + DataSet<Tuple2<Long, Long>> union123 = src1.union(union23); + DataSet<Tuple2<Long, Long>> union234 = src4.union(union23); + + union123.groupBy(0).sum(1).name("1").output(new DiscardingOutputFormat<Tuple2<Long, Long>>()); + union234.groupBy(1).sum(0).name("2").output(new DiscardingOutputFormat<Tuple2<Long, Long>>()); + + // ----------------------------------------------------------------------------------------- + // Verify optimized plan + // ----------------------------------------------------------------------------------------- + + OptimizedPlan optimizedPlan = compileNoStats(env.createProgramPlan()); + + OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(optimizedPlan); + + SingleInputPlanNode groupRed1 = resolver.getNode("1"); + SingleInputPlanNode groupRed2 = resolver.getNode("2"); + + // check partitioning is correct + Assert.assertTrue("Reduce input should be partitioned on 0.", + groupRed1.getInput().getGlobalProperties().getPartitioningFields().isExactMatch(new FieldList(0))); + Assert.assertTrue("Reduce input should be partitioned on 1.", + groupRed2.getInput().getGlobalProperties().getPartitioningFields().isExactMatch(new FieldList(1))); + + // check group reduce inputs are n-ary unions with three inputs + Assert.assertTrue("Reduce input should be n-ary union with three inputs.", + groupRed1.getInput().getSource() instanceof NAryUnionPlanNode && + ((NAryUnionPlanNode) groupRed1.getInput().getSource()).getListOfInputs().size() == 3); + Assert.assertTrue("Reduce input should be n-ary union with three inputs.", + groupRed2.getInput().getSource() instanceof NAryUnionPlanNode && + ((NAryUnionPlanNode) groupRed2.getInput().getSource()).getListOfInputs().size() == 3); + + // check channel from union to group reduce is forwarding + Assert.assertTrue("Channel between union and group reduce should be forwarding", + groupRed1.getInput().getShipStrategy().equals(ShipStrategyType.FORWARD)); + Assert.assertTrue("Channel between union and group reduce should be forwarding", + groupRed2.getInput().getShipStrategy().equals(ShipStrategyType.FORWARD)); + + // check that all inputs of unions are hash partitioned + List<Channel> union123In = ((NAryUnionPlanNode) groupRed1.getInput().getSource()).getListOfInputs(); + for(Channel i : union123In) { + Assert.assertTrue("Union input channel should hash partition on 0", + i.getShipStrategy().equals(ShipStrategyType.PARTITION_HASH) && + i.getShipStrategyKeys().isExactMatch(new FieldList(0))); + } + List<Channel> union234In = ((NAryUnionPlanNode) groupRed2.getInput().getSource()).getListOfInputs(); + for(Channel i : union234In) { + Assert.assertTrue("Union input channel should hash partition on 0", + i.getShipStrategy().equals(ShipStrategyType.PARTITION_HASH) && + i.getShipStrategyKeys().isExactMatch(new FieldList(1))); + } + + } } http://git-wip-us.apache.org/repos/asf/flink/blob/303f6fee/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/UnionClosedBranchingTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/UnionClosedBranchingTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/UnionClosedBranchingTest.java index b870a91..77b150a 100644 --- a/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/UnionClosedBranchingTest.java +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/UnionClosedBranchingTest.java @@ -37,6 +37,7 @@ import org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate; import org.apache.flink.runtime.jobgraph.IntermediateDataSet; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -73,9 +74,9 @@ public class UnionClosedBranchingTest extends CompilerTestBase { @Parameterized.Parameters public static Collection<Object[]> params() { Collection<Object[]> params = Arrays.asList(new Object[][]{ - {ExecutionMode.PIPELINED, PIPELINED, BATCH}, + {ExecutionMode.PIPELINED, BATCH, PIPELINED}, {ExecutionMode.PIPELINED_FORCED, PIPELINED, PIPELINED}, - {ExecutionMode.BATCH, BATCH, BATCH}, + {ExecutionMode.BATCH, BATCH, PIPELINED}, {ExecutionMode.BATCH_FORCED, BATCH, BATCH}, }); @@ -93,10 +94,16 @@ public class UnionClosedBranchingTest extends CompilerTestBase { /** Expected {@link DataExchangeMode} from union to join. */ private final DataExchangeMode unionToJoin; + /** Expected {@link ShipStrategyType} from source to union. */ + private final ShipStrategyType sourceToUnionStrategy = ShipStrategyType.PARTITION_HASH; + + /** Expected {@link ShipStrategyType} from union to join. */ + private final ShipStrategyType unionToJoinStrategy = ShipStrategyType.FORWARD; + public UnionClosedBranchingTest( - ExecutionMode executionMode, - DataExchangeMode sourceToUnion, - DataExchangeMode unionToJoin) { + ExecutionMode executionMode, + DataExchangeMode sourceToUnion, + DataExchangeMode unionToJoin) { this.executionMode = executionMode; this.sourceToUnion = sourceToUnion; @@ -140,12 +147,16 @@ public class UnionClosedBranchingTest extends CompilerTestBase { for (Channel channel : joinNode.getInputs()) { assertEquals("Unexpected data exchange mode between union and join node.", unionToJoin, channel.getDataExchangeMode()); + assertEquals("Unexpected ship strategy between union and join node.", + unionToJoinStrategy, channel.getShipStrategy()); } for (SourcePlanNode src : optimizedPlan.getDataSources()) { for (Channel channel : src.getOutgoingChannels()) { assertEquals("Unexpected data exchange mode between source and union node.", sourceToUnion, channel.getDataExchangeMode()); + assertEquals("Unexpected ship strategy between source and union node.", + sourceToUnionStrategy, channel.getShipStrategy()); } } @@ -176,9 +187,8 @@ public class UnionClosedBranchingTest extends CompilerTestBase { for (IntermediateDataSet dataSet : src.getProducedDataSets()) { ResultPartitionType dsType = dataSet.getResultType(); - // The result type is determined by the channel between the union and the join node - // and *not* the channel between source and union. - if (unionToJoin.equals(BATCH)) { + // Ensure batch exchange unless PIPELINED_FORCE is enabled. + if (!executionMode.equals(ExecutionMode.PIPELINED_FORCED)) { assertTrue("Expected batch exchange, but result type is " + dsType + ".", dsType.isBlocking()); } else {