This is an automated email from the ASF dual-hosted git repository.

martijnvisser pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-connector-kafka.git

commit f04a6abf7f75cd095526db747375fe273ed3110d
Author: Zakelly <zakelly....@gmail.com>
AuthorDate: Wed Oct 5 22:44:16 2022 +0800

    [FLINK-29437] Align the partition of data before and after the Kafka Shuffle
---
 .../internals/KafkaTopicPartitionAssigner.java     |  9 ++++--
 .../kafka/shuffle/FlinkKafkaShuffle.java           |  4 ++-
 .../kafka/shuffle/FlinkKafkaShuffleProducer.java   | 15 ++++++++-
 .../kafka/shuffle/KafkaShuffleTestBase.java        | 37 ++++++++++++++++++----
 4 files changed, 53 insertions(+), 12 deletions(-)

diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartitionAssigner.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartitionAssigner.java
index 9d12a80..be61e8a 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartitionAssigner.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartitionAssigner.java
@@ -47,12 +47,15 @@ public class KafkaTopicPartitionAssigner {
      * @return index of the target subtask that the Kafka partition should be 
assigned to.
      */
     public static int assign(KafkaTopicPartition partition, int 
numParallelSubtasks) {
-        int startIndex =
-                ((partition.getTopic().hashCode() * 31) & 0x7FFFFFFF) % 
numParallelSubtasks;
+        return assign(partition.getTopic(), partition.getPartition(), 
numParallelSubtasks);
+    }
+
+    public static int assign(String topic, int partition, int 
numParallelSubtasks) {
+        int startIndex = ((topic.hashCode() * 31) & 0x7FFFFFFF) % 
numParallelSubtasks;
 
         // here, the assumption is that the id of Kafka partitions are always 
ascending
         // starting from 0, and therefore can be used directly as the offset 
clockwise from the
         // start index
-        return (startIndex + partition.getPartition()) % numParallelSubtasks;
+        return (startIndex + partition) % numParallelSubtasks;
     }
 }
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
index 83c372f..58b09c9 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
@@ -343,7 +343,9 @@ public class FlinkKafkaShuffle {
         int numberOfPartitions =
                 PropertiesUtil.getInt(kafkaProperties, PARTITION_NUMBER, 
Integer.MIN_VALUE);
         DataStream<T> outputDataStream =
-                
env.addSource(kafkaConsumer).setParallelism(numberOfPartitions);
+                env.addSource(kafkaConsumer)
+                        .setParallelism(numberOfPartitions)
+                        .setMaxParallelism(numberOfPartitions);
 
         return DataStreamUtils.reinterpretAsKeyedStream(outputDataStream, 
keySelector);
     }
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java
index d6632d3..e05e8f9 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.connectors.kafka.FlinkKafkaException;
 import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.PropertiesUtil;
 
@@ -32,6 +33,8 @@ import org.apache.kafka.clients.producer.ProducerRecord;
 
 import java.io.IOException;
 import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
 import java.util.Properties;
 
 import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffle.PARTITION_NUMBER;
