noorall commented on code in PR #25414:
URL: https://github.com/apache/flink/pull/25414#discussion_r1856101261


##########
flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManagerTest.java:
##########
@@ -0,0 +1,354 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.graph;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.operators.ResourceSpec;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.connector.source.lib.NumberSequenceSource;
+import org.apache.flink.api.dag.Transformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.TaskManagerOptions;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.v2.DiscardingSink;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
+import 
org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+
+import org.apache.flink.shaded.guava32.com.google.common.collect.Iterables;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link AdaptiveGraphManager}. */
+public class AdaptiveGraphManagerTest extends JobGraphGeneratorTestBase {
+    @Override
+    JobGraph createJobGraph(StreamGraph streamGraph) {
+        return generateJobGraphInLazilyMode(streamGraph);
+    }
+
+    @Test
+    @Override
+    void testManagedMemoryFractionForUnknownResourceSpec() throws Exception {
+        final ResourceSpec resource = ResourceSpec.UNKNOWN;
+        final List<ResourceSpec> resourceSpecs =
+                Arrays.asList(resource, resource, resource, resource);
+
+        final Configuration taskManagerConfig =
+                new Configuration() {
+                    {
+                        set(
+                                
TaskManagerOptions.MANAGED_MEMORY_CONSUMER_WEIGHTS,
+                                new HashMap<String, String>() {
+                                    {
+                                        put(
+                                                TaskManagerOptions
+                                                        
.MANAGED_MEMORY_CONSUMER_NAME_OPERATOR,
+                                                "6");
+                                        put(
+                                                TaskManagerOptions
+                                                        
.MANAGED_MEMORY_CONSUMER_NAME_PYTHON,
+                                                "4");
+                                    }
+                                });
+                    }
+                };
+
+        final List<Map<ManagedMemoryUseCase, Integer>> 
operatorScopeManagedMemoryUseCaseWeights =
+                new ArrayList<>();
+        final List<Set<ManagedMemoryUseCase>> slotScopeManagedMemoryUseCases = 
new ArrayList<>();
+
+        // source: batch
+        operatorScopeManagedMemoryUseCaseWeights.add(
+                Collections.singletonMap(ManagedMemoryUseCase.OPERATOR, 1));
+        slotScopeManagedMemoryUseCases.add(Collections.emptySet());
+
+        // map1: batch, python
+        operatorScopeManagedMemoryUseCaseWeights.add(
+                Collections.singletonMap(ManagedMemoryUseCase.OPERATOR, 1));
+        
slotScopeManagedMemoryUseCases.add(Collections.singleton(ManagedMemoryUseCase.PYTHON));
+
+        // map3: python
+        operatorScopeManagedMemoryUseCaseWeights.add(Collections.emptyMap());
+        
slotScopeManagedMemoryUseCases.add(Collections.singleton(ManagedMemoryUseCase.PYTHON));
+
+        // map3: batch
+        operatorScopeManagedMemoryUseCaseWeights.add(
+                Collections.singletonMap(ManagedMemoryUseCase.OPERATOR, 1));
+        slotScopeManagedMemoryUseCases.add(Collections.emptySet());
+
+        // slotSharingGroup1 contains batch and python use cases: 
v1(source[batch]) -> map1[batch,
+        // python]), v2(map2[python])
+        // slotSharingGroup2 contains batch use case only: v3(map3[batch])
+        final JobGraph jobGraph =
+                createJobGraph(
+                        createStreamGraphForManagedMemoryFractionTest(
+                                resourceSpecs,
+                                operatorScopeManagedMemoryUseCaseWeights,
+                                slotScopeManagedMemoryUseCases));
+        final JobVertex vertex1 = 
jobGraph.getVerticesSortedTopologicallyFromSources().get(0);
+        final JobVertex vertex2 = 
jobGraph.getVerticesSortedTopologicallyFromSources().get(1);
+        final JobVertex vertex3 = 
jobGraph.getVerticesSortedTopologicallyFromSources().get(2);
+
+        final StreamConfig sourceConfig = new 
StreamConfig(vertex1.getConfiguration());
+        verifyFractions(sourceConfig, 0.6 / 2, 0.0, 0.0, taskManagerConfig);
+
+        final StreamConfig map1Config =
+                Iterables.getOnlyElement(
+                        sourceConfig
+                                .getTransitiveChainedTaskConfigs(
+                                        
JobGraphGeneratorTestBase.class.getClassLoader())
+                                .values());
+        verifyFractions(map1Config, 0.6 / 2, 0.4, 0.0, taskManagerConfig);
+
+        // Since the job graph is generated in a progressive way and cannot 
obtain global
+        // information, for map2, it only contains information about Python 
Fraction, so the
+        // value should be 1.0.
+        final StreamConfig map2Config = new 
StreamConfig(vertex2.getConfiguration());
+        verifyFractions(map2Config, 0.0, 1.0, 0.0, taskManagerConfig);
+
+        final StreamConfig map3Config = new 
StreamConfig(vertex3.getConfiguration());
+        verifyFractions(map3Config, 1.0, 0.0, 0.0, taskManagerConfig);
+    }
+
+    @Test
+    void testCreateJobVertexLazily() {
+        final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(1);
+
+        DataStream<Tuple2<String, String>> input =
+                env.fromData("a", "b", "c", "d", "e", "f")
+                        .map(
+                                new MapFunction<String, Tuple2<String, 
String>>() {
+
+                                    @Override
+                                    public Tuple2<String, String> map(String 
value) {
+                                        return new Tuple2<>(value, value);
+                                    }
+                                });
+
+        DataStream<Tuple2<String, String>> result =
+                input.keyBy(x -> x.f0)
+                        .map(
+                                new MapFunction<Tuple2<String, String>, 
Tuple2<String, String>>() {
+
+                                    @Override
+                                    public Tuple2<String, String> map(
+                                            Tuple2<String, String> value) {
+                                        return value;
+                                    }
+                                });
+
+        result.sinkTo(new DiscardingSink<>());
+        StreamGraph streamGraph = env.getStreamGraph();
+        streamGraph.setDynamic(true);
+
+        AdaptiveGraphManager adaptiveGraphManager =
+                new AdaptiveGraphManager(
+                        Thread.currentThread().getContextClassLoader(), 
streamGraph, Runnable::run);
+        JobGraph jobGraph = adaptiveGraphManager.getJobGraph();
+        List<JobVertex> jobVertices = 
jobGraph.getVerticesSortedTopologicallyFromSources();
+        assertThat(jobVertices.size()).isEqualTo(1);
+        while (!jobVertices.isEmpty()) {
+            List<JobVertex> newJobVertices = new ArrayList<>();
+            for (JobVertex jobVertex : jobVertices) {
+                
newJobVertices.addAll(adaptiveGraphManager.onJobVertexFinished(jobVertex.getID()));
+            }
+            jobVertices = newJobVertices;
+        }
+        jobVertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+        List<StreamNode> streamNodes =
+                streamGraph.getStreamNodes().stream()
+                        .sorted(Comparator.comparingInt(StreamNode::getId))
+                        .collect(Collectors.toList());
+        assertThat(jobVertices.size()).isEqualTo(2);
+        
assertThat(adaptiveGraphManager.getPendingOperatorsCount()).isEqualTo(0);
+        
assertThat(adaptiveGraphManager.getStreamNodeIdsByJobVertexId(jobVertices.get(0).getID()))
+                .isEqualTo(List.of(streamNodes.get(1).getId(), 
streamNodes.get(0).getId()));
+        
assertThat(adaptiveGraphManager.getStreamNodeIdsByJobVertexId(jobVertices.get(1).getID()))
+                .isEqualTo(List.of(streamNodes.get(3).getId(), 
streamNodes.get(2).getId()));
+        assertThat(
+                        adaptiveGraphManager.getProducerStreamNodeId(
+                                
jobVertices.get(0).getProducedDataSets().get(0).getId()))
+                .isEqualTo(streamNodes.get(1).getId());
+        List<StreamNode> chainedStreamNodes =
+                adaptiveGraphManager
+                        
.getStreamNodeForwardGroupByVertexId(jobVertices.get(0).getID())
+                        .getChainedStreamNodeGroups()
+                        .iterator()
+                        .next();
+        assertThat(chainedStreamNodes.size()).isEqualTo(2);
+        
assertThat(chainedStreamNodes.get(0).getId()).isEqualTo(streamNodes.get(0).getId());
+        
assertThat(chainedStreamNodes.get(1).getId()).isEqualTo(streamNodes.get(1).getId());
+    }
+
+    @Test
+    void testTheCorrectnessOfJobGraph() {
+        final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(1);
+
+        DataStream<Tuple2<String, String>> input =
+                env.fromData("a", "b", "c", "d", "e", "f")
+                        .map(
+                                new MapFunction<String, Tuple2<String, 
String>>() {
+
+                                    @Override
+                                    public Tuple2<String, String> map(String 
value) {
+                                        return new Tuple2<>(value, value);
+                                    }
+                                });
+
+        DataStream<Tuple2<String, String>> result =
+                input.keyBy(x -> x.f0)
+                        .map(
+                                new MapFunction<Tuple2<String, String>, 
Tuple2<String, String>>() {
+
+                                    @Override
+                                    public Tuple2<String, String> map(
+                                            Tuple2<String, String> value) {
+                                        return value;
+                                    }
+                                });
+
+        result.sinkTo(new DiscardingSink<>());
+        StreamGraph streamGraph = env.getStreamGraph();
+        JobGraph jobGraph1 = generateJobGraphInLazilyMode(streamGraph);
+        JobGraph jobGraph2 = 
StreamingJobGraphGenerator.createJobGraph(streamGraph);
+        assertThat(isJobGraphEquivalent(jobGraph1, jobGraph2)).isEqualTo(true);
+    }
+
+    @Test
+    void testSourceChain() {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setMaxParallelism(100);
+        env.setParallelism(100);
+        MultipleInputTransformation<Long> transform =
+                new MultipleInputTransformation<>(
+                        "mit", new UnusedOperatorFactory(), Types.LONG, -1);
+
+        Transformation<Long> input1 =
+                env.fromSource(
+                                new NumberSequenceSource(1, 2),
+                                WatermarkStrategy.noWatermarks(),
+                                "input1")
+                        .setParallelism(100)
+                        .getTransformation();
+        Transformation<Long> input2 =
+                env.fromSource(
+                                new NumberSequenceSource(1, 2),
+                                WatermarkStrategy.noWatermarks(),
+                                "input2")
+                        .setParallelism(1)
+                        .getTransformation();
+        Transformation<Long> input3 =
+                env.fromSource(
+                                new NumberSequenceSource(1, 2),
+                                WatermarkStrategy.noWatermarks(),
+                                "input3")
+                        .setParallelism(1)
+                        .getTransformation();
+        transform.addInput(input1);
+        transform.addInput(input2);
+        transform.addInput(input3);
+        transform.setChainingStrategy(ChainingStrategy.HEAD_WITH_SOURCES);
+        DataStream<Long> dataStream = new DataStream<>(env, transform);
+        // do not chain with sink operator.
+        dataStream.rebalance().sinkTo(new DiscardingSink<>()).name("sink");
+        env.addOperator(transform);
+        StreamGraph streamGraph = env.getStreamGraph();
+        streamGraph.setDynamic(true);
+        JobGraph jobGraph = createJobGraph(streamGraph);
+        
assertThat(jobGraph.getVerticesSortedTopologicallyFromSources().size()).isEqualTo(4);

Review Comment:
   > Only one source can be chained in this case? input3 will not be included 
in the multi-input chain?
   
   Yes, the parallelism of input2 and input3 is 1, and only the parallelism of 
input1 meets the conditions of sourceChain.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to