This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit c78097a56c56de695913a6d4e32d2c3edd040f70
Author: noorall <863485...@qq.com>
AuthorDate: Thu Dec 26 11:09:44 2024 +0800

    [FLINK-36629][table-planner] Modify StreamGraphContext to provide more 
necessary interfaces
---
 .../api/graph/DefaultStreamGraphContext.java       | 18 ++++++
 .../streaming/api/graph/StreamGraphContext.java    | 12 ++++
 .../api/graph/util/ImmutableStreamEdge.java        | 15 +++++
 .../graph/util/StreamEdgeUpdateRequestInfo.java    | 16 ++++++
 .../adaptivebatch/StreamGraphOptimizerTest.java    |  7 +++
 .../api/graph/DefaultStreamGraphContextTest.java   | 67 ++++++++++++++++++++++
 6 files changed, 135 insertions(+)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
index b7017b36e77..ebca5afd9ba 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java
@@ -155,6 +155,10 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
             if (requestInfo.getTypeNumber() != 0) {
                 targetEdge.setTypeNumber(requestInfo.getTypeNumber());
             }
+            if (requestInfo.getIntraInputKeyCorrelated() != null) {
+                modifyIntraInputKeyCorrelation(
+                        targetEdge, requestInfo.getIntraInputKeyCorrelated());
+            }
         }
 
         // Notify the listener that the StreamGraph has been updated.
@@ -205,6 +209,12 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
         return consumerEdgeIdToIntermediateDataSetMap.get(edgeId).getId();
     }
 
+    @Override
+    public StreamPartitioner<?> getOutputPartitioner(
+            String edgeId, Integer sourceId, Integer targetId) {
+        return checkNotNull(getStreamEdge(sourceId, targetId, 
edgeId)).getPartitioner();
+    }
+
     private boolean 
validateStreamEdgeUpdateRequest(StreamEdgeUpdateRequestInfo requestInfo) {
         Integer sourceNodeId = requestInfo.getSourceId();
         Integer targetNodeId = requestInfo.getTargetId();
@@ -310,6 +320,14 @@ public class DefaultStreamGraphContext implements 
StreamGraphContext {
                 targetEdge.getPartitioner());
     }
 