@@ -46,6 +49,8 @@ public class FlinkKafkaShuffleProducer<IN, KEY> extends 
FlinkKafkaProducer<IN> {
     private final KeySelector<IN, KEY> keySelector;
     private final int numberOfPartitions;
 
+    private final Map<Integer, Integer> subtaskToPartitionMap;
+
     FlinkKafkaShuffleProducer(
             String defaultTopicId,
             TypeSerializer<IN> typeSerializer,
@@ -67,6 +72,7 @@ public class FlinkKafkaShuffleProducer<IN, KEY> extends 
FlinkKafkaProducer<IN> {
                 props.getProperty(PARTITION_NUMBER) != null,
                 "Missing partition number for Kafka Shuffle");
         numberOfPartitions = PropertiesUtil.getInt(props, PARTITION_NUMBER, 
Integer.MIN_VALUE);
+        subtaskToPartitionMap = new HashMap<>();
     }
 
     /**
@@ -89,9 +95,10 @@ public class FlinkKafkaShuffleProducer<IN, KEY> extends 
FlinkKafkaProducer<IN> {
         int[] partitions = getPartitions(transaction);
         int partitionIndex;
         try {
-            partitionIndex =
+            int subtaskIndex =
                     KeyGroupRangeAssignment.assignKeyToParallelOperator(
                             keySelector.getKey(next), partitions.length, 
partitions.length);
+            partitionIndex = subtaskToPartitionMap.get(subtaskIndex);
         } catch (Exception e) {
             throw new RuntimeException("Fail to assign a partition number to 
record", e);
         }
@@ -142,6 +149,12 @@ public class FlinkKafkaShuffleProducer<IN, KEY> extends 
FlinkKafkaProducer<IN> {
         if (partitions == null) {
             partitions = getPartitionsByTopic(defaultTopicId, 
transaction.getProducer());
             topicPartitionsMap.put(defaultTopicId, partitions);
+            for (int i = 0; i < partitions.length; i++) {
+                subtaskToPartitionMap.put(
+                        KafkaTopicPartitionAssigner.assign(
+                                defaultTopicId, partitions[i], 
partitions.length),
+                        partitions[i]);
+            }
         }
 
         Preconditions.checkArgument(partitions.length == numberOfPartitions);
diff --git 
a/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleTestBase.java
 
b/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleTestBase.java
index a3d2beb..064aebd 100644
--- 
a/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleTestBase.java
+++ 
b/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleTestBase.java
@@ -21,6 +21,8 @@ import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple;
 import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.datastream.DataStream;
@@ -34,7 +36,6 @@ import 
org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
 import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
 import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
 import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
-import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
 import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
 import org.apache.flink.test.util.SuccessException;
 import org.apache.flink.util.Collector;
@@ -168,6 +169,7 @@ public class KafkaShuffleTestBase extends 
KafkaConsumerTestBase {
         private final KeySelector<Tuple3<Integer, Long, Integer>, Tuple> 
keySelector;
         private final int numberOfPartitions;
         private final String topic;
+        private KeyGroupRange keyGroupRange;
 
         private int previousPartition;
 
@@ -181,24 +183,45 @@ public class KafkaShuffleTestBase extends 
KafkaConsumerTestBase {
             this.previousPartition = -1;
         }
 
+        @Override
+        public void open(Configuration parameters) throws Exception {
+            super.open(parameters);
+            this.keyGroupRange =
+                    
KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(
+                            
getRuntimeContext().getMaxNumberOfParallelSubtasks(),
+                            numberOfPartitions,
+                            getRuntimeContext().getIndexOfThisSubtask());
+        }
+
         @Override
         public void processElement(
                 Tuple3<Integer, Long, Integer> in,
                 Context ctx,
                 Collector<Tuple3<Integer, Long, Integer>> out)
                 throws Exception {
-            int expectedPartition =
+            int expectedSubtask =
                     KeyGroupRangeAssignment.assignKeyToParallelOperator(
                             keySelector.getKey(in), numberOfPartitions, 
numberOfPartitions);
+            int expectedPartition = -1;
+            // This is how Kafka assign partition to subTask;
+            for (int i = 0; i < numberOfPartitions; i++) {
+                if (KafkaTopicPartitionAssigner.assign(topic, i, 
numberOfPartitions)
+                        == expectedSubtask) {
+                    expectedPartition = i;
+                }
+            }
             int indexOfThisSubtask = 
getRuntimeContext().getIndexOfThisSubtask();
-            KafkaTopicPartition partition = new KafkaTopicPartition(topic, 
expectedPartition);
 
-            // This is how Kafka assign partition to subTask;
             boolean rightAssignment =
-                    KafkaTopicPartitionAssigner.assign(partition, 
numberOfPartitions)
-                            == indexOfThisSubtask;
+                    (expectedSubtask == indexOfThisSubtask)
+                            && keyGroupRange.contains(
+                                    KeyGroupRangeAssignment.assignToKeyGroup(
+                                            keySelector.getKey(in),
+                                            
getRuntimeContext().getMaxNumberOfParallelSubtasks()));
             boolean samePartition =
-                    (previousPartition == expectedPartition) || 
(previousPartition == -1);
+                    (expectedPartition != -1)
+                            && ((previousPartition == expectedPartition)
+                                    || (previousPartition == -1));
             previousPartition = expectedPartition;
 
             if (!(rightAssignment && samePartition)) {

Reply via email to