This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch release-1.8
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.8 by this push:
     new 531d727  [FLINK-12296][StateBackend] Fix local state directory 
collision with state loss for chained keyed operators
531d727 is described below

commit 531d727f9b32c310d8d63b253019b8cc4a23a3eb
Author: klion26 <qcx978132...@gmail.com>
AuthorDate: Wed Apr 24 04:52:03 2019 +0200

    [FLINK-12296][StateBackend] Fix local state directory collision with state 
loss for chained keyed operators
    
    - Change
    Will change the local data path from
    
`.../local_state_root/allocatio_id/job_id/jobvertext_id_subtask_id/chk_id/rocksdb`
    to
    
`.../local_state_root/allocatio_id/job_id/jobvertext_id_subtask_id/chk_id/operator_id`
    
    When preparing the local directory Flink deletes the local directory for 
each subtask if it already exists,
    If more than one stateful operators chained in a single task, they'll share 
the same local directory path,
    then the local directory will be deleted unexpectedly, and the we'll get 
data loss.
    
    This closes #8263.
    
    (cherry picked from commit ee60846dc588b1a832a497ff9522d7a3a282c350)
---
 .../CheckpointStreamWithResultProviderTest.java    |   3 +
 .../state/StateSnapshotCompressionTest.java        |   2 +-
 .../ttl/mock/MockKeyedStateBackendBuilder.java     |   1 +
 .../runtime/state/ttl/mock/MockStateBackend.java   |   2 +-
 .../state/RocksDBKeyedStateBackendBuilder.java     |   1 +
 .../snapshot/RocksIncrementalSnapshotStrategy.java |  17 +-
 .../tasks/OneInputStreamTaskTestHarness.java       |  50 +++-
 .../runtime/tasks/StreamConfigChainer.java         |  23 +-
 .../runtime/tasks/StreamMockEnvironment.java       |   8 +-
 .../runtime/tasks/StreamTaskTestHarness.java       |  21 +-
 .../state/StatefulOperatorChainedTaskTest.java     | 260 +++++++++++++++++++++
 11 files changed, 369 insertions(+), 19 deletions(-)

diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java
index 2af25d9..57653e2 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java
@@ -35,6 +35,9 @@ import java.io.Closeable;
 import java.io.File;
 import java.io.IOException;
 
+/**
+ * Test for CheckpointStreamWithResultProvider.
+ */
 public class CheckpointStreamWithResultProviderTest extends TestLogger {
 
        private static TemporaryFolder temporaryFolder;
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
index a10be26..de687ff 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.commons.io.IOUtils;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
@@ -34,6 +33,7 @@ import 
org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.util.TestLogger;
 
+import org.apache.commons.io.IOUtils;
 import org.junit.Assert;
 import org.junit.Test;
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackendBuilder.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackendBuilder.java
index efef923..2196dc9 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackendBuilder.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackendBuilder.java
@@ -30,6 +30,7 @@ import 
org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 
 import javax.annotation.Nonnull;
+
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
index bdf07bf..f50f1b6 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
@@ -35,8 +35,8 @@ import 
org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CheckpointedStateScope;
 import org.apache.flink.runtime.state.CompletedCheckpointStorageLocation;
-import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java
index 9b98e6e..0393155 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java
@@ -231,6 +231,7 @@ public class RocksDBKeyedStateBackendBuilder<K> extends 
AbstractKeyedStateBacken
                }
        }
 
+       @Override
        public RocksDBKeyedStateBackend<K> build() throws 
