This is an automated email from the ASF dual-hosted git repository.

aljoscha pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 4b80bf813f96616caabbb5c8809900252607bd71
Author: ifndef-SleePy <[email protected]>
AuthorDate: Fri Jun 14 21:38:25 2019 +0800

    [FLINK-12686][datastream] Configure co-location group of iteration nodes in 
StreamGraph instread of in StreamingJobGraphGenerator
---
 .../jobmanager/scheduler/CoLocationGroup.java      |  4 ++
 .../flink/streaming/api/graph/StreamGraph.java     | 15 +++++--
 .../api/graph/StreamingJobGraphGenerator.java      | 15 +------
 .../api/graph/StreamGraphGeneratorTest.java        | 27 ++++++++++++
 .../api/graph/StreamingJobGraphGeneratorTest.java  | 48 ++++++++++++++++++++++
 5 files changed, 91 insertions(+), 18 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmanager/scheduler/CoLocationGroup.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmanager/scheduler/CoLocationGroup.java
index 35c48e1..ef8bd67 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmanager/scheduler/CoLocationGroup.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmanager/scheduler/CoLocationGroup.java
@@ -63,6 +63,10 @@ public class CoLocationGroup implements java.io.Serializable 
{
                Preconditions.checkNotNull(vertex);
                this.vertices.add(vertex);
        }
+
+       public List<JobVertex> getVertices() {
+               return vertices;
+       }
        
        public void mergeInto(CoLocationGroup other) {
                Preconditions.checkNotNull(other);
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
index 2d72276..942da2c 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
@@ -79,6 +79,10 @@ public class StreamGraph extends StreamingPlan {
 
        private static final Logger LOG = 
LoggerFactory.getLogger(StreamGraph.class);
 
+       public static final String ITERATION_SOURCE_NAME_PREFIX = 
"IterationSource";
+
+       public static final String ITERATION_SINK_NAME_PREFIX = "IterationSink";
+
        private String jobName;
 
        private final ExecutionConfig executionConfig;
@@ -620,12 +624,15 @@ public class StreamGraph extends StreamingPlan {
                int maxParallelism,
                ResourceSpec minResources,
                ResourceSpec preferredResources) {
+
+               final String coLocationGroup = "IterationCoLocationGroup-" + 
loopId;
+
                StreamNode source = this.addNode(sourceId,
                        null,
-                       null,
+                       coLocationGroup,
                        StreamIterationHead.class,
                        null,
-                       "IterationSource-" + loopId);
+                       ITERATION_SOURCE_NAME_PREFIX + "-" + loopId);
                sources.add(source.getId());
                setParallelism(source.getId(), parallelism);
                setMaxParallelism(source.getId(), maxParallelism);
@@ -633,10 +640,10 @@ public class StreamGraph extends StreamingPlan {
 
                StreamNode sink = this.addNode(sinkId,
                        null,
-                       null,
+                       coLocationGroup,
                        StreamIterationTail.class,
                        null,
-                       "IterationSink-" + loopId);
+                       ITERATION_SINK_NAME_PREFIX + "-" + loopId);
                sinks.add(sink.getId());
                setParallelism(sink.getId(), parallelism);
                setMaxParallelism(sink.getId(), parallelism);
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index 19151ce..dbedcdc 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -598,22 +598,9 @@ public class StreamingJobGraphGenerator {
                                }
 
                                vertex.updateCoLocationGroup(constraint.f1);
+                               constraint.f1.addVertex(vertex);
                        }
                }
