http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index 3f1cfae..f8f26b5 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -19,15 +19,14 @@ package org.apache.flink.streaming.api.operators; import org.apache.flink.api.common.ExecutionConfig; -import org.apache.flink.api.common.state.OperatorState; -import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.VoidSerializer; import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier; import org.apache.flink.streaming.api.graph.StreamConfig; -import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvStateSnapshot; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.streaming.runtime.operators.Triggerable; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; @@ -37,7 +36,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.HashMap; -import java.util.Map; /** * Base class for all stream operators. Operators that contain a user function should extend the class @@ -81,22 +79,16 @@ public abstract class AbstractStreamOperator<OUT> /** The runtime context for UDFs */ private transient StreamingRuntimeContext runtimeContext; - + // ---------------- key/value state ------------------ - + /** key selector used to get the key for the state. Non-null only is the operator uses key/value state */ - private transient KeySelector<?, ?> stateKeySelector; - - private transient KvState<?, ?, ?>[] keyValueStates; - - private transient HashMap<String, KvState<?, ?, ?>> keyValueStatesByName; - - private transient TypeSerializer<?> keySerializer; - - private transient HashMap<String, KvStateSnapshot<?, ?, ?>> keyValueStateSnapshots; + private transient KeySelector<?, ?> stateKeySelector1; + private transient KeySelector<?, ?> stateKeySelector2; + + /** The state backend that stores the state and checkpoints for this task */ + private AbstractStateBackend stateBackend = null; - private long recoveryTimestamp; - // ------------------------------------------------------------------------ // Life Cycle // ------------------------------------------------------------------------ @@ -107,6 +99,19 @@ public abstract class AbstractStreamOperator<OUT> this.config = config; this.output = output; this.runtimeContext = new StreamingRuntimeContext(this, container.getEnvironment(), container.getAccumulatorMap()); + + stateKeySelector1 = config.getStatePartitioner(0, getUserCodeClassloader()); + stateKeySelector2 = config.getStatePartitioner(1, getUserCodeClassloader()); + + try { + TypeSerializer<Object> keySerializer = config.getStateKeySerializer(getUserCodeClassloader()); + // if the keySerializer is null we still need to create the state backend + // for the non-partitioned state features it provides, such as the state output streams + String operatorIdentifier = getClass().getSimpleName() + "_" + config.getVertexID() + "_" + runtimeContext.getIndexOfThisSubtask(); + stateBackend = container.createStateBackend(operatorIdentifier, keySerializer); + } catch (Exception e) { + throw new RuntimeException("Could not initialize state backend. ", e); + } } /** @@ -144,9 +149,12 @@ public abstract class AbstractStreamOperator<OUT> */ @Override public void dispose() { - if (keyValueStates != null) { - for (KvState<?, ?, ?> state : keyValueStates) { - state.dispose(); + if (stateBackend != null) { + try { + stateBackend.close(); + stateBackend.dispose(); + } catch (Exception e) { + throw new RuntimeException("Error while closing/disposing state backend.", e); } } } @@ -160,37 +168,33 @@ public abstract class AbstractStreamOperator<OUT> // here, we deal with key/value state snapshots StreamTaskState state = new StreamTaskState(); - if (keyValueStates != null) { - HashMap<String, KvStateSnapshot<?, ?, ?>> snapshots = new HashMap<>(keyValueStatesByName.size()); - - for (Map.Entry<String, KvState<?, ?, ?>> entry : keyValueStatesByName.entrySet()) { - KvStateSnapshot<?, ?, ?> snapshot = entry.getValue().snapshot(checkpointId, timestamp); - snapshots.put(entry.getKey(), snapshot); + + if (stateBackend != null) { + HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> partitionedSnapshots = + stateBackend.snapshotPartitionedState(checkpointId, timestamp); + if (partitionedSnapshots != null) { + state.setKvStates(partitionedSnapshots); } - - state.setKvStates(snapshots); } - + + return state; } @Override + @SuppressWarnings("rawtypes,unchecked") public void restoreState(StreamTaskState state, long recoveryTimestamp) throws Exception { // restore the key/value state. the actual restore happens lazily, when the function requests // the state again, because the restore method needs information provided by the user function - keyValueStateSnapshots = state.getKvStates(); - this.recoveryTimestamp = recoveryTimestamp; + if (stateBackend != null) { + stateBackend.injectKeyValueStateSnapshots((HashMap)state.getKvStates(), recoveryTimestamp); + } } @Override public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { - // We check whether the KvStates require notifications - if (keyValueStates != null) { - for (KvState<?, ?, ?> kvstate : keyValueStates) { - if (kvstate instanceof CheckpointNotifier) { - ((CheckpointNotifier) kvstate).notifyCheckpointComplete(checkpointId); - } - } + if (stateBackend != null) { + stateBackend.notifyOfCompletedCheckpoint(checkpointId); } } @@ -229,8 +233,8 @@ public abstract class AbstractStreamOperator<OUT> return runtimeContext; } - public StateBackend<?> getStateBackend() { - return container.getStateBackend(); + public AbstractStateBackend getStateBackend() { + return stateBackend; } /** @@ -245,122 +249,50 @@ public abstract class AbstractStreamOperator<OUT> } /** - * Creates a key/value state handle, using the state backend configured for this task. - * - * @param stateType The type information for the state type, used for managed memory and state snapshots. - * @param defaultValue The default value that the state should return for keys that currently have - * no value associated with them - * - * @param <V> The type of the state value. - * - * @return The key/value state for this operator. - * + * Creates a partitioned state handle, using the state backend configured for this task. + * * @throws IllegalStateException Thrown, if the key/value state was already initialized. * @throws Exception Thrown, if the state backend cannot create the key/value state. */ - protected <V> OperatorState<V> createKeyValueState( - String name, TypeInformation<V> stateType, V defaultValue) throws Exception - { - return createKeyValueState(name, stateType.createSerializer(getExecutionConfig()), defaultValue); + protected <S extends State> S getPartitionedState(StateDescriptor<S> stateDescriptor) throws Exception { + return getStateBackend().getPartitionedState(null, VoidSerializer.INSTANCE, stateDescriptor); } - + /** - * Creates a key/value state handle, using the state backend configured for this task. - * - * @param valueSerializer The type serializer for the state type, used for managed memory and state snapshots. - * @param defaultValue The default value that the state should return for keys that currently have - * no value associated with them - * - * @param <K> The type of the state key. - * @param <V> The type of the state value. - * @param <Backend> The type of the state backend that creates the key/value state. - * - * @return The key/value state for this operator. - * + * Creates a partitioned state handle, using the state backend configured for this task. + * * @throws IllegalStateException Thrown, if the key/value state was already initialized. * @throws Exception Thrown, if the state backend cannot create the key/value state. */ @SuppressWarnings("unchecked") - protected <K, V, Backend extends StateBackend<Backend>> OperatorState<V> createKeyValueState( - String name, TypeSerializer<V> valueSerializer, V defaultValue) throws Exception - { - if (name == null || name.isEmpty()) { - throw new IllegalArgumentException(); - } - if (keyValueStatesByName != null && keyValueStatesByName.containsKey(name)) { - throw new IllegalStateException("The key/value state has already been created"); - } - - TypeSerializer<K> keySerializer; - - // first time state access, make sure we load the state partitioner - if (stateKeySelector == null) { - stateKeySelector = config.getStatePartitioner(getUserCodeClassloader()); - if (stateKeySelector == null) { - throw new UnsupportedOperationException("The function or operator is not executed " + - "on a KeyedStream and can hence not access the key/value state"); - } - - keySerializer = config.getStateKeySerializer(getUserCodeClassloader()); - if (keySerializer == null) { - throw new Exception("State key serializer has not been configured in the config."); - } - this.keySerializer = keySerializer; - } - else if (this.keySerializer != null) { - keySerializer = (TypeSerializer<K>) this.keySerializer; - } - else { - // should never happen, this is merely a safeguard - throw new RuntimeException(); - } - - Backend stateBackend = (Backend) container.getStateBackend(); - - KvState<K, V, Backend> kvstate = null; - - // check whether we restore the key/value state from a snapshot, or create a new blank one - if (keyValueStateSnapshots != null) { - KvStateSnapshot<K, V, Backend> snapshot = (KvStateSnapshot<K, V, Backend>) keyValueStateSnapshots.remove(name); + protected <S extends State, N> S getPartitionedState(N namespace, TypeSerializer<N> namespaceSerializer, StateDescriptor<S> stateDescriptor) throws Exception { + return getStateBackend().getPartitionedState(namespace, (TypeSerializer<Object>) namespaceSerializer, + stateDescriptor); + } - if (snapshot != null) { - kvstate = snapshot.restoreState( - stateBackend, keySerializer, valueSerializer, defaultValue, getUserCodeClassloader(), recoveryTimestamp); - } - } - - if (kvstate == null) { - // create unique state id from operator id + state name - String stateId = name + "_" + getOperatorConfig().getVertexID(); - // create a new blank key/value state - kvstate = stateBackend.createKvState(stateId ,name , keySerializer, valueSerializer, defaultValue); - } - if (keyValueStatesByName == null) { - keyValueStatesByName = new HashMap<>(); + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void setKeyContextElement1(StreamRecord record) throws Exception { + if (stateKeySelector1 != null) { + Object key = ((KeySelector) stateKeySelector1).getKey(record.getValue()); + getStateBackend().setCurrentKey(key); } - keyValueStatesByName.put(name, kvstate); - keyValueStates = keyValueStatesByName.values().toArray(new KvState[keyValueStatesByName.size()]); - return kvstate; } - + @Override @SuppressWarnings({"unchecked", "rawtypes"}) - public void setKeyContextElement(StreamRecord record) throws Exception { - if (stateKeySelector != null && keyValueStates != null) { - KeySelector selector = stateKeySelector; - for (KvState kv : keyValueStates) { - kv.setCurrentKey(selector.getKey(record.getValue())); - } + public void setKeyContextElement2(StreamRecord record) throws Exception { + if (stateKeySelector2 != null) { + Object key = ((KeySelector) stateKeySelector2).getKey(record.getValue()); + getStateBackend().setCurrentKey(key); } } @SuppressWarnings({"unchecked", "rawtypes"}) public void setKeyContext(Object key) { - if (keyValueStates != null) { - for (KvState kv : keyValueStates) { - kv.setCurrentKey(key); - } + if (stateKeySelector1 != null) { + stateBackend.setCurrentKey(key); } }
http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java index c205445..37dd6ab 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java @@ -26,10 +26,10 @@ import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.state.StateHandle; -import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier; +import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.graph.StreamConfig; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.streaming.runtime.tasks.StreamTaskState; @@ -98,6 +98,7 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> extends @Override public void dispose() { + super.dispose(); if (!functionsClosed) { functionsClosed = true; try { @@ -131,7 +132,7 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> extends if (udfState != null) { try { - StateBackend<?> stateBackend = getStateBackend(); + AbstractStateBackend stateBackend = getStateBackend(); StateHandle<Serializable> handle = stateBackend.checkpointStateSerializable(udfState, checkpointId, timestamp); state.setFunctionState(handle); @@ -172,8 +173,8 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> extends public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { super.notifyOfCompletedCheckpoint(checkpointId); - if (userFunction instanceof CheckpointNotifier) { - ((CheckpointNotifier) userFunction).notifyCheckpointComplete(checkpointId); + if (userFunction instanceof CheckpointListener) { + ((CheckpointListener) userFunction).notifyCheckpointComplete(checkpointId); } } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamGroupedFold.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamGroupedFold.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamGroupedFold.java index c383935..e627ec8 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamGroupedFold.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamGroupedFold.java @@ -23,7 +23,8 @@ import java.io.IOException; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.FoldFunction; -import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataInputViewStreamWrapper; @@ -40,7 +41,7 @@ public class StreamGroupedFold<IN, OUT, KEY> private static final String STATE_NAME = "_op_state"; // Grouped values - private transient OperatorState<OUT> values; + private transient ValueState<OUT> values; private transient OUT initialValue; @@ -66,7 +67,8 @@ public class StreamGroupedFold<IN, OUT, KEY> ByteArrayInputStream bais = new ByteArrayInputStream(serializedInitialValue); DataInputViewStreamWrapper in = new DataInputViewStreamWrapper(bais); initialValue = outTypeSerializer.deserialize(in); - values = createKeyValueState(STATE_NAME, outTypeSerializer, null); + ValueStateDescriptor<OUT> stateId = new ValueStateDescriptor<>(STATE_NAME, null, outTypeSerializer); + values = getPartitionedState(stateId); } @Override http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamGroupedReduce.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamGroupedReduce.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamGroupedReduce.java index ae15e92..c054563 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamGroupedReduce.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamGroupedReduce.java @@ -18,7 +18,8 @@ package org.apache.flink.streaming.api.operators; import org.apache.flink.api.common.functions.ReduceFunction; -import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -30,7 +31,7 @@ public class StreamGroupedReduce<IN> extends AbstractUdfStreamOperator<IN, Reduc private static final String STATE_NAME = "_op_state"; - private transient OperatorState<IN> values; + private transient ValueState<IN> values; private TypeSerializer<IN> serializer; @@ -43,7 +44,8 @@ public class StreamGroupedReduce<IN> extends AbstractUdfStreamOperator<IN, Reduc @Override public void open() throws Exception { super.open(); - values = createKeyValueState(STATE_NAME, serializer, null); + ValueStateDescriptor<IN> stateId = new ValueStateDescriptor<>(STATE_NAME, null, serializer); + values = getPartitionedState(stateId); } @Override http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java index 96ddda1..a1f6f01 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java @@ -134,8 +134,10 @@ public interface StreamOperator<OUT> extends Serializable { // miscellaneous // ------------------------------------------------------------------------ - void setKeyContextElement(StreamRecord<?> record) throws Exception; - + void setKeyContextElement1(StreamRecord<?> record) throws Exception; + + void setKeyContextElement2(StreamRecord<?> record) throws Exception; + /** * An operator can return true here to disable copying of its input elements. This overrides * the object-reuse setting on the {@link org.apache.flink.api.common.ExecutionConfig} http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java index 46f2fef..dda92bc 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java @@ -21,7 +21,10 @@ package org.apache.flink.streaming.api.operators; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.api.common.functions.BroadcastVariableInitializer; import org.apache.flink.api.common.functions.util.AbstractRuntimeUDFContext; -import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.runtime.execution.Environment; @@ -30,7 +33,6 @@ import org.apache.flink.streaming.api.CheckpointingMode; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.operators.Triggerable; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -47,17 +49,9 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext { /** The task environment running the operator */ private final Environment taskEnvironment; - - /** The key/value state, if the user-function requests it */ - private HashMap<String, OperatorState<?>> keyValueStates; - - /** Type of the values stored in the state, to make sure repeated requests of the state are consistent */ - private HashMap<String, TypeInformation<?>> stateTypeInfos; - /** Stream configuration object. */ private final StreamConfig streamConfig; - public StreamingRuntimeContext(AbstractStreamOperator<?> operator, Environment env, Map<String, Accumulator<?, ?>> accumulators) { super(env.getTaskInfo(), @@ -112,7 +106,17 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext { // ------------------------------------------------------------------------ @Override - public <S> OperatorState<S> getKeyValueState(String name, Class<S> stateType, S defaultState) { + public <S extends State> S getPartitionedState(StateDescriptor<S> stateDescriptor) { + try { + return operator.getPartitionedState(stateDescriptor); + } catch (Exception e) { + throw new RuntimeException("Error while getting state.", e); + } + } + + @Override + @Deprecated + public <S> ValueState<S> getKeyValueState(String name, Class<S> stateType, S defaultState) { requireNonNull(stateType, "The state type class must not be null"); TypeInformation<S> typeInfo; @@ -120,62 +124,22 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext { typeInfo = TypeExtractor.getForClass(stateType); } catch (Exception e) { - throw new RuntimeException("Cannot analyze type '" + stateType.getName() + + throw new RuntimeException("Cannot analyze type '" + stateType.getName() + "' from the class alone, due to generic type parameters. " + "Please specify the TypeInformation directly.", e); } - + return getKeyValueState(name, typeInfo, defaultState); } @Override - public <S> OperatorState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) { + @Deprecated + public <S> ValueState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) { requireNonNull(name, "The name of the state must not be null"); requireNonNull(stateType, "The state type information must not be null"); - - OperatorState<?> previousState; - - // check if this is a repeated call to access the state - if (this.stateTypeInfos != null && this.keyValueStates != null && - (previousState = this.keyValueStates.get(name)) != null) { - - // repeated call - TypeInformation<?> previousType; - if (stateType.equals((previousType = this.stateTypeInfos.get(name)))) { - // valid case, same type requested again - @SuppressWarnings("unchecked") - OperatorState<S> previous = (OperatorState<S>) previousState; - return previous; - } - else { - // invalid case, different type requested this time - throw new IllegalStateException("Cannot initialize key/value state for type " + stateType + - " ; The key/value state has already been created and initialized for a different type: " + - previousType); - } - } - else { - // first time access to the key/value state - if (this.stateTypeInfos == null) { - this.stateTypeInfos = new HashMap<>(); - } - if (this.keyValueStates == null) { - this.keyValueStates = new HashMap<>(); - } - - try { - OperatorState<S> state = operator.createKeyValueState(name, stateType, defaultState); - this.keyValueStates.put(name, state); - this.stateTypeInfos.put(name, stateType); - return state; - } - catch (RuntimeException e) { - throw e; - } - catch (Exception e) { - throw new RuntimeException("Cannot initialize the key/value state", e); - } - } + + ValueStateDescriptor<S> stateDesc = new ValueStateDescriptor<>(name, defaultState, stateType.createSerializer(getExecutionConfig())); + return getPartitionedState(stateDesc); } // ------------------ expose (read only) relevant information from the stream config -------- // http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java index 30f0733..b065df6 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.api.transformations; import com.google.common.collect.Lists; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; @@ -41,6 +42,12 @@ public class TwoInputTransformation<IN1, IN2, OUT> extends StreamTransformation< private final TwoInputStreamOperator<IN1, IN2, OUT> operator; + private KeySelector<IN1, ?> stateKeySelector1; + + private KeySelector<IN2, ?> stateKeySelector2; + + private TypeInformation<?> stateKeyType; + /** * Creates a new {@code TwoInputTransformation} from the given inputs and operator. * @@ -99,6 +106,46 @@ public class TwoInputTransformation<IN1, IN2, OUT> extends StreamTransformation< return operator; } + /** + * Sets the {@link KeySelector KeySelectors} that must be used for partitioning keyed state of + * this transformation. + * + * @param stateKeySelector1 The {@code KeySelector} to set for the first input + * @param stateKeySelector2 The {@code KeySelector} to set for the first input + */ + public void setStateKeySelectors(KeySelector<IN1, ?> stateKeySelector1, KeySelector<IN2, ?> stateKeySelector2) { + this.stateKeySelector1 = stateKeySelector1; + this.stateKeySelector2 = stateKeySelector2; + } + + /** + * Returns the {@code KeySelector} that must be used for partitioning keyed state in this + * Operation for the first input. + * + * @see #setStateKeySelectors + */ + public KeySelector<IN1, ?> getStateKeySelector1() { + return stateKeySelector1; + } + + /** + * Returns the {@code KeySelector} that must be used for partitioning keyed state in this + * Operation for the second input. + * + * @see #setStateKeySelectors + */ + public KeySelector<IN2, ?> getStateKeySelector2() { + return stateKeySelector2; + } + + public void setStateKeyType(TypeInformation<?> stateKeyType) { + this.stateKeyType = stateKeyType; + } + + public TypeInformation<?> getStateKeyType() { + return stateKeyType; + } + @Override public Collection<StreamTransformation<?>> getTransitivePredecessors() { List<StreamTransformation<?>> result = Lists.newArrayList(); http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/ContinuousEventTimeTrigger.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/ContinuousEventTimeTrigger.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/ContinuousEventTimeTrigger.java index 0454e85..b653be3 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/ContinuousEventTimeTrigger.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/ContinuousEventTimeTrigger.java @@ -18,7 +18,7 @@ package org.apache.flink.streaming.api.windowing.triggers; import com.google.common.annotations.VisibleForTesting; -import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.state.ValueState; import org.apache.flink.streaming.api.windowing.time.Time; import org.apache.flink.streaming.api.windowing.windows.Window; @@ -42,7 +42,7 @@ public class ContinuousEventTimeTrigger<W extends Window> implements Trigger<Obj @Override public TriggerResult onElement(Object element, long timestamp, W window, TriggerContext ctx) throws Exception { - OperatorState<Boolean> first = ctx.getKeyValueState("first", true); + ValueState<Boolean> first = ctx.getKeyValueState("first", true); if (first.value()) { long start = timestamp - (timestamp % interval); http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/ContinuousProcessingTimeTrigger.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/ContinuousProcessingTimeTrigger.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/ContinuousProcessingTimeTrigger.java index 3576394..7f3e7ec 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/ContinuousProcessingTimeTrigger.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/ContinuousProcessingTimeTrigger.java @@ -18,7 +18,7 @@ package org.apache.flink.streaming.api.windowing.triggers; import com.google.common.annotations.VisibleForTesting; -import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.state.ValueState; import org.apache.flink.streaming.api.windowing.time.Time; import org.apache.flink.streaming.api.windowing.windows.Window; @@ -41,7 +41,7 @@ public class ContinuousProcessingTimeTrigger<W extends Window> implements Trigge public TriggerResult onElement(Object element, long timestamp, W window, TriggerContext ctx) throws Exception { long currentTime = System.currentTimeMillis(); - OperatorState<Long> fireState = ctx.getKeyValueState("fire-timestamp", 0L); + ValueState<Long> fireState = ctx.getKeyValueState("fire-timestamp", 0L); long nextFireTimestamp = fireState.value(); if (nextFireTimestamp == 0) { @@ -70,7 +70,7 @@ public class ContinuousProcessingTimeTrigger<W extends Window> implements Trigge @Override public TriggerResult onProcessingTime(long time, W window, TriggerContext ctx) throws Exception { - OperatorState<Long> fireState = ctx.getKeyValueState("fire-timestamp", 0L); + ValueState<Long> fireState = ctx.getKeyValueState("fire-timestamp", 0L); long nextFireTimestamp = fireState.value(); // only fire if an element didn't already fire http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/CountTrigger.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/CountTrigger.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/CountTrigger.java index efb62d7..d101fe1 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/CountTrigger.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/CountTrigger.java @@ -17,7 +17,7 @@ */ package org.apache.flink.streaming.api.windowing.triggers; -import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.state.ValueState; import org.apache.flink.streaming.api.windowing.windows.Window; import java.io.IOException; @@ -38,7 +38,7 @@ public class CountTrigger<W extends Window> implements Trigger<Object, W> { @Override public TriggerResult onElement(Object element, long timestamp, W window, TriggerContext ctx) throws IOException { - OperatorState<Long> count = ctx.getKeyValueState("count", 0L); + ValueState<Long> count = ctx.getKeyValueState("count", 0L); long currentCount = count.value() + 1; count.update(currentCount); if (currentCount >= maxCount) { http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/DeltaTrigger.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/DeltaTrigger.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/DeltaTrigger.java index d791d28..37c8a45 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/DeltaTrigger.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/DeltaTrigger.java @@ -17,7 +17,7 @@ */ package org.apache.flink.streaming.api.windowing.triggers; -import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.state.ValueState; import org.apache.flink.streaming.api.functions.windowing.delta.DeltaFunction; import org.apache.flink.streaming.api.windowing.windows.Window; @@ -46,7 +46,7 @@ public class DeltaTrigger<T extends Serializable, W extends Window> implements T @Override public TriggerResult onElement(T element, long timestamp, W window, TriggerContext ctx) throws Exception { - OperatorState<T> lastElementState = ctx.getKeyValueState("last-element", null); + ValueState<T> lastElementState = ctx.getKeyValueState("last-element", null); if (lastElementState.value() == null) { lastElementState.update(element); return TriggerResult.CONTINUE; http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/Trigger.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/Trigger.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/Trigger.java index cfb7945..aed393b 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/Trigger.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/triggers/Trigger.java @@ -17,7 +17,7 @@ */ package org.apache.flink.streaming.api.windowing.triggers; -import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.state.ValueState; import org.apache.flink.streaming.api.windowing.windows.Window; import java.io.Serializable; @@ -149,13 +149,13 @@ public interface Trigger<T, W extends Window> extends Serializable { void registerEventTimeTimer(long time); /** - * Retrieves an {@link OperatorState} object that can be used to interact with + * Retrieves an {@link ValueState} object that can be used to interact with * fault-tolerant state that is scoped to the window and key of the current * trigger invocation. * * @param name A unique key for the state. * @param defaultState The default value of the state. */ - <S extends Serializable> OperatorState<S> getKeyValueState(final String name, final S defaultState); + <S extends Serializable> ValueState<S> getKeyValueState(final String name, final S defaultState); } } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamInputProcessor.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamInputProcessor.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamInputProcessor.java index e131cda..9dacc8d 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamInputProcessor.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamInputProcessor.java @@ -162,7 +162,7 @@ public class StreamInputProcessor<IN> { // now we can do the actual processing StreamRecord<IN> record = recordOrWatermark.asRecord(); synchronized (lock) { - streamOperator.setKeyContextElement(record); + streamOperator.setKeyContextElement1(record); streamOperator.processElement(record); } return true; http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessor.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessor.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessor.java index 882037e..f639b4a 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessor.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessor.java @@ -186,6 +186,7 @@ public class StreamTwoInputProcessor<IN1, IN2> { } else { synchronized (lock) { + streamOperator.setKeyContextElement1(recordOrWatermark.<IN1>asRecord()); streamOperator.processElement1(recordOrWatermark.<IN1>asRecord()); } return true; @@ -200,6 +201,7 @@ public class StreamTwoInputProcessor<IN1, IN2> { } else { synchronized (lock) { + streamOperator.setKeyContextElement2(recordOrWatermark.<IN2>asRecord()); streamOperator.processElement2(recordOrWatermark.<IN2>asRecord()); } return true; http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java index b8c95aa..bc31791 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java @@ -29,7 +29,7 @@ import org.apache.flink.runtime.util.MathUtils; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.TimestampedCollector; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.streaming.runtime.operators.Triggerable; @@ -247,7 +247,7 @@ public abstract class AbstractAlignedProcessingTimeWindowOperator<KEY, IN, OUT, // we write the panes with the key/value maps into the stream, as well as when this state // should have triggered and slided - StateBackend.CheckpointStateOutputView out = + AbstractStateBackend.CheckpointStateOutputView out = getStateBackend().createCheckpointStateOutputView(checkpointId, timestamp); out.writeLong(nextEvaluationTime); http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java index 2afa1e7..48ad387 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java @@ -19,12 +19,12 @@ package org.apache.flink.streaming.runtime.operators.windowing; import com.google.common.annotations.VisibleForTesting; import org.apache.flink.api.common.ExecutionConfig; -import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.typeutils.InputTypeConfigurable; import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; @@ -397,7 +397,7 @@ public class NonKeyedWindowOperator<IN, OUT, W extends Window> } } - protected void writeToState(StateBackend.CheckpointStateOutputView out) throws IOException { + protected void writeToState(AbstractStateBackend.CheckpointStateOutputView out) throws IOException { windowSerializer.serialize(window, out); out.writeLong(watermarkTimer); out.writeLong(processingTimeTimer); @@ -414,8 +414,8 @@ public class NonKeyedWindowOperator<IN, OUT, W extends Window> } @SuppressWarnings("unchecked") - public <S extends Serializable> OperatorState<S> getKeyValueState(final String name, final S defaultState) { - return new OperatorState<S>() { + public <S extends Serializable> ValueState<S> getKeyValueState(final String name, final S defaultState) { + return new ValueState<S>() { @Override public S value() throws IOException { Serializable value = state.get(name); @@ -430,6 +430,11 @@ public class NonKeyedWindowOperator<IN, OUT, W extends Window> public void update(S value) throws IOException { state.put(name, value); } + + @Override + public void clear() { + state.remove(name); + } }; } @@ -523,7 +528,7 @@ public class NonKeyedWindowOperator<IN, OUT, W extends Window> StreamTaskState taskState = super.snapshotOperatorState(checkpointId, timestamp); // we write the panes with the key/value maps into the stream - StateBackend.CheckpointStateOutputView out = getStateBackend().createCheckpointStateOutputView(checkpointId, timestamp); + AbstractStateBackend.CheckpointStateOutputView out = getStateBackend().createCheckpointStateOutputView(checkpointId, timestamp); int numWindows = windows.size(); out.writeInt(numWindows); http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java index 4a50efb..fd39481 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java @@ -19,13 +19,13 @@ package org.apache.flink.streaming.runtime.operators.windowing; import com.google.common.annotations.VisibleForTesting; import org.apache.flink.api.common.ExecutionConfig; -import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.typeutils.InputTypeConfigurable; import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.streaming.api.functions.windowing.WindowFunction; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; @@ -297,7 +297,7 @@ public class WindowOperator<K, IN, OUT, W extends Window> if (context.windowBuffer.size() > 0) { - setKeyContextElement(context.windowBuffer.getElements().iterator().next()); + setKeyContextElement1(context.windowBuffer.getElements().iterator().next()); userFunction.apply(context.key, context.window, @@ -436,7 +436,7 @@ public class WindowOperator<K, IN, OUT, W extends Window> /** * Constructs a new {@code Context} by reading from a {@link DataInputView} that * contains a serialized context that we wrote in - * {@link #writeToState(StateBackend.CheckpointStateOutputView)} + * {@link #writeToState(AbstractStateBackend.CheckpointStateOutputView)} */ @SuppressWarnings("unchecked") protected Context(DataInputView in, ClassLoader userClassloader) throws Exception { @@ -461,7 +461,7 @@ public class WindowOperator<K, IN, OUT, W extends Window> /** * Writes the {@code Context} to the given state checkpoint output. */ - protected void writeToState(StateBackend.CheckpointStateOutputView out) throws IOException { + protected void writeToState(AbstractStateBackend.CheckpointStateOutputView out) throws IOException { keySerializer.serialize(key, out); windowSerializer.serialize(window, out); out.writeLong(watermarkTimer); @@ -479,8 +479,8 @@ public class WindowOperator<K, IN, OUT, W extends Window> } @SuppressWarnings("unchecked") - public <S extends Serializable> OperatorState<S> getKeyValueState(final String name, final S defaultState) { - return new OperatorState<S>() { + public <S extends Serializable> ValueState<S> getKeyValueState(final String name, final S defaultState) { + return new ValueState<S>() { @Override public S value() throws IOException { Serializable value = state.get(name); @@ -495,6 +495,11 @@ public class WindowOperator<K, IN, OUT, W extends Window> public void update(S value) throws IOException { state.put(name, value); } + + @Override + public void clear() { + state.remove(name); + } }; } @@ -588,7 +593,7 @@ public class WindowOperator<K, IN, OUT, W extends Window> StreamTaskState taskState = super.snapshotOperatorState(checkpointId, timestamp); // we write the panes with the key/value maps into the stream - StateBackend.CheckpointStateOutputView out = getStateBackend().createCheckpointStateOutputView(checkpointId, timestamp); + AbstractStateBackend.CheckpointStateOutputView out = getStateBackend().createCheckpointStateOutputView(checkpointId, timestamp); int numKeys = windows.size(); out.writeInt(numKeys); http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java index ac27093..125279c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java @@ -268,7 +268,7 @@ public class OperatorChain<OUT> { @Override public void collect(StreamRecord<T> record) { try { - operator.setKeyContextElement(record); + operator.setKeyContextElement1(record); operator.processElement(record); } catch (Exception e) { @@ -312,7 +312,7 @@ public class OperatorChain<OUT> { StreamRecord<T> copy = new StreamRecord<>(serializer.copy(record.getValue()), record.getTimestamp()); - operator.setKeyContextElement(copy); + operator.setKeyContextElement1(copy); operator.processElement(copy); } catch (Exception e) { http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index c9624fc..cb6a468 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -27,6 +27,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; @@ -34,14 +35,13 @@ import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.AsynchronousStateHandle; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; import org.apache.flink.runtime.util.event.EventListener; -import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.StreamOperator; -import org.apache.flink.runtime.state.StateBackend; import org.apache.flink.runtime.state.StateBackendFactory; import org.apache.flink.runtime.state.filesystem.FsStateBackend; import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory; @@ -122,9 +122,6 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> /** The class loader used to load dynamic classes of a job */ private ClassLoader userClassLoader; - /** The state backend that stores the state and checkpoints for this task */ - private StateBackend<?> stateBackend; - /** The executor service that schedules and calls the triggers of this task*/ private ScheduledExecutorService timerService; @@ -203,8 +200,6 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> LOG.debug("Invoking {}", getName()); // first order of business is to give operators back their state - stateBackend = createStateBackend(); - stateBackend.initializeForJob(getEnvironment()); restoreState(); // we need to make sure that any triggers scheduled in open() cannot be @@ -289,14 +284,6 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> if (!disposed) { disposeAllOperators(); } - - try { - if (stateBackend != null) { - stateBackend.close(); - } - } catch (Throwable t) { - LOG.error("Error while closing the state backend", t); - } } } @@ -557,11 +544,6 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> if (isRunning) { LOG.debug("Notification of complete checkpoint for task {}", getName()); - // We first notify the state backend if necessary - if (stateBackend instanceof CheckpointNotifier) { - ((CheckpointNotifier) stateBackend).notifyCheckpointComplete(checkpointId); - } - for (StreamOperator<?> operator : operatorChain.getAllOperators()) { if (operator != null) { operator.notifyOfCompletedCheckpoint(checkpointId); @@ -578,23 +560,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // State backend // ------------------------------------------------------------------------ - /** - * Gets the state backend used by this task. The state backend defines how to maintain the - * key/value state and how and where to store state snapshots. - * - * @return The state backend used by this task. - */ - public StateBackend<?> getStateBackend() { - return stateBackend; - } - - private StateBackend<?> createStateBackend() throws Exception { - StateBackend<?> configuredBackend = configuration.getStateBackend(userClassLoader); + public AbstractStateBackend createStateBackend(String operatorIdentifier, TypeSerializer<?> keySerializer) throws Exception { + AbstractStateBackend stateBackend = configuration.getStateBackend(userClassLoader); - if (configuredBackend != null) { + if (stateBackend != null) { // backend has been configured on the environment - LOG.info("Using user-defined state backend: " + configuredBackend); - return configuredBackend; + LOG.info("Using user-defined state backend: " + stateBackend); } else { // see if we have a backend specified in the configuration Configuration flinkConfig = getEnvironment().getTaskManagerInfo().getConfiguration(); @@ -609,13 +580,15 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> switch (backendName) { case "jobmanager": LOG.info("State backend is set to heap memory (checkpoint to jobmanager)"); - return MemoryStateBackend.defaultInstance(); + stateBackend = MemoryStateBackend.create(); + break; case "filesystem": FsStateBackend backend = new FsStateBackendFactory().createFromConfig(flinkConfig); - LOG.info("State backend is set to filesystem (checkpoints to filesystem \"" + LOG.info("State backend is set to heap memory (checkpoints to filesystem \"" + backend.getBasePath() + "\")"); - return backend; + stateBackend = backend; + break; default: try { @@ -623,7 +596,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> Class<? extends StateBackendFactory> clazz = Class.forName(backendName, false, userClassLoader).asSubclass(StateBackendFactory.class); - return clazz.newInstance().createFromConfig(flinkConfig); + stateBackend = ((StateBackendFactory<?>) clazz.newInstance()).createFromConfig(flinkConfig); } catch (ClassNotFoundException e) { throw new IllegalConfigurationException("Cannot find configured state backend: " + backendName); } catch (ClassCastException e) { @@ -635,6 +608,9 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> } } } + stateBackend.initializeForJob(getEnvironment(), operatorIdentifier, keySerializer); + return stateBackend; + } /** http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskState.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskState.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskState.java index afeabd9..ace9cfd 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskState.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskState.java @@ -43,7 +43,7 @@ public class StreamTaskState implements Serializable { private StateHandle<Serializable> functionState; - private HashMap<String, KvStateSnapshot<?, ?, ?>> kvStates; + private HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> kvStates; // ------------------------------------------------------------------------ @@ -63,11 +63,11 @@ public class StreamTaskState implements Serializable { this.functionState = functionState; } - public HashMap<String, KvStateSnapshot<?, ?, ?>> getKvStates() { + public HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> getKvStates() { return kvStates; } - public void setKvStates(HashMap<String, KvStateSnapshot<?, ?, ?>> kvStates) { + public void setKvStates(HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> kvStates) { this.kvStates = kvStates; } @@ -92,7 +92,7 @@ public class StreamTaskState implements Serializable { public void discardState() throws Exception { StateHandle<?> operatorState = this.operatorState; StateHandle<?> functionState = this.functionState; - HashMap<String, KvStateSnapshot<?, ?, ?>> kvStates = this.kvStates; + HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> kvStates = this.kvStates; if (operatorState != null) { operatorState.discardState(); @@ -103,9 +103,9 @@ public class StreamTaskState implements Serializable { if (kvStates != null) { while (kvStates.size() > 0) { try { - Iterator<KvStateSnapshot<?, ?, ?>> values = kvStates.values().iterator(); + Iterator<KvStateSnapshot<?, ?, ?, ?, ?>> values = kvStates.values().iterator(); while (values.hasNext()) { - KvStateSnapshot<?, ?, ?> s = values.next(); + KvStateSnapshot<?, ?, ?, ?, ?> s = values.next(); s.discardState(); values.remove(); } @@ -121,4 +121,3 @@ public class StreamTaskState implements Serializable { this.kvStates = null; } } - \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskStateList.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskStateList.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskStateList.java index e9f9ab6..e698db6 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskStateList.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskStateList.java @@ -45,7 +45,7 @@ public class StreamTaskStateList implements StateHandle<StreamTaskState[]> { if (state != null) { StateHandle<?> operatorState = state.getOperatorState(); StateHandle<?> functionState = state.getFunctionState(); - HashMap<String, KvStateSnapshot<?, ?, ?>> kvStates = state.getKvStates(); + HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> kvStates = state.getKvStates(); if (operatorState != null) { sumStateSize += operatorState.getStateSize(); @@ -56,7 +56,7 @@ public class StreamTaskStateList implements StateHandle<StreamTaskState[]> { } if (kvStates != null) { - for (KvStateSnapshot<?, ?, ?> kvState : kvStates.values()) { + for (KvStateSnapshot<?, ?, ?, ?, ?> kvState : kvStates.values()) { if (kvState != null) { sumStateSize += kvState.getStateSize(); } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java index 169c93d..475a95d 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java @@ -653,7 +653,7 @@ public class DataStreamTest extends StreamingMultipleProgramsTestBase { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); DataStreamSink<Long> sink = env.generateSequence(1, 100).print(); - assertTrue(env.getStreamGraph().getStreamNode(sink.getTransformation().getId()).getStatePartitioner() == null); + assertTrue(env.getStreamGraph().getStreamNode(sink.getTransformation().getId()).getStatePartitioner1() == null); assertTrue(env.getStreamGraph().getStreamNode(sink.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof ForwardPartitioner); KeySelector<Long, Long> key1 = new KeySelector<Long, Long>() { @@ -668,10 +668,10 @@ public class DataStreamTest extends StreamingMultipleProgramsTestBase { DataStreamSink<Long> sink2 = env.generateSequence(1, 100).keyBy(key1).print(); - assertNotNull(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStatePartitioner()); + assertNotNull(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStatePartitioner1()); assertNotNull(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStateKeySerializer()); assertNotNull(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStateKeySerializer()); - assertEquals(key1, env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStatePartitioner()); + assertEquals(key1, env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStatePartitioner1()); assertTrue(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof HashPartitioner); KeySelector<Long, Long> key2 = new KeySelector<Long, Long>() { @@ -686,8 +686,8 @@ public class DataStreamTest extends StreamingMultipleProgramsTestBase { DataStreamSink<Long> sink3 = env.generateSequence(1, 100).keyBy(key2).print(); - assertTrue(env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getStatePartitioner() != null); - assertEquals(key2, env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getStatePartitioner()); + assertTrue(env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getStatePartitioner1() != null); + assertEquals(key2, env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getStatePartitioner1()); assertTrue(env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof HashPartitioner); } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/SelfConnectionTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/SelfConnectionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/SelfConnectionTest.java index d00dc67..8f04d41 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/SelfConnectionTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/SelfConnectionTest.java @@ -134,13 +134,13 @@ public class SelfConnectionTest extends StreamingMultipleProgramsTestBase { public Long map(Integer value) throws Exception { return Long.valueOf(value + 1); } - }).keyBy(new KeySelector<Long, Long>() { + }).keyBy(new KeySelector<Long, Integer>() { private static final long serialVersionUID = 1L; @Override - public Long getKey(Long value) throws Exception { - return value; + public Integer getKey(Long value) throws Exception { + return value.intValue(); } }); http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java index 8722802..02bb8b7 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java @@ -33,7 +33,7 @@ import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction; import org.apache.flink.streaming.api.functions.windowing.WindowFunction; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.Output; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.streaming.runtime.operators.Triggerable; @@ -46,7 +46,6 @@ import org.junit.After; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import org.mockito.stubbing.OngoingStubbing; import java.util.ArrayList; import java.util.Arrays; @@ -795,18 +794,27 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { when(task.getName()).thenReturn("Test task name"); when(task.getExecutionConfig()).thenReturn(new ExecutionConfig()); - Environment env = mock(Environment.class); + final Environment env = mock(Environment.class); when(env.getTaskInfo()).thenReturn(new TaskInfo("Test task name", 0, 1, 0)); when(env.getUserClassLoader()).thenReturn(AggregatingAlignedProcessingTimeWindowOperatorTest.class.getClassLoader()); when(task.getEnvironment()).thenReturn(env); - // ugly java generic hacks to get the state backend into the mock - @SuppressWarnings("unchecked") - OngoingStubbing<StateBackend<?>> stubbing = - (OngoingStubbing<StateBackend<?>>) (OngoingStubbing<?>) when(task.getStateBackend()); - stubbing.thenReturn(MemoryStateBackend.defaultInstance()); - + try { + doAnswer(new Answer<AbstractStateBackend>() { + @Override + public AbstractStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable { + final String operatorIdentifier = (String) invocationOnMock.getArguments()[0]; + final TypeSerializer<?> keySerializer = (TypeSerializer<?>) invocationOnMock.getArguments()[1]; + MemoryStateBackend backend = MemoryStateBackend.create(); + backend.initializeForJob(env, operatorIdentifier, keySerializer); + return backend; + } + }).when(task).createStateBackend(any(String.class), any(TypeSerializer.class)); + } catch (Exception e) { + e.printStackTrace(); + } + return task; } @@ -841,7 +849,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { private static StreamConfig createTaskConfig(KeySelector<?, ?> partitioner, TypeSerializer<?> keySerializer) { StreamConfig cfg = new StreamConfig(new Configuration()); - cfg.setStatePartitioner(partitioner); + cfg.setStatePartitioner(0, partitioner); cfg.setStateKeySerializer(keySerializer); return cfg; } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java index 611916e..35bd209 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java @@ -36,7 +36,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.Output; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.runtime.operators.Triggerable; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -47,7 +47,6 @@ import org.junit.After; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import org.mockito.stubbing.OngoingStubbing; import java.util.ArrayList; import java.util.Arrays; @@ -263,7 +262,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = 0; i < numElements; i++) { synchronized (lock) { StreamRecord<Tuple2<Integer, Integer>> next = new StreamRecord<>(new Tuple2<>(i, i)); - op.setKeyContextElement(next); + op.setKeyContextElement1(next); op.processElement(next); } Thread.sleep(1); @@ -327,7 +326,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { int val = ((int) nextTime) ^ ((int) (nextTime >>> 32)); StreamRecord<Tuple2<Integer, Integer>> next = new StreamRecord<>(new Tuple2<>(val, val)); - op.setKeyContextElement(next); + op.setKeyContextElement1(next); op.processElement(next); if (nextTime != previousNextTime) { @@ -388,7 +387,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = 0; i < numElements; i++) { synchronized (lock) { StreamRecord<Tuple2<Integer, Integer>> next = new StreamRecord<>(new Tuple2<>(i, i)); - op.setKeyContextElement(next); + op.setKeyContextElement1(next); op.processElement(next); } Thread.sleep(1); @@ -454,11 +453,11 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { synchronized (lock) { StreamRecord<Tuple2<Integer, Integer>> next1 = new StreamRecord<>(new Tuple2<>(1, 1)); - op.setKeyContextElement(next1); + op.setKeyContextElement1(next1); op.processElement(next1); StreamRecord<Tuple2<Integer, Integer>> next2 = new StreamRecord<>(new Tuple2<>(2, 2)); - op.setKeyContextElement(next2); + op.setKeyContextElement1(next2); op.processElement(next2); } @@ -517,14 +516,14 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = 0; i < 100; i++) { synchronized (lock) { StreamRecord<Tuple2<Integer, Integer>> next = new StreamRecord<>(new Tuple2<>(1, 1)); - op.setKeyContextElement(next); + op.setKeyContextElement1(next); op.processElement(next); } } try { StreamRecord<Tuple2<Integer, Integer>> next = new StreamRecord<>(new Tuple2<>(1, 1)); - op.setKeyContextElement(next); + op.setKeyContextElement1(next); op.processElement(next); fail("This fail with an exception"); } @@ -569,7 +568,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = 0; i < numElementsFirst; i++) { synchronized (lock) { StreamRecord<Tuple2<Integer, Integer>> next = new StreamRecord<>(new Tuple2<>(i, i)); - op.setKeyContextElement(next); + op.setKeyContextElement1(next); op.processElement(next); } Thread.sleep(1); @@ -592,7 +591,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = numElementsFirst; i < numElements; i++) { synchronized (lock) { StreamRecord<Tuple2<Integer, Integer>> next = new StreamRecord<>(new Tuple2<>(i, i)); - op.setKeyContextElement(next); + op.setKeyContextElement1(next); op.processElement(next); } Thread.sleep(1); @@ -615,7 +614,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = numElementsFirst; i < numElements; i++) { synchronized (lock) { StreamRecord<Tuple2<Integer, Integer>> next = new StreamRecord<>(new Tuple2<>(i, i)); - op.setKeyContextElement(next); + op.setKeyContextElement1(next); op.processElement(next); } Thread.sleep(1); @@ -677,7 +676,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = 0; i < numElementsFirst; i++) { synchronized (lock) { StreamRecord<Tuple2<Integer, Integer>> next = new StreamRecord<>(new Tuple2<>(i, i)); - op.setKeyContextElement(next); + op.setKeyContextElement1(next); op.processElement(next); } Thread.sleep(1); @@ -700,7 +699,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = numElementsFirst; i < numElements; i++) { synchronized (lock) { StreamRecord<Tuple2<Integer, Integer>> next = new StreamRecord<>(new Tuple2<>(i, i)); - op.setKeyContextElement(next); + op.setKeyContextElement1(next); op.processElement(next); } Thread.sleep(1); @@ -724,7 +723,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = numElementsFirst; i < numElements; i++) { synchronized (lock) { StreamRecord<Tuple2<Integer, Integer>> next = new StreamRecord<>(new Tuple2<>(i, i)); - op.setKeyContextElement(next); + op.setKeyContextElement1(next); op.processElement(next); } Thread.sleep(1); @@ -790,11 +789,11 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { synchronized (lock) { for (int i = 0; i < 10; i++) { StreamRecord<Tuple2<Integer, Integer>> next1 = new StreamRecord<>(new Tuple2<>(1, i)); - op.setKeyContextElement(next1); + op.setKeyContextElement1(next1); op.processElement(next1); StreamRecord<Tuple2<Integer, Integer>> next2 = new StreamRecord<>(new Tuple2<>(2, i)); - op.setKeyContextElement(next2); + op.setKeyContextElement1(next2); op.processElement(next2); } } @@ -859,13 +858,13 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { // because we do not release the lock between elements, they end up in the same windows synchronized (lock) { - op.setKeyContextElement(next1); + op.setKeyContextElement1(next1); op.processElement(next1); - op.setKeyContextElement(next2); + op.setKeyContextElement1(next2); op.processElement(next2); - op.setKeyContextElement(next3); + op.setKeyContextElement1(next3); op.processElement(next3); - op.setKeyContextElement(next4); + op.setKeyContextElement1(next4); op.processElement(next4); } @@ -969,18 +968,27 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { when(task.getName()).thenReturn("Test task name"); when(task.getExecutionConfig()).thenReturn(new ExecutionConfig()); - Environment env = mock(Environment.class); + final Environment env = mock(Environment.class); when(env.getTaskInfo()).thenReturn(new TaskInfo("Test task name", 0, 1, 0)); when(env.getUserClassLoader()).thenReturn(AggregatingAlignedProcessingTimeWindowOperatorTest.class.getClassLoader()); when(task.getEnvironment()).thenReturn(env); - // ugly java generic hacks to get the state backend into the mock - @SuppressWarnings("unchecked") - OngoingStubbing<StateBackend<?>> stubbing = - (OngoingStubbing<StateBackend<?>>) (OngoingStubbing<?>) when(task.getStateBackend()); - stubbing.thenReturn(MemoryStateBackend.defaultInstance()); - + try { + doAnswer(new Answer<AbstractStateBackend>() { + @Override + public AbstractStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable { + final String operatorIdentifier = (String) invocationOnMock.getArguments()[0]; + final TypeSerializer<?> keySerializer = (TypeSerializer<?>) invocationOnMock.getArguments()[1]; + MemoryStateBackend backend = MemoryStateBackend.create(); + backend.initializeForJob(env, operatorIdentifier, keySerializer); + return backend; + } + }).when(task).createStateBackend(any(String.class), any(TypeSerializer.class)); + } catch (Exception e) { + e.printStackTrace(); + } + return task; } @@ -1015,7 +1023,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { private static StreamConfig createTaskConfig(KeySelector<?, ?> partitioner, TypeSerializer<?> keySerializer) { StreamConfig cfg = new StreamConfig(new Configuration()); - cfg.setStatePartitioner(partitioner); + cfg.setStatePartitioner(0, partitioner); cfg.setStateKeySerializer(keySerializer); return cfg; } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/state/StateBackendITCase.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/state/StateBackendITCase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/state/StateBackendITCase.java index add532f..3a49331 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/state/StateBackendITCase.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/state/StateBackendITCase.java @@ -19,21 +19,27 @@ package org.apache.flink.streaming.runtime.state; import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.client.JobExecutionException; import org.apache.flink.runtime.execution.Environment; -import org.apache.flink.runtime.state.KvState; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase; + import org.junit.Test; import java.io.Serializable; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class StateBackendITCase extends StreamingMultipleProgramsTestBase { @@ -69,46 +75,47 @@ public class StateBackendITCase extends StreamingMultipleProgramsTestBase { } }) .print(); - - boolean caughtSuccess = false; + try { see.execute(); - } catch (JobExecutionException e) { - if (e.getCause() instanceof SuccessException) { - caughtSuccess = true; - } else { + fail(); + } + catch (JobExecutionException e) { + Throwable t = e.getCause(); + if (!(t != null && t.getCause() instanceof SuccessException)) { throw e; } } - - assertTrue(caughtSuccess); } - public static class FailingStateBackend extends StateBackend<FailingStateBackend> { + public static class FailingStateBackend extends AbstractStateBackend { + private static final long serialVersionUID = 1L; @Override - public void initializeForJob(Environment env) throws Exception { + public void initializeForJob(Environment env, String operatorIdentifier, TypeSerializer<?> keySerializer) throws Exception { throw new SuccessException(); } @Override - public void disposeAllStateForCurrentJob() throws Exception { + public void disposeAllStateForCurrentJob() throws Exception {} - } + @Override + public void close() throws Exception {} @Override - public void close() throws Exception { + protected <N, T> ValueState<T> createValueState(TypeSerializer<N> namespaceSerializer, ValueStateDescriptor<T> stateDesc) throws Exception { + return null; + } + @Override + protected <N, T> ListState<T> createListState(TypeSerializer<N> namespaceSerializer, ListStateDescriptor<T> stateDesc) throws Exception { + return null; } @Override - public <K, V> KvState<K, V, FailingStateBackend> createKvState(String stateId, - String stateName, - TypeSerializer<K> keySerializer, - TypeSerializer<V> valueSerializer, - V defaultValue) throws Exception { + protected <N, T> ReducingState<T> createReducingState(TypeSerializer<N> namespaceSerializer, ReducingStateDescriptor<T> stateDesc) throws Exception { return null; } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/MockContext.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/MockContext.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/MockContext.java index 0c708c6..675e7b6 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/MockContext.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/MockContext.java @@ -29,21 +29,22 @@ import java.util.concurrent.TimeUnit; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.operators.testutils.MockEnvironment; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.Output; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.runtime.operators.Triggerable; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import org.mockito.stubbing.OngoingStubbing; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; @@ -87,13 +88,13 @@ public class MockContext<IN, OUT> { StreamConfig config = new StreamConfig(new Configuration()); if (keySelector != null && keyType != null) { config.setStateKeySerializer(keyType.createSerializer(new ExecutionConfig())); - config.setStatePartitioner(keySelector); + config.setStatePartitioner(0, keySelector); } final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); final Object lock = new Object(); final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock); - + operator.setup(mockTask, config, mockContext.output); try { operator.open(); @@ -102,7 +103,7 @@ public class MockContext<IN, OUT> { for (IN in: inputs) { record = record.replace(in); synchronized (lock) { - operator.setKeyContextElement(record); + operator.setKeyContextElement1(record); operator.processElement(record); } } @@ -148,12 +149,22 @@ public class MockContext<IN, OUT> { } }).when(task).registerTimer(anyLong(), any(Triggerable.class)); - // ugly Java generic hacks to get the generic state backend into the mock - @SuppressWarnings("unchecked") - OngoingStubbing<StateBackend<?>> stubbing = - (OngoingStubbing<StateBackend<?>>) (OngoingStubbing<?>) when(task.getStateBackend()); - stubbing.thenReturn(MemoryStateBackend.defaultInstance()); - + + try { + doAnswer(new Answer<AbstractStateBackend>() { + @Override + public AbstractStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable { + final String operatorIdentifier = (String) invocationOnMock.getArguments()[0]; + final TypeSerializer<?> keySerializer = (TypeSerializer<?>) invocationOnMock.getArguments()[1]; + MemoryStateBackend backend = MemoryStateBackend.create(); + backend.initializeForJob(new DummyEnvironment("dummty", 1, 0), operatorIdentifier, keySerializer); + return backend; + } + }).when(task).createStateBackend(any(String.class), any(TypeSerializer.class)); + } catch (Exception e) { + e.printStackTrace(); + } + return task; } }
