http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index a73f3b2..0ca89ef 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -18,29 +18,35 @@ package org.apache.flink.streaming.api.operators; +import org.apache.commons.io.IOUtils; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.state.State; import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.metrics.Counter; import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.watermark.Watermark; -import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collection; +import java.util.concurrent.RunnableFuture; + /** * Base class for all stream operators. Operators that contain a user function should extend the class * {@link AbstractUdfStreamOperator} instead (which is a specialized subclass of this class). @@ -90,7 +96,12 @@ public abstract class AbstractStreamOperator<OUT> private transient KeySelector<?, ?> stateKeySelector2; /** Backend for keyed state. This might be empty if we're not on a keyed stream. */ - private transient KeyedStateBackend<?> keyedStateBackend; + private transient AbstractKeyedStateBackend<?> keyedStateBackend; + + /** Operator state backend */ + private transient OperatorStateBackend operatorStateBackend; + + private transient Collection<OperatorStateHandle> lazyRestoreStateHandles; protected transient MetricGroup metrics; @@ -116,9 +127,14 @@ public abstract class AbstractStreamOperator<OUT> return metrics; } + @Override + public void restoreState(Collection<OperatorStateHandle> stateHandles) { + this.lazyRestoreStateHandles = stateHandles; + } + /** * This method is called immediately before any elements are processed, it should contain the - * operator's initialization logic. + * operator's initialization logic, e.g. state initialization. * * <p>The default implementation does nothing. * @@ -126,24 +142,39 @@ public abstract class AbstractStreamOperator<OUT> */ @Override public void open() throws Exception { + initOperatorState(); + initKeyedState(); + } + + private void initKeyedState() { try { TypeSerializer<Object> keySerializer = config.getStateKeySerializer(getUserCodeClassloader()); // create a keyed state backend if there is keyed state, as indicated by the presence of a key serializer if (null != keySerializer) { - ExecutionConfig execConf = container.getEnvironment().getExecutionConfig();; KeyGroupRange subTaskKeyGroupRange = KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( container.getEnvironment().getTaskInfo().getNumberOfKeyGroups(), container.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(), container.getIndexInSubtaskGroup()); - keyedStateBackend = container.createKeyedStateBackend( + this.keyedStateBackend = container.createKeyedStateBackend( keySerializer, container.getConfiguration().getNumberOfKeyGroups(getUserCodeClassloader()), subTaskKeyGroupRange); + } + + } catch (Exception e) { + throw new IllegalStateException("Could not initialize keyed state backend.", e); + } + } + + private void initOperatorState() { + try { + // create an operator state backend + this.operatorStateBackend = container.createOperatorStateBackend(this, lazyRestoreStateHandles); } catch (Exception e) { - throw new RuntimeException("Could not initialize keyed state backend.", e); + throw new IllegalStateException("Could not initialize operator state backend.", e); } } @@ -171,18 +202,25 @@ public abstract class AbstractStreamOperator<OUT> */ @Override public void dispose() throws Exception { + + if (operatorStateBackend != null) { + IOUtils.closeQuietly(operatorStateBackend); + operatorStateBackend.dispose(); + } + if (keyedStateBackend != null) { - keyedStateBackend.close(); + IOUtils.closeQuietly(keyedStateBackend); + keyedStateBackend.dispose(); } } @Override - public void snapshotState(FSDataOutputStream out, - long checkpointId, - long timestamp) throws Exception {} + public RunnableFuture<OperatorStateHandle> snapshotState( + long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception { - @Override - public void restoreState(FSDataInputStream in) throws Exception {} + return operatorStateBackend != null ? + operatorStateBackend.snapshot(checkpointId, timestamp, streamFactory) : null; + } @Override public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {} @@ -223,10 +261,24 @@ public abstract class AbstractStreamOperator<OUT> } @SuppressWarnings("rawtypes, unchecked") - public <K> KeyedStateBackend<K> getStateBackend() { + public <K> KeyedStateBackend<K> getKeyedStateBackend() { + + if (null == keyedStateBackend) { + initKeyedState(); + } + return (KeyedStateBackend<K>) keyedStateBackend; } + public OperatorStateBackend getOperatorStateBackend() { + + if (null == operatorStateBackend) { + initOperatorState(); + } + + return operatorStateBackend; + } + /** * Returns the {@link TimeServiceProvider} responsible for getting the current * processing time and registering timers. @@ -268,18 +320,18 @@ public abstract class AbstractStreamOperator<OUT> @Override @SuppressWarnings({"unchecked", "rawtypes"}) public void setKeyContextElement1(StreamRecord record) throws Exception { - if (stateKeySelector1 != null) { - Object key = ((KeySelector) stateKeySelector1).getKey(record.getValue()); - getStateBackend().setCurrentKey(key); - } + setRawKeyContextElement(record, stateKeySelector1); } @Override @SuppressWarnings({"unchecked", "rawtypes"}) public void setKeyContextElement2(StreamRecord record) throws Exception { - if (stateKeySelector2 != null) { - Object key = ((KeySelector) stateKeySelector2).getKey(record.getValue()); + setRawKeyContextElement(record, stateKeySelector2); + } + private void setRawKeyContextElement(StreamRecord record, KeySelector<?, ?> selector) throws Exception { + if (selector != null) { + Object key = ((KeySelector) selector).getKey(record.getValue()); setKeyContext(key); } } @@ -290,7 +342,7 @@ public abstract class AbstractStreamOperator<OUT> try { // need to work around type restrictions @SuppressWarnings("unchecked,rawtypes") - KeyedStateBackend rawBackend = (KeyedStateBackend) keyedStateBackend; + AbstractKeyedStateBackend rawBackend = (AbstractKeyedStateBackend) keyedStateBackend; rawBackend.setCurrentKey(key); } catch (Exception e) {
http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java index 6ac73e7..f683d9a 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java @@ -18,23 +18,31 @@ package org.apache.flink.streaming.api.operators; -import java.io.Serializable; - import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.streaming.api.checkpoint.Checkpointed; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.util.InstantiationUtil; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.RunnableFuture; + import static java.util.Objects.requireNonNull; /** @@ -50,7 +58,8 @@ import static java.util.Objects.requireNonNull; @PublicEvolving public abstract class AbstractUdfStreamOperator<OUT, F extends Function> extends AbstractStreamOperator<OUT> - implements OutputTypeConfigurable<OUT> { + implements OutputTypeConfigurable<OUT>, + StreamCheckpointedOperator { private static final long serialVersionUID = 1L; @@ -91,6 +100,28 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> super.open(); FunctionUtils.openFunction(userFunction, new Configuration()); + + if (userFunction instanceof CheckpointedFunction) { + ((CheckpointedFunction) userFunction).initializeState(getOperatorStateBackend()); + } else if (userFunction instanceof ListCheckpointed) { + @SuppressWarnings("unchecked") + ListCheckpointed<Serializable> listCheckpointedFun = (ListCheckpointed<Serializable>) userFunction; + + ListState<Serializable> listState = + getOperatorStateBackend().getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR); + + List<Serializable> list = new ArrayList<>(); + + for (Serializable serializable : listState.get()) { + list.add(serializable); + } + + try { + listCheckpointedFun.restoreState(list); + } catch (Exception e) { + throw new Exception("Failed to restore state to function: " + e.getMessage(), e); + } + } } @Override @@ -115,7 +146,6 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> @Override public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception { - super.snapshotState(out, checkpointId, timestamp); if (userFunction instanceof Checkpointed) { @SuppressWarnings("unchecked") @@ -138,7 +168,6 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> @Override public void restoreState(FSDataInputStream in) throws Exception { - super.restoreState(in); if (userFunction instanceof Checkpointed) { @SuppressWarnings("unchecked") @@ -160,6 +189,32 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> } @Override + public RunnableFuture<OperatorStateHandle> snapshotState( + long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception { + + if (userFunction instanceof CheckpointedFunction) { + ((CheckpointedFunction) userFunction).prepareSnapshot(checkpointId, timestamp); + } + + if (userFunction instanceof ListCheckpointed) { + @SuppressWarnings("unchecked") + List<Serializable> partitionableState = + ((ListCheckpointed<Serializable>) userFunction).snapshotState(checkpointId, timestamp); + + ListState<Serializable> listState = + getOperatorStateBackend().getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR); + + listState.clear(); + + for (Serializable statePartition : partitionableState) { + listState.add(statePartition); + } + } + + return super.snapshotState(checkpointId, timestamp, streamFactory); + } + + @Override public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { super.notifyOfCompletedCheckpoint(checkpointId); http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java new file mode 100644 index 0000000..50cdc02 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.fs.FSDataOutputStream; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.tasks.StreamTask; + +@Deprecated +public interface StreamCheckpointedOperator { + + /** + * Called to draw a state snapshot from the operator. This method snapshots the operator state + * (if the operator is stateful). + * + * @param out The stream to which we have to write our state. + * @param checkpointId The ID of the checkpoint. + * @param timestamp The timestamp of the checkpoint. + * + * @throws Exception Forwards exceptions that occur while drawing snapshots from the operator + * and the key/value state. + */ + void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception; + + /** + * Restores the operator state, if this operator's execution is recovering from a checkpoint. + * This method restores the operator state (if the operator is stateful) and the key/value state + * (if it had been used and was initialized when the snapshot occurred). + * + * <p>This method is called after {@link #setup(StreamTask, StreamConfig, Output)} + * and before {@link #open()}. + * + * @param in The stream from which we have to restore our state. + * + * @throws Exception Exceptions during state restore should be forwarded, so that the system can + * properly react to failed state restore and fail the execution attempt. + */ + void restoreState(FSDataInputStream in) throws Exception; + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java index f1e8160..fae5fd0 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java @@ -17,16 +17,18 @@ package org.apache.flink.streaming.api.operators; -import java.io.Serializable; - import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.metrics.MetricGroup; -import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; +import java.io.Serializable; +import java.util.Collection; +import java.util.concurrent.RunnableFuture; + /** * Basic interface for stream operators. Implementers would implement one of * {@link org.apache.flink.streaming.api.operators.OneInputStreamOperator} or @@ -91,32 +93,27 @@ public interface StreamOperator<OUT> extends Serializable { // ------------------------------------------------------------------------ /** - * Called to draw a state snapshot from the operator. This method snapshots the operator state - * (if the operator is stateful). - * - * @param out The stream to which we have to write our state. - * @param checkpointId The ID of the checkpoint. - * @param timestamp The timestamp of the checkpoint. + * Called to draw a state snapshot from the operator. * - * @throws Exception Forwards exceptions that occur while drawing snapshots from the operator - * and the key/value state. + * @throws Exception Forwards exceptions that occur while preparing for the snapshot */ - void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception; /** - * Restores the operator state, if this operator's execution is recovering from a checkpoint. - * This method restores the operator state (if the operator is stateful) and the key/value state - * (if it had been used and was initialized when the snapshot occurred). + * Called to draw a state snapshot from the operator. * - * <p>This method is called after {@link #setup(StreamTask, StreamConfig, Output)} - * and before {@link #open()}. - * - * @param in The stream from which we have to restore our state. + * @return a runnable future to the state handle that points to the snapshotted state. For synchronous implementations, + * the runnable might already be finished. + * @throws Exception exception that happened during snapshotting. + */ + RunnableFuture<OperatorStateHandle> snapshotState( + long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception; + + /** + * Provides state handles to restore the operator state. * - * @throws Exception Exceptions during state restore should be forwarded, so that the system can - * properly react to failed state restore and fail the execution attempt. + * @param stateHandles state handles to the operator state. */ - void restoreState(FSDataInputStream in) throws Exception; + void restoreState(Collection<OperatorStateHandle> stateHandles); /** * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager. http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java index 4f85e3a..cc2e54b 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java @@ -24,13 +24,10 @@ import org.apache.flink.api.common.functions.BroadcastVariableInitializer; import org.apache.flink.api.common.functions.util.AbstractRuntimeUDFContext; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.state.OperatorState; import org.apache.flink.api.common.state.ReducingState; import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.streaming.api.CheckpointingMode; @@ -143,35 +140,6 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext { } } - @Override - @Deprecated - public <S> OperatorState<S> getKeyValueState(String name, Class<S> stateType, S defaultState) { - requireNonNull(stateType, "The state type class must not be null"); - - TypeInformation<S> typeInfo; - try { - typeInfo = TypeExtractor.getForClass(stateType); - } - catch (Exception e) { - throw new RuntimeException("Cannot analyze type '" + stateType.getName() + - "' from the class alone, due to generic type parameters. " + - "Please specify the TypeInformation directly.", e); - } - - return getKeyValueState(name, typeInfo, defaultState); - } - - @Override - @Deprecated - public <S> OperatorState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) { - requireNonNull(name, "The name of the state must not be null"); - requireNonNull(stateType, "The state type information must not be null"); - - ValueStateDescriptor<S> stateProps = - new ValueStateDescriptor<>(name, stateType, defaultState); - return getState(stateProps); - } - // ------------------ expose (read only) relevant information from the stream config -------- // /** http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java index 35d1108..b5500b7 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java @@ -27,6 +27,7 @@ import org.apache.flink.runtime.io.disk.InputViewIterator; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.util.ReusingMutableToRegularIteratorWrapper; +import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.watermark.Watermark; @@ -51,7 +52,9 @@ import java.util.UUID; * * @param <IN> Type of the elements emitted by this sink */ -public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<IN> implements OneInputStreamOperator<IN, IN> { +public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<IN> + implements OneInputStreamOperator<IN, IN>, StreamCheckpointedOperator { + private static final long serialVersionUID = 1L; protected static final Logger LOG = LoggerFactory.getLogger(GenericWriteAheadSink.class); @@ -110,7 +113,6 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception { - super.snapshotState(out, checkpointId, timestamp); saveHandleInState(checkpointId, timestamp); @@ -119,7 +121,6 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I @Override public void restoreState(FSDataInputStream in) throws Exception { - super.restoreState(in); this.state = InstantiationUtil.deserializeObject(in, getUserCodeClassloader()); } @@ -151,11 +152,19 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I try { if (!committer.isCheckpointCommitted(pastCheckpointId)) { Tuple2<Long, StreamStateHandle> handle = state.pendingHandles.get(pastCheckpointId); - FSDataInputStream in = handle.f1.openInputStream(); - boolean success = sendValues(new ReusingMutableToRegularIteratorWrapper<>(new InputViewIterator<>(new DataInputViewStreamWrapper(in), serializer), serializer), handle.f0); - if (success) { //if the sending has failed we will retry on the next notify - committer.commitCheckpoint(pastCheckpointId); - checkpointsToRemove.add(pastCheckpointId); + try (FSDataInputStream in = handle.f1.openInputStream()) { + boolean success = sendValues( + new ReusingMutableToRegularIteratorWrapper<>( + new InputViewIterator<>( + new DataInputViewStreamWrapper( + in), + serializer), + serializer), + handle.f0); + if (success) { //if the sending has failed we will retry on the next notify + committer.commitCheckpoint(pastCheckpointId); + checkpointsToRemove.add(pastCheckpointId); + } } } else { checkpointsToRemove.add(pastCheckpointId); http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java index 4de7729..a838faa 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java @@ -88,7 +88,7 @@ public class EvictingWindowOperator<K, IN, OUT, W extends Window> extends Window element.getTimestamp(), windowAssignerContext); - final K key = (K) getStateBackend().getCurrentKey(); + final K key = (K) getKeyedStateBackend().getCurrentKey(); if (windowAssigner instanceof MergingWindowAssigner) { @@ -122,7 +122,7 @@ public class EvictingWindowOperator<K, IN, OUT, W extends Window> extends Window } // merge the merged state windows into the newly resulting state window - getStateBackend().mergePartitionedStates( + getKeyedStateBackend().mergePartitionedStates( stateWindowResult, mergedStateWindows, windowSerializer, http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java index e4939db..ffdf334 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java @@ -298,7 +298,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window> Collection<W> elementWindows = windowAssigner.assignWindows( element.getValue(), element.getTimestamp(), windowAssignerContext); - final K key = (K) getStateBackend().getCurrentKey(); + final K key = (K) getKeyedStateBackend().getCurrentKey(); if (windowAssigner instanceof MergingWindowAssigner) { MergingWindowSet<W> mergingWindows = getMergingWindowSet(); @@ -329,7 +329,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window> } // merge the merged state windows into the newly resulting state window - getStateBackend().mergePartitionedStates( + getKeyedStateBackend().mergePartitionedStates( stateWindowResult, mergedStateWindows, windowSerializer, @@ -554,18 +554,18 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window> */ @SuppressWarnings("unchecked") protected MergingWindowSet<W> getMergingWindowSet() throws Exception { - MergingWindowSet<W> mergingWindows = mergingWindowsByKey.get((K) getStateBackend().getCurrentKey()); + MergingWindowSet<W> mergingWindows = mergingWindowsByKey.get((K) getKeyedStateBackend().getCurrentKey()); if (mergingWindows == null) { // try to retrieve from state TupleSerializer<Tuple2<W, W>> tupleSerializer = new TupleSerializer<>((Class) Tuple2.class, new TypeSerializer[] {windowSerializer, windowSerializer} ); ListStateDescriptor<Tuple2<W, W>> mergeStateDescriptor = new ListStateDescriptor<>("merging-window-set", tupleSerializer); - ListState<Tuple2<W, W>> mergeState = getStateBackend().getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor); + ListState<Tuple2<W, W>> mergeState = getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor); mergingWindows = new MergingWindowSet<>((MergingWindowAssigner<? super IN, W>) windowAssigner, mergeState); mergeState.clear(); - mergingWindowsByKey.put((K) getStateBackend().getCurrentKey(), mergingWindows); + mergingWindowsByKey.put((K) getKeyedStateBackend().getCurrentKey(), mergingWindows); } return mergingWindows; } @@ -709,7 +709,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window> public <S extends MergingState<?, ?>> void mergePartitionedState(StateDescriptor<S, ?> stateDescriptor) { if (mergedWindows != null && mergedWindows.size() > 0) { try { - WindowOperator.this.getStateBackend().mergePartitionedStates(window, + WindowOperator.this.getKeyedStateBackend().mergePartitionedStates(window, mergedWindows, windowSerializer, stateDescriptor); @@ -869,7 +869,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window> ListStateDescriptor<Tuple2<W, W>> mergeStateDescriptor = new ListStateDescriptor<>("merging-window-set", tupleSerializer); for (Map.Entry<K, MergingWindowSet<W>> key: mergingWindowsByKey.entrySet()) { setKeyContext(key.getKey()); - ListState<Tuple2<W, W>> mergeState = getStateBackend().getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor); + ListState<Tuple2<W, W>> mergeState = getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor); mergeState.clear(); key.getValue().persist(mergeState); } http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java index 0e24516..9e96f5d 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java @@ -17,12 +17,6 @@ package org.apache.flink.streaming.runtime.tasks; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; @@ -35,20 +29,25 @@ import org.apache.flink.runtime.plugable.SerializationDelegate; import org.apache.flink.streaming.api.collector.selector.CopyingDirectedOutput; import org.apache.flink.streaming.api.collector.selector.DirectedOutput; import org.apache.flink.streaming.api.collector.selector.OutputSelector; -import org.apache.flink.streaming.api.watermark.Watermark; -import org.apache.flink.streaming.runtime.io.RecordWriterOutput; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.graph.StreamEdge; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.io.RecordWriterOutput; import org.apache.flink.streaming.runtime.io.StreamRecordWriter; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + /** * The {@code OperatorChain} contains all operators that are executed as one chain within a single * {@link StreamTask}. @@ -57,7 +56,7 @@ import org.slf4j.LoggerFactory; * head operator. */ @Internal -public class OperatorChain<OUT> { +public class OperatorChain<OUT, OP extends StreamOperator<OUT>> { private static final Logger LOG = LoggerFactory.getLogger(OperatorChain.class); @@ -66,16 +65,17 @@ public class OperatorChain<OUT> { private final RecordWriterOutput<?>[] streamOutputs; private final Output<StreamRecord<OUT>> chainEntryPoint; - - public OperatorChain(StreamTask<OUT, ?> containingTask, - StreamOperator<OUT> headOperator, - AccumulatorRegistry.Reporter reporter) { + private final OP headOperator; + + public OperatorChain(StreamTask<OUT, OP> containingTask, AccumulatorRegistry.Reporter reporter) { final ClassLoader userCodeClassloader = containingTask.getUserCodeClassLoader(); final StreamConfig configuration = containingTask.getConfiguration(); final boolean enableTimestamps = containingTask.isSerializingTimestamps(); + headOperator = configuration.getStreamOperator(userCodeClassloader); + // we read the chained configs, and the order of record writer registrations by output name Map<Integer, StreamConfig> chainedConfigs = configuration.getTransitiveChainedTaskConfigs(userCodeClassloader); chainedConfigs.put(configuration.getVertexID(), configuration); @@ -104,11 +104,15 @@ public class OperatorChain<OUT> { List<StreamOperator<?>> allOps = new ArrayList<>(chainedConfigs.size()); this.chainEntryPoint = createOutputCollector(containingTask, configuration, chainedConfigs, userCodeClassloader, streamOutputMap, allOps); + + if (headOperator != null) { + headOperator.setup(containingTask, configuration, getChainEntryPoint()); + } + + // add head operator to end of chain + allOps.add(headOperator); - this.allOperators = allOps.toArray(new StreamOperator<?>[allOps.size() + 1]); - - // add the head operator to the end of the list - this.allOperators[this.allOperators.length - 1] = headOperator; + this.allOperators = allOps.toArray(new StreamOperator<?>[allOps.size()]); success = true; } @@ -181,7 +185,15 @@ public class OperatorChain<OUT> { } } } - + + public OP getHeadOperator() { + return headOperator; + } + + public int getChainLength() { + return allOperators == null ? 0 : allOperators.length; + } + // ------------------------------------------------------------------------ // initialization utilities // ------------------------------------------------------------------------ http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 7976f01..1725eca 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -26,14 +26,19 @@ import org.apache.flink.configuration.IllegalConfigurationException; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.metrics.Gauge; import org.apache.flink.runtime.execution.CancelTaskException; +import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.ClosableRegistry; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateBackendFactory; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.FsStateBackend; @@ -41,24 +46,23 @@ import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; import org.apache.flink.streaming.api.TimeCharacteristic; +import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.runtime.io.RecordWriterOutput; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; - +import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.Closeable; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.RunnableFuture; @@ -70,19 +74,19 @@ import java.util.concurrent.ScheduledThreadPoolExecutor; * the Task's operator chain. Operators that are chained together execute synchronously in the * same thread and hence on the same stream partition. A common case for these chains * are successive map/flatmap/filter tasks. - * - * <p>The task chain contains one "head" operator and multiple chained operators. + * + * <p>The task chain contains one "head" operator and multiple chained operators. * The StreamTask is specialized for the type of the head operator: one-input and two-input tasks, * as well as for sources, iteration heads and iteration tails. - * - * <p>The Task class deals with the setup of the streams read by the head operator, and the streams + * + * <p>The Task class deals with the setup of the streams read by the head operator, and the streams * produced by the operators at the ends of the operator chain. Note that the chain may fork and * thus have multiple ends. * - * The life cycle of the task is set up as follows: + * The life cycle of the task is set up as follows: * <pre>{@code - * -- restoreState() -> restores state of all operators in the chain - * + * -- getPartitionableState() -> restores state of all operators in the chain + * * -- invoke() * | * +----> Create basic utils (config, etc) and load the chain of operators @@ -99,35 +103,35 @@ import java.util.concurrent.ScheduledThreadPoolExecutor; * <p> The {@code StreamTask} has a lock object called {@code lock}. All calls to methods on a * {@code StreamOperator} must be synchronized on this lock object to ensure that no methods * are called concurrently. - * + * * @param <OUT> - * @param <Operator> + * @param <OP> */ @Internal -public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> +public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>> extends AbstractInvokable implements StatefulTask, AsyncExceptionHandler { /** The thread group that holds all trigger timer threads */ public static final ThreadGroup TRIGGER_THREAD_GROUP = new ThreadGroup("Triggers"); - + /** The logger used by the StreamTask and its subclasses */ private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class); - + // ------------------------------------------------------------------------ - + /** * All interaction with the {@code StreamOperator} must be synchronized on this lock object to ensure that * we don't have concurrent method calls that void consistent checkpoints. */ private final Object lock = new Object(); - + /** the head operator that consumes the input streams of this task */ - protected Operator headOperator; + protected OP headOperator; /** The chain of operators executed by this task */ - private OperatorChain<OUT> operatorChain; - + private OperatorChain<OUT, OP> operatorChain; + /** The configuration of this streaming task */ private StreamConfig configuration; @@ -135,7 +139,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> private AbstractStateBackend stateBackend; /** Keyed state backend for the head operator, if it is keyed. There can only ever be one. */ - private KeyedStateBackend<?> keyedStateBackend; + private AbstractKeyedStateBackend<?> keyedStateBackend; /** * The internal {@link TimeServiceProvider} used to define the current @@ -146,12 +150,14 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> /** The map of user-defined accumulators of this task */ private Map<String, Accumulator<?, ?>> accumulatorMap; - + /** The chained operator state to be restored once the initialization is done */ private ChainedStateHandle<StreamStateHandle> lazyRestoreChainedOperatorState; private List<KeyGroupsStateHandle> lazyRestoreKeyGroupStates; + private List<Collection<OperatorStateHandle>> lazyRestoreOperatorState; + /** * This field is used to forward an exception that is caught in the timer thread or other * asynchronous Threads. Subclasses must ensure that exceptions stored here get thrown on the @@ -159,12 +165,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> private volatile AsynchronousException asyncException; /** The currently active background materialization threads */ - private final Set<Closeable> cancelables = new HashSet<>(); - + private final ClosableRegistry cancelables = new ClosableRegistry(); + /** Flag to mark the task "in operation", in which case check * needs to be initialized to true, so that early cancel() before invoke() behaves correctly */ private volatile boolean isRunning; - + /** Flag to mark this task as canceled */ private volatile boolean canceled; @@ -178,11 +184,11 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // ------------------------------------------------------------------------ protected abstract void init() throws Exception; - + protected abstract void run() throws Exception; - + protected abstract void cleanup() throws Exception; - + protected abstract void cancelTask() throws Exception; // ------------------------------------------------------------------------ @@ -232,13 +238,8 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> timerService = DefaultTimeServiceProvider.create(this, executor, getCheckpointLock()); } - headOperator = configuration.getStreamOperator(getUserCodeClassLoader()); - operatorChain = new OperatorChain<>(this, headOperator, - getEnvironment().getAccumulatorRegistry().getReadWriteReporter()); - - if (headOperator != null) { - headOperator.setup(this, configuration, operatorChain.getChainEntryPoint()); - } + operatorChain = new OperatorChain<>(this, getEnvironment().getAccumulatorRegistry().getReadWriteReporter()); + headOperator = operatorChain.getHeadOperator(); getEnvironment().getMetricGroup().gauge("lastCheckpointSize", new Gauge<Long>() { @Override @@ -249,12 +250,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // task specific initialization init(); - + // save the work of reloadig state, etc, if the task is already canceled if (canceled) { throw new CancelTaskException(); } - + // -------- Invoke -------- LOG.debug("Invoking {}", getName()); @@ -278,7 +279,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> run(); LOG.debug("Finished task {}", getName()); - + // make sure no further checkpoint and notification actions happen. // we make sure that no other thread is currently in the locked scope before // we close the operators by trying to acquire the checkpoint scope lock @@ -286,13 +287,13 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // at the same time, this makes sure that during any "regular" exit where still synchronized (lock) { isRunning = false; - + // this is part of the main logic, so if this fails, the task is considered failed closeAllOperators(); } LOG.debug("Closed operators for task {}", getName()); - + // make sure all buffered data is flushed operatorChain.flushOutputs(); @@ -324,7 +325,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // stop all asynchronous checkpoint threads try { - closeAllClosables(); + cancelables.close(); shutdownAsyncThreads(); } catch (Throwable t) { @@ -371,13 +372,13 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> isRunning = false; canceled = true; cancelTask(); - closeAllClosables(); + cancelables.close(); } public final boolean isRunning() { return isRunning; } - + public final boolean isCanceled() { return canceled; } @@ -476,36 +477,14 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> } } - closeAllClosables(); - } - - private void closeAllClosables() { - // first, create a copy of the cancelables to prevent concurrent modifications - // and to not hold the lock for too long. the copy can be a cheap list - List<Closeable> localCancelables = null; - synchronized (cancelables) { - if (cancelables.size() > 0) { - localCancelables = new ArrayList<>(cancelables); - cancelables.clear(); - } - } - - if (localCancelables != null) { - for (Closeable cancelable : localCancelables) { - try { - cancelable.close(); - } catch (Throwable t) { - LOG.error("Error on canceling operation", t); - } - } - } + cancelables.close(); } boolean isSerializingTimestamps() { TimeCharacteristic tc = configuration.getTimeCharacteristic(); return tc == TimeCharacteristic.EventTime | tc == TimeCharacteristic.IngestionTime; } - + // ------------------------------------------------------------------------ // Access to properties and utilities // ------------------------------------------------------------------------ @@ -525,7 +504,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> public Object getCheckpointLock() { return lock; } - + public StreamConfig getConfiguration() { return configuration; } @@ -533,11 +512,11 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> public Map<String, Accumulator<?, ?>> getAccumulatorMap() { return accumulatorMap; } - + Output<StreamRecord<OUT>> getHeadOutput() { return operatorChain.getChainEntryPoint(); } - + RecordWriterOutput<?>[] getStreamOutputs() { return operatorChain.getStreamOutputs(); } @@ -547,40 +526,59 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // ------------------------------------------------------------------------ @Override - public void setInitialState(ChainedStateHandle<StreamStateHandle> chainedState, List<KeyGroupsStateHandle> keyGroupsState) { + public void setInitialState( + ChainedStateHandle<StreamStateHandle> chainedState, + List<KeyGroupsStateHandle> keyGroupsState, + List<Collection<OperatorStateHandle>> partitionableOperatorState) { + lazyRestoreChainedOperatorState = chainedState; lazyRestoreKeyGroupStates = keyGroupsState; + lazyRestoreOperatorState = partitionableOperatorState; } private void restoreState() throws Exception { final StreamOperator<?>[] allOperators = operatorChain.getAllOperators(); - try { - if (lazyRestoreChainedOperatorState != null) { + if (lazyRestoreChainedOperatorState != null) { + Preconditions.checkState(lazyRestoreChainedOperatorState.getLength() == allOperators.length, + "Invalid Invalid number of operator states. Found :" + lazyRestoreChainedOperatorState.getLength() + + ". Expected: " + allOperators.length); + } - synchronized (cancelables) { - cancelables.add(lazyRestoreChainedOperatorState); - } + if (lazyRestoreOperatorState != null) { + Preconditions.checkArgument(lazyRestoreOperatorState.isEmpty() + || lazyRestoreOperatorState.size() == allOperators.length, + "Invalid number of operator states. Found :" + lazyRestoreOperatorState.size() + + ". Expected: " + allOperators.length); + } - for (int i = 0; i < lazyRestoreChainedOperatorState.getLength(); i++) { + for (int i = 0; i < allOperators.length; i++) { + StreamOperator<?> operator = allOperators[i]; + + if (null != lazyRestoreOperatorState && !lazyRestoreOperatorState.isEmpty()) { + operator.restoreState(lazyRestoreOperatorState.get(i)); + } + + // TODO deprecated code path + if (operator instanceof StreamCheckpointedOperator) { + + if (lazyRestoreChainedOperatorState != null) { StreamStateHandle state = lazyRestoreChainedOperatorState.get(i); - if (state == null) { - continue; - } - StreamOperator<?> operator = allOperators[i]; - if (operator != null) { + if (state != null) { LOG.debug("Restore state of task {} in chain ({}).", i, getName()); - try (FSDataInputStream inputStream = state.openInputStream()) { - operator.restoreState(inputStream); + + FSDataInputStream is = state.openInputStream(); + try { + cancelables.registerClosable(is); + ((StreamCheckpointedOperator) operator).restoreState(is); + } finally { + cancelables.unregisterClosable(is); + is.close(); } } } } - } finally { - synchronized (cancelables) { - cancelables.remove(lazyRestoreChainedOperatorState); - } } } @@ -629,29 +627,58 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // Given this, we immediately emit the checkpoint barriers, so the downstream operators // can start their checkpoint work as soon as possible operatorChain.broadcastCheckpointBarrier(checkpointId, timestamp); - + // now draw the state snapshot final StreamOperator<?>[] allOperators = operatorChain.getAllOperators(); - final List<StreamStateHandle> nonPartitionedStates = Arrays.asList(new StreamStateHandle[allOperators.length]); + + final List<StreamStateHandle> nonPartitionedStates = + Arrays.asList(new StreamStateHandle[allOperators.length]); + + final List<OperatorStateHandle> operatorStates = + Arrays.asList(new OperatorStateHandle[allOperators.length]); for (int i = 0; i < allOperators.length; i++) { StreamOperator<?> operator = allOperators[i]; if (operator != null) { + + final String operatorId = createOperatorIdentifier(operator, configuration.getVertexID()); + CheckpointStreamFactory streamFactory = - stateBackend.createStreamFactory( - getEnvironment().getJobID(), - createOperatorIdentifier( - operator, - configuration.getVertexID())); + stateBackend.createStreamFactory(getEnvironment().getJobID(), operatorId); + + //TODO deprecated code path + if (operator instanceof StreamCheckpointedOperator) { + + CheckpointStreamFactory.CheckpointStateOutputStream outStream = + streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp); + + + cancelables.registerClosable(outStream); + + try { + ((StreamCheckpointedOperator) operator). + snapshotState(outStream, checkpointId, timestamp); + + nonPartitionedStates.set(i, outStream.closeAndGetHandle()); + } finally { + cancelables.unregisterClosable(outStream); + } + } - CheckpointStreamFactory.CheckpointStateOutputStream outStream = - streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp); + RunnableFuture<OperatorStateHandle> handleFuture = + operator.snapshotState(checkpointId, timestamp, streamFactory); - operator.snapshotState(outStream, checkpointId, timestamp); + if (null != handleFuture) { + //TODO for now we assume there are only synchrous snapshots, no need to start the runnable. + if (!handleFuture.isDone()) { + throw new IllegalStateException("Currently only supports synchronous snapshots!"); + } - nonPartitionedStates.set(i, outStream.closeAndGetHandle()); + operatorStates.set(i, handleFuture.get()); + } } + } RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture = null; @@ -659,16 +686,16 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> if (keyedStateBackend != null) { CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory( getEnvironment().getJobID(), - createOperatorIdentifier( - headOperator, - configuration.getVertexID())); - keyGroupsStateHandleFuture = keyedStateBackend.snapshot( - checkpointId, - timestamp, - streamFactory); + createOperatorIdentifier(headOperator, configuration.getVertexID())); + + keyGroupsStateHandleFuture = keyedStateBackend.snapshot(checkpointId, timestamp, streamFactory); } - ChainedStateHandle<StreamStateHandle> chainedStateHandles = new ChainedStateHandle<>(nonPartitionedStates); + ChainedStateHandle<StreamStateHandle> chainedNonPartitionedStateHandles = + new ChainedStateHandle<>(nonPartitionedStates); + + ChainedStateHandle<OperatorStateHandle> chainedPartitionedStateHandles = + new ChainedStateHandle<>(operatorStates); LOG.debug("Finished synchronous checkpoints for checkpoint {} on task {}", checkpointId, getName()); @@ -679,7 +706,8 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> "checkpoint-" + checkpointId + "-" + timestamp, this, cancelables, - chainedStateHandles, + chainedNonPartitionedStateHandles, + chainedPartitionedStateHandles, keyGroupsStateHandleFuture, checkpointId, bytesBufferedAlignment, @@ -687,9 +715,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> syncDurationMillis, endOfSyncPart); - synchronized (cancelables) { - cancelables.add(asyncCheckpointRunnable); - } + cancelables.registerClosable(asyncCheckpointRunnable); asyncOperationsThreadPool.submit(asyncCheckpointRunnable); return true; } else { @@ -707,7 +733,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> synchronized (lock) { if (isRunning) { LOG.debug("Notification of complete checkpoint for task {}", getName()); - + for (StreamOperator<?> operator : operatorChain.getAllOperators()) { if (operator != null) { operator.notifyOfCompletedCheckpoint(checkpointId); @@ -760,7 +786,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> Class<? extends StateBackendFactory> clazz = Class.forName(backendName, false, getUserCodeClassLoader()).asSubclass(StateBackendFactory.class); - stateBackend = ((StateBackendFactory<?>) clazz.newInstance()).createFromConfig(flinkConfig); + stateBackend = clazz.newInstance().createFromConfig(flinkConfig); } catch (ClassNotFoundException e) { throw new IllegalConfigurationException("Cannot find configured state backend: " + backendName); } catch (ClassCastException e) { @@ -772,10 +798,26 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> } } } + return stateBackend; } - public <K> KeyedStateBackend<K> createKeyedStateBackend( + public OperatorStateBackend createOperatorStateBackend( + StreamOperator<?> op, Collection<OperatorStateHandle> restoreStateHandles) throws Exception { + + Environment env = getEnvironment(); + String opId = createOperatorIdentifier(op, configuration.getVertexID()); + + OperatorStateBackend newBackend = restoreStateHandles == null ? + stateBackend.createOperatorStateBackend(env, opId) + : stateBackend.restoreOperatorStateBackend(env, opId, restoreStateHandles); + + cancelables.registerClosable(newBackend); + + return newBackend; + } + + public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend( TypeSerializer<K> keySerializer, int numberOfKeyGroups, KeyGroupRange keyGroupRange) throws Exception { @@ -811,8 +853,10 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> getEnvironment().getTaskKvStateRegistry()); } + cancelables.registerClosable(keyedStateBackend); + @SuppressWarnings("unchecked") - KeyedStateBackend<K> typedBackend = (KeyedStateBackend<K>) keyedStateBackend; + AbstractKeyedStateBackend<K> typedBackend = (AbstractKeyedStateBackend<K>) keyedStateBackend; return typedBackend; } @@ -825,9 +869,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> public CheckpointStreamFactory createCheckpointStreamFactory(StreamOperator<?> operator) throws IOException { return stateBackend.createStreamFactory( getEnvironment().getJobID(), - createOperatorIdentifier( - operator, - configuration.getVertexID())); + createOperatorIdentifier(operator, configuration.getVertexID())); } @@ -867,7 +909,6 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> if (isRunning) { LOG.error("Asynchronous exception registered.", exception); } - if (this.asyncException == null) { this.asyncException = exception; } @@ -877,20 +918,23 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // Utilities // ------------------------------------------------------------------------ + @Override public String toString() { return getName(); } // ------------------------------------------------------------------------ - + private static class AsyncCheckpointRunnable implements Runnable, Closeable { private final StreamTask<?, ?> owner; - private final Set<Closeable> cancelables; + private final ClosableRegistry cancelables; + + private final ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles; - private final ChainedStateHandle<StreamStateHandle> chainedStateHandles; + private final ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles; private final RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture; @@ -909,8 +953,9 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> AsyncCheckpointRunnable( String name, StreamTask<?, ?> owner, - Set<Closeable> cancelables, - ChainedStateHandle<StreamStateHandle> chainedStateHandles, + ClosableRegistry cancelables, + ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles, + ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles, RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture, long checkpointId, long bytesBufferedInAlignment, @@ -921,7 +966,8 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> this.name = name; this.owner = owner; this.cancelables = cancelables; - this.chainedStateHandles = chainedStateHandles; + this.nonPartitionedStateHandles = nonPartitionedStateHandles; + this.partitioneableStateHandles = partitioneableStateHandles; this.keyGroupsStateHandleFuture = keyGroupsStateHandleFuture; this.checkpointId = checkpointId; this.bytesBufferedInAlignment = bytesBufferedInAlignment; @@ -952,13 +998,19 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> final long asyncEndNanos = System.nanoTime(); final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000; - if (chainedStateHandles.isEmpty() && keyedStates.isEmpty()) { + if (nonPartitionedStateHandles.isEmpty() && keyedStates.isEmpty()) { owner.getEnvironment().acknowledgeCheckpoint(checkpointId, syncDurationMillies, asyncDurationMillis, bytesBufferedInAlignment, alignmentDurationNanos); } else { + + CheckpointStateHandles allStateHandles = new CheckpointStateHandles( + nonPartitionedStateHandles, + partitioneableStateHandles, + keyedStates); + owner.getEnvironment().acknowledgeCheckpoint(checkpointId, - chainedStateHandles, keyedStates, + allStateHandles, syncDurationMillies, asyncDurationMillis, bytesBufferedInAlignment, alignmentDurationNanos); } @@ -974,9 +1026,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> owner.registerAsyncException(asyncException); } finally { - synchronized (cancelables) { - cancelables.remove(this); - } + cancelables.unregisterClosable(this); } } http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java index fe09788..02409a3 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java @@ -36,8 +36,8 @@ import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.memory.MemoryStateBackend; @@ -188,15 +188,15 @@ public class StreamingRuntimeContextTest { public ListState<String> answer(InvocationOnMock invocationOnMock) throws Throwable { ListStateDescriptor<String> descr = (ListStateDescriptor<String>) invocationOnMock.getArguments()[0]; - KeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend( + + AbstractKeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend( new DummyEnvironment("test_task", 1, 0), new JobID(), "test_op", IntSerializer.INSTANCE, 1, new KeyGroupRange(0, 0), - new KvStateRegistry().createTaskRegistry(new JobID(), - new JobVertexID())); + new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID())); backend.setCurrentKey(0); return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr); } http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java index b549ef8..5d68841 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java @@ -21,6 +21,7 @@ package org.apache.flink.streaming.runtime.io; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; +import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; @@ -28,15 +29,15 @@ import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.io.network.api.CheckpointBarrier; - import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import java.io.File; import java.util.Arrays; +import java.util.Collection; import java.util.List; import static org.junit.Assert.assertEquals; @@ -974,7 +975,8 @@ public class BarrierBufferTest { @Override public void setInitialState( ChainedStateHandle<StreamStateHandle> chainedState, - List<KeyGroupsStateHandle> keyGroupsState) throws Exception { + List<KeyGroupsStateHandle> keyGroupsState, + List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception { throw new UnsupportedOperationException("should never be called"); } http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java index 314dcc4..f2f9092 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java @@ -19,21 +19,25 @@ package org.apache.flink.streaming.runtime.io; import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.io.network.api.CheckpointBarrier; - import org.junit.Test; import java.util.Arrays; +import java.util.Collection; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * Tests for the behavior of the barrier tracker. @@ -363,7 +367,8 @@ public class BarrierTrackerTest { @Override public void setInitialState( ChainedStateHandle<StreamStateHandle> chainedState, - List<KeyGroupsStateHandle> keyGroupsState) throws Exception { + List<KeyGroupsStateHandle> keyGroupsState, + List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception { throw new UnsupportedOperationException("should never be called"); } http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java index f4ac5b2..32e8ea9 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java @@ -19,7 +19,6 @@ package org.apache.flink.streaming.runtime.operators; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.execution.Environment; @@ -27,8 +26,6 @@ import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.operators.testutils.MockEnvironment; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; -import org.apache.flink.runtime.state.AbstractStateBackend; -import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.collector.selector.OutputSelector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SplitStream; @@ -42,19 +39,15 @@ import org.apache.flink.streaming.runtime.tasks.OperatorChain; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.junit.Assert; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.hamcrest.MatcherAssert.assertThat; /** * Tests for stream operator chaining behaviour. @@ -156,9 +149,8 @@ public class StreamOperatorChainingTest { StreamTask<Integer, StreamMap<Integer, Integer>> mockTask = createMockTask(streamConfig, chainedVertex.getName()); - OperatorChain<Integer> operatorChain = new OperatorChain<>( + OperatorChain<Integer, StreamMap<Integer, Integer>> operatorChain = new OperatorChain<>( mockTask, - headOperator, mock(AccumulatorRegistry.Reporter.class)); headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint()); @@ -299,9 +291,8 @@ public class StreamOperatorChainingTest { StreamTask<Integer, StreamMap<Integer, Integer>> mockTask = createMockTask(streamConfig, chainedVertex.getName()); - OperatorChain<Integer> operatorChain = new OperatorChain<>( + OperatorChain<Integer, StreamMap<Integer, Integer>> operatorChain = new OperatorChain<>( mockTask, - headOperator, mock(AccumulatorRegistry.Reporter.class)); headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint()); http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java index 6a7b024..b5b6582 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java @@ -41,9 +41,9 @@ import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.AbstractCloseableHandle; import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; @@ -56,19 +56,23 @@ import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.StreamSource; import org.apache.flink.util.SerializedValue; - import org.junit.Test; import java.io.EOFException; import java.io.IOException; import java.io.Serializable; import java.net.URL; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.concurrent.Executor; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * This test checks that task restores that get stuck in the presence of interrupts @@ -121,6 +125,7 @@ public class InterruptSensitiveRestoreTest { ChainedStateHandle<StreamStateHandle> operatorState = new ChainedStateHandle<>(Collections.singletonList(state)); List<KeyGroupsStateHandle> keyGroupState = Collections.emptyList(); + List<Collection<OperatorStateHandle>> partitionableOperatorState = Collections.emptyList(); return new TaskDeploymentDescriptor( new JobID(), @@ -139,42 +144,47 @@ public class InterruptSensitiveRestoreTest { Collections.<URL>emptyList(), 0, operatorState, - keyGroupState); + keyGroupState, + partitionableOperatorState); } - + private static Task createTask(TaskDeploymentDescriptor tdd) throws IOException { NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class); when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); return new Task( - tdd, - mock(MemoryManager.class), - mock(IOManager.class), - networkEnvironment, - mock(BroadcastVariableManager.class), + tdd, + mock(MemoryManager.class), + mock(IOManager.class), + networkEnvironment, + mock(BroadcastVariableManager.class), mock(TaskManagerConnection.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), - new FallbackLibraryCacheManager(), - new FileCache(new Configuration()), - new TaskManagerRuntimeInfo( - "localhost", new Configuration(), EnvironmentInformation.getTemporaryFileDirectory()), - new UnregisteredTaskMetricsGroup(), - mock(ResultPartitionConsumableNotifier.class), - mock(PartitionStateChecker.class), - mock(Executor.class)); - + new FallbackLibraryCacheManager(), + new FileCache(new Configuration()), + new TaskManagerRuntimeInfo( + "localhost", new Configuration(), EnvironmentInformation.getTemporaryFileDirectory()), + new UnregisteredTaskMetricsGroup(), + mock(ResultPartitionConsumableNotifier.class), + mock(PartitionStateChecker.class), + mock(Executor.class)); + } // ------------------------------------------------------------------------ @SuppressWarnings("serial") - private static class InterruptLockingStateHandle extends AbstractCloseableHandle implements StreamStateHandle { + private static class InterruptLockingStateHandle implements StreamStateHandle { + + private volatile boolean closed; @Override public FSDataInputStream openInputStream() throws IOException { - ensureNotClosed(); + + closed = false; + FSDataInputStream is = new FSDataInputStream() { @Override @@ -191,8 +201,14 @@ public class InterruptSensitiveRestoreTest { block(); throw new EOFException(); } + + @Override + public void close() throws IOException { + super.close(); + closed = true; + } }; - registerCloseable(is); + return is; } @@ -207,7 +223,7 @@ public class InterruptSensitiveRestoreTest { } } catch (InterruptedException e) { - while (!isClosed()) { + while (!closed) { try { synchronized (this) { wait(); @@ -227,7 +243,7 @@ public class InterruptSensitiveRestoreTest { } // ------------------------------------------------------------------------ - + private static class TestSource implements SourceFunction<Object>, Checkpointed<Serializable> { private static final long serialVersionUID = 1L; @@ -250,4 +266,4 @@ public class InterruptSensitiveRestoreTest { fail("should never be called"); } } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java index 88fb383..4003e59 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java @@ -21,7 +21,10 @@ package org.apache.flink.streaming.runtime.tasks; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.FSDataInputStream; @@ -31,8 +34,12 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.CheckpointStateHandles; +import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.graph.StreamEdge; import org.apache.flink.streaming.api.graph.StreamNode; @@ -56,17 +63,23 @@ import scala.concurrent.duration.FiniteDuration; import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.RunnableFuture; import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * Tests for {@link OneInputStreamTask}. @@ -82,6 +95,9 @@ import static org.junit.Assert.assertTrue; @PowerMockIgnore({"javax.management.*", "com.sun.jndi.*"}) public class OneInputStreamTaskTest extends TestLogger { + private static final ListStateDescriptor<Integer> TEST_DESCRIPTOR = + new ListStateDescriptor<>("test", new IntSerializer()); + /** * This test verifies that open() and close() are correctly called. This test also verifies * that timestamps of emitted elements are correct. {@link StreamMap} assigns the input @@ -358,7 +374,7 @@ public class OneInputStreamTaskTest extends TestLogger { testHarness.invoke(env); testHarness.waitForTaskRunning(deadline.timeLeft().toMillis()); - streamTask.triggerCheckpoint(checkpointId, checkpointTimestamp); + while(!streamTask.triggerCheckpoint(checkpointId, checkpointTimestamp)); // since no state was set, there shouldn't be restore calls assertEquals(0, TestingStreamOperator.numberRestoreCalls); @@ -371,7 +387,7 @@ public class OneInputStreamTaskTest extends TestLogger { testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis()); final OneInputStreamTask<String, String> restoredTask = new OneInputStreamTask<String, String>(); - restoredTask.setInitialState(env.getState(), env.getKeyGroupStates()); + restoredTask.setInitialState(env.getState(), env.getKeyGroupStates(), env.getPartitionableOperatorState()); final OneInputStreamTaskTestHarness<String, String> restoredTaskHarness = new OneInputStreamTaskTestHarness<String, String>(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO); @@ -465,6 +481,7 @@ public class OneInputStreamTaskTest extends TestLogger { private volatile long checkpointId; private volatile ChainedStateHandle<StreamStateHandle> state; private volatile List<KeyGroupsStateHandle> keyGroupStates; + private volatile List<Collection<OperatorStateHandle>> partitionableOperatorState; private final OneShotLatch checkpointLatch = new OneShotLatch(); @@ -486,6 +503,10 @@ public class OneInputStreamTaskTest extends TestLogger { return result; } + List<Collection<OperatorStateHandle>> getPartitionableOperatorState() { + return partitionableOperatorState; + } + AcknowledgeStreamMockEnvironment( Configuration jobConfig, Configuration taskConfig, ExecutionConfig executionConfig, long memorySize, @@ -497,13 +518,21 @@ public class OneInputStreamTaskTest extends TestLogger { @Override public void acknowledgeCheckpoint( long checkpointId, - ChainedStateHandle<StreamStateHandle> state, List<KeyGroupsStateHandle> keyGroupStates, + CheckpointStateHandles checkpointStateHandles, long syncDuration, long asymcDuration, long alignmentByte, long alignmentDuration) { this.checkpointId = checkpointId; - this.state = state; - this.keyGroupStates = keyGroupStates; - + if(checkpointStateHandles != null) { + this.state = checkpointStateHandles.getNonPartitionedStateHandles(); + this.keyGroupStates = checkpointStateHandles.getKeyGroupsStateHandle(); + ChainedStateHandle<OperatorStateHandle> chainedStateHandle = checkpointStateHandles.getPartitioneableStateHandles(); + Collection<OperatorStateHandle>[] ia = new Collection[chainedStateHandle.getLength()]; + this.partitionableOperatorState = Arrays.asList(ia); + + for (int i = 0; i < chainedStateHandle.getLength(); ++i) { + partitionableOperatorState.set(i, Collections.singletonList(chainedStateHandle.get(i))); + } + } checkpointLatch.trigger(); } @@ -513,17 +542,56 @@ public class OneInputStreamTaskTest extends TestLogger { } private static class TestingStreamOperator<IN, OUT> - extends AbstractStreamOperator<OUT> implements OneInputStreamOperator<IN, OUT> { + extends AbstractStreamOperator<OUT> + implements OneInputStreamOperator<IN, OUT>, StreamCheckpointedOperator { private static final long serialVersionUID = 774614855940397174L; public static int numberRestoreCalls = 0; + public static int numberSnapshotCalls = 0; private final long seed; private final long recoveryTimestamp; private transient Random random; + @Override + public void open() throws Exception { + super.open(); + + ListState<Integer> partitionableState = getOperatorStateBackend().getPartitionableState(TEST_DESCRIPTOR); + + if (numberSnapshotCalls == 0) { + for (Integer v : partitionableState.get()) { + fail(); + } + } else { + Set<Integer> result = new HashSet<>(); + for (Integer v : partitionableState.get()) { + result.add(v); + } + + assertEquals(2, result.size()); + assertTrue(result.contains(42)); + assertTrue(result.contains(4711)); + } + } + + @Override + public RunnableFuture<OperatorStateHandle> snapshotState( + long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception { + + ListState<Integer> partitionableState = + getOperatorStateBackend().getPartitionableState(TEST_DESCRIPTOR); + partitionableState.clear(); + + partitionableState.add(42); + partitionableState.add(4711); + + ++numberSnapshotCalls; + return super.snapshotState(checkpointId, timestamp, streamFactory); + } + TestingStreamOperator(long seed, long recoveryTimestamp) { this.seed = seed; this.recoveryTimestamp = recoveryTimestamp;
