zhuzhurk commented on code in PR #21880:
URL: https://github.com/apache/flink/pull/21880#discussion_r1101002408


##########
flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java:
##########
@@ -119,43 +119,62 @@ void testAdaptiveBatchScheduler() throws Exception {
         assertThat(sink.getParallelism()).isEqualTo(10);
 
         // check aggregatedInputDataBytes of each ExecutionVertex calculated.
-        checkAggregatedInputDataBytesIsCalculated(sinkExecutionJobVertex);
+        checkAggregatedInputDataBytesIsCalculated(sinkExecutionJobVertex, 
26_000L);
     }
 
     @Test
     void testDecideParallelismForForwardTarget() throws Exception {
-        JobGraph jobGraph = createJobGraph(true);
-        Iterator<JobVertex> jobVertexIterator = 
jobGraph.getVertices().iterator();
-        JobVertex source1 = jobVertexIterator.next();
-        JobVertex source2 = jobVertexIterator.next();
-        JobVertex sink = jobVertexIterator.next();
+        final JobVertex source = createJobVertex("source", 
SOURCE_PARALLELISM_1);
+        final JobVertex map = createJobVertex("map", -1);
+        final JobVertex sink = createJobVertex("sink", -1);
 
-        SchedulerBase scheduler = createScheduler(jobGraph);
+        map.connectNewDataSetAsInput(
+                source, DistributionPattern.POINTWISE, 
ResultPartitionType.BLOCKING);
+        sink.connectNewDataSetAsInput(
+                map, DistributionPattern.POINTWISE, 
ResultPartitionType.BLOCKING);
+        
map.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
+
+        SchedulerBase scheduler =
+                createScheduler(
+                        new JobGraph(new JobID(), "test job", source, map, 
sink),
+                        createCustomParallelismDecider(
+                                jobVertexId -> {
+                                    if (jobVertexId.equals(map.getID())) {
+                                        return 5;
+                                    } else {
+                                        return 10;
+                                    }
+                                }),
+                        128);
 
         final DefaultExecutionGraph graph = (DefaultExecutionGraph) 
scheduler.getExecutionGraph();
+        final ExecutionJobVertex mapExecutionJobVertex = 
graph.getJobVertex(map.getID());
         final ExecutionJobVertex sinkExecutionJobVertex = 
graph.getJobVertex(sink.getID());
 
         scheduler.startScheduling();
+        assertThat(mapExecutionJobVertex.getParallelism()).isEqualTo(-1);
         assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
 
-        // trigger source1 finished.
-        transitionExecutionsState(scheduler, ExecutionState.FINISHED, source1);
+        // trigger source finished.
+        transitionExecutionsState(scheduler, ExecutionState.FINISHED, source);
+        assertThat(mapExecutionJobVertex.getParallelism()).isEqualTo(5);
         assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
 
-        // trigger source2 finished.
-        transitionExecutionsState(scheduler, ExecutionState.FINISHED, source2);
-        
assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1);
+        // trigger map finished.
+        transitionExecutionsState(scheduler, ExecutionState.FINISHED, map);
+        assertThat(mapExecutionJobVertex.getParallelism()).isEqualTo(5);
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(5);
 
         // check that the jobGraph is updated
-        assertThat(sink.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1);
+        assertThat(sink.getParallelism()).isEqualTo(5);
 
         // check aggregatedInputDataBytes of each ExecutionVertex calculated.
-        checkAggregatedInputDataBytesIsCalculated(sinkExecutionJobVertex);
+        checkAggregatedInputDataBytesIsCalculated(sinkExecutionJobVertex, 
13_000L);

Review Comment:
   Maybe later we can improve 
`IntermediateResultPartition#computeNumberOfMaxPossiblePartitionConsumers` so 
that for a forward consumption group, the number of subpartitions can be 1.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to