[FLINK-8345] Add iterator of keyed state on broadcast side of connected streams.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/26918c95 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/26918c95 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/26918c95 Branch: refs/heads/master Commit: 26918c953287c7940120dfcfcc10dd5a42beaf81 Parents: c6c17be Author: kkloudas <[email protected]> Authored: Mon Jan 29 16:17:24 2018 +0100 Committer: kkloudas <[email protected]> Committed: Wed Feb 7 14:08:16 2018 +0100 ---------------------------------------------------------------------- .../state/AbstractKeyedStateBackend.java | 33 +++++ .../flink/runtime/state/KeyedStateBackend.java | 18 +++ .../flink/runtime/state/KeyedStateFunction.java | 38 ++++++ .../datastream/BroadcastConnectedStream.java | 10 +- .../co/KeyedBroadcastProcessFunction.java | 39 +++++- .../co/CoBroadcastWithKeyedOperator.java | 39 ++++-- .../flink/streaming/api/DataStreamTest.java | 8 +- .../co/CoBroadcastWithKeyedOperatorTest.java | 128 ++++++++++++++++--- 8 files changed, 274 insertions(+), 39 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java index cc53c0c..d159d46 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java @@ -283,6 +283,39 @@ public abstract class AbstractKeyedStateBackend<K> * @see KeyedStateBackend */ @Override + public <N, S extends State, T> void applyToAllKeys( + final N namespace, + final TypeSerializer<N> namespaceSerializer, + final StateDescriptor<S, T> stateDescriptor, + final KeyedStateFunction<K, S> function) throws Exception { + + try { + getKeys(stateDescriptor.getName(), namespace) + .forEach((K key) -> { + setCurrentKey(key); + try { + function.process( + key, + getPartitionedState( + namespace, + namespaceSerializer, + stateDescriptor) + ); + } catch (Throwable e) { + // we wrap the checked exception in an unchecked + // one and catch it (and re-throw it) later. + throw new RuntimeException(e); + } + }); + } catch (RuntimeException e) { + throw e; + } + } + + /** + * @see KeyedStateBackend + */ + @Override public <N, S extends State, V> S getOrCreateKeyedState( final TypeSerializer<N> namespaceSerializer, StateDescriptor<S, V> stateDescriptor) throws Exception { http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java index c74cfcf..cbe40ee 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java @@ -39,6 +39,24 @@ public interface KeyedStateBackend<K> extends InternalKeyContext<K> { void setCurrentKey(K newKey); /** + * Applies the provided {@link KeyedStateFunction} to the state with the provided + * {@link StateDescriptor} of all the currently active keys. + * + * @param namespace the namespace of the state. + * @param namespaceSerializer the serializer for the namespace. + * @param stateDescriptor the descriptor of the state to which the function is going to be applied. + * @param function the function to be applied to the keyed state. + * + * @param <N> The type of the namespace. + * @param <S> The type of the state. + */ + <N, S extends State, T> void applyToAllKeys( + final N namespace, + final TypeSerializer<N> namespaceSerializer, + final StateDescriptor<S, T> stateDescriptor, + final KeyedStateFunction<K, S> function) throws Exception; + + /** * @return A stream of all keys for the given state and namespace. Modifications to the state during iterating * over it keys are not supported. * @param state State variable for which existing keys will be returned. http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java new file mode 100644 index 0000000..de23dec --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.state.State; + +/** + * A function to be applied to all keyed states. + * + * <p>This functionality is only available through the + * {@code BroadcastConnectedStream.process(final KeyedBroadcastProcessFunction function)}. + */ +public abstract class KeyedStateFunction<K, S extends State> { + + /** + * The actual method to be applied on each of the states. + * + * @param key a safe copy of the key (see {@link KeyedStateBackend#getCurrentKeySafe()}) + * whose state is being processed. + * @param state the state associated with the aforementioned key. + */ + public abstract void process(K key, S state) throws Exception; +} http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java index aeb3bc2..453c850 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java @@ -119,18 +119,19 @@ public class BroadcastConnectedStream<IN1, IN2, K, V> { * {@link KeyedBroadcastProcessFunction} on them, thereby creating a transformed output stream. * * @param function The {@link KeyedBroadcastProcessFunction} that is called for each element in the stream. + * @param <KS> The type of the keys in the keyed stream. * @param <OUT> The type of the output elements. * @return The transformed {@link DataStream}. */ @PublicEvolving - public <OUT> SingleOutputStreamOperator<OUT> process(final KeyedBroadcastProcessFunction<IN1, IN2, OUT> function) { + public <KS, OUT> SingleOutputStreamOperator<OUT> process(final KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> function) { TypeInformation<OUT> outTypeInfo = TypeExtractor.getBinaryOperatorReturnType( function, KeyedBroadcastProcessFunction.class, - 0, 1, 2, + 3, TypeExtractor.NO_INDEX, TypeExtractor.NO_INDEX, TypeExtractor.NO_INDEX, @@ -148,12 +149,13 @@ public class BroadcastConnectedStream<IN1, IN2, K, V> { * * @param function The {@link KeyedBroadcastProcessFunction} that is called for each element in the stream. * @param outTypeInfo The type of the output elements. + * @param <KS> The type of the keys in the keyed stream. * @param <OUT> The type of the output elements. * @return The transformed {@link DataStream}. */ @PublicEvolving - public <OUT> SingleOutputStreamOperator<OUT> process( - final KeyedBroadcastProcessFunction<IN1, IN2, OUT> function, + public <KS, OUT> SingleOutputStreamOperator<OUT> process( + final KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> function, final TypeInformation<OUT> outTypeInfo) { Preconditions.checkNotNull(function); http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java index 9d14259..4b9f138 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java @@ -20,6 +20,9 @@ package org.apache.flink.streaming.api.functions.co; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.runtime.state.KeyedStateFunction; import org.apache.flink.streaming.api.TimeDomain; import org.apache.flink.streaming.api.TimerService; import org.apache.flink.util.Collector; @@ -36,7 +39,7 @@ import org.apache.flink.util.Collector; * * <p>The user has to implement two methods: * <ol> - * <li>the {@link #processBroadcastElement(Object, Context, Collector)} which will be applied to + * <li>the {@link #processBroadcastElement(Object, KeyedContext, Collector)} which will be applied to * each element in the broadcast side * <li> and the {@link #processElement(Object, KeyedReadOnlyContext, Collector)} which will be applied to the * non-broadcasted/keyed side. @@ -47,12 +50,13 @@ import org.apache.flink.util.Collector; * {@code processElement()} has read-only access to the broadcast state, but can read/write to the keyed state and * register timers. * + * @param <KS> The key type of the input keyed stream. * @param <IN1> The input type of the keyed (non-broadcast) side. * @param <IN2> The input type of the broadcast side. * @param <OUT> The output type of the operator. */ @PublicEvolving -public abstract class KeyedBroadcastProcessFunction<IN1, IN2, OUT> extends BaseBroadcastProcessFunction { +public abstract class KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> extends BaseBroadcastProcessFunction { private static final long serialVersionUID = -2584726797564976453L; @@ -83,19 +87,22 @@ public abstract class KeyedBroadcastProcessFunction<IN1, IN2, OUT> extends BaseB * * <p>It can output zero or more elements using the {@link Collector} parameter, * query the current processing/event time, and also query and update the internal - * {@link org.apache.flink.api.common.state.BroadcastState broadcast state}. These can - * be done through the provided {@link Context}. + * {@link org.apache.flink.api.common.state.BroadcastState broadcast state}. In addition, it + * can register a {@link KeyedStateFunction function} to be applied to all keyed states on + * the local partition. These can be done through the provided {@link Context}. * The context is only valid during the invocation of this method, do not store it. * * @param value The stream element. * @param ctx A {@link Context} that allows querying the timestamp of the element, * querying the current processing/event time and updating the broadcast state. + * In addition, it allows the registration of a {@link KeyedStateFunction function} + * to be applied to all keyed state with a given {@link StateDescriptor} on the local partition. * The context is only valid during the invocation of this method, do not store it. * @param out The collector to emit resulting elements to * @throws Exception The function may throw exceptions which cause the streaming program * to fail and go into recovery. */ - public abstract void processBroadcastElement(final IN2 value, final Context ctx, final Collector<OUT> out) throws Exception; + public abstract void processBroadcastElement(final IN2 value, final KeyedContext ctx, final Collector<OUT> out) throws Exception; /** * Called when a timer set using {@link TimerService} fires. @@ -116,6 +123,28 @@ public abstract class KeyedBroadcastProcessFunction<IN1, IN2, OUT> extends BaseB } /** + * A {@link BaseBroadcastProcessFunction.Context context} available to the broadcast side of + * a {@link org.apache.flink.streaming.api.datastream.BroadcastConnectedStream}. + * + * <p>Apart from the basic functionality of a {@link BaseBroadcastProcessFunction.Context context}, + * this also allows to apply a {@link KeyedStateFunction} to the (local) states of all active keys + * in the your backend. + */ + public abstract class KeyedContext extends Context { + + /** + * Applies the provided {@code function} to the state + * associated with the provided {@code state descriptor}. + * + * @param stateDescriptor the descriptor of the state to be processed. + * @param function the function to be applied. + */ + public abstract <VS, S extends State> void applyToKeyedState( + final StateDescriptor<S, VS> stateDescriptor, + final KeyedStateFunction<KS, S> function) throws Exception; + } + + /** * A {@link BaseBroadcastProcessFunction.Context context} available to the keyed stream side of * a {@link org.apache.flink.streaming.api.datastream.BroadcastConnectedStream} (if any). * http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java index 794b0db..4872c61 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java @@ -22,12 +22,15 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.state.BroadcastState; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ReadOnlyBroadcastState; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.KeyedStateFunction; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.SimpleTimerService; import org.apache.flink.streaming.api.TimeDomain; import org.apache.flink.streaming.api.TimerService; -import org.apache.flink.streaming.api.functions.co.BaseBroadcastProcessFunction; import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.InternalTimer; @@ -56,7 +59,7 @@ import static org.apache.flink.util.Preconditions.checkState; */ @Internal public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT> - extends AbstractUdfStreamOperator<OUT, KeyedBroadcastProcessFunction<IN1, IN2, OUT>> + extends AbstractUdfStreamOperator<OUT, KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>> implements TwoInputStreamOperator<IN1, IN2, OUT>, Triggerable<KS, VoidNamespace> { private static final long serialVersionUID = 5926499536290284870L; @@ -74,7 +77,7 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT> private transient OnTimerContextImpl onTimerContext; public CoBroadcastWithKeyedOperator( - final KeyedBroadcastProcessFunction<IN1, IN2, OUT> function, + final KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> function, final List<MapStateDescriptor<?, ?>> broadcastStateDescriptors) { super(function); this.broadcastStateDescriptors = Preconditions.checkNotNull(broadcastStateDescriptors); @@ -96,7 +99,7 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT> broadcastStates.put(descriptor, getOperatorStateBackend().getBroadcastState(descriptor)); } - rwContext = new ReadWriteContextImpl(userFunction, broadcastStates, timerService); + rwContext = new ReadWriteContextImpl(getKeyedStateBackend(), userFunction, broadcastStates, timerService); rContext = new ReadOnlyContextImpl(userFunction, broadcastStates, timerService); onTimerContext = new OnTimerContextImpl(userFunction, broadcastStates, timerService); } @@ -137,7 +140,9 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT> onTimerContext.timer = null; } - private class ReadWriteContextImpl extends BaseBroadcastProcessFunction.Context { + private class ReadWriteContextImpl extends KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>.KeyedContext { + + private final KeyedStateBackend<KS> keyedStateBackend; private final Map<MapStateDescriptor<?, ?>, BroadcastState<?, ?>> states; @@ -146,11 +151,13 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT> private StreamRecord<IN2> element; ReadWriteContextImpl ( - final KeyedBroadcastProcessFunction<IN1, IN2, OUT> function, + final KeyedStateBackend<KS> keyedStateBackend, + final KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> function, final Map<MapStateDescriptor<?, ?>, BroadcastState<?, ?>> broadcastStates, final TimerService timerService) { function.super(); + this.keyedStateBackend = Preconditions.checkNotNull(keyedStateBackend); this.states = Preconditions.checkNotNull(broadcastStates); this.timerService = Preconditions.checkNotNull(timerService); } @@ -192,9 +199,21 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT> public long currentWatermark() { return timerService.currentWatermark(); } + + @Override + public <VS, S extends State> void applyToKeyedState( + final StateDescriptor<S, VS> stateDescriptor, + final KeyedStateFunction<KS, S> function) throws Exception { + + keyedStateBackend.applyToAllKeys( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + Preconditions.checkNotNull(stateDescriptor), + Preconditions.checkNotNull(function)); + } } - private class ReadOnlyContextImpl extends KeyedBroadcastProcessFunction<IN1, IN2, OUT>.KeyedReadOnlyContext { + private class ReadOnlyContextImpl extends KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>.KeyedReadOnlyContext { private final Map<MapStateDescriptor<?, ?>, BroadcastState<?, ?>> states; @@ -203,7 +222,7 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT> private StreamRecord<IN1> element; ReadOnlyContextImpl( - final KeyedBroadcastProcessFunction<IN1, IN2, OUT> function, + final KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> function, final Map<MapStateDescriptor<?, ?>, BroadcastState<?, ?>> broadcastStates, final TimerService timerService) { @@ -256,7 +275,7 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT> } } - private class OnTimerContextImpl extends KeyedBroadcastProcessFunction<IN1, IN2, OUT>.OnTimerContext { + private class OnTimerContextImpl extends KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>.OnTimerContext { private final Map<MapStateDescriptor<?, ?>, BroadcastState<?, ?>> states; @@ -267,7 +286,7 @@ public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT> private InternalTimer<KS, VoidNamespace> timer; OnTimerContextImpl( - final KeyedBroadcastProcessFunction<IN1, IN2, OUT> function, + final KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> function, final Map<MapStateDescriptor<?, ?>, BroadcastState<?, ?>> broadcastStates, final TimerService timerService) { http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java index 59f54b5..bcbbfd6 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java @@ -813,7 +813,7 @@ public class DataStreamTest extends TestLogger { } } - private static class TestBroadcastProcessFunction extends KeyedBroadcastProcessFunction<Long, String, String> { + private static class TestBroadcastProcessFunction extends KeyedBroadcastProcessFunction<Long, Long, String, String> { private final Map<Long, String> expectedState; @@ -837,7 +837,7 @@ public class DataStreamTest extends TestLogger { } @Override - public void processBroadcastElement(String value, Context ctx, Collector<String> out) throws Exception { + public void processBroadcastElement(String value, KeyedContext ctx, Collector<String> out) throws Exception { long key = Long.parseLong(value.split(":")[1]); ctx.getBroadcastState(DESCRIPTOR).put(key, value); } @@ -925,10 +925,10 @@ public class DataStreamTest extends TestLogger { BroadcastStream<String, Long, String> broadcast = srcTwo.broadcast(descriptor); srcOne.connect(broadcast) - .process(new KeyedBroadcastProcessFunction<Long, String, String>() { + .process(new KeyedBroadcastProcessFunction<String, Long, String, String>() { @Override - public void processBroadcastElement(String value, Context ctx, Collector<String> out) throws Exception { + public void processBroadcastElement(String value, KeyedContext ctx, Collector<String> out) throws Exception { // do nothing } http://git-wip-us.apache.org/repos/asf/flink/blob/26918c95/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java index 3398d14..3fa439f 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java @@ -18,11 +18,14 @@ package org.apache.flink.streaming.api.operators.co; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.runtime.state.KeyedStateFunction; import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -63,6 +66,99 @@ public class CoBroadcastWithKeyedOperatorTest { BasicTypeInfo.INT_TYPE_INFO ); + /** Test the iteration over the keyed state on the broadcast side. */ + @Test + public void testAccessToKeyedStateIt() throws Exception { + final List<String> test1content = new ArrayList<>(); + test1content.add("test1"); + test1content.add("test1"); + + final List<String> test2content = new ArrayList<>(); + test2content.add("test2"); + test2content.add("test2"); + test2content.add("test2"); + test2content.add("test2"); + + final List<String> test3content = new ArrayList<>(); + test3content.add("test3"); + test3content.add("test3"); + test3content.add("test3"); + + final Map<String, List<String>> expectedState = new HashMap<>(); + expectedState.put("test1", test1content); + expectedState.put("test2", test2content); + expectedState.put("test3", test3content); + + try ( + TwoInputStreamOperatorTestHarness<String, Integer, String> testHarness = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new StatefulFunctionWithKeyedStateAccessedOnBroadcast(expectedState)) + ) { + + // send elements to the keyed state + testHarness.processElement1(new StreamRecord<>("test1", 12L)); + testHarness.processElement1(new StreamRecord<>("test1", 12L)); + + testHarness.processElement1(new StreamRecord<>("test2", 13L)); + testHarness.processElement1(new StreamRecord<>("test2", 13L)); + testHarness.processElement1(new StreamRecord<>("test2", 13L)); + + testHarness.processElement1(new StreamRecord<>("test3", 14L)); + testHarness.processElement1(new StreamRecord<>("test3", 14L)); + testHarness.processElement1(new StreamRecord<>("test3", 14L)); + + testHarness.processElement1(new StreamRecord<>("test2", 13L)); + + // this is the element on the broadcast side that will trigger the verification + // check the StatefulFunctionWithKeyedStateAccessedOnBroadcast#processBroadcastElement() + testHarness.processElement2(new StreamRecord<>(1, 13L)); + } + } + + /** + * Simple {@link KeyedBroadcastProcessFunction} that adds all incoming elements in the non-broadcast + * side to a listState and at the broadcast side it verifies if the stored data is the expected ones. + */ + private static class StatefulFunctionWithKeyedStateAccessedOnBroadcast + extends KeyedBroadcastProcessFunction<String, String, Integer, String> { + + private static final long serialVersionUID = 7496674620398203933L; + + private final ListStateDescriptor<String> listStateDesc = + new ListStateDescriptor<>("listStateTest", BasicTypeInfo.STRING_TYPE_INFO); + + private final Map<String, List<String>> expectedKeyedStates; + + StatefulFunctionWithKeyedStateAccessedOnBroadcast(Map<String, List<String>> expectedKeyedState) { + this.expectedKeyedStates = Preconditions.checkNotNull(expectedKeyedState); + } + + @Override + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector<String> out) throws Exception { + // put an element in the broadcast state + ctx.applyToKeyedState( + listStateDesc, + new KeyedStateFunction<String, ListState<String>>() { + @Override + public void process(String key, ListState<String> state) throws Exception { + final Iterator<String> it = state.get().iterator(); + + final List<String> list = new ArrayList<>(); + while (it.hasNext()) { + list.add(it.next()); + } + Assert.assertEquals(expectedKeyedStates.get(key), list); + } + }); + } + + @Override + public void processElement(String value, KeyedReadOnlyContext ctx, Collector<String> out) throws Exception { + getRuntimeContext().getListState(listStateDesc).add(value); + } + } + @Test public void testFunctionWithTimer() throws Exception { @@ -102,7 +198,7 @@ public class CoBroadcastWithKeyedOperatorTest { * {@link KeyedBroadcastProcessFunction} that registers a timer and emits * for every element the watermark and the timestamp of the element. */ - private static class FunctionWithTimerOnKeyed extends KeyedBroadcastProcessFunction<String, Integer, String> { + private static class FunctionWithTimerOnKeyed extends KeyedBroadcastProcessFunction<String, String, Integer, String> { private static final long serialVersionUID = 7496674620398203933L; @@ -113,7 +209,7 @@ public class CoBroadcastWithKeyedOperatorTest { } @Override - public void processBroadcastElement(Integer value, Context ctx, Collector<String> out) throws Exception { + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector<String> out) throws Exception { out.collect("BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); } @@ -172,7 +268,7 @@ public class CoBroadcastWithKeyedOperatorTest { /** * {@link KeyedBroadcastProcessFunction} that emits elements on side outputs. */ - private static class FunctionWithSideOutput extends KeyedBroadcastProcessFunction<String, Integer, String> { + private static class FunctionWithSideOutput extends KeyedBroadcastProcessFunction<String, String, Integer, String> { private static final long serialVersionUID = 7496674620398203933L; @@ -185,7 +281,7 @@ public class CoBroadcastWithKeyedOperatorTest { }; @Override - public void processBroadcastElement(Integer value, Context ctx, Collector<String> out) throws Exception { + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector<String> out) throws Exception { ctx.output(BROADCAST_TAG, "BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); } @@ -254,7 +350,7 @@ public class CoBroadcastWithKeyedOperatorTest { } } - private static class FunctionWithBroadcastState extends KeyedBroadcastProcessFunction<String, Integer, String> { + private static class FunctionWithBroadcastState extends KeyedBroadcastProcessFunction<String, String, Integer, String> { private static final long serialVersionUID = 7496674620398203933L; @@ -273,7 +369,7 @@ public class CoBroadcastWithKeyedOperatorTest { } @Override - public void processBroadcastElement(Integer value, Context ctx, Collector<String> out) throws Exception { + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector<String> out) throws Exception { // put an element in the broadcast state final String key = value + "." + keyPostfix; ctx.getBroadcastState(STATE_DESCRIPTOR).put(key, value); @@ -501,7 +597,7 @@ public class CoBroadcastWithKeyedOperatorTest { } } - private static class TestFunctionWithOutput extends KeyedBroadcastProcessFunction<String, Integer, String> { + private static class TestFunctionWithOutput extends KeyedBroadcastProcessFunction<String, String, Integer, String> { private static final long serialVersionUID = 7496674620398203933L; @@ -512,7 +608,7 @@ public class CoBroadcastWithKeyedOperatorTest { } @Override - public void processBroadcastElement(Integer value, Context ctx, Collector<String> out) throws Exception { + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector<String> out) throws Exception { // put an element in the broadcast state for (String k : keysToRegister) { ctx.getBroadcastState(STATE_DESCRIPTOR).put(k, value); @@ -536,14 +632,14 @@ public class CoBroadcastWithKeyedOperatorTest { TwoInputStreamOperatorTestHarness<String, Integer, String> testHarness = getInitializedTestHarness( BasicTypeInfo.STRING_TYPE_INFO, new IdentityKeySelector<>(), - new KeyedBroadcastProcessFunction<String, Integer, String>() { + new KeyedBroadcastProcessFunction<String, String, Integer, String>() { private static final long serialVersionUID = -1725365436500098384L; private final ValueStateDescriptor<String> valueState = new ValueStateDescriptor<>("any", BasicTypeInfo.STRING_TYPE_INFO); @Override - public void processBroadcastElement(Integer value, Context ctx, Collector<String> out) throws Exception { + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector<String> out) throws Exception { getRuntimeContext().getState(valueState).value(); // this should fail } @@ -575,10 +671,10 @@ public class CoBroadcastWithKeyedOperatorTest { } } - private static <KEY, IN1, IN2, K, V, OUT> TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness( + private static <KEY, IN1, IN2, OUT> TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness( final TypeInformation<KEY> keyTypeInfo, final KeySelector<IN1, KEY> keyKeySelector, - final KeyedBroadcastProcessFunction<IN1, IN2, OUT> function) throws Exception { + final KeyedBroadcastProcessFunction<KEY, IN1, IN2, OUT> function) throws Exception { return getInitializedTestHarness( keyTypeInfo, @@ -589,10 +685,10 @@ public class CoBroadcastWithKeyedOperatorTest { 0); } - private static <KEY, IN1, IN2, K, V, OUT> TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness( + private static <KEY, IN1, IN2, OUT> TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness( final TypeInformation<KEY> keyTypeInfo, final KeySelector<IN1, KEY> keyKeySelector, - final KeyedBroadcastProcessFunction<IN1, IN2, OUT> function, + final KeyedBroadcastProcessFunction<KEY, IN1, IN2, OUT> function, final int maxParallelism, final int numTasks, final int taskIdx) throws Exception { @@ -607,10 +703,10 @@ public class CoBroadcastWithKeyedOperatorTest { null); } - private static <KEY, IN1, IN2, K, V, OUT> TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness( + private static <KEY, IN1, IN2, OUT> TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> getInitializedTestHarness( final TypeInformation<KEY> keyTypeInfo, final KeySelector<IN1, KEY> keyKeySelector, - final KeyedBroadcastProcessFunction<IN1, IN2, OUT> function, + final KeyedBroadcastProcessFunction<KEY, IN1, IN2, OUT> function, final int maxParallelism, final int numTasks, final int taskIdx,
