[FLINK-8446] Support multiple broadcast states.

Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/28768235
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/28768235
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/28768235

Branch: refs/heads/master
Commit: 28768235068039e4ff50c5235ab79c54410b4ec0
Parents: 26918c9
Author: kkloudas <[email protected]>
Authored: Mon Jan 29 16:23:04 2018 +0100
Committer: kkloudas <[email protected]>
Committed: Wed Feb 7 14:08:52 2018 +0100

----------------------------------------------------------------------
 .../datastream/BroadcastConnectedStream.java    |  24 ++--
 .../api/datastream/BroadcastStream.java         |  29 ++---
 .../streaming/api/datastream/DataStream.java    |  11 +-
 .../functions/co/BroadcastProcessFunction.java  |   2 +-
 .../co/KeyedBroadcastProcessFunction.java       |   2 +-
 .../flink/streaming/api/DataStreamTest.java     |   6 +-
 .../co/CoBroadcastWithNonKeyedOperatorTest.java | 116 ++++++++++++++++---
 7 files changed, 134 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/28768235/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java
index 453c850..f3c4838 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java
@@ -33,14 +33,14 @@ import 
org.apache.flink.streaming.api.operators.co.CoBroadcastWithNonKeyedOperat
 import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
 import org.apache.flink.util.Preconditions;
 
-import java.util.Collections;
+import java.util.List;
 
 import static java.util.Objects.requireNonNull;
 
 /**
  * A BroadcastConnectedStream represents the result of connecting a keyed or 
non-keyed stream,
  * with a {@link BroadcastStream} with {@link 
org.apache.flink.api.common.state.BroadcastState
- * BroadcastState}. As in the case of {@link ConnectedStreams} these streams 
are useful for cases
+ * broadcast state(s)}. As in the case of {@link ConnectedStreams} these 
streams are useful for cases
  * where operations on one stream directly affect the operations on the other 
stream, usually via
  * shared state between the streams.
  *
@@ -52,26 +52,24 @@ import static java.util.Objects.requireNonNull;
  *
  * @param <IN1> The input type of the non-broadcast side.
  * @param <IN2> The input type of the broadcast side.
- * @param <K> The key type of the elements in the {@link 
org.apache.flink.api.common.state.BroadcastState BroadcastState}.
- * @param <V> The value type of the elements in the {@link 
org.apache.flink.api.common.state.BroadcastState BroadcastState}.
  */
 @PublicEvolving
