This is an automated email from the ASF dual-hosted git repository. zhuzh pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit a1555b0c53493fa80d0cf4b933396941c7de7073 Author: noorall <[email protected]> AuthorDate: Fri Oct 25 17:50:55 2024 +0800 [FLINK-36066][runtime] Introduce StreamGraphContext to provide the ability to modify the StreamGraph --- .../jobgraph/forwardgroup/ForwardGroup.java | 135 ++++------- .../forwardgroup/ForwardGroupComputeUtil.java | 114 +++++++-- ...orwardGroup.java => JobVertexForwardGroup.java} | 34 ++- .../forwardgroup/StreamNodeForwardGroup.java | 174 +++++++++++++ .../adaptivebatch/AdaptiveBatchScheduler.java | 11 +- .../AdaptiveBatchSchedulerFactory.java | 4 +- .../api/graph/DefaultStreamGraphContext.java | 270 +++++++++++++++++++++ .../streaming/api/graph/StreamGraphContext.java | 65 +++++ .../flink/streaming/api/graph/StreamNode.java | 4 +- .../api/graph/StreamingJobGraphGenerator.java | 3 +- .../graph/util/StreamEdgeUpdateRequestInfo.java | 59 +++++ .../forwardgroup/ForwardGroupComputeUtilTest.java | 163 ++++++++++++- .../forwardgroup/StreamNodeForwardGroupTest.java | 134 ++++++++++ .../api/graph/DefaultStreamGraphContextTest.java | 217 +++++++++++++++++ 14 files changed, 1254 insertions(+), 133 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java index 922acf231e2..0f81075f95d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java @@ -19,90 +19,61 @@ package org.apache.flink.runtime.jobgraph.forwardgroup; -import org.apache.flink.api.common.ExecutionConfig; -import org.apache.flink.runtime.jobgraph.JobVertex; -import org.apache.flink.runtime.jobgraph.JobVertexID; - -import java.util.Collections; -import java.util.HashSet; import java.util.Set; -import java.util.stream.Collectors; - -import static org.apache.flink.util.Preconditions.checkNotNull; -import static org.apache.flink.util.Preconditions.checkState; /** - * A forward group is a set of job vertices connected via forward edges. Parallelisms of all job - * vertices in the same {@link ForwardGroup} must be the same. + * A forward group is a set of job vertices or stream nodes connected via forward edges. + * Parallelisms of all job vertices or stream nodes in the same {@link ForwardGroup} must be the + * same. */ -public class ForwardGroup { - - private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT; - - private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT; - private final Set<JobVertexID> jobVertexIds = new HashSet<>(); - - public ForwardGroup(final Set<JobVertex> jobVertices) { - checkNotNull(jobVertices); - - Set<Integer> configuredParallelisms = - jobVertices.stream() - .filter( - jobVertex -> { - jobVertexIds.add(jobVertex.getID()); - return jobVertex.getParallelism() > 0; - }) - .map(JobVertex::getParallelism) - .collect(Collectors.toSet()); - - checkState(configuredParallelisms.size() <= 1); - if (configuredParallelisms.size() == 1) { - this.parallelism = configuredParallelisms.iterator().next(); - } - - Set<Integer> configuredMaxParallelisms = - jobVertices.stream() - .map(JobVertex::getMaxParallelism) - .filter(val -> val > 0) - .collect(Collectors.toSet()); - - if (!configuredMaxParallelisms.isEmpty()) { - this.maxParallelism = Collections.min(configuredMaxParallelisms); - checkState( - parallelism == ExecutionConfig.PARALLELISM_DEFAULT - || maxParallelism >= parallelism, - "There is a job vertex in the forward group whose maximum parallelism is smaller than the group's parallelism"); - } - } - - public void setParallelism(int parallelism) { - checkState(this.parallelism == ExecutionConfig.PARALLELISM_DEFAULT); - this.parallelism = parallelism; - } - - public boolean isParallelismDecided() { - return parallelism > 0; - } - - public int getParallelism() { - checkState(isParallelismDecided()); - return parallelism; - } - - public boolean isMaxParallelismDecided() { - return maxParallelism > 0; - } - - public int getMaxParallelism() { - checkState(isMaxParallelismDecided()); - return maxParallelism; - } - - public int size() { - return jobVertexIds.size(); - } - - public Set<JobVertexID> getJobVertexIds() { - return jobVertexIds; - } +public interface ForwardGroup<T> { + + /** + * Sets the parallelism for this forward group. + * + * @param parallelism the parallelism to set. + */ + void setParallelism(int parallelism); + + /** + * Returns if parallelism has been decided for this forward group. + * + * @return is parallelism decided for this forward group. + */ + boolean isParallelismDecided(); + + /** + * Returns the parallelism for this forward group. + * + * @return parallelism for this forward group. + */ + int getParallelism(); + + /** + * Sets the max parallelism for this forward group. + * + * @param maxParallelism the max parallelism to set. + */ + void setMaxParallelism(int maxParallelism); + + /** + * Returns if max parallelism has been decided for this forward group. + * + * @return is max parallelism decided for this forward group. + */ + boolean isMaxParallelismDecided(); + + /** + * Returns the max parallelism for this forward group. + * + * @return max parallelism for this forward group. + */ + int getMaxParallelism(); + + /** + * Returns the vertex ids in this forward group. + * + * @return vertex ids in this forward group. + */ + Set<T> getVertexIds(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java index dc2dc702e34..0db296b62ac 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSet; import org.apache.flink.runtime.jobgraph.JobEdge; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.streaming.api.graph.StreamNode; import java.util.HashMap; import java.util.HashSet; @@ -38,16 +39,17 @@ import static org.apache.flink.util.Preconditions.checkState; /** Common utils for computing forward groups. */ public class ForwardGroupComputeUtil { - public static Map<JobVertexID, ForwardGroup> computeForwardGroupsAndCheckParallelism( + public static Map<JobVertexID, JobVertexForwardGroup> computeForwardGroupsAndCheckParallelism( final Iterable<JobVertex> topologicallySortedVertices) { - final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId = + final Map<JobVertexID, JobVertexForwardGroup> forwardGroupsByJobVertexId = computeForwardGroups( topologicallySortedVertices, ForwardGroupComputeUtil::getForwardProducers); // the vertex's parallelism in parallelism-decided forward group should have been set at // compilation phase topologicallySortedVertices.forEach( jobVertex -> { - ForwardGroup forwardGroup = forwardGroupsByJobVertexId.get(jobVertex.getID()); + JobVertexForwardGroup forwardGroup = + forwardGroupsByJobVertexId.get(jobVertex.getID()); if (forwardGroup != null && forwardGroup.isParallelismDecided()) { checkState(jobVertex.getParallelism() == forwardGroup.getParallelism()); } @@ -55,28 +57,73 @@ public class ForwardGroupComputeUtil { return forwardGroupsByJobVertexId; } - public static Map<JobVertexID, ForwardGroup> computeForwardGroups( + public static Map<JobVertexID, JobVertexForwardGroup> computeForwardGroups( final Iterable<JobVertex> topologicallySortedVertices, final Function<JobVertex, Set<JobVertex>> forwardProducersRetriever) { - final Map<JobVertex, Set<JobVertex>> vertexToGroup = new IdentityHashMap<>(); + final Map<JobVertex, Set<JobVertex>> vertexToGroup = + computeVertexToGroup(topologicallySortedVertices, forwardProducersRetriever); + + final Map<JobVertexID, JobVertexForwardGroup> ret = new HashMap<>(); + for (Set<JobVertex> vertexGroup : + VertexGroupComputeUtil.uniqueVertexGroups(vertexToGroup)) { + if (vertexGroup.size() > 1) { + JobVertexForwardGroup forwardGroup = new JobVertexForwardGroup(vertexGroup); + for (JobVertexID jobVertexId : forwardGroup.getVertexIds()) { + ret.put(jobVertexId, forwardGroup); + } + } + } + return ret; + } + + /** + * We calculate forward group by a set of stream nodes. + * + * @param topologicallySortedStreamNodes topologically sorted chained stream nodes + * @param forwardProducersRetriever records all upstream stream nodes which connected to the + * given stream node with forward edge + * @return a map of forward groups, with the stream node id as the key + */ + public static Map<Integer, StreamNodeForwardGroup> computeStreamNodeForwardGroup( + final Iterable<StreamNode> topologicallySortedStreamNodes, + final Function<StreamNode, Set<StreamNode>> forwardProducersRetriever) { + // In the forwardProducersRetriever, only the upstream nodes connected to the given start + // node by the forward edge are saved. We need to calculate the chain groups that can be + // accessed with consecutive forward edges and put them in the same forward group. + final Map<StreamNode, Set<StreamNode>> nodeToGroup = + computeVertexToGroup(topologicallySortedStreamNodes, forwardProducersRetriever); + final Map<Integer, StreamNodeForwardGroup> ret = new HashMap<>(); + for (Set<StreamNode> nodeGroup : VertexGroupComputeUtil.uniqueVertexGroups(nodeToGroup)) { + StreamNodeForwardGroup forwardGroup = new StreamNodeForwardGroup(nodeGroup); + for (Integer vertexId : forwardGroup.getVertexIds()) { + ret.put(vertexId, forwardGroup); + } + } + return ret; + } + + private static <T> Map<T, Set<T>> computeVertexToGroup( + final Iterable<T> topologicallySortedVertices, + final Function<T, Set<T>> forwardProducersRetriever) { + final Map<T, Set<T>> vertexToGroup = new IdentityHashMap<>(); // iterate all the vertices which are topologically sorted - for (JobVertex vertex : topologicallySortedVertices) { - Set<JobVertex> currentGroup = new HashSet<>(); + for (T vertex : topologicallySortedVertices) { + Set<T> currentGroup = new HashSet<>(); currentGroup.add(vertex); vertexToGroup.put(vertex, currentGroup); - for (JobVertex producerVertex : forwardProducersRetriever.apply(vertex)) { - final Set<JobVertex> producerGroup = vertexToGroup.get(producerVertex); + for (T producerVertex : forwardProducersRetriever.apply(vertex)) { + final Set<T> producerGroup = vertexToGroup.get(producerVertex); if (producerGroup == null) { throw new IllegalStateException( "Producer task " - + producerVertex.getID() + + producerVertex + " forward group is null" + " while calculating forward group for the consumer task " - + vertex.getID() + + vertex + ". This should be a forward group building bug."); } @@ -87,18 +134,43 @@ public class ForwardGroupComputeUtil { } } } + return vertexToGroup; + } - final Map<JobVertexID, ForwardGroup> ret = new HashMap<>(); - for (Set<JobVertex> vertexGroup : - VertexGroupComputeUtil.uniqueVertexGroups(vertexToGroup)) { - if (vertexGroup.size() > 1) { - ForwardGroup forwardGroup = new ForwardGroup(vertexGroup); - for (JobVertexID jobVertexId : forwardGroup.getJobVertexIds()) { - ret.put(jobVertexId, forwardGroup); - } - } + /** + * Determines whether the target forward group can be merged into the source forward group. + * + * @param sourceForwardGroup The source forward group. + * @param forwardGroupToMerge The forward group needs to be merged. + * @return whether the merge is valid. + */ + public static boolean canTargetMergeIntoSourceForwardGroup( + ForwardGroup<?> sourceForwardGroup, ForwardGroup<?> forwardGroupToMerge) { + if (sourceForwardGroup == null || forwardGroupToMerge == null) { + return false; } - return ret; + + if (sourceForwardGroup == forwardGroupToMerge) { + return true; + } + + if (sourceForwardGroup.isParallelismDecided() + && forwardGroupToMerge.isParallelismDecided() + && sourceForwardGroup.getParallelism() != forwardGroupToMerge.getParallelism()) { + return false; + } + + // When the parallelism of source forward groups is determined, the maximum + // parallelism of the forwardGroupToMerge should not be less than the parallelism of the + // sourceForwardGroup to ensure the forwardGroupToMerge can also achieve the same + // parallelism. + if (sourceForwardGroup.isParallelismDecided() + && forwardGroupToMerge.isMaxParallelismDecided() + && sourceForwardGroup.getParallelism() > forwardGroupToMerge.getMaxParallelism()) { + return false; + } + + return true; } static Set<JobVertex> getForwardProducers(final JobVertex jobVertex) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/JobVertexForwardGroup.java similarity index 81% copy from flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java copy to flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/JobVertexForwardGroup.java index 922acf231e2..d9ad77a61de 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/JobVertexForwardGroup.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.jobgraph.forwardgroup; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -31,18 +32,15 @@ import java.util.stream.Collectors; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; -/** - * A forward group is a set of job vertices connected via forward edges. Parallelisms of all job - * vertices in the same {@link ForwardGroup} must be the same. - */ -public class ForwardGroup { +/** Job vertex level implement for {@link ForwardGroup}. */ +public class JobVertexForwardGroup implements ForwardGroup<JobVertexID> { private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT; private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT; private final Set<JobVertexID> jobVertexIds = new HashSet<>(); - public ForwardGroup(final Set<JobVertex> jobVertices) { + public JobVertexForwardGroup(final Set<JobVertex> jobVertices) { checkNotNull(jobVertices); Set<Integer> configuredParallelisms = @@ -75,34 +73,50 @@ public class ForwardGroup { } } + @Override public void setParallelism(int parallelism) { checkState(this.parallelism == ExecutionConfig.PARALLELISM_DEFAULT); this.parallelism = parallelism; } + @Override public boolean isParallelismDecided() { return parallelism > 0; } + @Override public int getParallelism() { checkState(isParallelismDecided()); return parallelism; } + @Override + public void setMaxParallelism(int maxParallelism) { + checkState( + maxParallelism == ExecutionConfig.PARALLELISM_DEFAULT + || maxParallelism >= parallelism, + "There is a job vertex in the forward group whose maximum parallelism is smaller than the group's parallelism"); + this.maxParallelism = maxParallelism; + } + + @Override public boolean isMaxParallelismDecided() { return maxParallelism > 0; } + @Override public int getMaxParallelism() { checkState(isMaxParallelismDecided()); return maxParallelism; } - public int size() { - return jobVertexIds.size(); + @Override + public Set<JobVertexID> getVertexIds() { + return jobVertexIds; } - public Set<JobVertexID> getJobVertexIds() { - return jobVertexIds; + @VisibleForTesting + public int size() { + return jobVertexIds.size(); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroup.java new file mode 100644 index 00000000000..af37d4262f7 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroup.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.jobgraph.forwardgroup; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.streaming.api.graph.StreamNode; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** Stream node level implement for {@link ForwardGroup}. */ +public class StreamNodeForwardGroup implements ForwardGroup<Integer> { + + private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT; + + private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT; + + private final Set<StreamNode> streamNodes = new HashSet<>(); + + // For a group of chained stream nodes, their parallelism is consistent. In order to make + // calculation and usage easier, we only use the start node to calculate forward group. + public StreamNodeForwardGroup(final Set<StreamNode> streamNodes) { + checkNotNull(streamNodes); + + Set<Integer> configuredParallelisms = + streamNodes.stream() + .map(StreamNode::getParallelism) + .filter(v -> v > 0) + .collect(Collectors.toSet()); + + checkState(configuredParallelisms.size() <= 1); + if (configuredParallelisms.size() == 1) { + this.parallelism = configuredParallelisms.iterator().next(); + } + + Set<Integer> configuredMaxParallelisms = + streamNodes.stream() + .map(StreamNode::getMaxParallelism) + .filter(val -> val > 0) + .collect(Collectors.toSet()); + + if (!configuredMaxParallelisms.isEmpty()) { + this.maxParallelism = Collections.min(configuredMaxParallelisms); + checkState( + parallelism == ExecutionConfig.PARALLELISM_DEFAULT + || maxParallelism >= parallelism, + "There is a start node in the forward group whose maximum parallelism is smaller than the group's parallelism"); + } + + this.streamNodes.addAll(streamNodes); + } + + @Override + public void setParallelism(int parallelism) { + checkState(this.parallelism == ExecutionConfig.PARALLELISM_DEFAULT); + this.parallelism = parallelism; + this.streamNodes.forEach( + streamNode -> { + streamNode.setParallelism(parallelism); + }); + } + + @Override + public boolean isParallelismDecided() { + return parallelism > 0; + } + + @Override + public int getParallelism() { + checkState(isParallelismDecided()); + return parallelism; + } + + @Override + public void setMaxParallelism(int maxParallelism) { + checkState( + maxParallelism == ExecutionConfig.PARALLELISM_DEFAULT + || maxParallelism >= parallelism, + "There is a job vertex in the forward group whose maximum parallelism is smaller than the group's parallelism"); + this.maxParallelism = maxParallelism; + this.streamNodes.forEach( + streamNode -> { + streamNode.setMaxParallelism(maxParallelism); + }); + } + + @Override + public boolean isMaxParallelismDecided() { + return maxParallelism > 0; + } + + @Override + public int getMaxParallelism() { + checkState(isMaxParallelismDecided()); + return maxParallelism; + } + + @Override + public Set<Integer> getVertexIds() { + return streamNodes.stream().map(StreamNode::getId).collect(Collectors.toSet()); + } + + /** + * Merges forwardGroupToMerge into this and update the parallelism information for stream nodes + * in merged forward group. + * + * @param forwardGroupToMerge The forward group to be merged. + * @return whether the merge was successful. + */ + public boolean mergeForwardGroup(final StreamNodeForwardGroup forwardGroupToMerge) { + checkNotNull(forwardGroupToMerge); + + if (forwardGroupToMerge == this) { + return true; + } + + if (!ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup( + this, forwardGroupToMerge)) { + return false; + } + + if (this.isParallelismDecided() && !forwardGroupToMerge.isParallelismDecided()) { + forwardGroupToMerge.setParallelism(this.parallelism); + } else if (!this.isParallelismDecided() && forwardGroupToMerge.isParallelismDecided()) { + this.setParallelism(forwardGroupToMerge.parallelism); + } else { + checkState(this.parallelism == forwardGroupToMerge.parallelism); + } + + if (forwardGroupToMerge.isMaxParallelismDecided() + && (!this.isMaxParallelismDecided() + || this.maxParallelism > forwardGroupToMerge.maxParallelism)) { + this.setMaxParallelism(forwardGroupToMerge.maxParallelism); + } else if (this.isMaxParallelismDecided() + && (!forwardGroupToMerge.isMaxParallelismDecided() + || forwardGroupToMerge.maxParallelism > this.maxParallelism)) { + forwardGroupToMerge.setMaxParallelism(this.maxParallelism); + } else { + checkState(this.maxParallelism == forwardGroupToMerge.maxParallelism); + } + + this.streamNodes.addAll(forwardGroupToMerge.streamNodes); + + return true; + } + + @VisibleForTesting + public int size() { + return streamNodes.size(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java index a3d3e582c65..e012712739a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java @@ -58,7 +58,7 @@ import org.apache.flink.runtime.jobgraph.JobEdge; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup; +import org.apache.flink.runtime.jobgraph.forwardgroup.JobVertexForwardGroup; import org.apache.flink.runtime.jobgraph.jsonplan.JsonPlanGenerator; import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalResult; import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalTopology; @@ -117,7 +117,7 @@ public class AdaptiveBatchScheduler extends DefaultScheduler { private final VertexParallelismAndInputInfosDecider vertexParallelismAndInputInfosDecider; - private final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId; + private final Map<JobVertexID, JobVertexForwardGroup> forwardGroupsByJobVertexId; private final Map<IntermediateDataSetID, BlockingResultInfo> blockingResultInfos; @@ -166,7 +166,7 @@ public class AdaptiveBatchScheduler extends DefaultScheduler { final int defaultMaxParallelism, final BlocklistOperations blocklistOperations, final HybridPartitionDataConsumeConstraint hybridPartitionDataConsumeConstraint, - final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId, + final Map<JobVertexID, JobVertexForwardGroup> forwardGroupsByJobVertexId, final BatchJobRecoveryHandler jobRecoveryHandler) throws Exception { @@ -626,7 +626,8 @@ public class AdaptiveBatchScheduler extends DefaultScheduler { private ParallelismAndInputInfos tryDecideParallelismAndInputInfos( final ExecutionJobVertex jobVertex, List<BlockingResultInfo> inputs) { int vertexInitialParallelism = jobVertex.getParallelism(); - ForwardGroup forwardGroup = forwardGroupsByJobVertexId.get(jobVertex.getJobVertexId()); + JobVertexForwardGroup forwardGroup = + forwardGroupsByJobVertexId.get(jobVertex.getJobVertexId()); if (!jobVertex.isParallelismDecided() && forwardGroup != null) { checkState(!forwardGroup.isParallelismDecided()); } @@ -681,7 +682,7 @@ public class AdaptiveBatchScheduler extends DefaultScheduler { // the ordering of these elements received by the committer cannot be assured, which // would break the assumption that CommittableSummary is received before // CommittableWithLineage. - for (JobVertexID jobVertexId : forwardGroup.getJobVertexIds()) { + for (JobVertexID jobVertexId : forwardGroup.getVertexIds()) { ExecutionJobVertex executionJobVertex = getExecutionJobVertex(jobVertexId); if (!executionJobVertex.isParallelismDecided()) { log.info( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java index 24d9f64b2b8..91a578ea995 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java @@ -45,8 +45,8 @@ import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobType; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup; import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.forwardgroup.JobVertexForwardGroup; import org.apache.flink.runtime.jobmaster.ExecutionDeploymentTracker; import org.apache.flink.runtime.jobmaster.event.FileSystemJobEventStore; import org.apache.flink.runtime.jobmaster.event.JobEventManager; @@ -269,7 +269,7 @@ public class AdaptiveBatchSchedulerFactory implements SchedulerNGFactory { int defaultMaxParallelism = getDefaultMaxParallelism(jobMasterConfiguration, executionConfig); - final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId = + final Map<JobVertexID, JobVertexForwardGroup> forwardGroupsByJobVertexId = ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism( jobGraph.getVerticesSortedTopologicallyFromSources()); diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java new file mode 100644 index 00000000000..07d8631bf92 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.runtime.partitioner.ForwardForConsecutiveHashPartitioner; +import org.apache.flink.streaming.runtime.partitioner.ForwardForUnspecifiedPartitioner; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup; +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation for {@link StreamGraphContext}. */ +@Internal +public class DefaultStreamGraphContext implements StreamGraphContext { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultStreamGraphContext.class); + + private final StreamGraph streamGraph; + private final ImmutableStreamGraph immutableStreamGraph; + + // The attributes below are reused from AdaptiveGraphManager as AdaptiveGraphManager also needs + // to use the modified information to create the job vertex. + + // A modifiable map which records the ids of stream nodes to their forward groups. + // When stream edge's partitioner is modified to forward, we need get forward groups by source + // and target node id and merge them. + private final Map<Integer, StreamNodeForwardGroup> steamNodeIdToForwardGroupMap; + // A read only map which records the id of stream node which job vertex is created, used to + // ensure that the stream nodes involved in the modification have not yet created job vertices. + private final Map<Integer, Integer> frozenNodeToStartNodeMap; + // A modifiable map which key is the id of stream node which creates the non-chained output, and + // value is the stream edge connected to the stream node and the non-chained output subscribed + // by the edge. It is used to verify whether the edge being modified is subscribed to a reused + // output and ensures that modifications to StreamEdge can be synchronized to NonChainedOutput + // as they reuse some attributes. + private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches; + + public DefaultStreamGraphContext( + StreamGraph streamGraph, + Map<Integer, StreamNodeForwardGroup> steamNodeIdToForwardGroupMap, + Map<Integer, Integer> frozenNodeToStartNodeMap, + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches) { + this.streamGraph = checkNotNull(streamGraph); + this.steamNodeIdToForwardGroupMap = checkNotNull(steamNodeIdToForwardGroupMap); + this.frozenNodeToStartNodeMap = checkNotNull(frozenNodeToStartNodeMap); + this.opIntermediateOutputsCaches = checkNotNull(opIntermediateOutputsCaches); + this.immutableStreamGraph = new ImmutableStreamGraph(this.streamGraph); + } + + @Override + public ImmutableStreamGraph getStreamGraph() { + return immutableStreamGraph; + } + + @Override + public @Nullable StreamOperatorFactory<?> getOperatorFactory(Integer streamNodeId) { + return streamGraph.getStreamNode(streamNodeId).getOperatorFactory(); + } + + @Override + public boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> requestInfos) { + // We first verify the legality of all requestInfos to ensure that all requests can be + // modified atomically. + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + if (!validateStreamEdgeUpdateRequest(requestInfo)) { + return false; + } + } + + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + StreamEdge targetEdge = + getStreamEdge( + requestInfo.getSourceId(), + requestInfo.getTargetId(), + requestInfo.getEdgeId()); + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null) { + modifyOutputPartitioner(targetEdge, newPartitioner); + } + } + + return true; + } + + private boolean validateStreamEdgeUpdateRequest(StreamEdgeUpdateRequestInfo requestInfo) { + Integer sourceNodeId = requestInfo.getSourceId(); + Integer targetNodeId = requestInfo.getTargetId(); + + StreamEdge targetEdge = getStreamEdge(sourceNodeId, targetNodeId, requestInfo.getEdgeId()); + + if (targetEdge == null) { + return false; + } + + // Modification is not allowed when the subscribing output is reused. + Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = + opIntermediateOutputsCaches.get(sourceNodeId); + NonChainedOutput output = + opIntermediateOutputs != null ? opIntermediateOutputs.get(targetEdge) : null; + if (output != null) { + Set<StreamEdge> consumerStreamEdges = + opIntermediateOutputs.entrySet().stream() + .filter(entry -> entry.getValue().equals(output)) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + if (consumerStreamEdges.size() != 1) { + LOG.info( + "Skip modifying edge {} because the subscribing output is reused.", + targetEdge); + return false; + } + } + + if (frozenNodeToStartNodeMap.containsKey(targetNodeId)) { + LOG.info( + "Skip modifying edge {} because the target node with id {} is in frozen list.", + targetEdge, + targetNodeId); + return false; + } + + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + + if (newPartitioner != null) { + if (targetEdge.getPartitioner().getClass().equals(ForwardPartitioner.class)) { + LOG.info( + "Modification for edge {} is not allowed as the origin partitioner is ForwardPartitioner.", + targetEdge); + return false; + } + if (newPartitioner.getClass().equals(ForwardPartitioner.class) + && !canTargetMergeIntoSourceForwardGroup( + steamNodeIdToForwardGroupMap.get(targetEdge.getSourceId()), + steamNodeIdToForwardGroupMap.get(targetEdge.getTargetId()))) { + LOG.info( + "Skip modifying edge {} because forward groups can not be merged.", + targetEdge); + return false; + } + } + + return true; + } + + private void modifyOutputPartitioner( + StreamEdge targetEdge, StreamPartitioner<?> newPartitioner) { + if (newPartitioner == null || targetEdge == null) { + return; + } + StreamPartitioner<?> oldPartitioner = targetEdge.getPartitioner(); + targetEdge.setPartitioner(newPartitioner); + + if (targetEdge.getPartitioner() instanceof ForwardPartitioner) { + tryConvertForwardPartitionerAndMergeForwardGroup(targetEdge); + } + + // The partitioner in NonChainedOutput derived from the consumer edge, so we need to ensure + // that any modifications to the partitioner of consumer edge are synchronized with + // NonChainedOutput. + Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = + opIntermediateOutputsCaches.get(targetEdge.getSourceId()); + NonChainedOutput output = + opIntermediateOutputs != null ? opIntermediateOutputs.get(targetEdge) : null; + if (output != null) { + output.setPartitioner(targetEdge.getPartitioner()); + } + LOG.info( + "The original partitioner of the edge {} is: {} , requested change to: {} , and finally modified to: {}.", + targetEdge, + oldPartitioner, + newPartitioner, + targetEdge.getPartitioner()); + } + + private void tryConvertForwardPartitionerAndMergeForwardGroup(StreamEdge targetEdge) { + checkState(targetEdge.getPartitioner() instanceof ForwardPartitioner); + Integer sourceNodeId = targetEdge.getSourceId(); + Integer targetNodeId = targetEdge.getTargetId(); + if (canConvertToForwardPartitioner(targetEdge)) { + targetEdge.setPartitioner(new ForwardPartitioner<>()); + checkState(mergeForwardGroups(sourceNodeId, targetNodeId)); + } else if (targetEdge.getPartitioner() instanceof ForwardForUnspecifiedPartitioner) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } else if (targetEdge.getPartitioner() instanceof ForwardForConsecutiveHashPartitioner) { + targetEdge.setPartitioner( + ((ForwardForConsecutiveHashPartitioner<?>) targetEdge.getPartitioner()) + .getHashPartitioner()); + } else { + // For ForwardPartitioner, StreamGraphContext can ensure the success of the merge. + checkState(mergeForwardGroups(sourceNodeId, targetNodeId)); + } + } + + private boolean canConvertToForwardPartitioner(StreamEdge targetEdge) { + Integer sourceNodeId = targetEdge.getSourceId(); + Integer targetNodeId = targetEdge.getTargetId(); + if (targetEdge.getPartitioner() instanceof ForwardForUnspecifiedPartitioner) { + return !frozenNodeToStartNodeMap.containsKey(sourceNodeId) + && StreamingJobGraphGenerator.isChainable(targetEdge, streamGraph, true) + && canTargetMergeIntoSourceForwardGroup( + steamNodeIdToForwardGroupMap.get(sourceNodeId), + steamNodeIdToForwardGroupMap.get(targetNodeId)); + } else if (targetEdge.getPartitioner() instanceof ForwardForConsecutiveHashPartitioner) { + return canTargetMergeIntoSourceForwardGroup( + steamNodeIdToForwardGroupMap.get(sourceNodeId), + steamNodeIdToForwardGroupMap.get(targetNodeId)); + } else { + return false; + } + } + + private boolean mergeForwardGroups(Integer sourceNodeId, Integer targetNodeId) { + StreamNodeForwardGroup sourceForwardGroup = steamNodeIdToForwardGroupMap.get(sourceNodeId); + StreamNodeForwardGroup forwardGroupToMerge = steamNodeIdToForwardGroupMap.get(targetNodeId); + if (sourceForwardGroup == null || forwardGroupToMerge == null) { + return false; + } + if (!sourceForwardGroup.mergeForwardGroup(forwardGroupToMerge)) { + return false; + } + // Update steamNodeIdToForwardGroupMap. + forwardGroupToMerge + .getVertexIds() + .forEach(nodeId -> steamNodeIdToForwardGroupMap.put(nodeId, sourceForwardGroup)); + return true; + } + + private StreamEdge getStreamEdge(Integer sourceId, Integer targetId, String edgeId) { + for (StreamEdge edge : streamGraph.getStreamEdges(sourceId, targetId)) { + if (edge.getEdgeId().equals(edgeId)) { + return edge; + } + } + return null; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java new file mode 100644 index 00000000000..82cbac56135 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; + +import javax.annotation.Nullable; + +import java.util.List; + +/** + * Defines a context for optimizing and working with a read-only view of a StreamGraph. It provides + * methods to modify StreamEdges and StreamNodes within the StreamGraph. + */ +@Internal +public interface StreamGraphContext { + + /** + * Returns a read-only view of the StreamGraph. + * + * @return a read-only view of the StreamGraph. + */ + ImmutableStreamGraph getStreamGraph(); + + /** + * Retrieves the {@link StreamOperatorFactory} for the specified stream node id. + * + * @param streamNodeId the id of the stream node + * @return the {@link StreamOperatorFactory} associated with the given {@code streamNodeId}, or + * {@code null} if no operator factory is available. + */ + @Nullable + StreamOperatorFactory<?> getOperatorFactory(Integer streamNodeId); + + /** + * Atomically modifies stream edges within the StreamGraph. + * + * <p>This method ensures that all the requested modifications to stream edges are applied + * atomically. This means that if any modification fails, none of the changes will be applied, + * maintaining the consistency of the StreamGraph. + * + * @param requestInfos the stream edges to be modified. + * @return true if all modifications were successful and applied atomically, false otherwise. + */ + boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> requestInfos); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java index dbc5c90db5f..90431aa6b83 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java @@ -224,7 +224,7 @@ public class StreamNode implements Serializable { * * @return Maximum parallelism */ - int getMaxParallelism() { + public int getMaxParallelism() { return maxParallelism; } @@ -233,7 +233,7 @@ public class StreamNode implements Serializable { * * @param maxParallelism Maximum parallelism to be set */ - void setMaxParallelism(int maxParallelism) { + public void setMaxParallelism(int maxParallelism) { this.maxParallelism = maxParallelism; } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index 6476cfdfb52..d8d89a3a68b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -46,6 +46,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup; import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.forwardgroup.JobVertexForwardGroup; import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable; import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalPipelinedRegion; import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalTopology; @@ -850,7 +851,7 @@ public class StreamingJobGraphGenerator { }); // compute forward groups - final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId = + final Map<JobVertexID, JobVertexForwardGroup> forwardGroupsByJobVertexId = ForwardGroupComputeUtil.computeForwardGroups( topologicalOrderVertices, jobVertex -> diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java new file mode 100644 index 00000000000..7a33e6d9265 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph.util; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +/** Helper class carries the data required to updates a stream edge. */ +@Internal +public class StreamEdgeUpdateRequestInfo { + private final String edgeId; + private final Integer sourceId; + private final Integer targetId; + + private StreamPartitioner<?> outputPartitioner; + + public StreamEdgeUpdateRequestInfo(String edgeId, Integer sourceId, Integer targetId) { + this.edgeId = edgeId; + this.sourceId = sourceId; + this.targetId = targetId; + } + + public StreamEdgeUpdateRequestInfo outputPartitioner(StreamPartitioner<?> outputPartitioner) { + this.outputPartitioner = outputPartitioner; + return this; + } + + public String getEdgeId() { + return edgeId; + } + + public Integer getSourceId() { + return sourceId; + } + + public Integer getTargetId() { + return targetId; + } + + public StreamPartitioner<?> getOutputPartitioner() { + return outputPartitioner; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java index e9a90e98918..3e8e31da8b6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java @@ -22,19 +22,27 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.testtasks.NoOpInvokable; +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.operators.StreamOperator; import org.junit.jupiter.api.Test; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; +import static org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil.computeStreamNodeForwardGroup; +import static org.apache.flink.util.Preconditions.checkState; import static org.assertj.core.api.Assertions.assertThat; -/** - * Unit tests for {@link org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil}. - */ +/** Unit tests for {@link ForwardGroupComputeUtil}. */ class ForwardGroupComputeUtilTest { /** @@ -54,11 +62,26 @@ class ForwardGroupComputeUtilTest { JobVertex v2 = new JobVertex("v2"); JobVertex v3 = new JobVertex("v3"); - Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3); + Set<ForwardGroup<?>> groups = computeForwardGroups(v1, v2, v3); checkGroupSize(groups, 0); } + @Test + void testIsolatedChainedStreamNodeGroups() throws Exception { + List<StreamNode> topologicallySortedStreamNodes = createStreamNodes(3); + Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = Collections.emptyMap(); + + Set<ForwardGroup<?>> groups = + computeForwardGroups( + topologicallySortedStreamNodes, forwardProducersByConsumerNodeId); + + // Different from the job vertex forward group, the stream node forward group is allowed to + // contain only one single stream node, as these groups may merge with other groups in the + // future. + checkGroupSize(groups, 3, 1, 1, 1); + } + /** * Tests that the computation of the vertices connected with edges which have various result * partition types works correctly. @@ -96,7 +119,39 @@ class ForwardGroupComputeUtilTest { v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true); } - Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3); + Set<ForwardGroup<?>> groups = computeForwardGroups(v1, v2, v3); + + checkGroupSize(groups, numOfGroups, groupSizes); + } + + @Test + void testVariousConnectTypesBetweenChainedStreamNodeGroup() throws Exception { + testThreeChainedStreamNodeGroupsConnectSequentially(false, true, 2, 1, 2); + testThreeChainedStreamNodeGroupsConnectSequentially(false, false, 3, 1, 1, 1); + testThreeChainedStreamNodeGroupsConnectSequentially(true, true, 1, 3); + } + + private void testThreeChainedStreamNodeGroupsConnectSequentially( + boolean isForward1, boolean isForward2, int numOfGroups, Integer... groupSizes) + throws Exception { + List<StreamNode> topologicallySortedStreamNodes = createStreamNodes(3); + Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = new HashMap<>(); + + if (isForward1) { + forwardProducersByConsumerNodeId + .computeIfAbsent(topologicallySortedStreamNodes.get(1), k -> new HashSet<>()) + .add(topologicallySortedStreamNodes.get(0)); + } + + if (isForward2) { + forwardProducersByConsumerNodeId + .computeIfAbsent(topologicallySortedStreamNodes.get(2), k -> new HashSet<>()) + .add(topologicallySortedStreamNodes.get(1)); + } + + Set<ForwardGroup<?>> groups = + computeForwardGroups( + topologicallySortedStreamNodes, forwardProducersByConsumerNodeId); checkGroupSize(groups, numOfGroups, groupSizes); } @@ -131,11 +186,31 @@ class ForwardGroupComputeUtilTest { v4.connectNewDataSetAsInput( v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING); - Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3, v4); + Set<ForwardGroup<?>> groups = computeForwardGroups(v1, v2, v3, v4); checkGroupSize(groups, 1, 3); } + @Test + void testTwoInputsMergesIntoOneForStreamNodeForwardGroup() throws Exception { + List<StreamNode> topologicallySortedStreamNodes = createStreamNodes(4); + Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = new HashMap<>(); + + forwardProducersByConsumerNodeId + .computeIfAbsent(topologicallySortedStreamNodes.get(2), k -> new HashSet<>()) + .add(topologicallySortedStreamNodes.get(0)); + + forwardProducersByConsumerNodeId + .computeIfAbsent(topologicallySortedStreamNodes.get(2), k -> new HashSet<>()) + .add(topologicallySortedStreamNodes.get(1)); + + Set<ForwardGroup<?>> groups = + computeForwardGroups( + topologicallySortedStreamNodes, forwardProducersByConsumerNodeId); + + checkGroupSize(groups, 2, 3, 1); + } + /** * Tests that the computation of the job graph where one upstream vertex connect with two * downstream vertices works correctly. @@ -166,12 +241,28 @@ class ForwardGroupComputeUtilTest { v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true); v2.getProducedDataSets().get(1).getConsumers().get(0).setForward(true); - Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3, v4); + Set<ForwardGroup<?>> groups = computeForwardGroups(v1, v2, v3, v4); checkGroupSize(groups, 1, 3); } - private static Set<ForwardGroup> computeForwardGroups(JobVertex... vertices) { + @Test + void testOneInputSplitsIntoTwoForStreamNodeForwardGroup() throws Exception { + List<StreamNode> topologicallySortedStreamNodes = createStreamNodes(4); + Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId = new HashMap<>(); + forwardProducersByConsumerNodeId + .computeIfAbsent(topologicallySortedStreamNodes.get(3), k -> new HashSet<>()) + .add(topologicallySortedStreamNodes.get(1)); + forwardProducersByConsumerNodeId + .computeIfAbsent(topologicallySortedStreamNodes.get(2), k -> new HashSet<>()) + .add(topologicallySortedStreamNodes.get(1)); + Set<ForwardGroup<?>> groups = + computeForwardGroups( + topologicallySortedStreamNodes, forwardProducersByConsumerNodeId); + checkGroupSize(groups, 2, 3, 1); + } + + private static Set<ForwardGroup<?>> computeForwardGroups(JobVertex... vertices) { Arrays.asList(vertices).forEach(vertex -> vertex.setInvokableClass(NoOpInvokable.class)); return new HashSet<>( ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism( @@ -180,9 +271,61 @@ class ForwardGroupComputeUtilTest { } private static void checkGroupSize( - Set<ForwardGroup> groups, int numOfGroups, Integer... sizes) { + Set<ForwardGroup<?>> groups, int numOfGroups, Integer... sizes) { assertThat(groups.size()).isEqualTo(numOfGroups); - assertThat(groups.stream().map(ForwardGroup::size).collect(Collectors.toList())) + assertThat( + groups.stream() + .map( + group -> { + if (group instanceof JobVertexForwardGroup) { + return ((JobVertexForwardGroup) group).size(); + } else { + return ((StreamNodeForwardGroup) group).size(); + } + }) + .collect(Collectors.toList())) .contains(sizes); } + + private static StreamNode createStreamNode(int id) { + return new StreamNode(id, null, null, (StreamOperator<?>) null, null, null); + } + + private static List<StreamNode> createStreamNodes(int count) { + List<StreamNode> streamNodes = new ArrayList<>(); + for (int i = 1; i <= count; i++) { + streamNodes.add(new StreamNode(i, null, null, (StreamOperator<?>) null, null, null)); + } + return streamNodes; + } + + private static Set<ForwardGroup<?>> computeForwardGroups( + List<StreamNode> topologicallySortedStreamNodes, + Map<StreamNode, Set<StreamNode>> forwardProducersByConsumerNodeId) { + return new HashSet<>( + computeStreamNodeForwardGroupAndCheckParallelism( + topologicallySortedStreamNodes, + id -> + forwardProducersByConsumerNodeId.getOrDefault( + id, Collections.emptySet())) + .values()); + } + + public static Map<Integer, StreamNodeForwardGroup> + computeStreamNodeForwardGroupAndCheckParallelism( + final Iterable<StreamNode> topologicallySortedStreamNodes, + final Function<StreamNode, Set<StreamNode>> forwardProducersRetriever) { + final Map<Integer, StreamNodeForwardGroup> forwardGroupsByStartNodeId = + computeStreamNodeForwardGroup( + topologicallySortedStreamNodes, forwardProducersRetriever); + topologicallySortedStreamNodes.forEach( + startNode -> { + StreamNodeForwardGroup forwardGroup = + forwardGroupsByStartNodeId.get(startNode.getId()); + if (forwardGroup != null && forwardGroup.isParallelismDecided()) { + checkState(startNode.getParallelism() == forwardGroup.getParallelism()); + } + }); + return forwardGroupsByStartNodeId; + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroupTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroupTest.java new file mode 100644 index 00000000000..cc68923cf76 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroupTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.jobgraph.forwardgroup; + +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.operators.StreamOperator; + +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Unit tests for {@link StreamNodeForwardGroup}. */ +class StreamNodeForwardGroupTest { + @Test + void testStreamNodeForwardGroup() { + Set<StreamNode> streamNodes = new HashSet<>(); + + streamNodes.add(createStreamNode(0, 1, 1)); + streamNodes.add(createStreamNode(1, 1, 1)); + + StreamNodeForwardGroup forwardGroup = new StreamNodeForwardGroup(streamNodes); + assertThat(forwardGroup.getParallelism()).isEqualTo(1); + assertThat(forwardGroup.getMaxParallelism()).isEqualTo(1); + assertThat(forwardGroup.size()).isEqualTo(2); + + streamNodes.add(createStreamNode(3, 1, 1)); + + StreamNodeForwardGroup forwardGroup2 = new StreamNodeForwardGroup(streamNodes); + assertThat(forwardGroup2.size()).isEqualTo(3); + } + + @Test + void testMergeForwardGroup() { + Map<Integer, StreamNode> streamNodeRetriever = new HashMap<>(); + StreamNodeForwardGroup forwardGroup = + createForwardGroupAndUpdateStreamNodeRetriever( + createStreamNode(0, -1, -1), streamNodeRetriever); + + StreamNodeForwardGroup forwardGroupWithUnDecidedParallelism = + createForwardGroupAndUpdateStreamNodeRetriever( + createStreamNode(1, -1, -1), streamNodeRetriever); + + forwardGroup.mergeForwardGroup(forwardGroupWithUnDecidedParallelism); + assertThat(forwardGroup.isParallelismDecided()).isFalse(); + assertThat(forwardGroup.isMaxParallelismDecided()).isFalse(); + + StreamNodeForwardGroup forwardGroupWithDecidedParallelism = + createForwardGroupAndUpdateStreamNodeRetriever( + createStreamNode(2, 2, 4), streamNodeRetriever); + forwardGroup.mergeForwardGroup(forwardGroupWithDecidedParallelism); + assertThat(forwardGroup.getParallelism()).isEqualTo(2); + assertThat(forwardGroup.getMaxParallelism()).isEqualTo(4); + + StreamNodeForwardGroup forwardGroupWithLargerMaxParallelism = + createForwardGroupAndUpdateStreamNodeRetriever( + createStreamNode(3, 2, 5), streamNodeRetriever); + // The target max parallelism is larger than source. + assertThat(forwardGroup.mergeForwardGroup(forwardGroupWithLargerMaxParallelism)).isTrue(); + assertThat(forwardGroup.getMaxParallelism()).isEqualTo(4); + + StreamNodeForwardGroup forwardGroupWithSmallerMaxParallelism = + createForwardGroupAndUpdateStreamNodeRetriever( + createStreamNode(4, 2, 3), streamNodeRetriever); + assertThat(forwardGroup.mergeForwardGroup(forwardGroupWithSmallerMaxParallelism)).isTrue(); + assertThat(forwardGroup.getMaxParallelism()).isEqualTo(3); + + StreamNodeForwardGroup forwardGroupWithMaxParallelismSmallerThanSourceParallelism = + createForwardGroupAndUpdateStreamNodeRetriever( + createStreamNode(5, -1, 1), streamNodeRetriever); + assertThat( + forwardGroup.mergeForwardGroup( + forwardGroupWithMaxParallelismSmallerThanSourceParallelism)) + .isFalse(); + + StreamNodeForwardGroup forwardGroupWithDifferentParallelism = + createForwardGroupAndUpdateStreamNodeRetriever( + createStreamNode(6, 1, 3), streamNodeRetriever); + assertThat(forwardGroup.mergeForwardGroup(forwardGroupWithDifferentParallelism)).isFalse(); + + StreamNodeForwardGroup forwardGroupWithUndefinedParallelism = + createForwardGroupAndUpdateStreamNodeRetriever( + createStreamNode(7, -1, 2), streamNodeRetriever); + assertThat(forwardGroup.mergeForwardGroup(forwardGroupWithUndefinedParallelism)).isTrue(); + assertThat(forwardGroup.size()).isEqualTo(6); + assertThat(forwardGroup.getParallelism()).isEqualTo(2); + assertThat(forwardGroup.getMaxParallelism()).isEqualTo(2); + + for (Integer nodeId : forwardGroup.getVertexIds()) { + StreamNode node = streamNodeRetriever.get(nodeId); + assertThat(node.getParallelism()).isEqualTo(forwardGroup.getParallelism()); + assertThat(node.getMaxParallelism()).isEqualTo(forwardGroup.getMaxParallelism()); + } + } + + private static StreamNode createStreamNode(int id, int parallelism, int maxParallelism) { + StreamNode streamNode = + new StreamNode(id, null, null, (StreamOperator<?>) null, null, null); + if (parallelism > 0) { + streamNode.setParallelism(parallelism); + } + if (maxParallelism > 0) { + streamNode.setMaxParallelism(maxParallelism); + } + return streamNode; + } + + private StreamNodeForwardGroup createForwardGroupAndUpdateStreamNodeRetriever( + StreamNode streamNode, Map<Integer, StreamNode> streamNodeRetriever) { + streamNodeRetriever.put(streamNode.getId(), streamNode); + return new StreamNodeForwardGroup(Collections.singleton(streamNode)); + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java new file mode 100644 index 00000000000..468f6883b73 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.api.transformations.PartitionTransformation; +import org.apache.flink.streaming.api.transformations.StreamExchangeMode; +import org.apache.flink.streaming.runtime.partitioner.ForwardForUnspecifiedPartitioner; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; + +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +/** Unit tests for {@link DefaultStreamGraphContext}. */ +class DefaultStreamGraphContextTest { + @Test + void testModifyStreamEdge() { + StreamGraph streamGraph = createStreamGraphForModifyStreamEdgeTest(); + Map<Integer, StreamNodeForwardGroup> forwardGroupsByEndpointNodeIdCache = new HashMap<>(); + Map<Integer, Integer> frozenNodeToStartNodeMap = new HashMap<>(); + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches = + new HashMap<>(); + StreamGraphContext streamGraphContext = + new DefaultStreamGraphContext( + streamGraph, + forwardGroupsByEndpointNodeIdCache, + frozenNodeToStartNodeMap, + opIntermediateOutputsCaches); + + StreamNode sourceNode = + streamGraph.getStreamNode(streamGraph.getSourceIDs().iterator().next()); + StreamNode targetNode = + streamGraph.getStreamNode(sourceNode.getOutEdges().get(0).getTargetId()); + targetNode.setParallelism(1); + StreamEdge targetEdge = sourceNode.getOutEdges().get(0); + + StreamNodeForwardGroup forwardGroup1 = + new StreamNodeForwardGroup(Collections.singleton(sourceNode)); + StreamNodeForwardGroup forwardGroup2 = + new StreamNodeForwardGroup(Collections.singleton(targetNode)); + forwardGroupsByEndpointNodeIdCache.put(sourceNode.getId(), forwardGroup1); + forwardGroupsByEndpointNodeIdCache.put(targetNode.getId(), forwardGroup2); + + StreamEdgeUpdateRequestInfo streamEdgeUpdateRequestInfo = + new StreamEdgeUpdateRequestInfo( + targetEdge.getEdgeId(), + targetEdge.getSourceId(), + targetEdge.getTargetId()) + .outputPartitioner(new ForwardForUnspecifiedPartitioner<>()); + + assertThat( + streamGraphContext.modifyStreamEdge( + Collections.singletonList(streamEdgeUpdateRequestInfo))) + .isTrue(); + assertThat(targetEdge.getPartitioner() instanceof ForwardPartitioner).isTrue(); + + // We cannot modify when partitioner is forward partitioner. + assertThat( + streamGraphContext.modifyStreamEdge( + Collections.singletonList(streamEdgeUpdateRequestInfo))) + .isEqualTo(false); + + // We cannot modify when target node job vertex is created. + frozenNodeToStartNodeMap.put(targetEdge.getTargetId(), targetEdge.getTargetId()); + assertThat( + streamGraphContext.modifyStreamEdge( + Collections.singletonList(streamEdgeUpdateRequestInfo))) + .isEqualTo(false); + + NonChainedOutput nonChainedOutput = + new NonChainedOutput( + targetEdge.supportsUnalignedCheckpoints(), + targetEdge.getSourceId(), + targetNode.getParallelism(), + targetNode.getMaxParallelism(), + targetEdge.getBufferTimeout(), + false, + new IntermediateDataSetID(), + targetEdge.getOutputTag(), + targetEdge.getPartitioner(), + ResultPartitionType.BLOCKING); + opIntermediateOutputsCaches.put( + targetEdge.getSourceId(), + Map.of( + targetEdge, + nonChainedOutput, + targetNode.getOutEdges().get(0), + nonChainedOutput)); + + // We cannot modify when target edge is consumed by multi edges. + frozenNodeToStartNodeMap.put(targetEdge.getTargetId(), targetEdge.getTargetId()); + assertThat( + streamGraphContext.modifyStreamEdge( + Collections.singletonList(streamEdgeUpdateRequestInfo))) + .isEqualTo(false); + } + + @Test + void testModifyToForwardPartitionerButResultIsRescale() { + StreamGraph streamGraph = createStreamGraphForModifyStreamEdgeTest(); + + Map<Integer, StreamNodeForwardGroup> forwardGroupsByEndpointNodeIdCache = new HashMap<>(); + Map<Integer, Integer> frozenNodeToStartNodeMap = new HashMap<>(); + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches = + new HashMap<>(); + + StreamGraphContext streamGraphContext = + new DefaultStreamGraphContext( + streamGraph, + forwardGroupsByEndpointNodeIdCache, + frozenNodeToStartNodeMap, + opIntermediateOutputsCaches); + + StreamNode sourceNode = + streamGraph.getStreamNode(streamGraph.getSourceIDs().iterator().next()); + StreamNode targetNode = + streamGraph.getStreamNode(sourceNode.getOutEdges().get(0).getTargetId()); + StreamEdge targetEdge = sourceNode.getOutEdges().get(0); + + StreamNodeForwardGroup forwardGroup1 = + new StreamNodeForwardGroup(Collections.singleton(sourceNode)); + StreamNodeForwardGroup forwardGroup2 = + new StreamNodeForwardGroup(Collections.singleton(targetNode)); + forwardGroupsByEndpointNodeIdCache.put(sourceNode.getId(), forwardGroup1); + forwardGroupsByEndpointNodeIdCache.put(targetNode.getId(), forwardGroup2); + + StreamEdgeUpdateRequestInfo streamEdgeUpdateRequestInfo = + new StreamEdgeUpdateRequestInfo( + targetEdge.getEdgeId(), + targetEdge.getSourceId(), + targetEdge.getTargetId()) + .outputPartitioner(new ForwardForUnspecifiedPartitioner<>()); + + // Modify rescale partitioner to forward partitioner. + + // 1. If the source and target are non-chainable. + assertThat( + streamGraphContext.modifyStreamEdge( + Collections.singletonList(streamEdgeUpdateRequestInfo))) + .isTrue(); + assertThat(targetEdge.getPartitioner() instanceof RescalePartitioner).isTrue(); + + // 2. If the forward group cannot be merged. + targetNode.setParallelism(1); + assertThat( + streamGraphContext.modifyStreamEdge( + Collections.singletonList(streamEdgeUpdateRequestInfo))) + .isTrue(); + assertThat(targetEdge.getPartitioner() instanceof RescalePartitioner).isTrue(); + + // 3. If the upstream job vertex is created. + frozenNodeToStartNodeMap.put( + streamGraph.getSourceIDs().iterator().next(), + streamGraph.getSourceIDs().iterator().next()); + assertThat( + streamGraphContext.modifyStreamEdge( + Collections.singletonList(streamEdgeUpdateRequestInfo))) + .isTrue(); + assertThat(targetEdge.getPartitioner() instanceof RescalePartitioner).isTrue(); + } + + private StreamGraph createStreamGraphForModifyStreamEdgeTest() { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + // fromElements(1) -> Map(2) -> Print + DataStream<Integer> sourceDataStream = env.fromData(1, 2, 3).setParallelism(1); + + DataStream<Integer> partitionAfterSourceDataStream = + new DataStream<>( + env, + new PartitionTransformation<>( + sourceDataStream.getTransformation(), + new RescalePartitioner<>(), + StreamExchangeMode.PIPELINED)); + + DataStream<Integer> mapDataStream = + partitionAfterSourceDataStream.map(value -> value).setParallelism(2); + + DataStream<Integer> partitionAfterMapDataStream = + new DataStream<>( + env, + new PartitionTransformation<>( + mapDataStream.getTransformation(), + new RescalePartitioner<>(), + StreamExchangeMode.PIPELINED)); + + partitionAfterMapDataStream.print(); + + return env.getStreamGraph(); + } +}
