Repository: flink Updated Branches: refs/heads/master de03e0cea -> fa664e5b9
[FLINK-4907] Add State Reshuffling in Operator Test Harness Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/94c65fbe Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/94c65fbe Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/94c65fbe Branch: refs/heads/master Commit: 94c65fbe59b77b32ea19983a0a96da41234daebc Parents: e396a5a Author: Aljoscha Krettek <[email protected]> Authored: Wed Oct 26 16:15:44 2016 +0200 Committer: Aljoscha Krettek <[email protected]> Committed: Wed Oct 26 23:26:28 2016 +0200 ---------------------------------------------------------------------- .../util/AbstractStreamOperatorTestHarness.java | 78 ++++++++++++++++++-- .../KeyedOneInputStreamOperatorTestHarness.java | 25 ++++++- 2 files changed, 95 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/94c65fbe/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java index dfc0af0..af1a7ba 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java @@ -24,12 +24,16 @@ import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.fs.FSDataOutputStream; +import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner; +import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner; +import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.testutils.MockEnvironment; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.ClosableRegistry; +import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.OperatorStateHandle; @@ -52,6 +56,7 @@ import org.apache.flink.util.FutureUtil; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; @@ -87,6 +92,9 @@ public class AbstractStreamOperatorTestHarness<OUT> { private final Object checkpointLock; + private final OperatorStateRepartitioner operatorStateRepartitioner = + RoundRobinOperatorStateRepartitioner.INSTANCE; + /** * Whether setup() was called on the operator. This is reset when calling close(). */ @@ -233,12 +241,72 @@ public class AbstractStreamOperatorTestHarness<OUT> { * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}. * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)} * if it was not called before. + * + * <p>This will reshape the state handles to include only those key-group states + * in the local key-group range and the operator states that would be assigned to the local + * subtask. */ public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception { if (!setupCalled) { setup(); } - operator.initializeState(operatorStateHandles); + + if (operatorStateHandles != null) { + int numKeyGroups = getEnvironment().getTaskInfo().getNumberOfKeyGroups(); + int numSubtasks = getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(); + int subtaskIndex = getEnvironment().getTaskInfo().getIndexOfThisSubtask(); + + // create a new OperatorStateHandles that only contains the state for our key-groups + + List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions( + numKeyGroups, + numSubtasks); + + KeyGroupRange localKeyGroupRange = + keyGroupPartitions.get(subtaskIndex); + + List<KeyGroupsStateHandle> localManagedKeyGroupState = null; + if (operatorStateHandles.getManagedKeyedState() != null) { + localManagedKeyGroupState = StateAssignmentOperation.getKeyGroupsStateHandles( + operatorStateHandles.getManagedKeyedState(), + localKeyGroupRange); + } + + List<KeyGroupsStateHandle> localRawKeyGroupState = null; + if (operatorStateHandles.getRawKeyedState() != null) { + localRawKeyGroupState = StateAssignmentOperation.getKeyGroupsStateHandles( + operatorStateHandles.getRawKeyedState(), + localKeyGroupRange); + } + + List<OperatorStateHandle> managedOperatorState = new ArrayList<>(); + if (operatorStateHandles.getManagedOperatorState() != null) { + managedOperatorState.addAll(operatorStateHandles.getManagedOperatorState()); + } + Collection<OperatorStateHandle> localManagedOperatorState = operatorStateRepartitioner.repartitionState( + managedOperatorState, + numSubtasks).get(subtaskIndex); + + List<OperatorStateHandle> rawOperatorState = new ArrayList<>(); + if (operatorStateHandles.getRawOperatorState() != null) { + rawOperatorState.addAll(operatorStateHandles.getRawOperatorState()); + } + Collection<OperatorStateHandle> localRawOperatorState = operatorStateRepartitioner.repartitionState( + rawOperatorState, + numSubtasks).get(subtaskIndex); + + OperatorStateHandles massagedOperatorStateHandles = new OperatorStateHandles( + 0, + null, + localManagedKeyGroupState, + localRawKeyGroupState, + localManagedOperatorState, + localRawOperatorState); + + operator.initializeState(massagedOperatorStateHandles); + } else { + operator.initializeState(null); + } initializeCalled = true; } @@ -275,10 +343,10 @@ public class AbstractStreamOperatorTestHarness<OUT> { OperatorStateHandles handles = new OperatorStateHandles( 0, null, - Collections.singletonList(keyedManaged), - Collections.singletonList(keyedRaw), - Collections.singletonList(opManaged), - Collections.singletonList(opRaw)); + keyedManaged != null ? Collections.singletonList(keyedManaged) : null, + keyedRaw != null ? Collections.singletonList(keyedRaw) : null, + opManaged != null ? Collections.singletonList(opManaged) : null, + opRaw != null ? Collections.singletonList(opRaw) : null); return handles; } http://git-wip-us.apache.org/repos/asf/flink/blob/94c65fbe/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java index 7d87eb8..0bdf5da 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.ClosureCleaner; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.KeyGroupRange; @@ -38,8 +39,8 @@ import org.mockito.stubbing.Answer; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; -import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.concurrent.RunnableFuture; import static org.mockito.Matchers.any; @@ -59,7 +60,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT> // when we restore we keep the state here so that we can call restore // when the operator requests the keyed state backend - private Collection<KeyGroupsStateHandle> restoredKeyedState = null; + private List<KeyGroupsStateHandle> restoredKeyedState = null; public KeyedOneInputStreamOperatorTestHarness( OneInputStreamOperator<IN, OUT> operator, @@ -186,7 +187,25 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT> @Override public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception { if (operatorStateHandles != null) { - restoredKeyedState = operatorStateHandles.getManagedKeyedState(); + int numKeyGroups = getEnvironment().getTaskInfo().getNumberOfKeyGroups(); + int numSubtasks = getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(); + int subtaskIndex = getEnvironment().getTaskInfo().getIndexOfThisSubtask(); + + // create a new OperatorStateHandles that only contains the state for our key-groups + + List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions( + numKeyGroups, + numSubtasks); + + KeyGroupRange localKeyGroupRange = + keyGroupPartitions.get(subtaskIndex); + + restoredKeyedState = null; + if (operatorStateHandles.getManagedKeyedState() != null) { + restoredKeyedState = StateAssignmentOperation.getKeyGroupsStateHandles( + operatorStateHandles.getManagedKeyedState(), + localKeyGroupRange); + } } super.initializeState(operatorStateHandles);