-
-               for (Tuple2<StreamNode, StreamNode> pair : 
streamGraph.getIterationSourceSinkPairs()) {
-
-                       CoLocationGroup ccg = new CoLocationGroup();
-
-                       JobVertex source = jobVertices.get(pair.f0.getId());
-                       JobVertex sink = jobVertices.get(pair.f1.getId());
-
-                       ccg.addVertex(source);
-                       ccg.addVertex(sink);
-                       source.updateCoLocationGroup(ccg);
-                       sink.updateCoLocationGroup(ccg);
-               }
-
        }
 
        private void configureCheckpointing() {
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
index 19e2fb6..5e42a55 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
@@ -21,8 +21,10 @@ package org.apache.flink.streaming.api.graph;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.streaming.api.datastream.ConnectedStreams;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.IterativeStream;
 import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.co.CoMapFunction;
@@ -49,6 +51,7 @@ import org.apache.flink.streaming.util.NoOpIntMap;
 import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 
 /**
@@ -420,6 +423,30 @@ public class StreamGraphGeneratorTest {
                env.getStreamGraph().getStreamingPlanAsJSON();
        }
 
+       /**
+        * Test iteration job, check slot sharing group and co-location group.
+        */
+       @Test
+       public void testIteration() {
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+
+               DataStream<Integer> source = env.fromElements(1, 2, 
3).name("source");
+               IterativeStream<Integer> iteration = source.iterate(3000);
+               iteration.name("iteration").setParallelism(2);
+               DataStream<Integer> map = iteration.map(x -> x + 
1).name("map").setParallelism(2);
+               DataStream<Integer> filter = map.filter((x) -> 
false).name("filter").setParallelism(2);
+               iteration.closeWith(filter).print();
+
+               StreamGraph streamGraph = env.getStreamGraph();
+               for (Tuple2<StreamNode, StreamNode> iterationPair : 
streamGraph.getIterationSourceSinkPairs()) {
+                       assertNotNull(iterationPair.f0.getCoLocationGroup());
+                       assertEquals(iterationPair.f0.getCoLocationGroup(), 
iterationPair.f1.getCoLocationGroup());
+
+                       assertNotNull(iterationPair.f0.getSlotSharingGroup());
+                       assertEquals(iterationPair.f0.getSlotSharingGroup(), 
iterationPair.f1.getSlotSharingGroup());
+               }
+       }
+
        private static class OutputTypeConfigurableOperationWithTwoInputs
                        extends AbstractStreamOperator<Integer>
                        implements TwoInputStreamOperator<Integer, Integer, 
Integer>, OutputTypeConfigurable<Integer> {
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
index 01b888e..4e3e49c 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
@@ -36,6 +36,8 @@ import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings;
+import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup;
+import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
 import org.apache.flink.runtime.operators.util.TaskConfig;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.DataStreamSink;
@@ -60,6 +62,8 @@ import java.util.Map;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 
 /**
@@ -398,4 +402,48 @@ public class StreamingJobGraphGeneratorTest extends 
TestLogger {
                assertEquals(ResultPartitionType.PIPELINED_BOUNDED, 
mapVertex.getInputs().get(0).getSource().getResultType());
                assertEquals(ResultPartitionType.BLOCKING, 
printVertex.getInputs().get(0).getSource().getResultType());
        }
+
+       /**
+        * Test iteration job, check slot sharing group and co-location group.
+        */
+       @Test
+       public void testIteration() {
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+
+               DataStream<Integer> source = env.fromElements(1, 2, 
3).name("source");
+               IterativeStream<Integer> iteration = source.iterate(3000);
+               iteration.name("iteration").setParallelism(2);
+               DataStream<Integer> map = iteration.map(x -> x + 
1).name("map").setParallelism(2);
+               DataStream<Integer> filter = map.filter((x) -> 
false).name("filter").setParallelism(2);
+               iteration.closeWith(filter).print();
+
+               JobGraph jobGraph = 
StreamingJobGraphGenerator.createJobGraph(env.getStreamGraph());
+
+               SlotSharingGroup slotSharingGroup = 
jobGraph.getVerticesAsArray()[0].getSlotSharingGroup();
+               assertNotNull(slotSharingGroup);
+
+               CoLocationGroup iterationSourceCoLocationGroup = null;
+               CoLocationGroup iterationSinkCoLocationGroup = null;
+
+               for (JobVertex jobVertex : jobGraph.getVertices()) {
+                       // all vertices have same slot sharing group by default
+                       assertEquals(slotSharingGroup, 
jobVertex.getSlotSharingGroup());
+
+                       // all iteration vertices have same co-location group,
+                       // others have no co-location group by default
+                       if 
(jobVertex.getName().startsWith(StreamGraph.ITERATION_SOURCE_NAME_PREFIX)) {
+                               iterationSourceCoLocationGroup = 
jobVertex.getCoLocationGroup();
+                               
assertTrue(iterationSourceCoLocationGroup.getVertices().contains(jobVertex));
+                       } else if 
(jobVertex.getName().startsWith(StreamGraph.ITERATION_SINK_NAME_PREFIX)) {
+                               iterationSinkCoLocationGroup = 
jobVertex.getCoLocationGroup();
+                               
assertTrue(iterationSinkCoLocationGroup.getVertices().contains(jobVertex));
+                       } else {
+                               assertNull(jobVertex.getCoLocationGroup());
+                       }
+               }
+
+               assertNotNull(iterationSourceCoLocationGroup);
+               assertNotNull(iterationSinkCoLocationGroup);
+               assertEquals(iterationSourceCoLocationGroup, 
iterationSinkCoLocationGroup);
+       }
 }

Reply via email to