This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch release-1.18
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 940b3bbda5b10abe3a41d60467d33fd424c7dae6
Author: Jeyhun Karimov <je.kari...@gmail.com>
AuthorDate: Sun Mar 10 15:56:02 2024 +0100

    [FLINK-32513][core] Add predecessor caching
    
    This closes #24475.
---
 .../org/apache/flink/api/dag/Transformation.java   |  28 +++-
 .../apache/flink/api/dag/TransformationTest.java   |   2 +-
 .../AbstractBroadcastStateTransformation.java      |  12 +-
 .../AbstractMultipleInputTransformation.java       |  12 +-
 .../api/transformations/CacheTransformation.java   |   2 +-
 .../transformations/CoFeedbackTransformation.java  |   2 +-
 .../transformations/FeedbackTransformation.java    |   2 +-
 .../transformations/LegacySinkTransformation.java  |   2 +-
 .../LegacySourceTransformation.java                |   2 +-
 .../transformations/OneInputTransformation.java    |   2 +-
 .../transformations/PartitionTransformation.java   |   2 +-
 .../api/transformations/ReduceTransformation.java  |   2 +-
 .../transformations/SideOutputTransformation.java  |   2 +-
 .../api/transformations/SinkTransformation.java    |   2 +-
 .../api/transformations/SourceTransformation.java  |   2 +-
 .../TimestampsAndWatermarksTransformation.java     |   2 +-
 .../transformations/TwoInputTransformation.java    |  18 ++-
 .../api/transformations/UnionTransformation.java   |  16 +-
 .../api/graph/StreamGraphGeneratorTest.java        |   2 +-
 .../GetTransitivePredecessorsTest.java             | 162 +++++++++++++++++++++
 .../TableOperatorWrapperGeneratorTest.java         |   2 +-
 21 files changed, 237 insertions(+), 41 deletions(-)

diff --git 
a/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java 
b/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java
index a0448697dd1..07c64907c82 100644
--- a/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java
+++ b/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java
@@ -19,6 +19,7 @@
 package org.apache.flink.api.dag;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.InvalidTypesException;
 import org.apache.flink.api.common.operators.ResourceSpec;