BackendBuildingException {
                RocksDBWriteBatchWrapper writeBatchWrapper = null;
                ColumnFamilyHandle defaultColumnFamilyHandle = null;
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
index 889b18d..38d5e7a 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
@@ -106,6 +106,9 @@ public class RocksIncrementalSnapshotStrategy<K> extends 
RocksDBSnapshotStrategy
        /** The help class used to upload state files. */
        private final RocksDBStateUploader stateUploader;
 
+       /** The local directory name of the current snapshot strategy. */
+       private final String localDirectoryName;
+
        public RocksIncrementalSnapshotStrategy(
                @Nonnull RocksDB db,
                @Nonnull ResourceGuard rocksDBResourceGuard,
@@ -137,6 +140,7 @@ public class RocksIncrementalSnapshotStrategy<K> extends 
RocksDBSnapshotStrategy
                this.materializedSstFiles = materializedSstFiles;
                this.lastCompletedCheckpointId = lastCompletedCheckpointId;
                this.stateUploader = new 
RocksDBStateUploader(numberOfTransferingThreads);
+               this.localDirectoryName = 
backendUID.toString().replaceAll("[\\-]", "");
        }
 
        @Nonnull
@@ -184,17 +188,18 @@ public class RocksIncrementalSnapshotStrategy<K> extends 
RocksDBSnapshotStrategy
                        LocalRecoveryDirectoryProvider directoryProvider = 
localRecoveryConfig.getLocalStateDirectoryProvider();
                        File directory = 
directoryProvider.subtaskSpecificCheckpointDirectory(checkpointId);
 
-                       if (directory.exists()) {
-                               FileUtils.deleteDirectory(directory);
-                       }
-
-                       if (!directory.mkdirs()) {
+                       if (!directory.exists() && !directory.mkdirs()) {
                                throw new IOException("Local state base 
directory for checkpoint " + checkpointId +
                                        " already exists: " + directory);
                        }
 
                        // introduces an extra directory because RocksDB wants 
a non-existing directory for native checkpoints.
-                       File rdbSnapshotDir = new File(directory, "rocks_db");
+                       // append localDirectoryName here to solve directory 
collision problem when two stateful operators chained in one task.
+                       File rdbSnapshotDir = new File(directory, 
localDirectoryName);
+                       if (rdbSnapshotDir.exists()) {
+                               FileUtils.deleteDirectory(rdbSnapshotDir);
+                       }
+
                        Path path = new Path(rdbSnapshotDir.toURI());
                        // create a "permanent" snapshot directory because 
local recovery is active.
                        try {
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
index 89a4f81..7ac0cf3 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
@@ -25,7 +25,10 @@ import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.execution.Environment;
 import 
org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate;
+import org.apache.flink.runtime.state.LocalRecoveryConfig;
+import org.apache.flink.runtime.state.TestLocalRecoveryConfig;
 
+import java.io.File;
 import java.io.IOException;
 import java.util.function.Function;
 
@@ -56,16 +59,48 @@ public class OneInputStreamTaskTestHarness<IN, OUT> extends 
StreamTaskTestHarnes
 
        /**
         * Creates a test harness with the specified number of input gates and 
specified number
-        * of channels per input gate.
+        * of channels per input gate and local recovery disabled.
+        */
+       public OneInputStreamTaskTestHarness(
+               Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory,
+               int numInputGates,
+               int numInputChannelsPerGate,
+               TypeInformation<IN> inputType,
+               TypeInformation<OUT> outputType) {
+               this(taskFactory, numInputGates, numInputChannelsPerGate, 
inputType, outputType, TestLocalRecoveryConfig.disabled());
+       }
+
+       public OneInputStreamTaskTestHarness(
+               Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory,
+               int numInputGates,
+               int numInputChannelsPerGate,
+               TypeInformation<IN> inputType,
+               TypeInformation<OUT> outputType,
+               File localRootDir) {
+               super(taskFactory, outputType, localRootDir);
+
+               this.inputType = inputType;
+               inputSerializer = inputType.createSerializer(executionConfig);
+
+               this.numInputGates = numInputGates;
+               this.numInputChannelsPerGate = numInputChannelsPerGate;
+
+               streamConfig.setStateKeySerializer(inputSerializer);
+       }
+
+       /**
+        * Creates a test harness with the specified number of input gates and 
specified number
+        * of channels per input gate and specified localRecoveryConfig.
         */
        public OneInputStreamTaskTestHarness(
                        Function<Environment, ? extends StreamTask<OUT, ?>> 
taskFactory,
                        int numInputGates,
                        int numInputChannelsPerGate,
                        TypeInformation<IN> inputType,
-                       TypeInformation<OUT> outputType) {
+                       TypeInformation<OUT> outputType,
+                       LocalRecoveryConfig localRecoveryConfig) {
 
-               super(taskFactory, outputType);
+               super(taskFactory, outputType, localRecoveryConfig);
 
                this.inputType = inputType;
                inputSerializer = inputType.createSerializer(executionConfig);
@@ -78,11 +113,10 @@ public class OneInputStreamTaskTestHarness<IN, OUT> 
extends StreamTaskTestHarnes
         * Creates a test harness with one input gate that has one input 
channel.
         */
        public OneInputStreamTaskTestHarness(
-                       Function<Environment, ? extends StreamTask<OUT, ?>> 
taskFactory,
-                       TypeInformation<IN> inputType,
-                       TypeInformation<OUT> outputType) {
-
-               this(taskFactory, 1, 1, inputType, outputType);
+               Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory,
+               TypeInformation<IN> inputType,
+               TypeInformation<OUT> outputType) {
+               this(taskFactory, 1, 1, inputType, outputType, 
TestLocalRecoveryConfig.disabled());
        }
 
        @Override
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java
index 10e50ce..747468e 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java
@@ -61,6 +61,14 @@ public class StreamConfigChainer {
        }
 
        public <T> StreamConfigChainer chain(
+               OperatorID operatorID,
+               OneInputStreamOperator<T, T> operator,
+               TypeSerializer<T> typeSerializer,
+               boolean createKeyedStateBackend) {
+               return chain(operatorID, operator, typeSerializer, 
typeSerializer, createKeyedStateBackend);
+       }
+
+       public <T> StreamConfigChainer chain(
                        OperatorID operatorID,
                        OneInputStreamOperator<T, T> operator,
                        TypeSerializer<T> typeSerializer) {
@@ -68,10 +76,19 @@ public class StreamConfigChainer {
        }
 
        public <IN, OUT> StreamConfigChainer chain(
+               OperatorID operatorID,
+               OneInputStreamOperator<IN, OUT> operator,
+               TypeSerializer<IN> inputSerializer,
+               TypeSerializer<OUT> outputSerializer) {
+               return chain(operatorID, operator, inputSerializer, 
outputSerializer, false);
+       }
+
+       public <IN, OUT> StreamConfigChainer chain(
                        OperatorID operatorID,
                        OneInputStreamOperator<IN, OUT> operator,
                        TypeSerializer<IN> inputSerializer,
-                       TypeSerializer<OUT> outputSerializer) {
+                       TypeSerializer<OUT> outputSerializer,
+                       boolean createKeyedStateBackend) {
                chainIndex++;
 
                tailConfig.setChainedOutputs(Collections.singletonList(
@@ -87,6 +104,10 @@ public class StreamConfigChainer {
                tailConfig.setOperatorID(checkNotNull(operatorID));
                tailConfig.setTypeSerializerIn1(inputSerializer);
                tailConfig.setTypeSerializerOut(outputSerializer);
+               if (createKeyedStateBackend) {
+                       // used to test multiple stateful operators chained in 
a single task.
+                       tailConfig.setStateKeySerializer(inputSerializer);
+               }
                tailConfig.setChainIndex(chainIndex);
 
                chainedConfigs.put(chainIndex, tailConfig);
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
index 6cd7617..134218a 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
@@ -106,6 +106,8 @@ public class StreamMockEnvironment implements Environment {
 
        private TaskEventDispatcher taskEventDispatcher = 
mock(TaskEventDispatcher.class);
 
+       private TaskManagerRuntimeInfo taskManagerRuntimeInfo = new 
TestingTaskManagerRuntimeInfo();
+
        public StreamMockEnvironment(
                Configuration jobConfig,
                Configuration taskConfig,
@@ -332,7 +334,11 @@ public class StreamMockEnvironment implements Environment {
 
        @Override
        public TaskManagerRuntimeInfo getTaskManagerInfo() {
-               return new TestingTaskManagerRuntimeInfo();
+               return this.taskManagerRuntimeInfo;
+       }
+
+       public void setTaskManagerInfo(TaskManagerRuntimeInfo 
taskManagerRuntimeInfo) {
+               this.taskManagerRuntimeInfo = taskManagerRuntimeInfo;
        }
 
        @Override
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
index b2f1b99..f46f91b 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
@@ -19,6 +19,7 @@
 package org.apache.flink.streaming.runtime.tasks;
 
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.Configuration;
@@ -26,10 +27,14 @@ import 
org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.execution.Environment;
 import 
org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.LocalRecoveryConfig;
+import org.apache.flink.runtime.state.LocalRecoveryDirectoryProviderImpl;
+import org.apache.flink.runtime.state.TestLocalRecoveryConfig;
 import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.collector.selector.OutputSelector;
@@ -47,6 +52,7 @@ import org.apache.flink.util.Preconditions;
 
 import org.junit.Assert;
 
+import java.io.File;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.LinkedList;
@@ -109,7 +115,20 @@ public class StreamTaskTestHarness<OUT> {
        public StreamTaskTestHarness(
                        Function<Environment, ? extends StreamTask<OUT, ?>> 
taskFactory,
                        TypeInformation<OUT> outputType) {
+               this(taskFactory, outputType, 
TestLocalRecoveryConfig.disabled());
+       }
 
+       public StreamTaskTestHarness(
+               Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory,
+               TypeInformation<OUT> outputType,
+               File localRootDir) {
+               this(taskFactory, outputType, new LocalRecoveryConfig(true, new 
LocalRecoveryDirectoryProviderImpl(localRootDir, new JobID(), new 
JobVertexID(), 0)));
+       }
+
+       public StreamTaskTestHarness(
+               Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory,
+               TypeInformation<OUT> outputType,
+               LocalRecoveryConfig localRecoveryConfig) {
                this.taskFactory = checkNotNull(taskFactory);
                this.memorySize = DEFAULT_MEMORY_MANAGER_SIZE;
                this.bufferSize = DEFAULT_NETWORK_BUFFER_SIZE;
@@ -123,7 +142,7 @@ public class StreamTaskTestHarness<OUT> {
                outputSerializer = outputType.createSerializer(executionConfig);
                outputStreamRecordSerializer = new 
StreamElementSerializer<OUT>(outputSerializer);
 
-               this.taskStateManager = new TestTaskStateManager();
+               this.taskStateManager = new 
TestTaskStateManager(localRecoveryConfig);
        }
 
        public ProcessingTimeService getProcessingTimeService() {
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/state/StatefulOperatorChainedTaskTest.java
 
b/flink-tests/src/test/java/org/apache/flink/test/state/StatefulOperatorChainedTaskTest.java
new file mode 100644
index 0000000..5651929
--- /dev/null
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/state/StatefulOperatorChainedTaskTest.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.state;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.state.TestTaskStateManager;
+import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
+import org.apache.flink.streaming.runtime.tasks.OneInputStreamTaskTestHarness;
+import org.apache.flink.streaming.runtime.tasks.StreamMockEnvironment;
+import org.apache.flink.streaming.util.TestHarnessUtil;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+import static 
org.apache.flink.configuration.CheckpointingOptions.CHECKPOINTS_DIRECTORY;
+import static 
org.apache.flink.configuration.CheckpointingOptions.INCREMENTAL_CHECKPOINTS;
+import static 
org.apache.flink.configuration.CheckpointingOptions.STATE_BACKEND;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Test for StatefulOperatorChainedTaskTest.
+ */
+public class StatefulOperatorChainedTaskTest {
+
+       private static final Set<OperatorID> RESTORED_OPERATORS = 
ConcurrentHashMap.newKeySet();
+       private TemporaryFolder temporaryFolder;
+
+       @Before
+       public void setup() throws IOException {
+               RESTORED_OPERATORS.clear();
+               temporaryFolder = new TemporaryFolder();
+               temporaryFolder.create();
+       }
+
+       @Test
+       public void testMultipleStatefulOperatorChainedSnapshotAndRestore() 
throws Exception {
+
+               OperatorID headOperatorID = new OperatorID(42L, 42L);
+               OperatorID tailOperatorID = new OperatorID(44L, 44L);
+
+               JobManagerTaskRestore restore = 
createRunAndCheckpointOperatorChain(
+                       headOperatorID,
+                       new CounterOperator("head"),
+                       tailOperatorID,
+                       new CounterOperator("tail"),
+                       Optional.empty());
+
+               TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
+
+               assertEquals(2, stateHandles.getSubtaskStateMappings().size());
+
+               createRunAndCheckpointOperatorChain(
+                       headOperatorID,
+                       new CounterOperator("head"),
+                       tailOperatorID,
+                       new CounterOperator("tail"),
+                       Optional.of(restore));
+
+               assertEquals(new HashSet<>(Arrays.asList(headOperatorID, 
tailOperatorID)), RESTORED_OPERATORS);
+       }
+
+       private JobManagerTaskRestore createRunAndCheckpointOperatorChain(
+               OperatorID headId,
+               OneInputStreamOperator<String, String> headOperator,
+               OperatorID tailId,
+               OneInputStreamOperator<String, String> tailOperator,
+               Optional<JobManagerTaskRestore> restore) throws Exception {
+
+               File localRootDir = temporaryFolder.newFolder();
+               final OneInputStreamTaskTestHarness<String, String> testHarness 
=
+                       new OneInputStreamTaskTestHarness<>(
+                               OneInputStreamTask::new,
+                               1, 1,
+                               BasicTypeInfo.STRING_TYPE_INFO,
+                               BasicTypeInfo.STRING_TYPE_INFO,
+                               localRootDir);
+
+               testHarness.setupOperatorChain(headId, headOperator)
+                       .chain(tailId, tailOperator, StringSerializer.INSTANCE, 
true)
+                       .finish();
+
+               if (restore.isPresent()) {
+                       JobManagerTaskRestore taskRestore = restore.get();
+                       testHarness.setTaskStateSnapshot(
+                               taskRestore.getRestoreCheckpointId(),
+                               taskRestore.getTaskStateSnapshot());
+               }
+
+               StreamMockEnvironment environment = new StreamMockEnvironment(
+                       testHarness.jobConfig,
+                       testHarness.taskConfig,
+                       testHarness.getExecutionConfig(),
+                       testHarness.memorySize,
+                       new MockInputSplitProvider(),
+                       testHarness.bufferSize,
+                       testHarness.getTaskStateManager());
+
+               Configuration configuration = new Configuration();
+               configuration.setString(STATE_BACKEND.key(), "rocksdb");
+               File file = temporaryFolder.newFolder();
+               configuration.setString(CHECKPOINTS_DIRECTORY.key(), 
file.toURI().toString());
+               configuration.setString(INCREMENTAL_CHECKPOINTS.key(), "true");
+               environment.setTaskManagerInfo(
+                       new TestingTaskManagerRuntimeInfo(
+                               configuration,
+                               System.getProperty("java.io.tmpdir").split(",|" 
+ File.pathSeparator)));
+               testHarness.invoke(environment);
+               testHarness.waitForTaskRunning();
+
+               OneInputStreamTask<String, String> streamTask = 
testHarness.getTask();
+
+               processRecords(testHarness);
+               triggerCheckpoint(testHarness, streamTask);
+
+               TestTaskStateManager taskStateManager = 
testHarness.getTaskStateManager();
+
+               JobManagerTaskRestore jobManagerTaskRestore = new 
JobManagerTaskRestore(
+                       taskStateManager.getReportedCheckpointId(),
+                       taskStateManager.getLastJobManagerTaskStateSnapshot());
+
+               testHarness.endInput();
+               testHarness.waitForTaskCompletion();
+               return jobManagerTaskRestore;
+       }
+
+       private void triggerCheckpoint(
+               OneInputStreamTaskTestHarness<String, String> testHarness,
+               OneInputStreamTask<String, String> streamTask) throws Exception 
{
+
+               long checkpointId = 1L;
+               CheckpointMetaData checkpointMetaData = new 
CheckpointMetaData(checkpointId, 1L);
+
+               testHarness.getTaskStateManager().setWaitForReportLatch(new 
OneShotLatch());
+
+               while (!streamTask.triggerCheckpoint(checkpointMetaData, 
CheckpointOptions.forCheckpointWithDefaultLocation(), false)) {}
+
+               
testHarness.getTaskStateManager().getWaitForReportLatch().await();
+               long reportedCheckpointId = 
testHarness.getTaskStateManager().getReportedCheckpointId();
+
+               assertEquals(checkpointId, reportedCheckpointId);
+       }
+
+       private void processRecords(OneInputStreamTaskTestHarness<String, 
String> testHarness) throws Exception {
+               ConcurrentLinkedQueue<Object> expectedOutput = new 
ConcurrentLinkedQueue<>();
+
+               testHarness.processElement(new StreamRecord<>("10"), 0, 0);
+               testHarness.processElement(new StreamRecord<>("20"), 0, 0);
+               testHarness.processElement(new StreamRecord<>("30"), 0, 0);
+
+               testHarness.waitForInputProcessing();
+
+               expectedOutput.add(new StreamRecord<>("10"));
+               expectedOutput.add(new StreamRecord<>("20"));
+               expectedOutput.add(new StreamRecord<>("30"));
+               TestHarnessUtil.assertOutputEquals("Output was not correct.", 
expectedOutput, testHarness.getOutput());
+       }
+
+       private abstract static class RestoreWatchOperator<IN, OUT>
+               extends AbstractStreamOperator<OUT>
+               implements OneInputStreamOperator<IN, OUT> {
+
+               @Override
+               public void initializeState(StateInitializationContext context) 
throws Exception {
+                       if (context.isRestored()) {
+                               RESTORED_OPERATORS.add(getOperatorID());
+                       }
+               }
+       }
+
+       /**
+        * Operator that counts processed messages and keeps result on state.
+        */
+       private static class CounterOperator extends 
RestoreWatchOperator<String, String> {
+               private static final long serialVersionUID = 
2048954179291813243L;
+
+               private static long snapshotOutData = 0L;
+               private ValueState<Long> counterState;
+               private long counter = 0;
+               private String prefix;
+
+               CounterOperator(String prefix) {
+                       this.prefix = prefix;
+               }
+
+               @Override
+               public void processElement(StreamRecord<String> element) throws 
Exception {
+                       counter++;
+                       output.collect(element);
+               }
+
+               @Override
+               public void initializeState(StateInitializationContext context) 
throws Exception {
+                       super.initializeState(context);
+
+                       counterState = context
+                               .getKeyedStateStore()
+                               .getState(new ValueStateDescriptor<>(prefix + 
"counter-state", LongSerializer.INSTANCE));
+
+                       // set key manually to make RocksDBListState get the 
serialized key.
+                       setCurrentKey("10");
+
+                       if (context.isRestored()) {
+                               counter =  counterState.value();
+                               assertEquals(snapshotOutData, counter);
+                               counterState.clear();
+                       }
+               }
+
+               @Override
+               public void snapshotState(StateSnapshotContext context) throws 
Exception {
+                       counterState.update(counter);
+                       snapshotOutData = counter;
+               }
+       }
+}
+

Reply via email to