StefanRRichter closed pull request #6537: [FLINK-10122] KafkaConsumer should
use partitionable state over union state if partition discovery is not active
URL: https://github.com/apache/flink/pull/6537
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
index cfb5b6d510d..3857a968dd5 100644
---
a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
+++
b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
@@ -105,8 +105,11 @@
/** Configuration key to define the consumer's partition discovery
interval, in milliseconds. */
public static final String KEY_PARTITION_DISCOVERY_INTERVAL_MILLIS =
"flink.partition-discovery.interval-millis";
+ /** For backwards compatibility. */
+ private static final String OLD_OFFSETS_STATE_NAME =
"topic-partition-offset-states";
+
/** State name of the consumer's partition offset states. */
- private static final String OFFSETS_STATE_NAME =
"topic-partition-offset-states";
+ private static final String OFFSETS_STATE_NAME =
"kafka-consumer-offsets";
//
------------------------------------------------------------------------
// configuration state, set on the client relevant for all subtasks
@@ -180,13 +183,7 @@
private transient volatile TreeMap<KafkaTopicPartition, Long>
restoredState;
/** Accessor for state in the operator state backend. */
- private transient ListState<Tuple2<KafkaTopicPartition, Long>>
unionOffsetStates;
-
- /**
- * Flag indicating whether the consumer is restored from older state
written with Flink 1.1 or 1.2.
- * When the current run is restored from older state, partition
discovery is disabled.
- */
- private boolean restoredFromOldState;
+ private transient ListState<Tuple2<KafkaTopicPartition, Long>>
offsetsState;
/** Discovery loop, executed in a separate thread. */
private transient volatile Thread discoveryLoopThread;
@@ -480,7 +477,7 @@ public void open(Configuration configuration) throws
Exception {
}
for (Map.Entry<KafkaTopicPartition, Long>
restoredStateEntry : restoredState.entrySet()) {
- if (!restoredFromOldState) {
+ if (discoveryIntervalMillis !=
PARTITION_DISCOVERY_DISABLED) {
// seed the partition discoverer with
the union state while filtering out
// restored partitions that should not
be subscribed by this subtask
if (KafkaTopicPartitionAssigner.assign(
@@ -489,8 +486,7 @@ public void open(Configuration configuration) throws
Exception {
subscribedPartitionsToStartOffsets.put(restoredStateEntry.getKey(),
restoredStateEntry.getValue());
}
} else {
- // when restoring from older 1.1 / 1.2
state, the restored state would not be the union state;
- // in this case, just use the restored
state as the subscribed partitions
+ // just restore from assigned partitions
subscribedPartitionsToStartOffsets.put(restoredStateEntry.getKey(),
restoredStateEntry.getValue());
}
}
@@ -783,30 +779,26 @@ public final void
initializeState(FunctionInitializationContext context) throws
OperatorStateStore stateStore = context.getOperatorStateStore();
- ListState<Tuple2<KafkaTopicPartition, Long>>
oldRoundRobinListState =
-
stateStore.getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
+ final TypeInformation<Tuple2<KafkaTopicPartition, Long>>
offsetStateTypeInfo =
+ TypeInformation.of(new
TypeHint<Tuple2<KafkaTopicPartition, Long>>() {});
- this.unionOffsetStates = stateStore.getUnionListState(new
ListStateDescriptor<>(
- OFFSETS_STATE_NAME,
- TypeInformation.of(new
TypeHint<Tuple2<KafkaTopicPartition, Long>>() {})));
+ ListStateDescriptor<Tuple2<KafkaTopicPartition, Long>>
offsetStateDescriptor =
+ new ListStateDescriptor<>(OFFSETS_STATE_NAME,
offsetStateTypeInfo);
- if (context.isRestored() && !restoredFromOldState) {
- restoredState = new TreeMap<>(new
KafkaTopicPartition.Comparator());
+ this.offsetsState =
+ discoveryIntervalMillis != PARTITION_DISCOVERY_DISABLED
?
+
stateStore.getUnionListState(offsetStateDescriptor) :
stateStore.getListState(offsetStateDescriptor);
- // migrate from 1.2 state, if there is any
- for (Tuple2<KafkaTopicPartition, Long> kafkaOffset :
oldRoundRobinListState.get()) {
- restoredFromOldState = true;
- unionOffsetStates.add(kafkaOffset);
- }
- oldRoundRobinListState.clear();
+ if (context.isRestored()) {
- if (restoredFromOldState && discoveryIntervalMillis !=
PARTITION_DISCOVERY_DISABLED) {
- throw new IllegalArgumentException(
- "Topic / partition discovery cannot be
enabled if the job is restored from a savepoint from Flink 1.2.x.");
- }
+ restoredState = new TreeMap<>(new
KafkaTopicPartition.Comparator());
+
+ // backwards compatibility
+ handleMigration_1_2(stateStore);
+ handleMigration_1_6(stateStore, offsetStateTypeInfo);
// populate actual holder for restored state
- for (Tuple2<KafkaTopicPartition, Long> kafkaOffset :
unionOffsetStates.get()) {
+ for (Tuple2<KafkaTopicPartition, Long> kafkaOffset :
offsetsState.get()) {
restoredState.put(kafkaOffset.f0,
kafkaOffset.f1);
}
@@ -816,19 +808,60 @@ public final void
initializeState(FunctionInitializationContext context) throws
}
}
+ private void handleMigration_1_2(
+ OperatorStateStore stateStore) throws Exception {
+ boolean restoredFromOldState = false;
+ ListState<Tuple2<KafkaTopicPartition, Long>>
oldRoundRobinListState =
+
stateStore.getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
+ for (Tuple2<KafkaTopicPartition, Long> kafkaOffset :
oldRoundRobinListState.get()) {
+ restoredFromOldState = true;
+ offsetsState.add(kafkaOffset);
+ }
+
+ // we remove this state again immediately so it will no longer
exist in future check/savepoints
+ oldRoundRobinListState.clear();
+
stateStore.removeOperatorState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
+
+ if (restoredFromOldState) {
+ if (discoveryIntervalMillis !=
PARTITION_DISCOVERY_DISABLED) {
+ throw new IllegalArgumentException(
+ "Topic / partition discovery cannot be
enabled if the job is restored from a savepoint from Flink 1.2.x.");
+ }
+ }
+ }
+
+ private void handleMigration_1_6(
+ OperatorStateStore stateStore,
+ TypeInformation<Tuple2<KafkaTopicPartition, Long>>
offsetStateTypeInfo) throws Exception {
+
+ ListStateDescriptor<Tuple2<KafkaTopicPartition, Long>>
oldUnionStateDescriptor =
+ new ListStateDescriptor<>(OLD_OFFSETS_STATE_NAME,
offsetStateTypeInfo);
+
+ ListState<Tuple2<KafkaTopicPartition, Long>> oldUnionListState =
+ stateStore.getUnionListState(oldUnionStateDescriptor);
+
+ for (Tuple2<KafkaTopicPartition, Long> kafkaOffset :
oldUnionListState.get()) {
+ offsetsState.add(kafkaOffset);
+ }
+
+ // we remove this state again immediately so it will no longer
exist in future check/savepoints
+ oldUnionListState.clear();
+
stateStore.removeOperatorState(oldUnionStateDescriptor.getName());
+ }
+
@Override
public final void snapshotState(FunctionSnapshotContext context) throws
Exception {
if (!running) {
LOG.debug("snapshotState() called on closed source");
} else {
- unionOffsetStates.clear();
+ offsetsState.clear();
final AbstractFetcher<?, ?> fetcher = this.kafkaFetcher;
if (fetcher == null) {
// the fetcher has not yet been initialized,
which means we need to return the
// originally restored offsets or the assigned
partitions
for (Map.Entry<KafkaTopicPartition, Long>
subscribedPartition : subscribedPartitionsToStartOffsets.entrySet()) {
-
unionOffsetStates.add(Tuple2.of(subscribedPartition.getKey(),
subscribedPartition.getValue()));
+
offsetsState.add(Tuple2.of(subscribedPartition.getKey(),
subscribedPartition.getValue()));
}
if (offsetCommitMode ==
OffsetCommitMode.ON_CHECKPOINTS) {
@@ -846,7 +879,7 @@ public final void snapshotState(FunctionSnapshotContext
context) throws Exceptio
}
for (Map.Entry<KafkaTopicPartition, Long>
kafkaTopicPartitionLongEntry : currentOffsets.entrySet()) {
- unionOffsetStates.add(
+ offsetsState.add(
Tuple2.of(kafkaTopicPartitionLongEntry.getKey(),
kafkaTopicPartitionLongEntry.getValue()));
}
}
diff --git
a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
index c9b52415a3e..26eff023b29 100644
---
a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
+++
b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
@@ -643,7 +643,7 @@ private void testRescaling(
Collections.singletonList("dummy-topic"),
null,
(KeyedDeserializationSchema < T >)
mock(KeyedDeserializationSchema.class),
- PARTITION_DISCOVERY_DISABLED,
+ 1L,
false);
this.testFetcher = testFetcher;
@@ -884,16 +884,16 @@ public OperatorID getOperatorID() {
private static class MockOperatorStateStore implements
OperatorStateStore {
- private final ListState<?> mockRestoredUnionListState;
+ private final ListState<?> mockListState;
private MockOperatorStateStore(ListState<?>
restoredUnionListState) {
- this.mockRestoredUnionListState =
restoredUnionListState;
+ this.mockListState = restoredUnionListState;
}
@Override
@SuppressWarnings("unchecked")
public <S> ListState<S>
getUnionListState(ListStateDescriptor<S> stateDescriptor) throws Exception {
- return (ListState<S>) mockRestoredUnionListState;
+ return (ListState<S>) mockListState;
}
@Override
@@ -914,9 +914,20 @@ private MockOperatorStateStore(ListState<?>
restoredUnionListState) {
throw new UnsupportedOperationException();
}
+ @SuppressWarnings("unchecked")
@Override
public <S> ListState<S> getListState(ListStateDescriptor<S>
stateDescriptor) throws Exception {
- throw new UnsupportedOperationException();
+ return (ListState<S>) mockListState;
+ }
+
+ @Override
+ public void removeOperatorState(String name) throws Exception {
+
+ }
+
+ @Override
+ public void removeBroadcastState(String name) throws Exception {
+
}
@Override
diff --git
a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
index 7a998e6149c..7f0cd6ad0a2 100644
---
a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
+++
b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
@@ -75,6 +75,16 @@
*/
<S> ListState<S> getListState(ListStateDescriptor<S> stateDescriptor)
throws Exception;
+ /**
+ * Removes the operator state with the given name from the state store,
if it exists.
+ */
+ void removeOperatorState(String name) throws Exception;
+
+ /**
+ * Removes the broadcast state with the given name from the state
store, if it exists.
+ */
+ void removeBroadcastState(String name) throws Exception;
+
/**
* Creates (or restores) a list state. Each state is registered under a
unique name.
* The provided serializer is used to de/serialize the state in case of
checkpointing (snapshot/restore).
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java
index 090f48a3c87..ce2d40309dd 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java
@@ -20,7 +20,6 @@
import org.apache.flink.runtime.state.OperatorStateHandle;
-import java.util.Collection;
import java.util.List;
/**
@@ -36,7 +35,7 @@
* @return List with one entry per parallel subtask. Each subtask
receives now one collection of states that build
* of the new total state for this subtask.
*/
- List<Collection<OperatorStateHandle>> repartitionState(
+ List<List<OperatorStateHandle>> repartitionState(
List<OperatorStateHandle> previousParallelSubtaskStates,
int newParallelism);
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
index e6fa687fd14..4705265430e 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
@@ -26,11 +26,11 @@
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Collection;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Set;
/**
* Current default implementation of {@link OperatorStateRepartitioner} that
redistributes state in round robin fashion.
@@ -41,7 +41,7 @@
private static final boolean OPTIMIZE_MEMORY_USE = false;
@Override
- public List<Collection<OperatorStateHandle>> repartitionState(
+ public List<List<OperatorStateHandle>> repartitionState(
List<OperatorStateHandle> previousParallelSubtaskStates,
int newParallelism) {
@@ -56,7 +56,7 @@
}
// Assemble result from all merge maps
- List<Collection<OperatorStateHandle>> result = new
ArrayList<>(newParallelism);
+ List<List<OperatorStateHandle>> result = new
ArrayList<>(newParallelism);
// Do the actual repartitioning for all named states
List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList =
@@ -93,20 +93,19 @@ private GroupByStateNameResults
groupByStateName(List<OperatorStateHandle> previ
continue;
}
- for (Map.Entry<String,
OperatorStateHandle.StateMetaInfo> e :
-
psh.getStateNameToPartitionOffsets().entrySet()) {
+ final Set<Map.Entry<String,
OperatorStateHandle.StateMetaInfo>> partitionOffsetEntries =
+ psh.getStateNameToPartitionOffsets().entrySet();
+
+ for (Map.Entry<String,
OperatorStateHandle.StateMetaInfo> e : partitionOffsetEntries) {
OperatorStateHandle.StateMetaInfo metaInfo =
e.getValue();
Map<String, List<Tuple2<StreamStateHandle,
OperatorStateHandle.StateMetaInfo>>> nameToState =
nameToStateByMode.get(metaInfo.getDistributionMode());
List<Tuple2<StreamStateHandle,
OperatorStateHandle.StateMetaInfo>> stateLocations =
- nameToState.get(e.getKey());
-
- if (stateLocations == null) {
- stateLocations = new ArrayList<>();
- nameToState.put(e.getKey(),
stateLocations);
- }
+ nameToState.computeIfAbsent(
+ e.getKey(),
+ k -> new
ArrayList<>(previousParallelSubtaskStates.size() *
partitionOffsetEntries.size()));
stateLocations.add(new
Tuple2<>(psh.getDelegateStateHandle(), e.getValue()));
}
@@ -203,7 +202,9 @@ private GroupByStateNameResults
groupByStateName(List<OperatorStateHandle> previ
Map<StreamStateHandle,
OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx);
OperatorStateHandle operatorStateHandle
= mergeMap.get(handleWithOffsets.f0);
if (operatorStateHandle == null) {
- operatorStateHandle = new
OperatorStreamStateHandle(new HashMap<>(), handleWithOffsets.f0);
+ operatorStateHandle = new
OperatorStreamStateHandle(
+ new
HashMap<>(distributeNameToState.size()),
+ handleWithOffsets.f0);
mergeMap.put(handleWithOffsets.f0, operatorStateHandle);
}
operatorStateHandle.getStateNameToPartitionOffsets().put(
@@ -229,7 +230,9 @@ private GroupByStateNameResults
groupByStateName(List<OperatorStateHandle> previ
for (Tuple2<StreamStateHandle,
OperatorStateHandle.StateMetaInfo> handleWithMetaInfo : e.getValue()) {
OperatorStateHandle operatorStateHandle
= mergeMap.get(handleWithMetaInfo.f0);
if (operatorStateHandle == null) {
- operatorStateHandle = new
OperatorStreamStateHandle(new HashMap<>(), handleWithMetaInfo.f0);
+ operatorStateHandle = new
OperatorStreamStateHandle(
+ new
HashMap<>(broadcastNameToState.size()),
+ handleWithMetaInfo.f0);
mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle);
}
operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(),
handleWithMetaInfo.f1);
@@ -256,7 +259,9 @@ private GroupByStateNameResults
groupByStateName(List<OperatorStateHandle> previ
OperatorStateHandle operatorStateHandle =
mergeMap.get(handleWithMetaInfo.f0);
if (operatorStateHandle == null) {
- operatorStateHandle = new
OperatorStreamStateHandle(new HashMap<>(), handleWithMetaInfo.f0);
+ operatorStateHandle = new
OperatorStreamStateHandle(
+ new
HashMap<>(uniformBroadcastNameToState.size()),
+ handleWithMetaInfo.f0);
mergeMap.put(handleWithMetaInfo.f0,
operatorStateHandle);
}
operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(),
handleWithMetaInfo.f1);
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 592489f2baf..b0173886d57 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -24,11 +24,11 @@
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
-import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
@@ -83,7 +83,7 @@ public boolean assignStates() throws Exception {
// find the states of all operators belonging to this
task
List<OperatorID> operatorIDs =
executionJobVertex.getOperatorIDs();
List<OperatorID> altOperatorIDs =
executionJobVertex.getUserDefinedOperatorIDs();
- List<OperatorState> operatorStates = new ArrayList<>();
+ List<OperatorState> operatorStates = new
ArrayList<>(operatorIDs.size());
boolean statelessTask = true;
for (int x = 0; x < operatorIDs.size(); x++) {
OperatorID operatorID = altOperatorIDs.get(x)
== null
@@ -124,7 +124,9 @@ private void assignAttemptState(ExecutionJobVertex
executionJobVertex, List<Oper
executionJobVertex.getMaxParallelism(),
newParallelism);
- /**
+ final int expectedNumberOfSubTasks = newParallelism *
operatorIDs.size();
+
+ /*
* Redistribute ManagedOperatorStates and RawOperatorStates
from old parallelism to new parallelism.
*
* The old ManagedOperatorStates with old parallelism 3:
@@ -143,8 +145,10 @@ private void assignAttemptState(ExecutionJobVertex
executionJobVertex, List<Oper
* op2 state2,0 state2,1 state2,2
state2,3
* op3 state3,0 state3,1 state3,2
state3,3
*/
- Map<OperatorInstanceID, List<OperatorStateHandle>>
newManagedOperatorStates = new HashMap<>();
- Map<OperatorInstanceID, List<OperatorStateHandle>>
newRawOperatorStates = new HashMap<>();
+ Map<OperatorInstanceID, List<OperatorStateHandle>>
newManagedOperatorStates =
+ new HashMap<>(expectedNumberOfSubTasks);
+ Map<OperatorInstanceID, List<OperatorStateHandle>>
newRawOperatorStates =
+ new HashMap<>(expectedNumberOfSubTasks);
reDistributePartitionableStates(
operatorStates,
@@ -153,8 +157,10 @@ private void assignAttemptState(ExecutionJobVertex
executionJobVertex, List<Oper
newManagedOperatorStates,
newRawOperatorStates);
- Map<OperatorInstanceID, List<KeyedStateHandle>>
newManagedKeyedState = new HashMap<>();
- Map<OperatorInstanceID, List<KeyedStateHandle>>
newRawKeyedState = new HashMap<>();
+ Map<OperatorInstanceID, List<KeyedStateHandle>>
newManagedKeyedState =
+ new HashMap<>(expectedNumberOfSubTasks);
+ Map<OperatorInstanceID, List<KeyedStateHandle>>
newRawKeyedState =
+ new HashMap<>(expectedNumberOfSubTasks);
reDistributeKeyedStates(
operatorStates,
@@ -164,7 +170,7 @@ private void assignAttemptState(ExecutionJobVertex
executionJobVertex, List<Oper
newManagedKeyedState,
newRawKeyedState);
- /**
+ /*
* An executionJobVertex's all state handles needed to restore
are something like a matrix
*
* parallelism0 parallelism1 parallelism2
parallelism3
@@ -198,7 +204,7 @@ private void assignTaskStateToExecutionJobVertices(
Execution currentExecutionAttempt =
executionJobVertex.getTaskVertices()[subTaskIndex]
.getCurrentExecutionAttempt();
- TaskStateSnapshot taskState = new TaskStateSnapshot();
+ TaskStateSnapshot taskState = new
TaskStateSnapshot(operatorIDs.size());
boolean statelessTask = true;
for (OperatorID operatorID : operatorIDs) {
@@ -276,38 +282,34 @@ private void reDistributeKeyedStates(
for (int subTaskIndex = 0; subTaskIndex <
newParallelism; subTaskIndex++) {
OperatorInstanceID instanceID =
OperatorInstanceID.of(subTaskIndex, newOperatorIDs.get(operatorIndex));
if (isHeadOperator(operatorIndex,
newOperatorIDs)) {
- Tuple2<Collection<KeyedStateHandle>,
Collection<KeyedStateHandle>> subKeyedStates = reAssignSubKeyedStates(
+ Tuple2<List<KeyedStateHandle>,
List<KeyedStateHandle>> subKeyedStates = reAssignSubKeyedStates(
operatorState,
newKeyGroupPartitions,
subTaskIndex,
newParallelism,
oldParallelism);
- newManagedKeyedState
- .computeIfAbsent(instanceID,
key -> new ArrayList<>())
- .addAll(subKeyedStates.f0);
- newRawKeyedState
- .computeIfAbsent(instanceID,
key -> new ArrayList<>())
- .addAll(subKeyedStates.f1);
+ newManagedKeyedState.put(instanceID,
subKeyedStates.f0);
+ newRawKeyedState.put(instanceID,
subKeyedStates.f1);
}
}
}
}
// TODO rewrite based on operator id
- private Tuple2<Collection<KeyedStateHandle>,
Collection<KeyedStateHandle>> reAssignSubKeyedStates(
+ private Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>>
reAssignSubKeyedStates(
OperatorState operatorState,
List<KeyGroupRange> keyGroupPartitions,
int subTaskIndex,
int newParallelism,
int oldParallelism) {
- Collection<KeyedStateHandle> subManagedKeyedState;
- Collection<KeyedStateHandle> subRawKeyedState;
+ List<KeyedStateHandle> subManagedKeyedState;
+ List<KeyedStateHandle> subRawKeyedState;
if (newParallelism == oldParallelism) {
if (operatorState.getState(subTaskIndex) != null) {
- subManagedKeyedState =
operatorState.getState(subTaskIndex).getManagedKeyedState();
- subRawKeyedState =
operatorState.getState(subTaskIndex).getRawKeyedState();
+ subManagedKeyedState =
operatorState.getState(subTaskIndex).getManagedKeyedState().asList();
+ subRawKeyedState =
operatorState.getState(subTaskIndex).getRawKeyedState().asList();
} else {
subManagedKeyedState = Collections.emptyList();
subRawKeyedState = Collections.emptyList();
@@ -336,8 +338,8 @@ private void reDistributePartitionableStates(
"This method still depends on the order of the new and
old operators");
//collect the old partitionable state
- List<List<OperatorStateHandle>> oldManagedOperatorStates = new
ArrayList<>();
- List<List<OperatorStateHandle>> oldRawOperatorStates = new
ArrayList<>();
+ List<List<OperatorStateHandle>> oldManagedOperatorStates = new
ArrayList<>(oldOperatorStates.size());
+ List<List<OperatorStateHandle>> oldRawOperatorStates = new
ArrayList<>(oldOperatorStates.size());
collectPartionableStates(oldOperatorStates,
oldManagedOperatorStates, oldRawOperatorStates);
@@ -368,24 +370,29 @@ private void collectPartionableStates(
List<List<OperatorStateHandle>> rawOperatorStates) {
for (OperatorState operatorState : operatorStates) {
+
+ final int parallelism = operatorState.getParallelism();
+
List<OperatorStateHandle> managedOperatorState = null;
List<OperatorStateHandle> rawOperatorState = null;
- for (int i = 0; i < operatorState.getParallelism();
i++) {
+ for (int i = 0; i < parallelism; i++) {
OperatorSubtaskState operatorSubtaskState =
operatorState.getState(i);
if (operatorSubtaskState != null) {
+
StateObjectCollection<OperatorStateHandle> managed =
operatorSubtaskState.getManagedOperatorState();
+
StateObjectCollection<OperatorStateHandle> raw =
operatorSubtaskState.getRawOperatorState();
+
if (managedOperatorState == null) {
- managedOperatorState = new
ArrayList<>();
+ managedOperatorState = new
ArrayList<>(parallelism * managed.size());
}
-
managedOperatorState.addAll(operatorSubtaskState.getManagedOperatorState());
+ managedOperatorState.addAll(managed);
if (rawOperatorState == null) {
- rawOperatorState = new
ArrayList<>();
+ rawOperatorState = new
ArrayList<>(parallelism * raw.size());
}
-
rawOperatorState.addAll(operatorSubtaskState.getRawOperatorState());
+ rawOperatorState.addAll(raw);
}
-
}
managedOperatorStates.add(managedOperatorState);
rawOperatorStates.add(rawOperatorState);
@@ -404,12 +411,19 @@ private void collectPartionableStates(
OperatorState operatorState,
KeyGroupRange subtaskKeyGroupRange) {
- List<KeyedStateHandle> subtaskKeyedStateHandles = new
ArrayList<>();
+ final int parallelism = operatorState.getParallelism();
- for (int i = 0; i < operatorState.getParallelism(); i++) {
+ List<KeyedStateHandle> subtaskKeyedStateHandles = null;
+
+ for (int i = 0; i < parallelism; i++) {
if (operatorState.getState(i) != null) {
Collection<KeyedStateHandle> keyedStateHandles
= operatorState.getState(i).getManagedKeyedState();
+
+ if (subtaskKeyedStateHandles == null) {
+ subtaskKeyedStateHandles = new
ArrayList<>(parallelism * keyedStateHandles.size());
+ }
+
extractIntersectingState(
keyedStateHandles,
subtaskKeyGroupRange,
@@ -432,11 +446,19 @@ private void collectPartionableStates(
OperatorState operatorState,
KeyGroupRange subtaskKeyGroupRange) {
- List<KeyedStateHandle> extractedKeyedStateHandles = new
ArrayList<>();
+ final int parallelism = operatorState.getParallelism();
- for (int i = 0; i < operatorState.getParallelism(); i++) {
+ List<KeyedStateHandle> extractedKeyedStateHandles = null;
+
+ for (int i = 0; i < parallelism; i++) {
if (operatorState.getState(i) != null) {
+
Collection<KeyedStateHandle> rawKeyedState =
operatorState.getState(i).getRawKeyedState();
+
+ if (extractedKeyedStateHandles == null) {
+ extractedKeyedStateHandles = new
ArrayList<>(parallelism * rawKeyedState.size());
+ }
+
extractIntersectingState(
rawKeyedState,
subtaskKeyGroupRange,
@@ -565,19 +587,18 @@ private static void checkStateMappingCompleteness(
List<OperatorStateHandle> chainOpParallelStates,
int oldParallelism,
int newParallelism) {
- Map<OperatorInstanceID, List<OperatorStateHandle>> result = new
HashMap<>();
- List<Collection<OperatorStateHandle>> states =
applyRepartitioner(
+ List<List<OperatorStateHandle>> states = applyRepartitioner(
opStateRepartitioner,
chainOpParallelStates,
oldParallelism,
newParallelism);
+ Map<OperatorInstanceID, List<OperatorStateHandle>> result = new
HashMap<>(states.size());
+
for (int subtaskIndex = 0; subtaskIndex < states.size();
subtaskIndex++) {
checkNotNull(states.get(subtaskIndex) != null,
"states.get(subtaskIndex) is null");
- result
-
.computeIfAbsent(OperatorInstanceID.of(subtaskIndex, operatorID), key -> new
ArrayList<>())
- .addAll(states.get(subtaskIndex));
+ result.put(OperatorInstanceID.of(subtaskIndex,
operatorID), states.get(subtaskIndex));
}
return result;
@@ -594,7 +615,7 @@ private static void checkStateMappingCompleteness(
* @return repartitioned state
*/
// TODO rewrite based on operator id
- public static List<Collection<OperatorStateHandle>> applyRepartitioner(
+ public static List<List<OperatorStateHandle>> applyRepartitioner(
OperatorStateRepartitioner opStateRepartitioner,
List<OperatorStateHandle> chainOpParallelStates,
int oldParallelism,
@@ -611,7 +632,7 @@ private static void checkStateMappingCompleteness(
chainOpParallelStates,
newParallelism);
} else {
- List<Collection<OperatorStateHandle>> repackStream =
new ArrayList<>(newParallelism);
+ List<List<OperatorStateHandle>> repackStream = new
ArrayList<>(newParallelism);
for (OperatorStateHandle operatorStateHandle :
chainOpParallelStates) {
if (operatorStateHandle != null) {
@@ -645,7 +666,7 @@ private static void checkStateMappingCompleteness(
Collection<? extends KeyedStateHandle> keyedStateHandles,
KeyGroupRange subtaskKeyGroupRange) {
- List<KeyedStateHandle> subtaskKeyedStateHandles = new
ArrayList<>();
+ List<KeyedStateHandle> subtaskKeyedStateHandles = new
ArrayList<>(keyedStateHandles.size());
for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
KeyedStateHandle intersectedKeyedStateHandle =
keyedStateHandle.getIntersection(subtaskKeyGroupRange);
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateObjectCollection.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateObjectCollection.java
index 38e3d15da29..30768477eed 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateObjectCollection.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateObjectCollection.java
@@ -27,6 +27,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
+import java.util.List;
import java.util.function.Predicate;
/**
@@ -178,6 +179,14 @@ public String toString() {
return "StateObjectCollection{" + stateObjects + '}';
}
+ public List<T> asList() {
+ return stateObjects instanceof List ?
+ (List<T>) stateObjects :
+ stateObjects != null ?
+ new ArrayList<>(stateObjects) :
+ Collections.emptyList();
+ }
+
//
------------------------------------------------------------------------
// Helper methods.
//
------------------------------------------------------------------------
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java
index a44a508ecbc..76ab7a5f9c3 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java
@@ -25,8 +25,8 @@
import org.apache.flink.runtime.state.internal.InternalKvState;
import org.apache.flink.util.Preconditions;
-import java.util.ArrayList;
-import java.util.List;
+import java.util.HashMap;
+import java.util.Map;
/**
* A helper for KvState registrations of a single task.
@@ -43,7 +43,7 @@
private final JobVertexID jobVertexId;
/** List of all registered KvState instances of this task. */
- private final List<KvStateInfo> registeredKvStates = new ArrayList<>();
+ private final Map<String, KvStateInfo> registeredKvStates = new
HashMap<>();
TaskKvStateRegistry(KvStateRegistry registry, JobID jobId, JobVertexID
jobVertexId) {
this.registry = Preconditions.checkNotNull(registry,
"KvStateRegistry");
@@ -61,19 +61,35 @@
* @param kvState The
*/
public void registerKvState(KeyGroupRange keyGroupRange, String
registrationName, InternalKvState<?, ?, ?> kvState) {
+ unregisterKvState(registrationName);
KvStateID kvStateId = registry.registerKvState(jobId,
jobVertexId, keyGroupRange, registrationName, kvState);
- registeredKvStates.add(new KvStateInfo(keyGroupRange,
registrationName, kvStateId));
+ registeredKvStates.put(registrationName, new
KvStateInfo(keyGroupRange, registrationName, kvStateId));
+ }
+
+ /**
+ *
+ * @param registrationName
+ */
+ public void unregisterKvState(String registrationName) {
+ KvStateInfo kvStateInfo =
registeredKvStates.get(registrationName);
+ if (kvStateInfo != null) {
+ unregisterInternal(kvStateInfo);
+ }
}
/**
* Unregisters all registered KvState instances from the
KvStateRegistry.
*/
public void unregisterAll() {
- for (KvStateInfo kvState : registeredKvStates) {
- registry.unregisterKvState(jobId, jobVertexId,
kvState.keyGroupRange, kvState.registrationName, kvState.kvStateId);
+ for (KvStateInfo kvState : registeredKvStates.values()) {
+ unregisterInternal(kvState);
}
}
+ private void unregisterInternal(KvStateInfo kvState) {
+ registry.unregisterKvState(jobId, jobVertexId,
kvState.keyGroupRange, kvState.registrationName, kvState.kvStateId);
+ }
+
/**
* 3-tuple holding registered KvState meta data.
*/
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
index 1c2d2a3ecaf..d80b76bec85 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -62,7 +62,7 @@
private int currentKeyGroup;
/** So that we can give out state when the user uses the same key. */
- private final HashMap<String, InternalKvState<K, ?, ?>>
keyValueStatesByName;
+ protected final HashMap<String, InternalKvState<K, ?, ?>>
keyValueStatesByName;
/** For caching the last accessed partitioned state. */
private String lastName;
@@ -319,5 +319,4 @@ StreamCompressionDecorator
getKeyGroupCompressionDecorator() {
public boolean requiresLegacySynchronousTimerSnapshots() {
return false;
}
-
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
index d9fc41e6529..c80f516bcbe 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -180,6 +180,35 @@ public void dispose() {
//
-------------------------------------------------------------------------------------------
// State access methods
//
-------------------------------------------------------------------------------------------
+ @Override
+ public void removeBroadcastState(String name) {
+ restoredBroadcastStateMetaInfos.remove(name);
+ if (registeredBroadcastStates.remove(name) != null) {
+ accessedBroadcastStatesByName.remove(name);
+ }
+ }
+
+ @Override
+ public void removeOperatorState(String name) {
+ restoredOperatorStateMetaInfos.remove(name);
+ if (registeredOperatorStates.remove(name) != null) {
+ accessedStatesByName.remove(name);
+ }
+ }
+
+ public void deleteBroadCastState(String name) {
+ restoredBroadcastStateMetaInfos.remove(name);
+ if (registeredBroadcastStates.remove(name) != null) {
+ accessedBroadcastStatesByName.remove(name);
+ }
+ }
+
+ public void deleteOperatorState(String name) {
+ restoredOperatorStateMetaInfos.remove(name);
+ if (registeredOperatorStates.remove(name) != null) {
+ accessedStatesByName.remove(name);
+ }
+ }
@SuppressWarnings("unchecked")
@Override
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
index 7ba14b3d007..14c2d842dcf 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
@@ -24,6 +24,8 @@
import org.apache.flink.runtime.state.heap.InternalKeyContext;
import org.apache.flink.util.Disposable;
+import javax.annotation.Nonnull;
+
import java.util.stream.Stream;
/**
@@ -104,6 +106,16 @@
TypeSerializer<N> namespaceSerializer,
StateDescriptor<S, ?> stateDescriptor) throws Exception;
+ /**
+ * Removes the operator state with the given name from the state store,
if it exists.
+ */
+ void removeKeyedState(@Nonnull String stateName) throws Exception;
+
+ /**
+ * Removes the queue state with the given name from the state store, if
it exists.
+ */
+ void removeQueueState(@Nonnull String name) throws Exception;
+
@Override
void dispose();
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java
index 2245e72bce6..e41805b1ca2 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java
@@ -38,7 +38,7 @@
* @return the queue with the specified unique name.
*/
@Nonnull
- <T extends HeapPriorityQueueElement & PriorityComparable & Keyed>
KeyGroupedInternalPriorityQueue<T> create(
+ <T extends HeapPriorityQueueElement & PriorityComparable & Keyed>
KeyGroupedInternalPriorityQueue<T> createQueueState(
@Nonnull String stateName,
- @Nonnull TypeSerializer<T> byteOrderedElementSerializer);
+ @Nonnull TypeSerializer<T> byteOrderedElementSerializer) throws
Exception;
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index bc1e0f52507..e6a7a1be4d7 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -62,10 +62,9 @@
import org.apache.flink.runtime.state.SnapshotResult;
import org.apache.flink.runtime.state.SnapshotStrategy;
import org.apache.flink.runtime.state.StateSnapshot;
-import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
import org.apache.flink.runtime.state.StateSnapshotRestore;
-import
org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.StreamCompressionDecorator;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
@@ -185,7 +184,7 @@ public HeapKeyedStateBackend(
@SuppressWarnings("unchecked")
@Nonnull
@Override
- public <T extends HeapPriorityQueueElement & PriorityComparable &
Keyed> KeyGroupedInternalPriorityQueue<T> create(
+ public <T extends HeapPriorityQueueElement & PriorityComparable &
Keyed> KeyGroupedInternalPriorityQueue<T> createQueueState(
@Nonnull String stateName,
@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
@@ -228,12 +227,29 @@ public HeapKeyedStateBackend(
}
}
+ @Override
+ public void removeQueueState(@Nonnull String stateName) {
+ restoredStateMetaInfo.remove(StateUID.of(stateName,
StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE));
+ registeredPQStates.remove(stateName);
+ }
+
+ @Override
+ public void removeKeyedState(@Nonnull String stateName) {
+ restoredStateMetaInfo.remove(StateUID.of(stateName,
StateMetaInfoSnapshot.BackendStateType.KEY_VALUE));
+ if (registeredKVStates.remove(stateName) != null) {
+ if (kvStateRegistry != null) {
+ kvStateRegistry.unregisterKvState(stateName);
+ }
+ keyValueStatesByName.remove(stateName);
+ }
+ }
+
@Nonnull
private <T extends HeapPriorityQueueElement & PriorityComparable &
Keyed> KeyGroupedInternalPriorityQueue<T> createInternal(
RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) {
final String stateName = metaInfo.getName();
- final HeapPriorityQueueSet<T> priorityQueue =
priorityQueueSetFactory.create(
+ final HeapPriorityQueueSet<T> priorityQueue =
priorityQueueSetFactory.createQueueState(
stateName,
metaInfo.getElementSerializer());
@@ -312,7 +328,7 @@ private boolean hasRegisteredState() {
public <N, SV, SEV, S extends State, IS extends S> IS
createInternalState(
@Nonnull TypeSerializer<N> namespaceSerializer,
@Nonnull StateDescriptor<S, SV> stateDesc,
- @Nonnull StateSnapshotTransformFactory<SEV>
snapshotTransformFactory) throws Exception {
+ @Nonnull
StateSnapshotTransformer.StateSnapshotTransformFactory<SEV>
snapshotTransformFactory) throws Exception {
StateFactory stateFactory =
STATE_FACTORIES.get(stateDesc.getClass());
if (stateFactory == null) {
String message = String.format("State %s is not
supported by %s",
@@ -327,7 +343,7 @@ private boolean hasRegisteredState() {
@SuppressWarnings("unchecked")
private <SV, SEV> StateSnapshotTransformer<SV>
getStateSnapshotTransformer(
StateDescriptor<?, SV> stateDesc,
- StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+ StateSnapshotTransformer.StateSnapshotTransformFactory<SEV>
snapshotTransformFactory) {
Optional<StateSnapshotTransformer<SEV>> original =
snapshotTransformFactory.createForDeserializedState();
if (original.isPresent()) {
if (stateDesc instanceof ListStateDescriptor) {
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
index 80d79ac1fc1..b32f4deaff0 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
@@ -55,7 +55,7 @@ public HeapPriorityQueueSetFactory(
@Nonnull
@Override
- public <T extends HeapPriorityQueueElement & PriorityComparable &
Keyed> HeapPriorityQueueSet<T> create(
+ public <T extends HeapPriorityQueueElement & PriorityComparable &
Keyed> HeapPriorityQueueSet<T> createQueueState(
@Nonnull String stateName,
@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
diff --git
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 1b2062a7481..b113e12ef69 100644
---
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -53,11 +53,12 @@
import
org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
import org.apache.flink.runtime.testutils.CommonTestUtils;
import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
-import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.TestLogger;
+import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
+
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
@@ -2741,7 +2742,7 @@ public void testReplicateModeStateHandle() {
OperatorStateHandle osh = new
OperatorStreamStateHandle(metaInfoMap, new ByteStreamStateHandle("test", new
byte[150]));
OperatorStateRepartitioner repartitioner =
RoundRobinOperatorStateRepartitioner.INSTANCE;
- List<Collection<OperatorStateHandle>> repartitionedStates =
+ List<List<OperatorStateHandle>> repartitionedStates =
repartitioner.repartitionState(Collections.singletonList(osh), 3);
Map<String, Integer> checkCounts = new HashMap<>(3);
@@ -3331,7 +3332,7 @@ private void doTestPartitionableStateRepartitioning(
OperatorStateRepartitioner repartitioner =
RoundRobinOperatorStateRepartitioner.INSTANCE;
- List<Collection<OperatorStateHandle>> pshs =
+ List<List<OperatorStateHandle>> pshs =
repartitioner.repartitionState(previousParallelOpInstanceStates,
newParallelism);
Map<StreamStateHandle, Map<String, List<Long>>> actual = new
HashMap<>();
diff --git
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
index d8918e78478..61ab3cd7170 100644
---
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
+++
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -54,6 +54,7 @@
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
@@ -946,6 +947,50 @@ static MutableType of(int value) {
}
}
+ @Test
+ public void testDeleteBroadcastState() throws Exception {
+ final OperatorStateBackend operatorStateBackend =
+ new DefaultOperatorStateBackend(classLoader, new
ExecutionConfig(), false);
+
+ MapStateDescriptor<Integer, Integer> broadcastStateDesc1 = new
MapStateDescriptor<>(
+ "test-broadcast-1", IntSerializer.INSTANCE,
IntSerializer.INSTANCE);
+
+ MapStateDescriptor<Integer, Integer> broadcastStateDesc2 = new
MapStateDescriptor<>(
+ "test-broadcast-2", IntSerializer.INSTANCE,
IntSerializer.INSTANCE);
+
+ operatorStateBackend.getBroadcastState(broadcastStateDesc1);
+ operatorStateBackend.getBroadcastState(broadcastStateDesc2);
+
+ Assert.assertEquals(2,
operatorStateBackend.getRegisteredBroadcastStateNames().size());
+
+
operatorStateBackend.removeBroadcastState(broadcastStateDesc2.getName());
+ Assert.assertEquals(1,
operatorStateBackend.getRegisteredBroadcastStateNames().size());
+
Assert.assertTrue(operatorStateBackend.getRegisteredBroadcastStateNames().contains(broadcastStateDesc1.getName()));
+ }
+
+ @Test
+ public void testDeleteOperatorState() throws Exception {
+ final OperatorStateBackend operatorStateBackend =
+ new DefaultOperatorStateBackend(classLoader, new
ExecutionConfig(), false);
+
+ ListStateDescriptor<Integer> listStateDesc1 = new
ListStateDescriptor<>("test-broadcast-1", IntSerializer.INSTANCE);
+ ListStateDescriptor<Integer> listStateDesc2 = new
ListStateDescriptor<>("test-broadcast-2", IntSerializer.INSTANCE);
+
+ operatorStateBackend.getListState(listStateDesc1);
+ operatorStateBackend.getUnionListState(listStateDesc2);
+
+ Set<String> registeredStateNames =
operatorStateBackend.getRegisteredStateNames();
+
+ Assert.assertEquals(2, registeredStateNames.size());
+
+
operatorStateBackend.removeOperatorState(listStateDesc1.getName());
+ Assert.assertEquals(1, registeredStateNames.size());
+
Assert.assertTrue(registeredStateNames.contains(listStateDesc2.getName()));
+
+
operatorStateBackend.removeOperatorState(listStateDesc2.getName());
+ Assert.assertEquals(0, registeredStateNames.size());
+ }
+
//
------------------------------------------------------------------------
// utilities
//
------------------------------------------------------------------------
diff --git
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 059a706c6a8..66fec859560 100644
---
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -1160,7 +1160,7 @@ public void testPriorityQueueSerializerUpdates() throws
Exception {
InternalPriorityQueueTestBase.TestElementSerializer.INSTANCE;
KeyGroupedInternalPriorityQueue<InternalPriorityQueueTestBase.TestElement>
priorityQueue =
- keyedBackend.create(stateName, serializer);
+ keyedBackend.createQueueState(stateName,
serializer);
priorityQueue.add(new
InternalPriorityQueueTestBase.TestElement(42L, 0L));
@@ -1177,7 +1177,7 @@ public void testPriorityQueueSerializerUpdates() throws
Exception {
serializer = new ModifiedTestElementSerializer();
- priorityQueue = keyedBackend.create(stateName,
serializer);
+ priorityQueue =
keyedBackend.createQueueState(stateName, serializer);
final InternalPriorityQueueTestBase.TestElement
checkElement =
new
InternalPriorityQueueTestBase.TestElement(4711L, 1L);
@@ -1192,7 +1192,7 @@ public void testPriorityQueueSerializerUpdates() throws
Exception {
// test that the modified serializer was actually used
---------------------------
keyedBackend =
restoreKeyedBackend(IntSerializer.INSTANCE, keyedStateHandle);
- priorityQueue = keyedBackend.create(stateName,
serializer);
+ priorityQueue =
keyedBackend.createQueueState(stateName, serializer);
priorityQueue.poll();
@@ -1217,7 +1217,7 @@ public void testPriorityQueueSerializerUpdates() throws
Exception {
try {
// this is expected to fail, because the old
and new serializer shoulbe be incompatible through
// different revision numbers.
- keyedBackend.create("test", serializer);
+ keyedBackend.createQueueState("test",
serializer);
Assert.fail("Expected exception from
incompatible serializer.");
} catch (Exception e) {
Assert.assertTrue("Exception was not caused by
state migration: " + e,
@@ -4129,6 +4129,39 @@ public void
testCheckConcurrencyProblemWhenPerformingCheckpointAsync() throws Ex
}
}
+ @Test
+ public void testDeleteKeyedState() throws Exception {
+
+ Environment env = new DummyEnvironment();
+ AbstractKeyedStateBackend<Integer> backend =
createKeyedBackend(IntSerializer.INSTANCE, env);
+ try {
+ ValueStateDescriptor<Integer> kv1 = new
ValueStateDescriptor<>("kv_1", IntSerializer.INSTANCE);
+ ValueStateDescriptor<Integer> kv2 = new
ValueStateDescriptor<>("kv_2", IntSerializer.INSTANCE);
+ ValueState<Integer> state1 =
backend.getOrCreateKeyedState(VoidNamespaceSerializer.INSTANCE, kv1);
+ ValueState<Integer> state2 =
backend.getOrCreateKeyedState(VoidNamespaceSerializer.INSTANCE, kv2);
+
+ backend.removeKeyedState(kv2.getName());
+ } finally {
+ backend.dispose();
+ }
+ }
+
+ @Test
+ public void testDeletePriorityQueueState() throws Exception {
+
+ Environment env = new DummyEnvironment();
+ AbstractKeyedStateBackend<Integer> backend =
createKeyedBackend(IntSerializer.INSTANCE, env);
+ try {
+ String state1 = "state_1";
+ String state2 = "state_2";
+ backend.createQueueState(state1,
InternalPriorityQueueTestBase.TestElementSerializer.INSTANCE);
+ backend.createQueueState(state2,
InternalPriorityQueueTestBase.TestElementSerializer.INSTANCE);
+ backend.removeQueueState(state2);
+ } finally {
+ backend.dispose();
+ }
+ }
+
protected Future<SnapshotResult<KeyedStateHandle>> runSnapshotAsync(
ExecutorService executorService,
RunnableFuture<SnapshotResult<KeyedStateHandle>>
snapshotRunnableFuture) throws Exception {
diff --git
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index 0b5931ce1d4..5eb9c5f8a0d 100644
---
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -170,6 +170,16 @@ public void notifyCheckpointComplete(long checkpointId) {
.map(Map.Entry::getKey);
}
+ @Override
+ public void removeKeyedState(@Nonnull String stateName) {
+
+ }
+
+ @Override
+ public void removeQueueState(@Nonnull String name) {
+
+ }
+
@Override
public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
long checkpointId,
@@ -229,7 +239,7 @@ public void restore(Collection<KeyedStateHandle> state) {
@Nonnull
@Override
public <T extends HeapPriorityQueueElement & PriorityComparable &
Keyed> KeyGroupedInternalPriorityQueue<T>
- create(
+ createQueueState(
@Nonnull String stateName,
@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
return new HeapPriorityQueueSet<>(
diff --git
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 0fd11252b2f..37864214cd5 100644
---
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -85,7 +85,6 @@
import org.apache.flink.runtime.state.StateHandleID;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.runtime.state.StateSnapshotTransformer;
-import
org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
import org.apache.flink.runtime.state.StateUtil;
import org.apache.flink.runtime.state.StreamCompressionDecorator;
import org.apache.flink.runtime.state.StreamStateHandle;
@@ -353,6 +352,33 @@ private static void checkAndCreateDirectory(File
directory) throws IOException {
}
}
+ @Override
+ public void removeQueueState(@Nonnull String stateName) throws
RocksDBException {
+ removeInternal(stateName);
+ }
+
+ @Override
+ public void removeKeyedState(@Nonnull String stateName) throws
RocksDBException {
+ removeInternal(stateName);
+ }
+
+ private void removeInternal(@Nonnull String name) throws
RocksDBException {
+ restoredKvStateMetaInfos.remove(name);
+ Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>
removedStateMetaInfo = kvStateInformation.remove(name);
+ if (removedStateMetaInfo != null) {
+ if (kvStateRegistry != null) {
+ kvStateRegistry.unregisterKvState(name);
+ }
+ keyValueStatesByName.remove(name);
+ ColumnFamilyHandle removeColumnFamily =
removedStateMetaInfo.f0;
+ try {
+ db.dropColumnFamily(removeColumnFamily);
+ } finally {
+ IOUtils.closeQuietly(removeColumnFamily);
+ }
+ }
+ }
+
@SuppressWarnings("unchecked")
@Override
public <N> Stream<K> getKeys(String state, N namespace) {
@@ -444,10 +470,10 @@ public void dispose() {
@Nonnull
@Override
public <T extends HeapPriorityQueueElement & PriorityComparable &
Keyed> KeyGroupedInternalPriorityQueue<T>
- create(
+ createQueueState(
@Nonnull String stateName,
- @Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
- return priorityQueueFactory.create(stateName,
byteOrderedElementSerializer);
+ @Nonnull TypeSerializer<T> byteOrderedElementSerializer) throws
Exception {
+ return priorityQueueFactory.createQueueState(stateName,
byteOrderedElementSerializer);
}
private void cleanInstanceBasePath() {
@@ -1381,7 +1407,7 @@ private ColumnFamilyHandle createColumnFamily(String
stateName) {
public <N, SV, SEV, S extends State, IS extends S> IS
createInternalState(
@Nonnull TypeSerializer<N> namespaceSerializer,
@Nonnull StateDescriptor<S, SV> stateDesc,
- @Nonnull StateSnapshotTransformFactory<SEV>
snapshotTransformFactory) throws Exception {
+ @Nonnull
StateSnapshotTransformer.StateSnapshotTransformFactory<SEV>
snapshotTransformFactory) throws Exception {
StateFactory stateFactory =
STATE_FACTORIES.get(stateDesc.getClass());
if (stateFactory == null) {
String message = String.format("State %s is not
supported by %s",
@@ -1396,7 +1422,7 @@ private ColumnFamilyHandle createColumnFamily(String
stateName) {
@SuppressWarnings("unchecked")
private <SV, SEV> StateSnapshotTransformer<SV>
getStateSnapshotTransformer(
StateDescriptor<?, SV> stateDesc,
- StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+ StateSnapshotTransformer.StateSnapshotTransformFactory<SEV>
snapshotTransformFactory) {
if (stateDesc instanceof ListStateDescriptor) {
Optional<StateSnapshotTransformer<SEV>> original =
snapshotTransformFactory.createForDeserializedState();
return original.map(est ->
createRocksDBListStateTransformer(stateDesc, est)).orElse(null);
@@ -2756,7 +2782,7 @@ private static RocksIteratorWrapper getRocksIterator(
@Nonnull
@Override
public <T extends HeapPriorityQueueElement & PriorityComparable
& Keyed> KeyGroupedInternalPriorityQueue<T>
- create(@Nonnull String stateName, @Nonnull TypeSerializer<T>
byteOrderedElementSerializer) {
+ createQueueState(@Nonnull String stateName, @Nonnull
TypeSerializer<T> byteOrderedElementSerializer) {
final Tuple2<ColumnFamilyHandle,
RegisteredStateMetaInfoBase> metaInfoTuple =
tryRegisterPriorityQueueMetaInfo(stateName,
byteOrderedElementSerializer);
diff --git
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index 52ba3e4f1f5..5313ed780ac 100644
---
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -729,7 +729,7 @@ public void close() {
public <K, N> InternalTimerService<N> getInternalTimerService(
String name,
TypeSerializer<N> namespaceSerializer,
- Triggerable<K, N> triggerable) {
+ Triggerable<K, N> triggerable) throws Exception {
checkTimerServiceInitialization();
diff --git
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
index ff48c3fae03..1939a8daccd 100644
---
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
+++
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
@@ -82,7 +82,7 @@
public <N> InternalTimerService<N> getInternalTimerService(
String name,
TimerSerializer<K, N> timerSerializer,
- Triggerable<K, N> triggerable) {
+ Triggerable<K, N> triggerable) throws Exception {
InternalTimerServiceImpl<K, N> timerService =
registerOrGetTimerService(name, timerSerializer);
@@ -95,7 +95,10 @@
}
@SuppressWarnings("unchecked")
- <N> InternalTimerServiceImpl<K, N> registerOrGetTimerService(String
name, TimerSerializer<K, N> timerSerializer) {
+ <N> InternalTimerServiceImpl<K, N> registerOrGetTimerService(
+ String name,
+ TimerSerializer<K, N> timerSerializer) throws Exception {
+
InternalTimerServiceImpl<K, N> timerService =
(InternalTimerServiceImpl<K, N>) timerServices.get(name);
if (timerService == null) {
@@ -117,8 +120,8 @@
private <N> KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K,
N>> createTimerPriorityQueue(
String name,
- TimerSerializer<K, N> timerSerializer) {
- return priorityQueueSetFactory.create(
+ TimerSerializer<K, N> timerSerializer) throws Exception {
+ return priorityQueueSetFactory.createQueueState(
name,
timerSerializer);
}
diff --git
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java
index dea17f98665..ca9d9c0376c 100644
---
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java
+++
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java
@@ -104,9 +104,14 @@ protected void read(DataInputView in, boolean
wasVersioned) throws IOException {
.getReaderForVersion(readerVersion,
userCodeClassLoader)
.readTimersSnapshot(in);
- InternalTimerServiceImpl<K, ?> timerService =
registerOrGetTimerService(
- serviceName,
- restoredTimersSnapshot);
+ InternalTimerServiceImpl<K, ?> timerService;
+ try {
+ timerService = registerOrGetTimerService(
+ serviceName,
+ restoredTimersSnapshot);
+ } catch (Exception e) {
+ throw new IOException("Could not create timer
service in restore.", e);
+ }
timerService.restoreTimersForKeyGroup(restoredTimersSnapshot, keyGroupIdx);
}
@@ -114,7 +119,7 @@ protected void read(DataInputView in, boolean wasVersioned)
throws IOException {
@SuppressWarnings("unchecked")
private <N> InternalTimerServiceImpl<K, N> registerOrGetTimerService(
- String serviceName, InternalTimersSnapshot<?, ?>
restoredTimersSnapshot) {
+ String serviceName, InternalTimersSnapshot<?, ?>
restoredTimersSnapshot) throws Exception {
final TypeSerializer<K> keySerializer = (TypeSerializer<K>)
restoredTimersSnapshot.getKeySerializer();
final TypeSerializer<N> namespaceSerializer =
(TypeSerializer<N>) restoredTimersSnapshot.getNamespaceSerializer();
TimerSerializer<K, N> timerSerializer = new
TimerSerializer<>(keySerializer, namespaceSerializer);
diff --git
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java
index f2da6da3b05..c9f37e5d21a 100644
---
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java
+++
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java
@@ -79,7 +79,7 @@ public InternalTimerServiceImplTest(int startKeyGroup, int
endKeyGroup, int maxP
}
@Test
- public void testKeyGroupStartIndexSetting() {
+ public void testKeyGroupStartIndexSetting() throws Exception {
int startKeyGroupIdx = 7;
int endKeyGroupIdx = 21;
@@ -101,7 +101,7 @@ public void testKeyGroupStartIndexSetting() {
}
@Test
- public void testTimerAssignmentToKeyGroups() {
+ public void testTimerAssignmentToKeyGroups() throws Exception {
int totalNoOfTimers = 100;
int totalNoOfKeyGroups = 100;
@@ -811,7 +811,7 @@ private static int getKeyInKeyGroupRange(KeyGroupRange
range, int maxParallelism
KeyContext keyContext,
ProcessingTimeService processingTimeService,
KeyGroupRange keyGroupList,
- PriorityQueueSetFactory priorityQueueSetFactory) {
+ PriorityQueueSetFactory priorityQueueSetFactory) throws
Exception {
InternalTimerServiceImpl<Integer, String> service =
createInternalTimerService(
keyGroupList,
keyContext,
@@ -892,7 +892,7 @@ protected PriorityQueueSetFactory
createQueueFactory(KeyGroupRange keyGroupRange
ProcessingTimeService processingTimeService,
TypeSerializer<K> keySerializer,
TypeSerializer<N> namespaceSerializer,
- PriorityQueueSetFactory priorityQueueSetFactory) {
+ PriorityQueueSetFactory priorityQueueSetFactory) throws
Exception {
TimerSerializer<K, N> timerSerializer = new
TimerSerializer<>(keySerializer, namespaceSerializer);
@@ -907,8 +907,8 @@ protected PriorityQueueSetFactory
createQueueFactory(KeyGroupRange keyGroupRange
private static <K, N>
KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> createTimerQueue(
String name,
TimerSerializer<K, N> timerSerializer,
- PriorityQueueSetFactory priorityQueueSetFactory) {
- return priorityQueueSetFactory.create(
+ PriorityQueueSetFactory priorityQueueSetFactory) throws
Exception {
+ return priorityQueueSetFactory.createQueueState(
name,
timerSerializer);
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services