[FLINK-8345] Add iterator of keyed state on broadcast side of connected streams.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/26918c95
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/26918c95
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/26918c95

Branch: refs/heads/master
Commit: 26918c953287c7940120dfcfcc10dd5a42beaf81
Parents: c6c17be
Author: kkloudas <[email protected]>
Authored: Mon Jan 29 16:17:24 2018 +0100
Committer: kkloudas <[email protected]>
Committed: Wed Feb 7 14:08:16 2018 +0100

----------------------------------------------------------------------
 .../state/AbstractKeyedStateBackend.java        |  33 +++++
 .../flink/runtime/state/KeyedStateBackend.java  |  18 +++
 .../flink/runtime/state/KeyedStateFunction.java |  38 ++++++
 .../datastream/BroadcastConnectedStream.java    |  10 +-
 .../co/KeyedBroadcastProcessFunction.java       |  39 +++++-
 .../co/CoBroadcastWithKeyedOperator.java        |  39 ++++--
 .../flink/streaming/api/DataStreamTest.java     |   8 +-
 .../co/CoBroadcastWithKeyedOperatorTest.java    | 128 ++++++++++++++++---
 8 files changed, 274 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
index cc53c0c..d159d46 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -283,6 +283,39 @@ public abstract class AbstractKeyedStateBackend<K>
         * @see KeyedStateBackend
         */
        @Override
