[FLINK-2662] [dataSet] [optimizer] Fix merging of unions with multiple outputs.

Translate union with N outputs into N unions with single output.

This closes #2508.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/303f6fee
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/303f6fee
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/303f6fee

Branch: refs/heads/flip-6
Commit: 303f6fee99b731dd138e37513705271f97f76d72
Parents: 5c02988
Author: Fabian Hueske <fhue...@apache.org>
Authored: Fri Sep 16 18:40:32 2016 +0200
Committer: Fabian Hueske <fhue...@apache.org>
Committed: Tue Sep 20 21:52:08 2016 +0200

----------------------------------------------------------------------
 .../api/java/operators/OperatorTranslation.java |  23 +++--
 .../flink/optimizer/dag/BinaryUnionNode.java    |   8 +-
 .../flink/optimizer/UnionReplacementTest.java   | 102 ++++++++++++++++++-
 .../dataexchange/UnionClosedBranchingTest.java  |  26 +++--
 4 files changed, 141 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/303f6fee/flink-java/src/main/java/org/apache/flink/api/java/operators/OperatorTranslation.java
----------------------------------------------------------------------
diff --git 
a/flink-java/src/main/java/org/apache/flink/api/java/operators/OperatorTranslation.java
 
b/flink-java/src/main/java/org/apache/flink/api/java/operators/OperatorTranslation.java
index 3f44d58..88c9c37 100644
--- 
a/flink-java/src/main/java/org/apache/flink/api/java/operators/OperatorTranslation.java
+++ 
b/flink-java/src/main/java/org/apache/flink/api/java/operators/OperatorTranslation.java
@@ -40,11 +40,11 @@ import java.util.Map;
 public class OperatorTranslation {
        
        /** The already translated operations */
-       private Map<DataSet<?>, Operator<?>> translated = new 
HashMap<DataSet<?>, Operator<?>>();
+       private Map<DataSet<?>, Operator<?>> translated = new HashMap<>();
        
        
        public Plan translateToPlan(List<DataSink<?>> sinks, String jobName) {
-               List<GenericDataSinkBase<?>> planSinks = new 
ArrayList<GenericDataSinkBase<?>>();
+               List<GenericDataSinkBase<?>> planSinks = new ArrayList<>();
                
                for (DataSink<?> sink : sinks) {
                        planSinks.add(translate(sink));
@@ -74,11 +74,18 @@ public class OperatorTranslation {
                }
 
                // check if we have already translated that data set (operation 
or source)
-               Operator<?> previous = (Operator<?>) 
this.translated.get(dataSet);
+               Operator<?> previous = this.translated.get(dataSet);
                if (previous != null) {
-                       @SuppressWarnings("unchecked")
-                       Operator<T> typedPrevious = (Operator<T>) previous;
-                       return typedPrevious;
+
+                       // Union operators may only have a single output.
+                       // We ensure this by not reusing previously created 
union operators.
+                       // The optimizer will merge subsequent binary unions 
into one n-ary union.
+                       if (!(dataSet instanceof UnionOperator)) {
+                               // all other operators are reused.
+                               @SuppressWarnings("unchecked")
+                               Operator<T> typedPrevious = (Operator<T>) 
previous;
+                               return typedPrevious;
+                       }
                }
                
                Operator<T> dataFlowOp;
@@ -190,7 +197,7 @@ public class OperatorTranslation {
                BulkIterationResultSet<T> iterationEnd = 
(BulkIterationResultSet<T>) untypedIterationEnd;
                
                BulkIterationBase<T> iterationOperator =
-                               new BulkIterationBase<T>(new 
UnaryOperatorInformation<T, T>(iterationEnd.getType(), iterationEnd.getType()), 
"Bulk Iteration");
+                               new BulkIterationBase<>(new 
UnaryOperatorInformation<>(iterationEnd.getType(), iterationEnd.getType()), 
"Bulk Iteration");
                IterativeDataSet<T> iterationHead = 
iterationEnd.getIterationHead();
 
                translated.put(iterationHead, 
iterationOperator.getPartialSolution());
@@ -216,7 +223,7 @@ public class OperatorTranslation {
                
                String name = iterationHead.getName() == null ? "Unnamed Delta 
Iteration" : iterationHead.getName();
                
-               DeltaIterationBase<D, W> iterationOperator = new 
DeltaIterationBase<D, W>(new BinaryOperatorInformation<D, W, 
D>(iterationEnd.getType(), iterationEnd.getWorksetType(), 
iterationEnd.getType()),
+               DeltaIterationBase<D, W> iterationOperator = new 
DeltaIterationBase<>(new BinaryOperatorInformation<>(iterationEnd.getType(), 
iterationEnd.getWorksetType(), iterationEnd.getType()),
                                iterationEnd.getKeyPositions(), name);
                
                
iterationOperator.setMaximumNumberOfIterations(iterationEnd.getMaxIterations());

http://git-wip-us.apache.org/repos/asf/flink/blob/303f6fee/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/BinaryUnionNode.java
----------------------------------------------------------------------
diff --git 
a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/BinaryUnionNode.java
 
b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/BinaryUnionNode.java
index fdd76a8..d262cf6 100644
--- 
a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/BinaryUnionNode.java
+++ 
b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/BinaryUnionNode.java
@@ -98,6 +98,12 @@ public class BinaryUnionNode extends TwoInputNode {
        
        @Override
        public List<PlanNode> getAlternativePlans(CostEstimator estimator) {
+
+               // check that union has only a single successor
+               if (this.getOutgoingConnections().size() > 1) {
+                       throw new CompilerException("BinaryUnionNode has more 
than one successor.");
+               }
+
                // check if we have a cached version
                if (this.cachedPlans != null) {
                        return this.cachedPlans;
@@ -173,7 +179,7 @@ public class BinaryUnionNode extends TwoInputNode {
                                                }
                                        }
                                        
-                                       // create a candidate channel for the 
first input. mark it cached, if the connection says so
+                                       // create a candidate channel for the 
second input. mark it cached, if the connection says so
                                        Channel c2 = new Channel(child2, 
this.input2.getMaterializationMode());
                                        if (this.input2.getShipStrategy() == 
null) {
                                                // free to choose the ship 
strategy

http://git-wip-us.apache.org/repos/asf/flink/blob/303f6fee/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java
----------------------------------------------------------------------
diff --git 
a/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java
 
b/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java
index 65dd2b3..3be7657 100644
--- 
a/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java
+++ 
b/flink-optimizer/src/test/java/org/apache/flink/optimizer/UnionReplacementTest.java
@@ -18,16 +18,25 @@
 
 package org.apache.flink.optimizer;
 
+import junit.framework.Assert;
+import org.apache.flink.api.common.operators.util.FieldList;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.ExecutionEnvironment;
 import org.apache.flink.api.common.Plan;
 import org.apache.flink.api.java.io.DiscardingOutputFormat;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.optimizer.plan.Channel;
+import org.apache.flink.optimizer.plan.NAryUnionPlanNode;
 import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.SingleInputPlanNode;
 import org.apache.flink.optimizer.plantranslate.JobGraphGenerator;
 import org.apache.flink.optimizer.util.CompilerTestBase;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
 import org.junit.Test;
 
-import static org.junit.Assert.fail;
+import java.util.List;
+
+import static org.junit.Assert.*;
 
 @SuppressWarnings("serial")
 public class UnionReplacementTest extends CompilerTestBase {
@@ -54,4 +63,95 @@ public class UnionReplacementTest extends CompilerTestBase {
                        fail(e.getMessage());
                }
        }
+
+       /**
+        *
+        * Test for FLINK-2662.
+        *
+        * Checks that a plan with an union with two outputs is correctly 
translated.
+        * The program can be illustrated as follows:
+        *
+        * Src1 ----------------\
+        *                       >-> Union123 -> GroupBy(0) -> Sum -> Output
+        * Src2 -\              /
+        *        >-> Union23--<
+        * Src3 -/              \
+        *                       >-> Union234 -> GroupBy(1) -> Sum -> Output
+        * Src4 ----------------/
+        *
+        * The fix for FLINK-2662 translates the union with two output 
(Union-23) into two separate
+        * unions (Union-23_1 and Union-23_2) with one output each. Due to this 
change, the interesting
+        * partitioning properties for GroupBy(0) and GroupBy(1) are pushed 
through Union-23_1 and
+        * Union-23_2 and do not interfere with each other (which would be the 
case if Union-23 would
+        * be a single operator with two outputs).
+        *
+        */
+       @Test
+       public void testUnionWithTwoOutputsTest() throws Exception {
+
+               // 
-----------------------------------------------------------------------------------------
+               // Build test program
+               // 
-----------------------------------------------------------------------------------------
+
+               ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(DEFAULT_PARALLELISM);
+
+               DataSet<Tuple2<Long, Long>> src1 = env.fromElements(new 
Tuple2<>(0L, 0L));
+               DataSet<Tuple2<Long, Long>> src2 = env.fromElements(new 
Tuple2<>(0L, 0L));
+               DataSet<Tuple2<Long, Long>> src3 = env.fromElements(new 
Tuple2<>(0L, 0L));
+               DataSet<Tuple2<Long, Long>> src4 = env.fromElements(new 
Tuple2<>(0L, 0L));
+
+               DataSet<Tuple2<Long, Long>> union23 = src2.union(src3);
+               DataSet<Tuple2<Long, Long>> union123 = src1.union(union23);
+               DataSet<Tuple2<Long, Long>> union234 = src4.union(union23);
+
+               union123.groupBy(0).sum(1).name("1").output(new 
DiscardingOutputFormat<Tuple2<Long, Long>>());
+               union234.groupBy(1).sum(0).name("2").output(new 
DiscardingOutputFormat<Tuple2<Long, Long>>());
+
+               // 
-----------------------------------------------------------------------------------------
+               // Verify optimized plan
+               // 
-----------------------------------------------------------------------------------------
+
+               OptimizedPlan optimizedPlan = 
compileNoStats(env.createProgramPlan());
+
+               OptimizerPlanNodeResolver resolver = 
getOptimizerPlanNodeResolver(optimizedPlan);
+
+               SingleInputPlanNode groupRed1 = resolver.getNode("1");
+               SingleInputPlanNode groupRed2 = resolver.getNode("2");
+
+               // check partitioning is correct
+               Assert.assertTrue("Reduce input should be partitioned on 0.",
+                       
groupRed1.getInput().getGlobalProperties().getPartitioningFields().isExactMatch(new
 FieldList(0)));
+               Assert.assertTrue("Reduce input should be partitioned on 1.",
+                       
groupRed2.getInput().getGlobalProperties().getPartitioningFields().isExactMatch(new
 FieldList(1)));
+
+               // check group reduce inputs are n-ary unions with three inputs
+               Assert.assertTrue("Reduce input should be n-ary union with 
three inputs.",
+                       groupRed1.getInput().getSource() instanceof 
NAryUnionPlanNode &&
+                               ((NAryUnionPlanNode) 
groupRed1.getInput().getSource()).getListOfInputs().size() == 3);
+               Assert.assertTrue("Reduce input should be n-ary union with 
three inputs.",
+                       groupRed2.getInput().getSource() instanceof 
NAryUnionPlanNode &&
+                               ((NAryUnionPlanNode) 
groupRed2.getInput().getSource()).getListOfInputs().size() == 3);
+
+               // check channel from union to group reduce is forwarding
+               Assert.assertTrue("Channel between union and group reduce 
should be forwarding",
+                       
groupRed1.getInput().getShipStrategy().equals(ShipStrategyType.FORWARD));
+               Assert.assertTrue("Channel between union and group reduce 
should be forwarding",
+                       
groupRed2.getInput().getShipStrategy().equals(ShipStrategyType.FORWARD));
+
+               // check that all inputs of unions are hash partitioned
+               List<Channel> union123In = ((NAryUnionPlanNode) 
groupRed1.getInput().getSource()).getListOfInputs();
+               for(Channel i : union123In) {
+                       Assert.assertTrue("Union input channel should hash 
partition on 0",
+                               
i.getShipStrategy().equals(ShipStrategyType.PARTITION_HASH) &&
+                                       
i.getShipStrategyKeys().isExactMatch(new FieldList(0)));
+               }
+               List<Channel> union234In = ((NAryUnionPlanNode) 
groupRed2.getInput().getSource()).getListOfInputs();
+               for(Channel i : union234In) {
+                       Assert.assertTrue("Union input channel should hash 
partition on 0",
+                               
i.getShipStrategy().equals(ShipStrategyType.PARTITION_HASH) &&
+                                       
i.getShipStrategyKeys().isExactMatch(new FieldList(1)));
+               }
+
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/303f6fee/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/UnionClosedBranchingTest.java
----------------------------------------------------------------------
diff --git 
a/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/UnionClosedBranchingTest.java
 
b/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/UnionClosedBranchingTest.java
index b870a91..77b150a 100644
--- 
a/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/UnionClosedBranchingTest.java
+++ 
b/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/UnionClosedBranchingTest.java
@@ -37,6 +37,7 @@ import 
org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -73,9 +74,9 @@ public class UnionClosedBranchingTest extends 
CompilerTestBase {
        @Parameterized.Parameters
        public static Collection<Object[]> params() {
                Collection<Object[]> params = Arrays.asList(new Object[][]{
-                               {ExecutionMode.PIPELINED, PIPELINED, BATCH},
+                               {ExecutionMode.PIPELINED, BATCH, PIPELINED},
                                {ExecutionMode.PIPELINED_FORCED, PIPELINED, 
PIPELINED},
-                               {ExecutionMode.BATCH, BATCH, BATCH},
+                               {ExecutionMode.BATCH, BATCH, PIPELINED},
                                {ExecutionMode.BATCH_FORCED, BATCH, BATCH},
                });
 
@@ -93,10 +94,16 @@ public class UnionClosedBranchingTest extends 
CompilerTestBase {
        /** Expected {@link DataExchangeMode} from union to join. */
        private final DataExchangeMode unionToJoin;
 
+       /** Expected {@link ShipStrategyType} from source to union. */
+       private final ShipStrategyType sourceToUnionStrategy = 
ShipStrategyType.PARTITION_HASH;
+
+       /** Expected {@link ShipStrategyType} from union to join. */
+       private final ShipStrategyType unionToJoinStrategy = 
ShipStrategyType.FORWARD;
+
        public UnionClosedBranchingTest(
-                       ExecutionMode executionMode,
-                       DataExchangeMode sourceToUnion,
-                       DataExchangeMode unionToJoin) {
+               ExecutionMode executionMode,
+               DataExchangeMode sourceToUnion,
+               DataExchangeMode unionToJoin) {
 
                this.executionMode = executionMode;
                this.sourceToUnion = sourceToUnion;
@@ -140,12 +147,16 @@ public class UnionClosedBranchingTest extends 
CompilerTestBase {
                for (Channel channel : joinNode.getInputs()) {
                        assertEquals("Unexpected data exchange mode between 
union and join node.",
                                        unionToJoin, 
channel.getDataExchangeMode());
+                       assertEquals("Unexpected ship strategy between union 
and join node.",
+                                       unionToJoinStrategy, 
channel.getShipStrategy());
                }
 
                for (SourcePlanNode src : optimizedPlan.getDataSources()) {
                        for (Channel channel : src.getOutgoingChannels()) {
                                assertEquals("Unexpected data exchange mode 
between source and union node.",
                                                sourceToUnion, 
channel.getDataExchangeMode());
+                               assertEquals("Unexpected ship strategy between 
source and union node.",
+                                       sourceToUnionStrategy, 
channel.getShipStrategy());
                        }
                }
 
@@ -176,9 +187,8 @@ public class UnionClosedBranchingTest extends 
CompilerTestBase {
                        for (IntermediateDataSet dataSet : 
src.getProducedDataSets()) {
                                ResultPartitionType dsType = 
dataSet.getResultType();
 
-                               // The result type is determined by the channel 
between the union and the join node
-                               // and *not* the channel between source and 
union.
-                               if (unionToJoin.equals(BATCH)) {
+                               // Ensure batch exchange unless PIPELINED_FORCE 
is enabled.
+                               if 
(!executionMode.equals(ExecutionMode.PIPELINED_FORCED)) {
                                        assertTrue("Expected batch exchange, 
but result type is " + dsType + ".",
                                                        dsType.isBlocking());
                                } else {

Reply via email to