This is an automated email from the ASF dual-hosted git repository. martijnvisser pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-connector-kafka.git
commit f04a6abf7f75cd095526db747375fe273ed3110d Author: Zakelly <zakelly....@gmail.com> AuthorDate: Wed Oct 5 22:44:16 2022 +0800 [FLINK-29437] Align the partition of data before and after the Kafka Shuffle --- .../internals/KafkaTopicPartitionAssigner.java | 9 ++++-- .../kafka/shuffle/FlinkKafkaShuffle.java | 4 ++- .../kafka/shuffle/FlinkKafkaShuffleProducer.java | 15 ++++++++- .../kafka/shuffle/KafkaShuffleTestBase.java | 37 ++++++++++++++++++---- 4 files changed, 53 insertions(+), 12 deletions(-) diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartitionAssigner.java b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartitionAssigner.java index 9d12a80..be61e8a 100644 --- a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartitionAssigner.java +++ b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartitionAssigner.java @@ -47,12 +47,15 @@ public class KafkaTopicPartitionAssigner { * @return index of the target subtask that the Kafka partition should be assigned to. */ public static int assign(KafkaTopicPartition partition, int numParallelSubtasks) { - int startIndex = - ((partition.getTopic().hashCode() * 31) & 0x7FFFFFFF) % numParallelSubtasks; + return assign(partition.getTopic(), partition.getPartition(), numParallelSubtasks); + } + + public static int assign(String topic, int partition, int numParallelSubtasks) { + int startIndex = ((topic.hashCode() * 31) & 0x7FFFFFFF) % numParallelSubtasks; // here, the assumption is that the id of Kafka partitions are always ascending // starting from 0, and therefore can be used directly as the offset clockwise from the // start index - return (startIndex + partition.getPartition()) % numParallelSubtasks; + return (startIndex + partition) % numParallelSubtasks; } } diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java index 83c372f..58b09c9 100644 --- a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java +++ b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java @@ -343,7 +343,9 @@ public class FlinkKafkaShuffle { int numberOfPartitions = PropertiesUtil.getInt(kafkaProperties, PARTITION_NUMBER, Integer.MIN_VALUE); DataStream<T> outputDataStream = - env.addSource(kafkaConsumer).setParallelism(numberOfPartitions); + env.addSource(kafkaConsumer) + .setParallelism(numberOfPartitions) + .setMaxParallelism(numberOfPartitions); return DataStreamUtils.reinterpretAsKeyedStream(outputDataStream, keySelector); } diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java index d6632d3..e05e8f9 100644 --- a/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java +++ b/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java @@ -25,6 +25,7 @@ import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.connectors.kafka.FlinkKafkaException; import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer; +import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner; import org.apache.flink.util.Preconditions; import org.apache.flink.util.PropertiesUtil; @@ -32,6 +33,8 @@ import org.apache.kafka.clients.producer.ProducerRecord; import java.io.IOException; import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; import java.util.Properties; import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffle.PARTITION_NUMBER; @@ -46,6 +49,8 @@ public class FlinkKafkaShuffleProducer<IN, KEY> extends FlinkKafkaProducer<IN> { private final KeySelector<IN, KEY> keySelector; private final int numberOfPartitions; + private final Map<Integer, Integer> subtaskToPartitionMap; + FlinkKafkaShuffleProducer( String defaultTopicId, TypeSerializer<IN> typeSerializer, @@ -67,6 +72,7 @@ public class FlinkKafkaShuffleProducer<IN, KEY> extends FlinkKafkaProducer<IN> { props.getProperty(PARTITION_NUMBER) != null, "Missing partition number for Kafka Shuffle"); numberOfPartitions = PropertiesUtil.getInt(props, PARTITION_NUMBER, Integer.MIN_VALUE); + subtaskToPartitionMap = new HashMap<>(); } /** @@ -89,9 +95,10 @@ public class FlinkKafkaShuffleProducer<IN, KEY> extends FlinkKafkaProducer<IN> { int[] partitions = getPartitions(transaction); int partitionIndex; try { - partitionIndex = + int subtaskIndex = KeyGroupRangeAssignment.assignKeyToParallelOperator( keySelector.getKey(next), partitions.length, partitions.length); + partitionIndex = subtaskToPartitionMap.get(subtaskIndex); } catch (Exception e) { throw new RuntimeException("Fail to assign a partition number to record", e); } @@ -142,6 +149,12 @@ public class FlinkKafkaShuffleProducer<IN, KEY> extends FlinkKafkaProducer<IN> { if (partitions == null) { partitions = getPartitionsByTopic(defaultTopicId, transaction.getProducer()); topicPartitionsMap.put(defaultTopicId, partitions); + for (int i = 0; i < partitions.length; i++) { + subtaskToPartitionMap.put( + KafkaTopicPartitionAssigner.assign( + defaultTopicId, partitions[i], partitions.length), + partitions[i]); + } } Preconditions.checkArgument(partitions.length == numberOfPartitions); diff --git a/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleTestBase.java b/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleTestBase.java index a3d2beb..064aebd 100644 --- a/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleTestBase.java +++ b/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleTestBase.java @@ -21,6 +21,8 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.datastream.DataStream; @@ -34,7 +36,6 @@ import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer; import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase; import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase; import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl; -import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition; import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner; import org.apache.flink.test.util.SuccessException; import org.apache.flink.util.Collector; @@ -168,6 +169,7 @@ public class KafkaShuffleTestBase extends KafkaConsumerTestBase { private final KeySelector<Tuple3<Integer, Long, Integer>, Tuple> keySelector; private final int numberOfPartitions; private final String topic; + private KeyGroupRange keyGroupRange; private int previousPartition; @@ -181,24 +183,45 @@ public class KafkaShuffleTestBase extends KafkaConsumerTestBase { this.previousPartition = -1; } + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + this.keyGroupRange = + KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( + getRuntimeContext().getMaxNumberOfParallelSubtasks(), + numberOfPartitions, + getRuntimeContext().getIndexOfThisSubtask()); + } + @Override public void processElement( Tuple3<Integer, Long, Integer> in, Context ctx, Collector<Tuple3<Integer, Long, Integer>> out) throws Exception { - int expectedPartition = + int expectedSubtask = KeyGroupRangeAssignment.assignKeyToParallelOperator( keySelector.getKey(in), numberOfPartitions, numberOfPartitions); + int expectedPartition = -1; + // This is how Kafka assign partition to subTask; + for (int i = 0; i < numberOfPartitions; i++) { + if (KafkaTopicPartitionAssigner.assign(topic, i, numberOfPartitions) + == expectedSubtask) { + expectedPartition = i; + } + } int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask(); - KafkaTopicPartition partition = new KafkaTopicPartition(topic, expectedPartition); - // This is how Kafka assign partition to subTask; boolean rightAssignment = - KafkaTopicPartitionAssigner.assign(partition, numberOfPartitions) - == indexOfThisSubtask; + (expectedSubtask == indexOfThisSubtask) + && keyGroupRange.contains( + KeyGroupRangeAssignment.assignToKeyGroup( + keySelector.getKey(in), + getRuntimeContext().getMaxNumberOfParallelSubtasks())); boolean samePartition = - (previousPartition == expectedPartition) || (previousPartition == -1); + (expectedPartition != -1) + && ((previousPartition == expectedPartition) + || (previousPartition == -1)); previousPartition = expectedPartition; if (!(rightAssignment && samePartition)) {