This is an automated email from the ASF dual-hosted git repository. zhuzh pushed a commit to branch release-1.18 in repository https://gitbox.apache.org/repos/asf/flink.git
commit 940b3bbda5b10abe3a41d60467d33fd424c7dae6 Author: Jeyhun Karimov <je.kari...@gmail.com> AuthorDate: Sun Mar 10 15:56:02 2024 +0100 [FLINK-32513][core] Add predecessor caching This closes #24475. --- .../org/apache/flink/api/dag/Transformation.java | 28 +++- .../apache/flink/api/dag/TransformationTest.java | 2 +- .../AbstractBroadcastStateTransformation.java | 12 +- .../AbstractMultipleInputTransformation.java | 12 +- .../api/transformations/CacheTransformation.java | 2 +- .../transformations/CoFeedbackTransformation.java | 2 +- .../transformations/FeedbackTransformation.java | 2 +- .../transformations/LegacySinkTransformation.java | 2 +- .../LegacySourceTransformation.java | 2 +- .../transformations/OneInputTransformation.java | 2 +- .../transformations/PartitionTransformation.java | 2 +- .../api/transformations/ReduceTransformation.java | 2 +- .../transformations/SideOutputTransformation.java | 2 +- .../api/transformations/SinkTransformation.java | 2 +- .../api/transformations/SourceTransformation.java | 2 +- .../TimestampsAndWatermarksTransformation.java | 2 +- .../transformations/TwoInputTransformation.java | 18 ++- .../api/transformations/UnionTransformation.java | 16 +- .../api/graph/StreamGraphGeneratorTest.java | 2 +- .../GetTransitivePredecessorsTest.java | 162 +++++++++++++++++++++ .../TableOperatorWrapperGeneratorTest.java | 2 +- 21 files changed, 237 insertions(+), 41 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java b/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java index a0448697dd1..07c64907c82 100644 --- a/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java +++ b/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java @@ -19,6 +19,7 @@ package org.apache.flink.api.dag; import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.InvalidTypesException; import org.apache.flink.api.common.operators.ResourceSpec; @@ -161,6 +162,13 @@ public abstract class Transformation<T> { private final Map<ManagedMemoryUseCase, Integer> managedMemoryOperatorScopeUseCaseWeights = new HashMap<>(); + /** + * This map is a cache that stores transitive predecessors and used in {@code + * getTransitivePredecessors()}. + */ + private final Map<Transformation<T>, List<Transformation<?>>> predecessorsCache = + new HashMap<>(); + /** Slot scope use cases that this transformation needs managed memory for. */ private final Set<ManagedMemoryUseCase> managedMemorySlotScopeUseCases = new HashSet<>(); @@ -230,6 +238,12 @@ public abstract class Transformation<T> { return name; } + /** Returns the predecessorsCache of this {@code Transformation}. */ + @VisibleForTesting + Map<Transformation<T>, List<Transformation<?>>> getPredecessorsCache() { + return predecessorsCache; + } + /** Changes the description of this {@code Transformation}. */ public void setDescription(String description) { this.description = Preconditions.checkNotNull(description); @@ -578,7 +592,19 @@ public abstract class Transformation<T> { * * @return The list of transitive predecessors. */ - public abstract List<Transformation<?>> getTransitivePredecessors(); + protected abstract List<Transformation<?>> getTransitivePredecessorsInternal(); + + /** + * Returns all transitive predecessor {@code Transformation}s of this {@code Transformation}. + * This is, for example, used when determining whether a feedback edge of an iteration actually + * has the iteration head as a predecessor. This method is just a wrapper on top of {@code + * getTransitivePredecessorsInternal} method with public access. It uses caching internally. + * + * @return The list of transitive predecessors. + */ + public final List<Transformation<?>> getTransitivePredecessors() { + return predecessorsCache.computeIfAbsent(this, key -> getTransitivePredecessorsInternal()); + } /** * Returns the {@link Transformation transformations} that are the immediate predecessors of the diff --git a/flink-core/src/test/java/org/apache/flink/api/dag/TransformationTest.java b/flink-core/src/test/java/org/apache/flink/api/dag/TransformationTest.java index ced452cd14f..c200fac33fc 100644 --- a/flink-core/src/test/java/org/apache/flink/api/dag/TransformationTest.java +++ b/flink-core/src/test/java/org/apache/flink/api/dag/TransformationTest.java @@ -142,7 +142,7 @@ class TransformationTest { } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { return Collections.emptyList(); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractBroadcastStateTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractBroadcastStateTransformation.java index c7869e61389..1ce6c7c9bfd 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractBroadcastStateTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractBroadcastStateTransformation.java @@ -26,6 +26,8 @@ import org.apache.flink.streaming.api.operators.ChainingStrategy; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -103,12 +105,10 @@ public class AbstractBroadcastStateTransformation<IN1, IN2, OUT> } @Override - public List<Transformation<?>> getTransitivePredecessors() { - final List<Transformation<?>> predecessors = new ArrayList<>(); - predecessors.add(this); - predecessors.add(regularInput); - predecessors.add(broadcastInput); - return predecessors; + protected List<Transformation<?>> getTransitivePredecessorsInternal() { + return Stream.of(this, regularInput, broadcastInput) + .distinct() + .collect(Collectors.toList()); } @Override diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractMultipleInputTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractMultipleInputTransformation.java index 65bbc53188f..a733cebbaaf 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractMultipleInputTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/AbstractMultipleInputTransformation.java @@ -76,10 +76,14 @@ public abstract class AbstractMultipleInputTransformation<OUT> extends PhysicalT } @Override - public List<Transformation<?>> getTransitivePredecessors() { - return inputs.stream() - .flatMap(input -> input.getTransitivePredecessors().stream()) - .collect(Collectors.toList()); + protected List<Transformation<?>> getTransitivePredecessorsInternal() { + List<Transformation<?>> predecessors = + getInputs().stream() + .flatMap(input -> input.getTransitivePredecessors().stream()) + .distinct() + .collect(Collectors.toList()); + predecessors.add(this); + return predecessors; } @Override diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CacheTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CacheTransformation.java index 4632c9b020d..e8bd22adcc8 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CacheTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CacheTransformation.java @@ -57,7 +57,7 @@ public class CacheTransformation<T> extends Transformation<T> { } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { List<Transformation<?>> result = Lists.newArrayList(); result.add(this); if (isCached) { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java index 11f8a912608..4d366ba2a8f 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java @@ -106,7 +106,7 @@ public class CoFeedbackTransformation<F> extends Transformation<F> { } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { return Collections.singletonList(this); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java index ab8f7424741..ef4ec2766bf 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java @@ -102,7 +102,7 @@ public class FeedbackTransformation<T> extends Transformation<T> { } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { List<Transformation<?>> result = Lists.newArrayList(); result.add(this); result.addAll(input.getTransitivePredecessors()); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java index a0bd6417881..fcd753c13b4 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySinkTransformation.java @@ -136,7 +136,7 @@ public class LegacySinkTransformation<T> extends PhysicalTransformation<T> { } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { List<Transformation<?>> result = Lists.newArrayList(); result.add(this); result.addAll(input.getTransitivePredecessors()); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySourceTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySourceTransformation.java index 4718ff18386..9443a3a14ea 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySourceTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/LegacySourceTransformation.java @@ -92,7 +92,7 @@ public class LegacySourceTransformation<T> extends PhysicalTransformation<T> } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { return Collections.singletonList(this); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java index 0a3957130fc..7222fad97da 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java @@ -168,7 +168,7 @@ public class OneInputTransformation<IN, OUT> extends PhysicalTransformation<OUT> } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { List<Transformation<?>> result = Lists.newArrayList(); result.add(this); result.addAll(input.getTransitivePredecessors()); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java index 92d3f264731..966e565c78f 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java @@ -89,7 +89,7 @@ public class PartitionTransformation<T> extends Transformation<T> { } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { List<Transformation<?>> result = Lists.newArrayList(); result.add(this); result.addAll(input.getTransitivePredecessors()); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/ReduceTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/ReduceTransformation.java index da159f3af11..82c4e1c4694 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/ReduceTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/ReduceTransformation.java @@ -89,7 +89,7 @@ public final class ReduceTransformation<IN, K> extends PhysicalTransformation<IN } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { List<Transformation<?>> result = Lists.newArrayList(); result.add(this); result.addAll(input.getTransitivePredecessors()); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java index 3177bfa0eb1..e054d37be54 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java @@ -53,7 +53,7 @@ public class SideOutputTransformation<T> extends Transformation<T> { } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { List<Transformation<?>> result = Lists.newArrayList(); result.add(this); result.addAll(input.getTransitivePredecessors()); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java index bdaa8d2dede..a4b4310588e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java @@ -73,7 +73,7 @@ public class SinkTransformation<InputT, OutputT> extends PhysicalTransformation< } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { final List<Transformation<?>> result = Lists.newArrayList(); result.add(this); result.addAll(input.getTransitivePredecessors()); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformation.java index 86c07ca5511..0260da43b82 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformation.java @@ -92,7 +92,7 @@ public class SourceTransformation<OUT, SplitT extends SourceSplit, EnumChkT> } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { return Collections.singletonList(this); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TimestampsAndWatermarksTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TimestampsAndWatermarksTransformation.java index c8bde911a04..26541f74673 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TimestampsAndWatermarksTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TimestampsAndWatermarksTransformation.java @@ -77,7 +77,7 @@ public class TimestampsAndWatermarksTransformation<IN> extends PhysicalTransform } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { List<Transformation<?>> transformations = Lists.newArrayList(); transformations.add(this); transformations.addAll(input.getTransitivePredecessors()); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java index 4c576170741..48005bf95c4 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java @@ -28,10 +28,10 @@ import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; -import org.apache.flink.shaded.guava31.com.google.common.collect.Lists; - import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * This Transformation represents the application of a {@link TwoInputStreamOperator} to two input @@ -217,12 +217,14 @@ public class TwoInputTransformation<IN1, IN2, OUT> extends PhysicalTransformatio } @Override - public List<Transformation<?>> getTransitivePredecessors() { - List<Transformation<?>> result = Lists.newArrayList(); - result.add(this); - result.addAll(input1.getTransitivePredecessors()); - result.addAll(input2.getTransitivePredecessors()); - return result; + protected List<Transformation<?>> getTransitivePredecessorsInternal() { + List<Transformation<?>> predecessors = + Stream.of(input1, input2) + .flatMap(input -> input.getTransitivePredecessors().stream()) + .distinct() + .collect(Collectors.toList()); + predecessors.add(this); + return predecessors; } @Override diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java index dd29206a9ac..59ce6adceb7 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java @@ -25,6 +25,7 @@ import org.apache.flink.shaded.guava31.com.google.common.collect.Lists; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; /** * This transformation represents a union of several input {@link Transformation Transformations}. @@ -63,12 +64,13 @@ public class UnionTransformation<T> extends Transformation<T> { } @Override - public List<Transformation<?>> getTransitivePredecessors() { - List<Transformation<?>> result = Lists.newArrayList(); - result.add(this); - for (Transformation<T> input : inputs) { - result.addAll(input.getTransitivePredecessors()); - } - return result; + protected List<Transformation<?>> getTransitivePredecessorsInternal() { + List<Transformation<?>> predecessors = + inputs.stream() + .flatMap(input -> input.getTransitivePredecessors().stream()) + .distinct() + .collect(Collectors.toList()); + predecessors.add(this); + return predecessors; } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java index 13495a0d933..ade882a3083 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java @@ -1023,7 +1023,7 @@ public class StreamGraphGeneratorTest extends TestLogger { } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { return Collections.emptyList(); } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/transformations/GetTransitivePredecessorsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/transformations/GetTransitivePredecessorsTest.java new file mode 100644 index 00000000000..c658e7bb3fa --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/transformations/GetTransitivePredecessorsTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.transformations; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@code getTransitivePredecessors} method of {@link Transformation}. */ +class GetTransitivePredecessorsTest { + + private TestTransformation<Integer> commonNode; + private Transformation<Integer> midNode; + + @BeforeEach + void setup() { + commonNode = new TestTransformation<>("commonNode", new MockIntegerTypeInfo(), 1); + midNode = + new OneInputTransformation<>( + commonNode, + "midNode", + new DummyOneInputOperator(), + new MockIntegerTypeInfo(), + 1); + } + + @Test + void testTwoInputTransformation() { + Transformation<Integer> topNode = + new TwoInputTransformation<>( + commonNode, + midNode, + "topNode", + new DummyTwoInputOperator<>(), + midNode.getOutputType(), + 1); + List<Transformation<?>> predecessors = topNode.getTransitivePredecessors(); + assertThat(predecessors.size()).isEqualTo(3); + assertThat(commonNode.getNumGetTransitivePredecessor()).isEqualTo(1); + } + + @Test + void testUnionTransformation() { + Transformation<Integer> topNode = + new UnionTransformation<>(Arrays.asList(commonNode, midNode)); + List<Transformation<?>> predecessors = topNode.getTransitivePredecessors(); + assertThat(predecessors.size()).isEqualTo(3); + assertThat(commonNode.getNumGetTransitivePredecessor()).isEqualTo(1); + } + + @Test + void testBroadcastStateTransformation() { + Transformation<Integer> topNode = + new AbstractBroadcastStateTransformation<>( + "topNode", commonNode, midNode, null, midNode.getOutputType(), 1); + List<Transformation<?>> predecessors = topNode.getTransitivePredecessors(); + assertThat(predecessors.size()).isEqualTo(3); + assertThat(commonNode.getNumGetTransitivePredecessor()).isEqualTo(0); + } + + @Test + void testAbstractMultipleInputTransformation() { + Transformation<Integer> topNode = + new AbstractMultipleInputTransformation<Integer>( + "topNode", + SimpleOperatorFactory.of(new DummyTwoInputOperator<>()), + midNode.getOutputType(), + 1) { + @Override + public List<Transformation<?>> getInputs() { + return Arrays.asList(commonNode, midNode); + } + }; + List<Transformation<?>> predecessors = topNode.getTransitivePredecessors(); + assertThat(predecessors.size()).isEqualTo(3); + assertThat(commonNode.getNumGetTransitivePredecessor()).isEqualTo(1); + } + + /** A test implementation of {@link Transformation}. */ + private static class TestTransformation<T> extends Transformation<T> { + private int numGetTransitivePredecessor = 0; + + public TestTransformation(String name, TypeInformation<T> outputType, int parallelism) { + super(name, outputType, parallelism); + } + + @Override + protected List<Transformation<?>> getTransitivePredecessorsInternal() { + ++numGetTransitivePredecessor; + return Collections.singletonList(this); + } + + @Override + public List<Transformation<?>> getInputs() { + return Collections.emptyList(); + } + + public int getNumGetTransitivePredecessor() { + return numGetTransitivePredecessor; + } + } + + /** A test implementation of {@link OneInputTransformation}. */ + private static class DummyOneInputOperator extends AbstractStreamOperator<Integer> + implements OneInputStreamOperator<Integer, Integer> { + + @Override + public void processElement(StreamRecord<Integer> element) throws Exception {} + } + + /** A test implementation of {@link TwoInputStreamOperator}. */ + private static class DummyTwoInputOperator<T> extends AbstractStreamOperator<T> + implements TwoInputStreamOperator<T, T, T> { + + @Override + public void processElement1(StreamRecord<T> element) throws Exception { + output.collect(element); + } + + @Override + public void processElement2(StreamRecord<T> element) throws Exception { + output.collect(element); + } + } + + /** A test implementation of {@link TypeInformation}. */ + private static class MockIntegerTypeInfo extends GenericTypeInfo<Integer> { + public MockIntegerTypeInfo() { + super(Integer.class); + } + } +} diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/multipleinput/TableOperatorWrapperGeneratorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/multipleinput/TableOperatorWrapperGeneratorTest.java index 9d4041bbf2a..6ee22d8f9a4 100644 --- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/multipleinput/TableOperatorWrapperGeneratorTest.java +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/multipleinput/TableOperatorWrapperGeneratorTest.java @@ -484,7 +484,7 @@ public class TableOperatorWrapperGeneratorTest extends MultipleInputTestBase { } @Override - public List<Transformation<?>> getTransitivePredecessors() { + protected List<Transformation<?>> getTransitivePredecessorsInternal() { return Collections.emptyList(); }