@@ -161,6 +162,13 @@ public abstract class Transformation<T> {
     private final Map<ManagedMemoryUseCase, Integer> 
managedMemoryOperatorScopeUseCaseWeights =
             new HashMap<>();
 
+    /**
+     * This map is a cache that stores transitive predecessors and used in 
{@code
+     * getTransitivePredecessors()}.
+     */
+    private final Map<Transformation<T>, List<Transformation<?>>> 
predecessorsCache =
+            new HashMap<>();
+
     /** Slot scope use cases that this transformation needs managed memory 
for. */
     private final Set<ManagedMemoryUseCase> managedMemorySlotScopeUseCases = 
new HashSet<>();
 
@@ -230,6 +238,12 @@ public abstract class Transformation<T> {
         return name;
     }
 
+    /** Returns the predecessorsCache of this {@code Transformation}. */
+    @VisibleForTesting
+    Map<Transformation<T>, List<Transformation<?>>> getPredecessorsCache() {
+        return predecessorsCache;
+    }
+
     /** Changes the description of this {@code Transformation}. */
     public void setDescription(String description) {
         this.description = Preconditions.checkNotNull(description);
@@ -578,7 +592,19 @@ public abstract class Transformation<T> {
      *
      * @return The list of transitive predecessors.
      */
-    public abstract List<Transformation<?>> getTransitivePredecessors();
+    protected abstract List<Transformation<?>> 
getTransitivePredecessorsInternal();
+
+    /**
+     * Returns all transitive predecessor {@code Transformation}s of this 
{@code Transformation}.
+     * This is, for example, used when determining whether a feedback edge of 
an iteration actually
+     * has the iteration head as a predecessor. This method is just a wrapper 
on top of {@code
+     * getTransitivePredecessorsInternal} method with public access. It uses 
caching internally.
+     *
+     * @return The list of transitive predecessors.
+     */
+    public final List<Transformation<?>> getTransitivePredecessors() {
+        return predecessorsCache.computeIfAbsent(this, key -> 
getTransitivePredecessorsInternal());
+    }
 
     /**
      * Returns the {@link Transformation transformations} that are the 
immediate predecessors of the
diff --git 
a/flink-core/src/test/java/org/apache/flink/api/dag/TransformationTest.java 
b/flink-core/src/test/java/org/apache/flink/api/dag/TransformationTest.java
index ced452cd14f..c200fac33fc 100644
--- a/flink-core/src/test/java/org/apache/flink/api/dag/TransformationTest.java
+++ b/flink-core/src/test/java/org/apache/flink/api/dag/TransformationTest.java
@@ -142,7 +142,7 @@ class TransformationTest {
         }
 
         @Override
-        public List<Transformation<?>> getTransitivePredecessors() {
+        protected List<Transformation<?>> getTransitivePredecessorsInternal() {
             return Collections.emptyList();
         }
 
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractBroadcastStateTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractBroadcastStateTransformation.java
index c7869e61389..1ce6c7c9bfd 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractBroadcastStateTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractBroadcastStateTransformation.java
@@ -26,6 +26,8 @@ import 
org.apache.flink.streaming.api.operators.ChainingStrategy;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
@@ -103,12 +105,10 @@ public class AbstractBroadcastStateTransformation<IN1, 
IN2, OUT>
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
-        final List<Transformation<?>> predecessors = new ArrayList<>();
-        predecessors.add(this);
-        predecessors.add(regularInput);
-        predecessors.add(broadcastInput);
-        return predecessors;
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
+        return Stream.of(this, regularInput, broadcastInput)
+                .distinct()
+                .collect(Collectors.toList());
     }
 
     @Override
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractMultipleInputTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractMultipleInputTransformation.java
index 65bbc53188f..a733cebbaaf 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractMultipleInputTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractMultipleInputTransformation.java
@@ -76,10 +76,14 @@ public abstract class 
AbstractMultipleInputTransformation<OUT> extends PhysicalT
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
-        return inputs.stream()
-                .flatMap(input -> input.getTransitivePredecessors().stream())
-                .collect(Collectors.toList());
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
+        List<Transformation<?>> predecessors =
+                getInputs().stream()
+                        .flatMap(input -> 
input.getTransitivePredecessors().stream())
+                        .distinct()
+                        .collect(Collectors.toList());
+        predecessors.add(this);
+        return predecessors;
     }
 
     @Override
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CacheTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CacheTransformation.java
index 4632c9b020d..e8bd22adcc8 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CacheTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CacheTransformation.java
@@ -57,7 +57,7 @@ public class CacheTransformation<T> extends Transformation<T> 
{
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         List<Transformation<?>> result = Lists.newArrayList();
         result.add(this);
         if (isCached) {
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java
index 11f8a912608..4d366ba2a8f 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java
@@ -106,7 +106,7 @@ public class CoFeedbackTransformation<F> extends 
Transformation<F> {
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         return Collections.singletonList(this);
     }
 
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java
index ab8f7424741..ef4ec2766bf 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java
@@ -102,7 +102,7 @@ public class FeedbackTransformation<T> extends 
Transformation<T> {
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         List<Transformation<?>> result = Lists.newArrayList();
         result.add(this);
         result.addAll(input.getTransitivePredecessors());
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java
index a0bd6417881..fcd753c13b4 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java
@@ -136,7 +136,7 @@ public class LegacySinkTransformation<T> extends 
PhysicalTransformation<T> {
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         List<Transformation<?>> result = Lists.newArrayList();
         result.add(this);
         result.addAll(input.getTransitivePredecessors());
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySourceTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySourceTransformation.java
index 4718ff18386..9443a3a14ea 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySourceTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySourceTransformation.java
@@ -92,7 +92,7 @@ public class LegacySourceTransformation<T> extends 
PhysicalTransformation<T>
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         return Collections.singletonList(this);
     }
 
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java
index 0a3957130fc..7222fad97da 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java
@@ -168,7 +168,7 @@ public class OneInputTransformation<IN, OUT> extends 
PhysicalTransformation<OUT>
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         List<Transformation<?>> result = Lists.newArrayList();
         result.add(this);
         result.addAll(input.getTransitivePredecessors());
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java
index 92d3f264731..966e565c78f 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java
@@ -89,7 +89,7 @@ public class PartitionTransformation<T> extends 
Transformation<T> {
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         List<Transformation<?>> result = Lists.newArrayList();
         result.add(this);
         result.addAll(input.getTransitivePredecessors());
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/ReduceTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/ReduceTransformation.java
index da159f3af11..82c4e1c4694 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/ReduceTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/ReduceTransformation.java
@@ -89,7 +89,7 @@ public final class ReduceTransformation<IN, K> extends 
PhysicalTransformation<IN
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         List<Transformation<?>> result = Lists.newArrayList();
         result.add(this);
         result.addAll(input.getTransitivePredecessors());
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java
index 3177bfa0eb1..e054d37be54 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java
@@ -53,7 +53,7 @@ public class SideOutputTransformation<T> extends 
Transformation<T> {
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         List<Transformation<?>> result = Lists.newArrayList();
         result.add(this);
         result.addAll(input.getTransitivePredecessors());
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java
index bdaa8d2dede..a4b4310588e 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java
@@ -73,7 +73,7 @@ public class SinkTransformation<InputT, OutputT> extends 
PhysicalTransformation<
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         final List<Transformation<?>> result = Lists.newArrayList();
         result.add(this);
         result.addAll(input.getTransitivePredecessors());
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformation.java
index 86c07ca5511..0260da43b82 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformation.java
@@ -92,7 +92,7 @@ public class SourceTransformation<OUT, SplitT extends 
SourceSplit, EnumChkT>
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         return Collections.singletonList(this);
     }
 
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TimestampsAndWatermarksTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TimestampsAndWatermarksTransformation.java
index c8bde911a04..26541f74673 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TimestampsAndWatermarksTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TimestampsAndWatermarksTransformation.java
@@ -77,7 +77,7 @@ public class TimestampsAndWatermarksTransformation<IN> 
extends PhysicalTransform
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
         List<Transformation<?>> transformations = Lists.newArrayList();
         transformations.add(this);
         transformations.addAll(input.getTransitivePredecessors());
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java
index 4c576170741..48005bf95c4 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java
@@ -28,10 +28,10 @@ import 
org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 
-import org.apache.flink.shaded.guava31.com.google.common.collect.Lists;
-
 import java.util.ArrayList;
 import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /**
  * This Transformation represents the application of a {@link 
TwoInputStreamOperator} to two input
@@ -217,12 +217,14 @@ public class TwoInputTransformation<IN1, IN2, OUT> 
extends PhysicalTransformatio
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
-        List<Transformation<?>> result = Lists.newArrayList();
-        result.add(this);
-        result.addAll(input1.getTransitivePredecessors());
-        result.addAll(input2.getTransitivePredecessors());
-        return result;
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
+        List<Transformation<?>> predecessors =
+                Stream.of(input1, input2)
+                        .flatMap(input -> 
input.getTransitivePredecessors().stream())
+                        .distinct()
+                        .collect(Collectors.toList());
+        predecessors.add(this);
+        return predecessors;
     }
 
     @Override
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java
index dd29206a9ac..59ce6adceb7 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java
@@ -25,6 +25,7 @@ import 
org.apache.flink.shaded.guava31.com.google.common.collect.Lists;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.stream.Collectors;
 
 /**
  * This transformation represents a union of several input {@link 
Transformation Transformations}.
@@ -63,12 +64,13 @@ public class UnionTransformation<T> extends 
Transformation<T> {
     }
 
     @Override
-    public List<Transformation<?>> getTransitivePredecessors() {
-        List<Transformation<?>> result = Lists.newArrayList();
-        result.add(this);
-        for (Transformation<T> input : inputs) {
-            result.addAll(input.getTransitivePredecessors());
-        }
-        return result;
+    protected List<Transformation<?>> getTransitivePredecessorsInternal() {
+        List<Transformation<?>> predecessors =
+                inputs.stream()
+                        .flatMap(input -> 
input.getTransitivePredecessors().stream())
+                        .distinct()
+                        .collect(Collectors.toList());
+        predecessors.add(this);
+        return predecessors;
     }
 }
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
index 13495a0d933..ade882a3083 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
@@ -1023,7 +1023,7 @@ public class StreamGraphGeneratorTest extends TestLogger {
         }
 
         @Override
-        public List<Transformation<?>> getTransitivePredecessors() {
+        protected List<Transformation<?>> getTransitivePredecessorsInternal() {
             return Collections.emptyList();
         }
 
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/transformations/GetTransitivePredecessorsTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/transformations/GetTransitivePredecessorsTest.java
new file mode 100644
index 00000000000..c658e7bb3fa
--- /dev/null
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/transformations/GetTransitivePredecessorsTest.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.transformations;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.dag.Transformation;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@code getTransitivePredecessors} method of {@link 
Transformation}. */
+class GetTransitivePredecessorsTest {
+
+    private TestTransformation<Integer> commonNode;
+    private Transformation<Integer> midNode;
+
+    @BeforeEach
+    void setup() {
+        commonNode = new TestTransformation<>("commonNode", new 
MockIntegerTypeInfo(), 1);
+        midNode =
+                new OneInputTransformation<>(
+                        commonNode,
+                        "midNode",
+                        new DummyOneInputOperator(),
+                        new MockIntegerTypeInfo(),
+                        1);
+    }
+
+    @Test
+    void testTwoInputTransformation() {
+        Transformation<Integer> topNode =
+                new TwoInputTransformation<>(
+                        commonNode,
+                        midNode,
+                        "topNode",
+                        new DummyTwoInputOperator<>(),
+                        midNode.getOutputType(),
+                        1);
+        List<Transformation<?>> predecessors = 
topNode.getTransitivePredecessors();
+        assertThat(predecessors.size()).isEqualTo(3);
+        assertThat(commonNode.getNumGetTransitivePredecessor()).isEqualTo(1);
+    }
+
+    @Test
+    void testUnionTransformation() {
+        Transformation<Integer> topNode =
+                new UnionTransformation<>(Arrays.asList(commonNode, midNode));
+        List<Transformation<?>> predecessors = 
topNode.getTransitivePredecessors();
+        assertThat(predecessors.size()).isEqualTo(3);
+        assertThat(commonNode.getNumGetTransitivePredecessor()).isEqualTo(1);
+    }
+
+    @Test
+    void testBroadcastStateTransformation() {
+        Transformation<Integer> topNode =
+                new AbstractBroadcastStateTransformation<>(
+                        "topNode", commonNode, midNode, null, 
midNode.getOutputType(), 1);
+        List<Transformation<?>> predecessors = 
topNode.getTransitivePredecessors();
+        assertThat(predecessors.size()).isEqualTo(3);
+        assertThat(commonNode.getNumGetTransitivePredecessor()).isEqualTo(0);
+    }
+
+    @Test
+    void testAbstractMultipleInputTransformation() {
+        Transformation<Integer> topNode =
+                new AbstractMultipleInputTransformation<Integer>(
+                        "topNode",
+                        SimpleOperatorFactory.of(new 
DummyTwoInputOperator<>()),
+                        midNode.getOutputType(),
+                        1) {
+                    @Override
+                    public List<Transformation<?>> getInputs() {
+                        return Arrays.asList(commonNode, midNode);
+                    }
+                };
+        List<Transformation<?>> predecessors = 
topNode.getTransitivePredecessors();
+        assertThat(predecessors.size()).isEqualTo(3);
+        assertThat(commonNode.getNumGetTransitivePredecessor()).isEqualTo(1);
+    }
+
+    /** A test implementation of {@link Transformation}. */
+    private static class TestTransformation<T> extends Transformation<T> {
+        private int numGetTransitivePredecessor = 0;
+
+        public TestTransformation(String name, TypeInformation<T> outputType, 
int parallelism) {
+            super(name, outputType, parallelism);
+        }
+
+        @Override
+        protected List<Transformation<?>> getTransitivePredecessorsInternal() {
+            ++numGetTransitivePredecessor;
+            return Collections.singletonList(this);
+        }
+
+        @Override
+        public List<Transformation<?>> getInputs() {
+            return Collections.emptyList();
+        }
+
+        public int getNumGetTransitivePredecessor() {
+            return numGetTransitivePredecessor;
+        }
+    }
+
+    /** A test implementation of {@link OneInputTransformation}. */
+    private static class DummyOneInputOperator extends 
AbstractStreamOperator<Integer>
+            implements OneInputStreamOperator<Integer, Integer> {
+
+        @Override
+        public void processElement(StreamRecord<Integer> element) throws 
Exception {}
+    }
+
+    /** A test implementation of {@link TwoInputStreamOperator}. */
+    private static class DummyTwoInputOperator<T> extends 
AbstractStreamOperator<T>
+            implements TwoInputStreamOperator<T, T, T> {
+
+        @Override
+        public void processElement1(StreamRecord<T> element) throws Exception {
+            output.collect(element);
+        }
+
+        @Override
+        public void processElement2(StreamRecord<T> element) throws Exception {
+            output.collect(element);
+        }
+    }
+
+    /** A test implementation of {@link TypeInformation}. */
+    private static class MockIntegerTypeInfo extends GenericTypeInfo<Integer> {
+        public MockIntegerTypeInfo() {
+            super(Integer.class);
+        }
+    }
+}
diff --git 
a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/multipleinput/TableOperatorWrapperGeneratorTest.java
 
b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/multipleinput/TableOperatorWrapperGeneratorTest.java
index 9d4041bbf2a..6ee22d8f9a4 100644
--- 
a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/multipleinput/TableOperatorWrapperGeneratorTest.java
+++ 
b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/multipleinput/TableOperatorWrapperGeneratorTest.java
@@ -484,7 +484,7 @@ public class TableOperatorWrapperGeneratorTest extends 
MultipleInputTestBase {
         }
 
         @Override
-        public List<Transformation<?>> getTransitivePredecessors() {
+        protected List<Transformation<?>> getTransitivePredecessorsInternal() {
             return Collections.emptyList();
         }
 

Reply via email to