+    private void modifyIntraInputKeyCorrelation(
+            StreamEdge targetEdge, boolean existIntraInputKeyCorrelation) {
+        if (targetEdge.isIntraInputKeyCorrelated() == 
existIntraInputKeyCorrelation) {
+            return;
+        }
+        targetEdge.setIntraInputKeyCorrelated(existIntraInputKeyCorrelation);
+    }
+
     private void tryConvertForwardPartitionerAndMergeForwardGroup(StreamEdge 
targetEdge) {
         checkState(targetEdge.getPartitioner() instanceof ForwardPartitioner);
         Integer sourceNodeId = targetEdge.getSourceId();
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java
index 2368c7fa329..e697b98083a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphContext.java
@@ -25,6 +25,7 @@ import 
org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
 import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
 import org.apache.flink.streaming.api.graph.util.StreamNodeUpdateRequestInfo;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 
 import javax.annotation.Nullable;
 
@@ -96,4 +97,15 @@ public interface StreamGraphContext {
         /** This method is called whenever the StreamGraph is updated. */
         void onStreamGraphUpdated();
     }
+
+    /**
+     * Gets the output partitioner of the specified edge.
+     *
+     * @param edgeId id of the edge
+     * @param sourceId source node id of the edge
+     * @param targetId target node id of the edge
+     * @return the output partitioner
+     */
+    @Nullable
+    StreamPartitioner<?> getOutputPartitioner(String edgeId, Integer sourceId, 
Integer targetId);
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamEdge.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamEdge.java
index f0ab4d42e12..d62cde9f63c 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamEdge.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamEdge.java
@@ -20,6 +20,9 @@ package org.apache.flink.streaming.api.graph.util;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.streaming.api.graph.StreamEdge;
+import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
+import 
org.apache.flink.streaming.runtime.partitioner.ForwardForConsecutiveHashPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
 
 /** Helper class that provides read-only StreamEdge. */
 @Internal
@@ -45,4 +48,16 @@ public class ImmutableStreamEdge {
     public String getEdgeId() {
         return streamEdge.getEdgeId();
     }
+
+    public boolean isForwardForConsecutiveHashEdge() {
+        return streamEdge.getPartitioner() instanceof 
ForwardForConsecutiveHashPartitioner;
+    }
+
+    public boolean isExactForwardEdge() {
+        return 
streamEdge.getPartitioner().getClass().equals(ForwardPartitioner.class);
+    }
+
+    public boolean isBroadcastEdge() {
+        return streamEdge.getPartitioner() instanceof BroadcastPartitioner;
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
index 4d0c4e78565..ad97726f59d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/StreamEdgeUpdateRequestInfo.java
@@ -21,6 +21,8 @@ package org.apache.flink.streaming.api.graph.util;
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 
+import javax.annotation.Nullable;
+
 /** Helper class carries the data required to updates a stream edge. */
 @Internal
 public class StreamEdgeUpdateRequestInfo {
@@ -35,6 +37,9 @@ public class StreamEdgeUpdateRequestInfo {
     // typeNumber.
     private int typeNumber;
 
+    // Null means no modifications will be applied to it
+    @Nullable private Boolean intraInputKeyCorrelated;
+
     public StreamEdgeUpdateRequestInfo(String edgeId, Integer sourceId, 
Integer targetId) {
         this.edgeId = edgeId;
         this.sourceId = sourceId;
@@ -52,6 +57,12 @@ public class StreamEdgeUpdateRequestInfo {
         return this;
     }
 
+    public StreamEdgeUpdateRequestInfo withIntraInputKeyCorrelated(
+            boolean intraInputKeyCorrelated) {
+        this.intraInputKeyCorrelated = intraInputKeyCorrelated;
+        return this;
+    }
+
     public String getEdgeId() {
         return edgeId;
     }
@@ -71,4 +82,9 @@ public class StreamEdgeUpdateRequestInfo {
     public int getTypeNumber() {
         return typeNumber;
     }
+
+    @Nullable
+    public Boolean getIntraInputKeyCorrelated() {
+        return intraInputKeyCorrelated;
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java
index e2316a09e62..ef9af85537b 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java
@@ -26,6 +26,7 @@ import 
org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
 import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
 import org.apache.flink.streaming.api.graph.util.StreamNodeUpdateRequestInfo;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -96,6 +97,12 @@ class StreamGraphOptimizerTest {
                     public IntermediateDataSetID 
getConsumedIntermediateDataSetId(String edgeId) {
                         return null;
                     }
+
+                    @Override
+                    public @Nullable StreamPartitioner<?> getOutputPartitioner(
+                            String edgeId, Integer sourceId, Integer targetId) 
{
+                        return null;
+                    }
                 };
 
         optimizer.onOperatorsFinished(operatorsFinished, context);
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
index dcb1566ddd2..d3ef738e961 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContextTest.java
@@ -18,6 +18,8 @@
 
 package org.apache.flink.streaming.api.graph;
 
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup;
@@ -26,12 +28,14 @@ import 
org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
 import org.apache.flink.streaming.api.transformations.PartitionTransformation;
 import org.apache.flink.streaming.api.transformations.StreamExchangeMode;
+import 
org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindows;
 import 
org.apache.flink.streaming.runtime.partitioner.ForwardForUnspecifiedPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
 
 import org.junit.jupiter.api.Test;
 
+import java.time.Duration;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -193,6 +197,46 @@ class DefaultStreamGraphContextTest {
         assertThat(targetEdge.getPartitioner() instanceof 
RescalePartitioner).isTrue();
     }
 
+    @Test
+    void testModifyIntraInputKeyCorrelation() {
+        StreamGraph streamGraph = createStreamGraphWithCorrelatedInputs();
+        StreamGraphContext streamGraphContext =
+                new DefaultStreamGraphContext(
+                        streamGraph,
+                        new HashMap<>(),
+                        new HashMap<>(),
+                        new HashMap<>(),
+                        new HashMap<>(),
+                        new HashSet<>(),
+                        Thread.currentThread().getContextClassLoader());
+        StreamNode sourceNode =
+                
streamGraph.getStreamNode(streamGraph.getSourceIDs().iterator().next());
+        StreamEdge targetEdge = sourceNode.getOutEdges().get(0);
+        assertThat(targetEdge.areInterInputsKeysCorrelated()).isTrue();
+        assertThat(targetEdge.isIntraInputKeyCorrelated()).isTrue();
+        assertThat(
+                        streamGraphContext.modifyStreamEdge(
+                                Collections.singletonList(
+                                        new StreamEdgeUpdateRequestInfo(
+                                                        targetEdge.getEdgeId(),
+                                                        
targetEdge.getSourceId(),
+                                                        
targetEdge.getTargetId())
+                                                
.withIntraInputKeyCorrelated(false))))
+                .isTrue();
+        assertThat(targetEdge.isIntraInputKeyCorrelated()).isFalse();
+
+        assertThat(
+                        streamGraphContext.modifyStreamEdge(
+                                Collections.singletonList(
+                                        new StreamEdgeUpdateRequestInfo(
+                                                        targetEdge.getEdgeId(),
+                                                        
targetEdge.getSourceId(),
+                                                        
targetEdge.getTargetId())
+                                                
.withIntraInputKeyCorrelated(true))))
+                .isTrue();
+        assertThat(targetEdge.isIntraInputKeyCorrelated()).isTrue();
+    }
+
     private StreamGraph createStreamGraphForModifyStreamEdgeTest() {
         StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
         // fromElements(1) -> Map(2) -> Print
@@ -221,4 +265,27 @@ class DefaultStreamGraphContextTest {
 
         return env.getStreamGraph();
     }
+
+    private StreamGraph createStreamGraphWithCorrelatedInputs() {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        // A--
+        //    -(join)->  C
+        // B--
+        DataStream<Tuple2<Integer, String>> streamA =
+                env.fromData(new Tuple2<>(1, "a1"), new Tuple2<>(2, "a2"), new 
Tuple2<>(3, "a3"))
+                        .keyBy(value -> value.f0);
+        DataStream<Tuple2<Integer, String>> streamB =
+                env.fromData(new Tuple2<>(1, "b1"), new Tuple2<>(2, "b2"), new 
Tuple2<>(3, "b3"))
+                        .keyBy(value -> value.f0);
+        DataStream<String> joinedStream =
+                streamA.join(streamB)
+                        .where(v -> v.f0)
+                        .equalTo(v -> v.f0)
+                        
.window(TumblingEventTimeWindows.of(Duration.ofMillis(1)))
+                        .apply(
+                                (first, second) -> first.f1 + second.f1,
+                                BasicTypeInfo.STRING_TYPE_INFO);
+        joinedStream.print();
+        return env.getStreamGraph();
+    }
 }

Reply via email to