This is an automated email from the ASF dual-hosted git repository.

sewen pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.9 by this push:
     new 92eb0b8  [FLINK-13541][state-processor-api] State Processor Api sets 
the wrong key selector when writing savepoints
92eb0b8 is described below

commit 92eb0b80ed3e761f51825f4e56329085436f39e3
Author: Seth Wiesman <[email protected]>
AuthorDate: Thu Aug 1 15:22:23 2019 -0500

    [FLINK-13541][state-processor-api] State Processor Api sets the wrong key 
selector when writing savepoints
    
    This closes #9324
---
 .../flink/state/api/BootstrapTransformation.java   | 46 ++++++++++++++--------
 .../state/api/BootstrapTransformationTest.java     | 35 ++++++++++++++++
 2 files changed, 64 insertions(+), 17 deletions(-)

diff --git 
a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/BootstrapTransformation.java
 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/BootstrapTransformation.java
index a5163f3..5e2a7c2 100644
--- 
a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/BootstrapTransformation.java
+++ 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/BootstrapTransformation.java
@@ -68,7 +68,11 @@ public class BootstrapTransformation<T> {
 
        /** Partitioner for the bootstrapping data set. Only relevant if this 
bootstraps partitioned state. */
        @Nullable
-       private final HashSelector<T> keySelector;
+       private final KeySelector<T, ?> originalKeySelector;
+
+       /** Partitioner for distributing data by key group. Only relevant if 
this bootstraps partitioned state. */
+       @Nullable
+       private final HashSelector<T> hashKeySelector;
 
        /** Type information for the key of the bootstrapped state. Only 
relevant if this bootstraps partitioned state. */
        @Nullable
@@ -84,7 +88,8 @@ public class BootstrapTransformation<T> {
                this.dataSet = dataSet;
                this.operatorMaxParallelism = operatorMaxParallelism;
                this.factory = factory;
-               this.keySelector = null;
+               this.originalKeySelector = null;
+               this.hashKeySelector = null;
                this.keyType = null;
        }
 
@@ -97,7 +102,8 @@ public class BootstrapTransformation<T> {
                this.dataSet = dataSet;
                this.operatorMaxParallelism = operatorMaxParallelism;
                this.factory = factory;
-               this.keySelector = new HashSelector<>(keySelector);
+               this.originalKeySelector = keySelector;
+               this.hashKeySelector = new HashSelector<>(keySelector);
                this.keyType = keyType;
        }
 
@@ -135,16 +141,8 @@ public class BootstrapTransformation<T> {
                int localMaxParallelism) {
 
                DataSet<T> input = dataSet;
-               if (keySelector != null) {
-                       input = dataSet.partitionCustom(new 
KeyGroupRangePartitioner(localMaxParallelism), keySelector);
-               }
-
-               final StreamConfig config;
-               if (keyType == null) {
-                       config = new BoundedStreamConfig();
-               } else {
-                       TypeSerializer<?> keySerializer = 
keyType.createSerializer(dataSet.getExecutionEnvironment().getConfig());
-                       config = new BoundedStreamConfig(keySerializer, 
keySelector);
+               if (originalKeySelector != null) {
+                       input = dataSet.partitionCustom(new 
KeyGroupRangePartitioner(localMaxParallelism), hashKeySelector);
                }
 
                StreamOperator<TaggedOperatorSubtaskState> operator = 
factory.createOperator(
@@ -152,11 +150,8 @@ public class BootstrapTransformation<T> {
                        savepointPath);
 
                operator = dataSet.clean(operator);
-               config.setStreamOperator(operator);
 
-               config.setOperatorName(operatorID.toHexString());
-               config.setOperatorID(operatorID);
-               config.setStateBackend(stateBackend);
+               final StreamConfig config = getConfig(operatorID, stateBackend, 
operator);
 
                BoundedOneInputStreamTaskRunner<T> operatorRunner = new 
BoundedOneInputStreamTaskRunner<>(
                        config,
@@ -178,6 +173,23 @@ public class BootstrapTransformation<T> {
                return subtaskStates;
        }
 
+       @VisibleForTesting
+       StreamConfig getConfig(OperatorID operatorID, StateBackend 
stateBackend, StreamOperator<TaggedOperatorSubtaskState> operator) {
+               final StreamConfig config;
+               if (keyType == null) {
+                       config = new BoundedStreamConfig();
+               } else {
+                       TypeSerializer<?> keySerializer = 
keyType.createSerializer(dataSet.getExecutionEnvironment().getConfig());
+                       config = new BoundedStreamConfig(keySerializer, 
originalKeySelector);
+               }
+
+               config.setStreamOperator(operator);
+               config.setOperatorName(operatorID.toHexString());
+               config.setOperatorID(operatorID);
+               config.setStateBackend(stateBackend);
+               return config;
+       }
+
        private static <T> int getParallelism(MapPartitionOperator<T, 
TaggedOperatorSubtaskState> subtaskStates) {
                int parallelism = subtaskStates.getParallelism();
                if (parallelism == ExecutionConfig.PARALLELISM_DEFAULT) {
diff --git 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/BootstrapTransformationTest.java
 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/BootstrapTransformationTest.java
index 64994c4..72cb218 100644
--- 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/BootstrapTransformationTest.java
+++ 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/BootstrapTransformationTest.java
@@ -21,6 +21,7 @@ package org.apache.flink.state.api;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.operators.DataSource;
 import org.apache.flink.api.java.operators.Operator;
 import org.apache.flink.core.fs.Path;
@@ -28,9 +29,11 @@ import 
org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.state.api.functions.BroadcastStateBootstrapFunction;
+import org.apache.flink.state.api.functions.KeyedStateBootstrapFunction;
 import org.apache.flink.state.api.functions.StateBootstrapFunction;
 import org.apache.flink.state.api.output.TaggedOperatorSubtaskState;
 import org.apache.flink.state.api.runtime.OperatorIDGenerator;
+import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.test.util.AbstractTestBase;
 
 import org.junit.Assert;
@@ -39,6 +42,7 @@ import org.junit.Test;
 /**
  * Tests for bootstrap transformations.
  */
+@SuppressWarnings("serial")
 public class BootstrapTransformationTest extends AbstractTestBase {
 
        @Test
@@ -138,6 +142,30 @@ public class BootstrapTransformationTest extends 
AbstractTestBase {
                Assert.assertEquals("The parallelism of a data set should be 
constrained my the savepoint max parallelism", 1, getParallelism(result));
        }
 
+       @Test
+       public void testStreamConfig() {
+               ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+               DataSource<String> input = env.fromElements("");
+
+               BootstrapTransformation<String> transformation = 
OperatorTransformation
+                       .bootstrapWith(input)
+                       .keyBy(new CustomKeySelector())
+                       .transform(new ExampleKeyedStateBootstrapFunction());
+
+               StreamConfig config = 
transformation.getConfig(OperatorIDGenerator.fromUid("uid"), new 
MemoryStateBackend(), null);
+               KeySelector selector = config.getStatePartitioner(0, 
Thread.currentThread().getContextClassLoader());
+
+               Assert.assertEquals("Incorrect key selector forwarded to stream 
operator", CustomKeySelector.class, selector.getClass());
+       }
+
+       private static class CustomKeySelector implements KeySelector<String, 
String> {
+
+               @Override
+               public String getKey(String value) throws Exception {
+                       return value;
+               }
+       }
+
        private static <T> int getParallelism(DataSet<T> dataSet) {
                //All concrete implementations of DataSet are operators so this 
should always be safe.
                return ((Operator) dataSet).getParallelism();
@@ -164,4 +192,11 @@ public class BootstrapTransformationTest extends 
AbstractTestBase {
                public void initializeState(FunctionInitializationContext 
context) throws Exception {
                }
        }
+
+       private static class ExampleKeyedStateBootstrapFunction extends 
KeyedStateBootstrapFunction<String, String> {
+
+               @Override
+               public void processElement(String value, Context ctx) throws 
Exception {
+               }
+       }
 }

Reply via email to