-public class BroadcastConnectedStream<IN1, IN2, K, V> {
+public class BroadcastConnectedStream<IN1, IN2> {
 
        private final StreamExecutionEnvironment environment;
        private final DataStream<IN1> inputStream1;
-       private final BroadcastStream<IN2, K, V> inputStream2;
-       private final MapStateDescriptor<K, V> broadcastStateDescriptor;
+       private final BroadcastStream<IN2> inputStream2;
+       private final List<MapStateDescriptor<?, ?>> broadcastStateDescriptors;
 
        protected BroadcastConnectedStream(
                        final StreamExecutionEnvironment env,
                        final DataStream<IN1> input1,
-                       final BroadcastStream<IN2, K, V> input2,
-                       final MapStateDescriptor<K, V> 
broadcastStateDescriptor) {
+                       final BroadcastStream<IN2> input2,
+                       final List<MapStateDescriptor<?, ?>> 
broadcastStateDescriptors) {
                this.environment = requireNonNull(env);
                this.inputStream1 = requireNonNull(input1);
                this.inputStream2 = requireNonNull(input2);
-               this.broadcastStateDescriptor = 
requireNonNull(broadcastStateDescriptor);
+               this.broadcastStateDescriptors = 
requireNonNull(broadcastStateDescriptors);
        }
 
        public StreamExecutionEnvironment getExecutionEnvironment() {
@@ -92,7 +90,7 @@ public class BroadcastConnectedStream<IN1, IN2, K, V> {
         *
         * @return The stream which, by convention, is the broadcast one.
         */
-       public BroadcastStream<IN2, K, V> getSecondInput() {
+       public BroadcastStream<IN2> getSecondInput() {
                return inputStream2;
        }
 
@@ -163,7 +161,7 @@ public class BroadcastConnectedStream<IN1, IN2, K, V> {
                                "A KeyedBroadcastProcessFunction can only be 
used with a keyed stream as the second input.");
 
                TwoInputStreamOperator<IN1, IN2, OUT> operator =
-                               new CoBroadcastWithKeyedOperator<>(function, 
Collections.singletonList(broadcastStateDescriptor));
+                               new CoBroadcastWithKeyedOperator<>(function, 
broadcastStateDescriptors);
                return transform("Co-Process-Broadcast-Keyed", outTypeInfo, 
operator);
        }
 
@@ -214,7 +212,7 @@ public class BroadcastConnectedStream<IN1, IN2, K, V> {
                                "A BroadcastProcessFunction can only be used 
with a non-keyed stream as the second input.");
 
                TwoInputStreamOperator<IN1, IN2, OUT> operator =
-                               new CoBroadcastWithNonKeyedOperator<>(function, 
Collections.singletonList(broadcastStateDescriptor));
+                               new CoBroadcastWithNonKeyedOperator<>(function, 
broadcastStateDescriptors);
                return transform("Co-Process-Broadcast", outTypeInfo, operator);
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/28768235/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java
index e21e36f..6c56f98 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java
@@ -24,12 +24,15 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.transformations.StreamTransformation;
 
+import java.util.Arrays;
+import java.util.List;
+
 import static java.util.Objects.requireNonNull;
 
 /**
- * A {@code BroadcastStream} is a stream with {@link 
org.apache.flink.api.common.state.BroadcastState BroadcastState}.
- * This can be created by any stream using the {@link 
DataStream#broadcast(MapStateDescriptor)} method and
- * implicitly creates a state where the user can store elements of the created 
{@code BroadcastStream}.
+ * A {@code BroadcastStream} is a stream with {@link 
org.apache.flink.api.common.state.BroadcastState broadcast state(s)}.
+ * This can be created by any stream using the {@link 
DataStream#broadcast(MapStateDescriptor[])} method and
+ * implicitly creates states where the user can store elements of the created 
{@code BroadcastStream}.
  * (see {@link BroadcastConnectedStream}).
  *
  * <p>Note that no further operation can be applied to these streams. The only 
available option is to connect them
@@ -38,31 +41,29 @@ import static java.util.Objects.requireNonNull;
  * {@link BroadcastConnectedStream} for further processing.
  *
  * @param <T> The type of input/output elements.
- * @param <K> The key type of the elements in the {@link 
org.apache.flink.api.common.state.BroadcastState BroadcastState}.
- * @param <V> The value type of the elements in the {@link 
org.apache.flink.api.common.state.BroadcastState BroadcastState}.
  */
 @PublicEvolving
-public class BroadcastStream<T, K, V> {
+public class BroadcastStream<T> {
 
        private final StreamExecutionEnvironment environment;
 
        private final DataStream<T> inputStream;
 
        /**
-        * The {@link org.apache.flink.api.common.state.StateDescriptor state 
descriptor} of the
-        * {@link org.apache.flink.api.common.state.BroadcastState broadcast 
state}. This state
-        * has a {@code key-value} format.
+        * The {@link org.apache.flink.api.common.state.StateDescriptor state 
descriptors} of the
+        * registered {@link org.apache.flink.api.common.state.BroadcastState 
broadcast states}. These
+        * states have {@code key-value} format.
         */
-       private final MapStateDescriptor<K, V> broadcastStateDescriptor;
+       private final List<MapStateDescriptor<?, ?>> broadcastStateDescriptors;
 
        protected BroadcastStream(
                        final StreamExecutionEnvironment env,
                        final DataStream<T> input,
-                       final MapStateDescriptor<K, V> 
broadcastStateDescriptor) {
+                       final MapStateDescriptor<?, ?>... 
broadcastStateDescriptors) {
 
                this.environment = requireNonNull(env);
                this.inputStream = requireNonNull(input);
-               this.broadcastStateDescriptor = 
requireNonNull(broadcastStateDescriptor);
+               this.broadcastStateDescriptors = 
Arrays.asList(requireNonNull(broadcastStateDescriptors));
        }
 
        public TypeInformation<T> getType() {
@@ -77,8 +78,8 @@ public class BroadcastStream<T, K, V> {
                return inputStream.getTransformation();
        }
 
-       public MapStateDescriptor<K, V> getBroadcastStateDescriptor() {
-               return broadcastStateDescriptor;
+       public List<MapStateDescriptor<?, ?>> getBroadcastStateDescriptor() {
+               return broadcastStateDescriptors;
        }
 
        public StreamExecutionEnvironment getEnvironment() {

http://git-wip-us.apache.org/repos/asf/flink/blob/28768235/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
index d859689..8d18b80 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
@@ -257,7 +257,7 @@ public class DataStream<T> {
         * Creates a new {@link BroadcastConnectedStream} by connecting the 
current
         * {@link DataStream} or {@link KeyedStream} with a {@link 
BroadcastStream}.
         *
-        * <p>The latter can be created using the {@link 
#broadcast(MapStateDescriptor)} method.
+        * <p>The latter can be created using the {@link 
#broadcast(MapStateDescriptor[])} method.
         *
         * <p>The resulting stream can be further processed using the {@code 
BroadcastConnectedStream.process(MyFunction)}
         * method, where {@code MyFunction} can be either a
@@ -269,7 +269,7 @@ public class DataStream<T> {
         * @return The {@link BroadcastConnectedStream}.
         */
        @PublicEvolving
-       public <R, K, V> BroadcastConnectedStream<T, R, K, V> 
connect(BroadcastStream<R, K, V> broadcastStream) {
+       public <R> BroadcastConnectedStream<T, R> connect(BroadcastStream<R> 
broadcastStream) {
                return new BroadcastConnectedStream<>(
                                environment,
                                this,
@@ -402,14 +402,15 @@ public class DataStream<T> {
         * it implicitly creates a {@link 
org.apache.flink.api.common.state.BroadcastState broadcast state}
         * which can be used to store the element of the stream.
         *
+        * @param broadcastStateDescriptors the descriptors of the broadcast 
states to create.
         * @return A {@link BroadcastStream} which can be used in the {@link 
#connect(BroadcastStream)} to
         * create a {@link BroadcastConnectedStream} for further processing of 
the elements.
         */
        @PublicEvolving
-       public <K, V> BroadcastStream<T, K, V> broadcast(final 
MapStateDescriptor<K, V> broadcastStateDescriptor) {
-               Preconditions.checkNotNull(broadcastStateDescriptor);
+       public BroadcastStream<T> broadcast(final MapStateDescriptor<?, ?>... 
broadcastStateDescriptors) {
+               Preconditions.checkNotNull(broadcastStateDescriptors);
                final DataStream<T> broadcastStream = setConnectionType(new 
BroadcastPartitioner<>());
-               return new BroadcastStream<>(environment, broadcastStream, 
broadcastStateDescriptor);
+               return new BroadcastStream<>(environment, broadcastStream, 
broadcastStateDescriptors);
        }
 
        /**

http://git-wip-us.apache.org/repos/asf/flink/blob/28768235/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java
index 4dcc929..257ea83 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java
@@ -29,7 +29,7 @@ import org.apache.flink.util.Collector;
  * with broadcast state, with a <b>non-keyed</b> {@link 
org.apache.flink.streaming.api.datastream.DataStream DataStream}.
  *
  * <p>The stream with the broadcast state can be created using the
- * {@link 
org.apache.flink.streaming.api.datastream.DataStream#broadcast(MapStateDescriptor)
+ * {@link 
org.apache.flink.streaming.api.datastream.DataStream#broadcast(MapStateDescriptor[])}
  * stream.broadcast(MapStateDescriptor)} method.
  *
  * <p>The user has to implement two methods:

http://git-wip-us.apache.org/repos/asf/flink/blob/28768235/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
index 4b9f138..de9cb32 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
@@ -34,7 +34,7 @@ import org.apache.flink.util.Collector;
  * with broadcast state, with a {@link 
org.apache.flink.streaming.api.datastream.KeyedStream KeyedStream}.
  *
  * <p>The stream with the broadcast state can be created using the
- * {@link 
org.apache.flink.streaming.api.datastream.KeyedStream#broadcast(MapStateDescriptor)
+ * {@link 
org.apache.flink.streaming.api.datastream.KeyedStream#broadcast(MapStateDescriptor[])}
  * keyedStream.broadcast(MapStateDescriptor)} method.
  *
  * <p>The user has to implement two methods:

http://git-wip-us.apache.org/repos/asf/flink/blob/28768235/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
index bcbbfd6..ca76ef4 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
@@ -794,7 +794,7 @@ public class DataStreamTest extends TestLogger {
                                        }
                                });
 
-               final BroadcastStream<String, Long, String> broadcast = 
srcTwo.broadcast(TestBroadcastProcessFunction.DESCRIPTOR);
+               final BroadcastStream<String> broadcast = 
srcTwo.broadcast(TestBroadcastProcessFunction.DESCRIPTOR);
 
                // the timestamp should be high enough to trigger the timer 
after all the elements arrive.
                final DataStream<String> output = 
srcOne.connect(broadcast).process(
@@ -880,7 +880,7 @@ public class DataStreamTest extends TestLogger {
                                        }
                                });
 
-               BroadcastStream<String, Long, String> broadcast = 
srcTwo.broadcast(descriptor);
+               BroadcastStream<String> broadcast = 
srcTwo.broadcast(descriptor);
                srcOne.connect(broadcast)
                                .process(new BroadcastProcessFunction<Long, 
String, String>() {
                                        @Override
@@ -923,7 +923,7 @@ public class DataStreamTest extends TestLogger {
                                        }
                                });
 
-               BroadcastStream<String, Long, String> broadcast = 
srcTwo.broadcast(descriptor);
+               BroadcastStream<String> broadcast = 
srcTwo.broadcast(descriptor);
                srcOne.connect(broadcast)
                                .process(new 
KeyedBroadcastProcessFunction<String, Long, String, String>() {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/28768235/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java
index 066a80f..96e1c3e 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java
@@ -35,7 +35,7 @@ import org.apache.flink.util.Preconditions;
 import org.junit.Assert;
 import org.junit.Test;
 
-import java.util.Collections;
+import java.util.Arrays;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Queue;
@@ -54,6 +54,59 @@ public class CoBroadcastWithNonKeyedOperatorTest {
                                        BasicTypeInfo.INT_TYPE_INFO
                        );
 
+       private static final MapStateDescriptor<Integer, String> 
STATE_DESCRIPTOR_A =
+                       new MapStateDescriptor<>(
+                                       "broadcast-state-A",
+                                       BasicTypeInfo.INT_TYPE_INFO,
+                                       BasicTypeInfo.STRING_TYPE_INFO
+                       );
+
+       @Test
+       public void testMultiStateSupport() throws Exception {
+               try (
+                               TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness =
+                                               getInitializedTestHarness(new 
FunctionWithMultipleStates(), STATE_DESCRIPTOR, STATE_DESCRIPTOR_A)
+               ) {
+                       testHarness.processElement2(new StreamRecord<>(5, 12L));
+                       testHarness.processElement2(new StreamRecord<>(6, 13L));
+
+                       testHarness.processElement1(new StreamRecord<>("9", 
15L));
+
+                       Queue<Object> expectedBr = new 
ConcurrentLinkedQueue<>();
+                       expectedBr.add(new StreamRecord<>("9:key.6->6", 15L));
+                       expectedBr.add(new StreamRecord<>("9:key.5->5", 15L));
+                       expectedBr.add(new StreamRecord<>("9:5->value.5", 15L));
+                       expectedBr.add(new StreamRecord<>("9:6->value.6", 15L));
+
+                       TestHarnessUtil.assertOutputEquals("Wrong Side Output", 
expectedBr, testHarness.getOutput());
+               }
+       }
+
+       /**
+        * {@link BroadcastProcessFunction} that puts elements on multiple 
broadcast states.
+        */
+       private static class FunctionWithMultipleStates extends 
BroadcastProcessFunction<String, Integer, String> {
+
+               private static final long serialVersionUID = 
7496674620398203933L;
+
+               @Override
+               public void processBroadcastElement(Integer value, Context ctx, 
Collector<String> out) throws Exception {
+                       ctx.getBroadcastState(STATE_DESCRIPTOR).put("key." + 
value, value);
+                       ctx.getBroadcastState(STATE_DESCRIPTOR_A).put(value, 
"value." + value);
+               }
+
+               @Override
+               public void processElement(String value, ReadOnlyContext ctx, 
Collector<String> out) throws Exception {
+                       for (Map.Entry<String, Integer> entry: 
ctx.getBroadcastState(STATE_DESCRIPTOR).immutableEntries()) {
+                               out.collect(value + ":" + entry.getKey() + "->" 
+ entry.getValue());
+                       }
+
+                       for (Map.Entry<Integer, String> entry: 
ctx.getBroadcastState(STATE_DESCRIPTOR_A).immutableEntries()) {
+                               out.collect(value + ":" + entry.getKey() + "->" 
+ entry.getValue());
+                       }
+               }
+       }
+
        @Test
        public void testBroadcastState() throws Exception {
 
@@ -64,7 +117,7 @@ public class CoBroadcastWithNonKeyedOperatorTest {
 
                try (
                                TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness = getInitializedTestHarness(
-                                               new 
TestFunction(keysToRegister))
+                                               new 
TestFunction(keysToRegister), STATE_DESCRIPTOR)
                ) {
                        testHarness.processWatermark1(new Watermark(10L));
                        testHarness.processWatermark2(new Watermark(10L));
@@ -127,7 +180,7 @@ public class CoBroadcastWithNonKeyedOperatorTest {
        public void testSideOutput() throws Exception {
                try (
                                TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness = getInitializedTestHarness(
-                                               new FunctionWithSideOutput())
+                                               new FunctionWithSideOutput(), 
STATE_DESCRIPTOR)
                ) {
 
                        testHarness.processWatermark1(new Watermark(10L));
@@ -197,13 +250,15 @@ public class CoBroadcastWithNonKeyedOperatorTest {
                                                new 
TestFunctionWithOutput(keysToRegister),
                                                10,
                                                2,
-                                               0);
+                                               0,
+                                               STATE_DESCRIPTOR);
 
                                TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness2 = getInitializedTestHarness(
                                                new 
TestFunctionWithOutput(keysToRegister),
                                                10,
                                                2,
-                                               1)
+                                               1,
+                                               STATE_DESCRIPTOR)
                ) {
                        // make sure all operators have the same state
                        testHarness1.processElement2(new StreamRecord<>(3));
@@ -226,21 +281,24 @@ public class CoBroadcastWithNonKeyedOperatorTest {
                                                10,
                                                3,
                                                0,
-                                               mergedSnapshot);
+                                               mergedSnapshot,
+                                               STATE_DESCRIPTOR);
 
                                TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness2 = getInitializedTestHarness(
                                                new 
TestFunctionWithOutput(keysToRegister),
                                                10,
                                                3,
                                                1,
-                                               mergedSnapshot);
+                                               mergedSnapshot,
+                                               STATE_DESCRIPTOR);
 
                                TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness3 = getInitializedTestHarness(
                                                new 
TestFunctionWithOutput(keysToRegister),
                                                10,
                                                3,
                                                2,
-                                               mergedSnapshot)
+                                               mergedSnapshot,
+                                               STATE_DESCRIPTOR)
                ) {
                        testHarness1.processElement1(new 
StreamRecord<>("trigger"));
                        testHarness2.processElement1(new 
StreamRecord<>("trigger"));
@@ -284,19 +342,22 @@ public class CoBroadcastWithNonKeyedOperatorTest {
                                                new 
TestFunctionWithOutput(keysToRegister),
                                                10,
                                                3,
-                                               0);
+                                               0,
+                                               STATE_DESCRIPTOR);
 
                                TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness2 = getInitializedTestHarness(
                                                new 
TestFunctionWithOutput(keysToRegister),
                                                10,
                                                3,
-                                               1);
+                                               1,
+                                               STATE_DESCRIPTOR);
 
                                TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness3 = getInitializedTestHarness(
                                                new 
TestFunctionWithOutput(keysToRegister),
                                                10,
                                                3,
-                                               2)
+                                               2,
+                                               STATE_DESCRIPTOR)
                ) {
 
                        // make sure all operators have the same state
@@ -322,14 +383,16 @@ public class CoBroadcastWithNonKeyedOperatorTest {
                                                10,
                                                2,
                                                0,
-                                               mergedSnapshot);
+                                               mergedSnapshot,
+                                               STATE_DESCRIPTOR);
 
                                TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness2 = getInitializedTestHarness(
                                                new 
TestFunctionWithOutput(keysToRegister),
                                                10,
                                                2,
                                                1,
-                                               mergedSnapshot)
+                                               mergedSnapshot,
+                                               STATE_DESCRIPTOR)
                ) {
                        testHarness1.processElement1(new 
StreamRecord<>("trigger"));
                        testHarness2.processElement1(new 
StreamRecord<>("trigger"));
@@ -452,40 +515,55 @@ public class CoBroadcastWithNonKeyedOperatorTest {
        }
 
        private static <IN1, IN2, OUT> TwoInputStreamOperatorTestHarness<IN1, 
IN2, OUT> getInitializedTestHarness(
-                       final BroadcastProcessFunction<IN1, IN2, OUT> function) 
throws Exception {
+                       final BroadcastProcessFunction<IN1, IN2, OUT> function,
+                       final MapStateDescriptor<?, ?>... descriptors) throws 
Exception {
 
                return getInitializedTestHarness(
                                function,
                                1,
                                1,
-                               0);
+                               0,
+                               descriptors);
        }
 
        private static <IN1, IN2, OUT> TwoInputStreamOperatorTestHarness<IN1, 
IN2, OUT> getInitializedTestHarness(
                        final BroadcastProcessFunction<IN1, IN2, OUT> function,
                        final int maxParallelism,
                        final int numTasks,
-                       final int taskIdx) throws Exception {
+                       final int taskIdx,
+                       final MapStateDescriptor<?, ?>... descriptors) throws 
Exception {
 
                return getInitializedTestHarness(
                                function,
                                maxParallelism,
                                numTasks,
                                taskIdx,
-                               null);
+                               null,
+                               descriptors);
        }
 
+//     private static <IN1, IN2, OUT> TwoInputStreamOperatorTestHarness<IN1, 
IN2, OUT> getInitializedTestHarness(
+//                     final BroadcastProcessFunction<IN1, IN2, OUT> function,
+//                     final int maxParallelism,
+//                     final int numTasks,
+//                     final int taskIdx,
+//                     final OperatorStateHandles initState) throws Exception {
+//
+//             return getInitializedTestHarness(function, maxParallelism, 
numTasks, taskIdx, initState, STATE_DESCRIPTOR);
+//     }
+
        private static <IN1, IN2, OUT> TwoInputStreamOperatorTestHarness<IN1, 
IN2, OUT> getInitializedTestHarness(
                        final BroadcastProcessFunction<IN1, IN2, OUT> function,
                        final int maxParallelism,
                        final int numTasks,
                        final int taskIdx,
-                       final OperatorStateHandles initState) throws Exception {
+                       final OperatorStateHandles initState,
+                       final MapStateDescriptor<?, ?>... descriptors) throws 
Exception {
 
                TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> testHarness = 
new TwoInputStreamOperatorTestHarness<>(
                                new CoBroadcastWithNonKeyedOperator<>(
                                                
Preconditions.checkNotNull(function),
-                                               
Collections.singletonList(STATE_DESCRIPTOR)),
+                                               Arrays.asList(descriptors)),
                                maxParallelism, numTasks, taskIdx
                );
                testHarness.setup();

Reply via email to