Github user tillrohrmann commented on a diff in the pull request:
https://github.com/apache/flink/pull/5239#discussion_r165326694
--- Diff:
flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/LocalStateForwardingTest.java
---
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.tasks;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.checkpoint.StateObjectCollection;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.DoneFuture;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.TaskLocalStateStore;
+import org.apache.flink.runtime.state.TaskStateManagerImpl;
+import org.apache.flink.runtime.state.TestTaskStateManager;
+import org.apache.flink.runtime.taskmanager.TestCheckpointResponder;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.Future;
+import java.util.concurrent.RunnableFuture;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static org.mockito.Mockito.mock;
+
+/**
+ * Test for forwarding of state reporting to and from {@link
org.apache.flink.runtime.state.TaskStateManager}.
+ */
+public class LocalStateForwardingTest {
+
+ /**
+ * This tests the forwarding of jm and tm-local state from the futures
reported by the backends, through the
+ * async checkpointing thread to the {@link
org.apache.flink.runtime.state.TaskStateManager}.
+ */
+ @Test
+ public void testForwardingFromSnapshotToTaskStateManager() throws
Exception {
+
+ TestTaskStateManager taskStateManager = new
TestTaskStateManager();
+
+ StreamMockEnvironment streamMockEnvironment = new
StreamMockEnvironment(
+ new Configuration(),
+ new Configuration(),
+ new ExecutionConfig(),
+ 1024*1024,
+ new MockInputSplitProvider(),
+ 0,
+ taskStateManager);
+
+ StreamTask testStreamTask = new
StreamTaskTest.NoOpStreamTask(streamMockEnvironment);
+ CheckpointMetaData checkpointMetaData = new
CheckpointMetaData(0L, 0L);
+ CheckpointMetrics checkpointMetrics = new CheckpointMetrics();
+
+ Map<OperatorID, OperatorSnapshotFutures> snapshots = new
HashMap<>(1);
+ OperatorSnapshotFutures osFuture = new
OperatorSnapshotFutures();
+
+
osFuture.setKeyedStateManagedFuture(createSnapshotResult(KeyedStateHandle.class));
+
osFuture.setKeyedStateRawFuture(createSnapshotResult(KeyedStateHandle.class));
+
osFuture.setOperatorStateManagedFuture(createSnapshotResult(OperatorStateHandle.class));
+
osFuture.setOperatorStateRawFuture(createSnapshotResult(OperatorStateHandle.class));
+
+ OperatorID operatorID = new OperatorID();
+ snapshots.put(operatorID, osFuture);
+
+ StreamTask.AsyncCheckpointRunnable checkpointRunnable =
+ new StreamTask.AsyncCheckpointRunnable(
+ testStreamTask,
+ snapshots,
+ checkpointMetaData,
+ checkpointMetrics,
+ 0L);
+
+ checkpointRunnable.run();
+
+ TaskStateSnapshot lastJobManagerTaskStateSnapshot =
taskStateManager.getLastJobManagerTaskStateSnapshot();
+ TaskStateSnapshot lastTaskManagerTaskStateSnapshot =
taskStateManager.getLastTaskManagerTaskStateSnapshot();
+
+ OperatorSubtaskState jmState =
+
lastJobManagerTaskStateSnapshot.getSubtaskStateByOperatorID(operatorID);
+
+ OperatorSubtaskState tmState =
+
lastTaskManagerTaskStateSnapshot.getSubtaskStateByOperatorID(operatorID);
+
+ performCheck(osFuture.getKeyedStateManagedFuture(),
jmState.getManagedKeyedState(), tmState.getManagedKeyedState());
+ performCheck(osFuture.getKeyedStateRawFuture(),
jmState.getRawKeyedState(), tmState.getRawKeyedState());
+ performCheck(osFuture.getOperatorStateManagedFuture(),
jmState.getManagedOperatorState(), tmState.getManagedOperatorState());
+ performCheck(osFuture.getOperatorStateRawFuture(),
jmState.getRawOperatorState(), tmState.getRawOperatorState());
+ }
+
+ /**
+ * This tests that state that was reported to the {@link
org.apache.flink.runtime.state.TaskStateManager} is also
+ * reported to {@link
org.apache.flink.runtime.taskmanager.CheckpointResponder} and {@link
TaskLocalStateStore}.
+ */
+ @Test
+ public void
testForwardingFromTaskStateManagerToResponderAndTaskLocalStateStore() {
+
+ final JobID jobID = new JobID();
+ final ExecutionAttemptID executionAttemptID = new
ExecutionAttemptID();
+ final CheckpointMetaData checkpointMetaData = new
CheckpointMetaData(42L, 4711L);
+ final CheckpointMetrics checkpointMetrics = new
CheckpointMetrics();
+ final int subtaskIdx = 42;
+ JobVertexID jobVertexID = new JobVertexID();
+
+ TaskStateSnapshot jm = new TaskStateSnapshot();
+ TaskStateSnapshot tm = new TaskStateSnapshot();
+
+ final AtomicBoolean jmReported = new AtomicBoolean(false);
+ final AtomicBoolean tmReported = new AtomicBoolean(false);
+
+ TestCheckpointResponder checkpointResponder = new
TestCheckpointResponder() {
+
+ @Override
+ public void acknowledgeCheckpoint(
+ JobID lJobID,
+ ExecutionAttemptID lExecutionAttemptID,
+ long lCheckpointId,
+ CheckpointMetrics lCheckpointMetrics,
+ TaskStateSnapshot lSubtaskState) {
+
+ Assert.assertEquals(jobID, lJobID);
+ Assert.assertEquals(executionAttemptID,
lExecutionAttemptID);
+
Assert.assertEquals(checkpointMetaData.getCheckpointId(), lCheckpointId);
+ Assert.assertEquals(checkpointMetrics,
lCheckpointMetrics);
+ jmReported.set(true);
+ }
+ };
+
+ TaskLocalStateStore taskLocalStateStore = new
TaskLocalStateStore(jobID, jobVertexID, subtaskIdx) {
+ @Override
+ public void storeLocalState(
+ @Nonnull CheckpointMetaData checkpointMetaData,
+ @Nullable TaskStateSnapshot localState) {
+
+ Assert.assertEquals(tm, localState);
+ tmReported.set(true);
+ }
+ };
+
+ TaskStateManagerImpl taskStateManager =
+ new TaskStateManagerImpl(
+ jobID,
+ executionAttemptID,
+ taskLocalStateStore,
+ null,
+ checkpointResponder);
+
+ taskStateManager.reportTaskStateSnapshots(
+ checkpointMetaData,
+ checkpointMetrics,
+ jm,
+ tm);
+
+ Assert.assertTrue("Reporting for JM state was not called.",
jmReported.get());
+ Assert.assertTrue("Reporting for TM state was not called.",
tmReported.get());
+ }
+
+ private static <T extends StateObject> void performCheck(
+ Future<SnapshotResult<T>> cc,
+ StateObjectCollection<T> jm,
+ StateObjectCollection<T> tm) {
--- End diff --
Naming of variables could be a bit more expressive.
---