+       public <N, S extends State, T> void applyToAllKeys(
+                       final N namespace,
+                       final TypeSerializer<N> namespaceSerializer,
+                       final StateDescriptor<S, T> stateDescriptor,
+                       final KeyedStateFunction<K, S> function) throws 
Exception {
+
+               try {
+                       getKeys(stateDescriptor.getName(), namespace)
+                                       .forEach((K key) -> {
+                                               setCurrentKey(key);
+                                               try {
+                                                       function.process(
+                                                                       key,
+                                                                       
getPartitionedState(
+                                                                               
        namespace,
+                                                                               
        namespaceSerializer,
+                                                                               
        stateDescriptor)
+                                                       );
+                                               } catch (Throwable e) {
+                                                       // we wrap the checked 
exception in an unchecked
+                                                       // one and catch it 
(and re-throw it) later.
+                                                       throw new 
RuntimeException(e);
+                                               }
+                                       });
+               } catch (RuntimeException e) {
+                       throw e;
+               }
+       }
+
+       /**
+        * @see KeyedStateBackend
+        */
+       @Override
        public <N, S extends State, V> S getOrCreateKeyedState(
                        final TypeSerializer<N> namespaceSerializer,
                        StateDescriptor<S, V> stateDescriptor) throws Exception 
{

http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
index c74cfcf..cbe40ee 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
@@ -39,6 +39,24 @@ public interface KeyedStateBackend<K> extends 
InternalKeyContext<K> {
        void setCurrentKey(K newKey);
 
        /**
+        * Applies the provided {@link KeyedStateFunction} to the state with 
the provided
+        * {@link StateDescriptor} of all the currently active keys.
+        *
+        * @param namespace the namespace of the state.
+        * @param namespaceSerializer the serializer for the namespace.
+        * @param stateDescriptor the descriptor of the state to which the 
function is going to be applied.
+        * @param function the function to be applied to the keyed state.
+        *
+        * @param <N> The type of the namespace.
+        * @param <S> The type of the state.
+        */
+       <N, S extends State, T> void applyToAllKeys(
+                       final N namespace,
+                       final TypeSerializer<N> namespaceSerializer,
+                       final StateDescriptor<S, T> stateDescriptor,
+                       final KeyedStateFunction<K, S> function) throws 
Exception;
+
+       /**
         * @return A stream of all keys for the given state and namespace. 
Modifications to the state during iterating
         *                 over it keys are not supported.
         * @param state State variable for which existing keys will be returned.

http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java
new file mode 100644
index 0000000..de23dec
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.state.State;
+
+/**
+ * A function to be applied to all keyed states.
+ *
+ * <p>This functionality is only available through the
+ * {@code BroadcastConnectedStream.process(final KeyedBroadcastProcessFunction 
function)}.
+ */
+public abstract class KeyedStateFunction<K, S extends State> {
+
+       /**
+        * The actual method to be applied on each of the states.
+        *
+        * @param key a safe copy of the key (see {@link 
KeyedStateBackend#getCurrentKeySafe()})
+        *               whose state is being processed.
+        * @param state the state associated with the aforementioned key.
+        */
+       public abstract void process(K key, S state) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java
index aeb3bc2..453c850 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java
@@ -119,18 +119,19 @@ public class BroadcastConnectedStream<IN1, IN2, K, V> {
         * {@link KeyedBroadcastProcessFunction} on them, thereby creating a 
transformed output stream.
         *
         * @param function The {@link KeyedBroadcastProcessFunction} that is 
called for each element in the stream.
+        * @param <KS> The type of the keys in the keyed stream.
         * @param <OUT> The type of the output elements.
         * @return The transformed {@link DataStream}.
         */
        @PublicEvolving
-       public <OUT> SingleOutputStreamOperator<OUT> process(final 
KeyedBroadcastProcessFunction<IN1, IN2, OUT> function) {
+       public <KS, OUT> SingleOutputStreamOperator<OUT> process(final 
KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> function) {
 
                TypeInformation<OUT> outTypeInfo = 
TypeExtractor.getBinaryOperatorReturnType(
                                function,
                                KeyedBroadcastProcessFunction.class,
-                               0,
                                1,
                                2,
+                               3,
                                TypeExtractor.NO_INDEX,
                                TypeExtractor.NO_INDEX,
                                TypeExtractor.NO_INDEX,
@@ -148,12 +149,13 @@ public class BroadcastConnectedStream<IN1, IN2, K, V> {
         *
         * @param function The {@link KeyedBroadcastProcessFunction} that is 
called for each element in the stream.
         * @param outTypeInfo The type of the output elements.
+        * @param <KS> The type of the keys in the keyed stream.
         * @param <OUT> The type of the output elements.
         * @return The transformed {@link DataStream}.
         */
        @PublicEvolving
-       public <OUT> SingleOutputStreamOperator<OUT> process(
-                       final KeyedBroadcastProcessFunction<IN1, IN2, OUT> 
function,
+       public <KS, OUT> SingleOutputStreamOperator<OUT> process(
+                       final KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> 
function,
                        final TypeInformation<OUT> outTypeInfo) {
 
                Preconditions.checkNotNull(function);

http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
index 9d14259..4b9f138 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java
@@ -20,6 +20,9 @@ package org.apache.flink.streaming.api.functions.co;
 
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.State;
+import org.apache.flink.api.common.state.StateDescriptor;
+import org.apache.flink.runtime.state.KeyedStateFunction;
 import org.apache.flink.streaming.api.TimeDomain;
 import org.apache.flink.streaming.api.TimerService;
 import org.apache.flink.util.Collector;
@@ -36,7 +39,7 @@ import org.apache.flink.util.Collector;
  *
  * <p>The user has to implement two methods:
  * <ol>
- *     <li>the {@link #processBroadcastElement(Object, Context, Collector)} 
which will be applied to
+ *     <li>the {@link #processBroadcastElement(Object, KeyedContext, 
Collector)} which will be applied to
  *     each element in the broadcast side
  *     <li> and the {@link #processElement(Object, KeyedReadOnlyContext, 
Collector)} which will be applied to the
  *     non-broadcasted/keyed side.
@@ -47,12 +50,13 @@ import org.apache.flink.util.Collector;
  * {@code processElement()} has read-only access to the broadcast state, but 
can read/write to the keyed state and
  * register timers.
  *
+ * @param <KS> The key type of the input keyed stream.
  * @param <IN1> The input type of the keyed (non-broadcast) side.
  * @param <IN2> The input type of the broadcast side.
  * @param <OUT> The output type of the operator.
  */
 @PublicEvolving
-public abstract class KeyedBroadcastProcessFunction<IN1, IN2, OUT> extends 
BaseBroadcastProcessFunction {
+public abstract class KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> extends 
BaseBroadcastProcessFunction {
 
        private static final long serialVersionUID = -2584726797564976453L;
 
@@ -83,19 +87,22 @@ public abstract class KeyedBroadcastProcessFunction<IN1, 
IN2, OUT> extends BaseB
         *
         * <p>It can output zero or more elements using the {@link Collector} 
parameter,
         * query the current processing/event time, and also query and update 
the internal
-        * {@link org.apache.flink.api.common.state.BroadcastState broadcast 
state}. These can
-        * be done through the provided {@link Context}.
+        * {@link org.apache.flink.api.common.state.BroadcastState broadcast 
state}. In addition, it
+        * can register a {@link KeyedStateFunction function} to be applied to 
all keyed states on
+        * the local partition. These can be done through the provided {@link 
Context}.
         * The context is only valid during the invocation of this method, do 
not store it.
         *
         * @param value The stream element.
         * @param ctx A {@link Context} that allows querying the timestamp of 
the element,
         *            querying the current processing/event time and updating 
the broadcast state.
+        *            In addition, it allows the registration of a {@link 
KeyedStateFunction function}
+        *            to be applied to all keyed state with a given {@link 
StateDescriptor} on the local partition.
         *            The context is only valid during the invocation of this 
method, do not store it.
         * @param out The collector to emit resulting elements to
         * @throws Exception The function may throw exceptions which cause the 
streaming program
         *                   to fail and go into recovery.
         */
-       public abstract void processBroadcastElement(final IN2 value, final 
Context ctx, final Collector<OUT> out) throws Exception;
+       public abstract void processBroadcastElement(final IN2 value, final 
KeyedContext ctx, final Collector<OUT> out) throws Exception;
 
        /**
         * Called when a timer set using {@link TimerService} fires.
@@ -116,6 +123,28 @@ public abstract class KeyedBroadcastProcessFunction<IN1, 
IN2, OUT> extends BaseB
        }
 
        /**
+        * A {@link BaseBroadcastProcessFunction.Context context} available to 
the broadcast side of
+        * a {@link 
org.apache.flink.streaming.api.datastream.BroadcastConnectedStream}.
+        *
+        * <p>Apart from the basic functionality of a {@link 
BaseBroadcastProcessFunction.Context context},
+        * this also allows to apply a {@link KeyedStateFunction} to the 
(local) states of all active keys
+        * in the your backend.
+        */
+       public abstract class KeyedContext extends Context {
+
+               /**
+                * Applies the provided {@code function} to the state
+                * associated with the provided {@code state descriptor}.
+                *
+                * @param stateDescriptor the descriptor of the state to be 
processed.
+                * @param function the function to be applied.
+                */
+               public abstract <VS, S extends State> void applyToKeyedState(
+                               final StateDescriptor<S, VS> stateDescriptor,
+                               final KeyedStateFunction<KS, S> function) 
throws Exception;
+       }
+
+       /**
         * A {@link BaseBroadcastProcessFunction.Context context} available to 
the keyed stream side of
         * a {@link 
org.apache.flink.streaming.api.datastream.BroadcastConnectedStream} (if any).
         *

http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java
index 794b0db..4872c61 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java
@@ -22,12 +22,15 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.state.BroadcastState;
 import org.apache.flink.api.common.state.MapStateDescriptor;
 import org.apache.flink.api.common.state.ReadOnlyBroadcastState;
+import org.apache.flink.api.common.state.State;
+import org.apache.flink.api.common.state.StateDescriptor;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.KeyedStateFunction;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.streaming.api.SimpleTimerService;
 import org.apache.flink.streaming.api.TimeDomain;
 import org.apache.flink.streaming.api.TimerService;
-import 
org.apache.flink.streaming.api.functions.co.BaseBroadcastProcessFunction;
 import 
org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
 import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
 import org.apache.flink.streaming.api.operators.InternalTimer;
@@ -56,7 +59,7 @@ import static org.apache.flink.util.Preconditions.checkState;
  */
 @Internal
 public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT>
-               extends AbstractUdfStreamOperator<OUT, 
KeyedBroadcastProcessFunction<IN1, IN2, OUT>>
+               extends AbstractUdfStreamOperator<OUT, 
KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>>
                implements TwoInputStreamOperator<IN1, IN2, OUT>, 
Triggerable<KS, VoidNamespace> {
 
        private static final long serialVersionUID = 5926499536290284870L;
@@ -74,7 +77,7 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT>
        private transient OnTimerContextImpl onTimerContext;
 
        public CoBroadcastWithKeyedOperator(
-                       final KeyedBroadcastProcessFunction<IN1, IN2, OUT> 
function,
+                       final KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> 
function,
                        final List<MapStateDescriptor<?, ?>> 
broadcastStateDescriptors) {
                super(function);
                this.broadcastStateDescriptors = 
Preconditions.checkNotNull(broadcastStateDescriptors);
@@ -96,7 +99,7 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT>
                        broadcastStates.put(descriptor, 
getOperatorStateBackend().getBroadcastState(descriptor));
                }
 
-               rwContext = new ReadWriteContextImpl(userFunction, 
broadcastStates, timerService);
+               rwContext = new ReadWriteContextImpl(getKeyedStateBackend(), 
userFunction, broadcastStates, timerService);
                rContext = new ReadOnlyContextImpl(userFunction, 
broadcastStates, timerService);
                onTimerContext = new OnTimerContextImpl(userFunction, 
broadcastStates, timerService);
        }
@@ -137,7 +140,9 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT>
                onTimerContext.timer = null;
        }
 
-       private class ReadWriteContextImpl extends 
BaseBroadcastProcessFunction.Context {
+       private class ReadWriteContextImpl extends 
KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>.KeyedContext {
+
+               private final KeyedStateBackend<KS> keyedStateBackend;
 
                private final Map<MapStateDescriptor<?, ?>, BroadcastState<?, 
?>> states;
 
@@ -146,11 +151,13 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, 
OUT>
                private StreamRecord<IN2> element;
 
                ReadWriteContextImpl (
-                               final KeyedBroadcastProcessFunction<IN1, IN2, 
OUT> function,
+                               final KeyedStateBackend<KS> keyedStateBackend,
+                               final KeyedBroadcastProcessFunction<KS, IN1, 
IN2, OUT> function,
                                final Map<MapStateDescriptor<?, ?>, 
BroadcastState<?, ?>> broadcastStates,
                                final TimerService timerService) {
 
                        function.super();
+                       this.keyedStateBackend = 
Preconditions.checkNotNull(keyedStateBackend);
                        this.states = 
Preconditions.checkNotNull(broadcastStates);
                        this.timerService = 
Preconditions.checkNotNull(timerService);
                }
@@ -192,9 +199,21 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, 
OUT>
                public long currentWatermark() {
                        return timerService.currentWatermark();
                }
+
+               @Override
+               public <VS, S extends State> void applyToKeyedState(
+                               final StateDescriptor<S, VS> stateDescriptor,
+                               final KeyedStateFunction<KS, S> function) 
throws Exception {
+
+                       keyedStateBackend.applyToAllKeys(
+                                       VoidNamespace.INSTANCE,
+                                       VoidNamespaceSerializer.INSTANCE,
+                                       
Preconditions.checkNotNull(stateDescriptor),
+                                       Preconditions.checkNotNull(function));
+               }
        }
 
-       private class ReadOnlyContextImpl extends 
KeyedBroadcastProcessFunction<IN1, IN2, OUT>.KeyedReadOnlyContext {
+       private class ReadOnlyContextImpl extends 
KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>.KeyedReadOnlyContext {
 
                private final Map<MapStateDescriptor<?, ?>, BroadcastState<?, 
?>> states;
 
@@ -203,7 +222,7 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT>
                private StreamRecord<IN1> element;
 
                ReadOnlyContextImpl(
-                               final KeyedBroadcastProcessFunction<IN1, IN2, 
OUT> function,
+                               final KeyedBroadcastProcessFunction<KS, IN1, 
IN2, OUT> function,
                                final Map<MapStateDescriptor<?, ?>, 
BroadcastState<?, ?>> broadcastStates,
                                final TimerService timerService) {
 
@@ -256,7 +275,7 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT>
                }
        }
 
-       private class OnTimerContextImpl extends 
KeyedBroadcastProcessFunction<IN1, IN2, OUT>.OnTimerContext {
+       private class OnTimerContextImpl extends 
KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>.OnTimerContext {
 
                private final Map<MapStateDescriptor<?, ?>, BroadcastState<?, 
?>> states;
 
@@ -267,7 +286,7 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT>
                private InternalTimer<KS, VoidNamespace> timer;
 
                OnTimerContextImpl(
-                               final KeyedBroadcastProcessFunction<IN1, IN2, 
OUT> function,
+                               final KeyedBroadcastProcessFunction<KS, IN1, 
IN2, OUT> function,
                                final Map<MapStateDescriptor<?, ?>, 
BroadcastState<?, ?>> broadcastStates,
                                final TimerService timerService) {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
index 59f54b5..bcbbfd6 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
@@ -813,7 +813,7 @@ public class DataStreamTest extends TestLogger {
                }
        }
 
-       private static class TestBroadcastProcessFunction extends 
KeyedBroadcastProcessFunction<Long, String, String> {
+       private static class TestBroadcastProcessFunction extends 
KeyedBroadcastProcessFunction<Long, Long, String, String> {
 
                private final Map<Long, String> expectedState;
 
@@ -837,7 +837,7 @@ public class DataStreamTest extends TestLogger {
                }
 
                @Override
-               public void processBroadcastElement(String value, Context ctx, 
Collector<String> out) throws Exception {
+               public void processBroadcastElement(String value, KeyedContext 
ctx, Collector<String> out) throws Exception {
                        long key = Long.parseLong(value.split(":")[1]);
                        ctx.getBroadcastState(DESCRIPTOR).put(key, value);
                }
@@ -925,10 +925,10 @@ public class DataStreamTest extends TestLogger {
 
                BroadcastStream<String, Long, String> broadcast = 
srcTwo.broadcast(descriptor);
                srcOne.connect(broadcast)
-                               .process(new 
KeyedBroadcastProcessFunction<Long, String, String>() {
+                               .process(new 
KeyedBroadcastProcessFunction<String, Long, String, String>() {
 
                                        @Override
-                                       public void 
processBroadcastElement(String value, Context ctx, Collector<String> out) 
throws Exception {
+                                       public void 
processBroadcastElement(String value, KeyedContext ctx, Collector<String> out) 
throws Exception {
                                                // do nothing
                                        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
index 3398d14..3fa439f 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
@@ -18,11 +18,14 @@
 
 package org.apache.flink.streaming.api.operators.co;
 
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.MapStateDescriptor;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.runtime.state.KeyedStateFunction;
 import 
org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -63,6 +66,99 @@ public class CoBroadcastWithKeyedOperatorTest {
                                        BasicTypeInfo.INT_TYPE_INFO
                        );
 
+       /** Test the iteration over the keyed state on the broadcast side. */
+       @Test
+       public void testAccessToKeyedStateIt() throws Exception {
+               final List<String> test1content = new ArrayList<>();
+               test1content.add("test1");
+               test1content.add("test1");
+
+               final List<String> test2content = new ArrayList<>();
+               test2content.add("test2");
+               test2content.add("test2");
+               test2content.add("test2");
+               test2content.add("test2");
+
+               final List<String> test3content = new ArrayList<>();
+               test3content.add("test3");
+               test3content.add("test3");
+               test3content.add("test3");
+
+               final Map<String, List<String>> expectedState = new HashMap<>();
+               expectedState.put("test1", test1content);
+               expectedState.put("test2", test2content);
+               expectedState.put("test3", test3content);
+
+               try (
+                               TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness = getInitializedTestHarness(
+                                               BasicTypeInfo.STRING_TYPE_INFO,
+                                               new IdentityKeySelector<>(),
+                                               new 
StatefulFunctionWithKeyedStateAccessedOnBroadcast(expectedState))
+               ) {
+
+                       // send elements to the keyed state
+                       testHarness.processElement1(new StreamRecord<>("test1", 
12L));
+                       testHarness.processElement1(new StreamRecord<>("test1", 
12L));
+
+                       testHarness.processElement1(new StreamRecord<>("test2", 
13L));
+                       testHarness.processElement1(new StreamRecord<>("test2", 
13L));
+                       testHarness.processElement1(new StreamRecord<>("test2", 
13L));
+
+                       testHarness.processElement1(new StreamRecord<>("test3", 
14L));
+                       testHarness.processElement1(new StreamRecord<>("test3", 
14L));
+                       testHarness.processElement1(new StreamRecord<>("test3", 
14L));
+
+                       testHarness.processElement1(new StreamRecord<>("test2", 
13L));
+
+                       // this is the element on the broadcast side that will 
trigger the verification
+                       // check the 
StatefulFunctionWithKeyedStateAccessedOnBroadcast#processBroadcastElement()
+                       testHarness.processElement2(new StreamRecord<>(1, 13L));
+               }
+       }
+
+       /**
+        * Simple {@link KeyedBroadcastProcessFunction} that adds all incoming 
elements in the non-broadcast
+        * side to a listState and at the broadcast side it verifies if the 
stored data is the expected ones.
+        */
+       private static class StatefulFunctionWithKeyedStateAccessedOnBroadcast
+                       extends KeyedBroadcastProcessFunction<String, String, 
Integer, String> {
+
+               private static final long serialVersionUID = 
7496674620398203933L;
+
+               private final ListStateDescriptor<String> listStateDesc =
+                               new ListStateDescriptor<>("listStateTest", 
BasicTypeInfo.STRING_TYPE_INFO);
+
+               private final Map<String, List<String>> expectedKeyedStates;
+
+               StatefulFunctionWithKeyedStateAccessedOnBroadcast(Map<String, 
List<String>> expectedKeyedState) {
+                       this.expectedKeyedStates = 
Preconditions.checkNotNull(expectedKeyedState);
+               }
+
+               @Override
+               public void processBroadcastElement(Integer value, KeyedContext 
ctx, Collector<String> out) throws Exception {
+                       // put an element in the broadcast state
+                       ctx.applyToKeyedState(
+                                       listStateDesc,
+                                       new KeyedStateFunction<String, 
ListState<String>>() {
+                                               @Override
+                                               public void process(String key, 
ListState<String> state) throws Exception {
+                                                       final Iterator<String> 
it = state.get().iterator();
+
+                                                       final List<String> list 
= new ArrayList<>();
+                                                       while (it.hasNext()) {
+                                                               
list.add(it.next());
+                                                       }
+                                                       
Assert.assertEquals(expectedKeyedStates.get(key), list);
+                                               }
+                                       });
+               }
+
+               @Override
+               public void processElement(String value, KeyedReadOnlyContext 
ctx, Collector<String> out) throws Exception {
+                       
getRuntimeContext().getListState(listStateDesc).add(value);
+               }
+       }
+
        @Test
        public void testFunctionWithTimer() throws Exception {
 
@@ -102,7 +198,7 @@ public class CoBroadcastWithKeyedOperatorTest {
         * {@link KeyedBroadcastProcessFunction} that registers a timer and 
emits
         * for every element the watermark and the timestamp of the element.
         */
-       private static class FunctionWithTimerOnKeyed extends 
KeyedBroadcastProcessFunction<String, Integer, String> {
+       private static class FunctionWithTimerOnKeyed extends 
KeyedBroadcastProcessFunction<String, String, Integer, String> {
 
                private static final long serialVersionUID = 
7496674620398203933L;
 
@@ -113,7 +209,7 @@ public class CoBroadcastWithKeyedOperatorTest {
                }
 
                @Override
-               public void processBroadcastElement(Integer value, Context ctx, 
Collector<String> out) throws Exception {
+               public void processBroadcastElement(Integer value, KeyedContext 
ctx, Collector<String> out) throws Exception {
                        out.collect("BR:" + value + " WM:" + 
ctx.currentWatermark() + " TS:" + ctx.timestamp());
                }
 
@@ -172,7 +268,7 @@ public class CoBroadcastWithKeyedOperatorTest {
        /**
         * {@link KeyedBroadcastProcessFunction} that emits elements on side 
outputs.
         */
-       private static class FunctionWithSideOutput extends 
KeyedBroadcastProcessFunction<String, Integer, String> {
+       private static class FunctionWithSideOutput extends 
KeyedBroadcastProcessFunction<String, String, Integer, String> {
 
                private static final long serialVersionUID = 
7496674620398203933L;
 
@@ -185,7 +281,7 @@ public class CoBroadcastWithKeyedOperatorTest {
                };
 
                @Override
-               public void processBroadcastElement(Integer value, Context ctx, 
Collector<String> out) throws Exception {
+               public void processBroadcastElement(Integer value, KeyedContext 
ctx, Collector<String> out) throws Exception {
                        ctx.output(BROADCAST_TAG, "BR:" + value + " WM:" + 
ctx.currentWatermark() + " TS:" + ctx.timestamp());
                }
 
@@ -254,7 +350,7 @@ public class CoBroadcastWithKeyedOperatorTest {
                }
        }
 
-       private static class FunctionWithBroadcastState extends 
KeyedBroadcastProcessFunction<String, Integer, String> {
+       private static class FunctionWithBroadcastState extends 
KeyedBroadcastProcessFunction<String, String, Integer, String> {
 
                private static final long serialVersionUID = 
7496674620398203933L;
 
@@ -273,7 +369,7 @@ public class CoBroadcastWithKeyedOperatorTest {
                }
 
                @Override
-               public void processBroadcastElement(Integer value, Context ctx, 
Collector<String> out) throws Exception {
+               public void processBroadcastElement(Integer value, KeyedContext 
ctx, Collector<String> out) throws Exception {
                        // put an element in the broadcast state
                        final String key = value + "." + keyPostfix;
                        ctx.getBroadcastState(STATE_DESCRIPTOR).put(key, value);
@@ -501,7 +597,7 @@ public class CoBroadcastWithKeyedOperatorTest {
                }
        }
 
-       private static class TestFunctionWithOutput extends 
KeyedBroadcastProcessFunction<String, Integer, String> {
+       private static class TestFunctionWithOutput extends 
KeyedBroadcastProcessFunction<String, String, Integer, String> {
 
                private static final long serialVersionUID = 
7496674620398203933L;
 
@@ -512,7 +608,7 @@ public class CoBroadcastWithKeyedOperatorTest {
                }
 
                @Override
-               public void processBroadcastElement(Integer value, Context ctx, 
Collector<String> out) throws Exception {
+               public void processBroadcastElement(Integer value, KeyedContext 
ctx, Collector<String> out) throws Exception {
                        // put an element in the broadcast state
                        for (String k : keysToRegister) {
                                ctx.getBroadcastState(STATE_DESCRIPTOR).put(k, 
value);
@@ -536,14 +632,14 @@ public class CoBroadcastWithKeyedOperatorTest {
                                TwoInputStreamOperatorTestHarness<String, 
Integer, String> testHarness = getInitializedTestHarness(
                                                BasicTypeInfo.STRING_TYPE_INFO,
                                                new IdentityKeySelector<>(),
-                                               new 
KeyedBroadcastProcessFunction<String, Integer, String>() {
+                                               new 
KeyedBroadcastProcessFunction<String, String, Integer, String>() {
 
                                                        private static final 
long serialVersionUID = -1725365436500098384L;
 
                                                        private final 
ValueStateDescriptor<String> valueState = new ValueStateDescriptor<>("any", 
BasicTypeInfo.STRING_TYPE_INFO);
 
                                                        @Override
-                                                       public void 
processBroadcastElement(Integer value, Context ctx, Collector<String> out) 
throws Exception {
+                                                       public void 
processBroadcastElement(Integer value, KeyedContext ctx, Collector<String> out) 
throws Exception {
                                                                
getRuntimeContext().getState(valueState).value(); // this should fail
                                                        }
 
@@ -575,10 +671,10 @@ public class CoBroadcastWithKeyedOperatorTest {
                }
        }
 
-       private static <KEY, IN1, IN2, K, V, OUT> 
TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness(
+       private static <KEY, IN1, IN2, OUT> 
TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness(
                        final TypeInformation<KEY> keyTypeInfo,
                        final KeySelector<IN1, KEY> keyKeySelector,
-                       final KeyedBroadcastProcessFunction<IN1, IN2, OUT> 
function) throws Exception {
+                       final KeyedBroadcastProcessFunction<KEY, IN1, IN2, OUT> 
function) throws Exception {
 
                return getInitializedTestHarness(
                                keyTypeInfo,
@@ -589,10 +685,10 @@ public class CoBroadcastWithKeyedOperatorTest {
                                0);
        }
 
-       private static <KEY, IN1, IN2, K, V, OUT> 
TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness(
+       private static <KEY, IN1, IN2, OUT> 
TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness(
                        final TypeInformation<KEY> keyTypeInfo,
                        final KeySelector<IN1, KEY> keyKeySelector,
-                       final KeyedBroadcastProcessFunction<IN1, IN2, OUT> 
function,
+                       final KeyedBroadcastProcessFunction<KEY, IN1, IN2, OUT> 
function,
                        final int maxParallelism,
                        final int numTasks,
                        final int taskIdx) throws Exception {
@@ -607,10 +703,10 @@ public class CoBroadcastWithKeyedOperatorTest {
                                null);
        }
 
-       private static <KEY, IN1, IN2, K, V, OUT> 
TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness(
+       private static <KEY, IN1, IN2, OUT> 
TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness(
                        final TypeInformation<KEY> keyTypeInfo,
                        final KeySelector<IN1, KEY> keyKeySelector,
-                       final KeyedBroadcastProcessFunction<IN1, IN2, OUT> 
function,
+                       final KeyedBroadcastProcessFunction<KEY, IN1, IN2, OUT> 
function,
                        final int maxParallelism,
                        final int numTasks,
                        final int taskIdx,

Reply via email to