AHeise commented on a change in pull request #11725:
URL: https://github.com/apache/flink/pull/11725#discussion_r425636250



##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> 
{

Review comment:
       I'd probably extract an `AbstractKafkaFetcher` from `KafkaFetcher` with 
an abstract `handleRecord`.

##########
File path: 
flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
##########
@@ -264,7 +264,7 @@ public FlinkKafkaConsumerBase(
         * @param properties - Kafka configuration properties to be adjusted
         * @param offsetCommitMode offset commit mode

Review comment:
       Please revise first line of commit message.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,379 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> 
{
+       private static final Logger LOG = 
LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+       /** The handler to check and generate watermarks from fetched records. 
**/
+       private final WatermarkHandler watermarkHandler;
+
+       /** The schema to convert between Kafka's byte messages, and Flink's 
objects. */
+       private final KafkaShuffleElementDeserializer<T> deserializer;
+
+       /** Serializer to serialize record. */
+       private final TypeSerializer<T> serializer;
+
+       /** The handover of data and exceptions between the consumer thread and 
the task thread. */
+       private final Handover handover;
+
+       /** The thread that runs the actual KafkaConsumer and hand the record 
batches to this fetcher. */
+       private final KafkaConsumerThread consumerThread;
+
+       /** Flag to mark the main work loop as alive. */
+       private volatile boolean running = true;
+
+       public KafkaShuffleFetcher(
+                       SourceFunction.SourceContext<T> sourceContext,
+                       Map<KafkaTopicPartition, Long> 
assignedPartitionsWithInitialOffsets,
+                       SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
+                       SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
+                       ProcessingTimeService processingTimeProvider,
+                       long autoWatermarkInterval,
+                       ClassLoader userCodeClassLoader,
+                       String taskNameWithSubtasks,
+                       TypeSerializer<T> serializer,
+                       Properties kafkaProperties,
+                       long pollTimeout,
+                       MetricGroup subtaskMetricGroup,
+                       MetricGroup consumerMetricGroup,
+                       boolean useMetrics,
+                       int producerParallelism) throws Exception {
+               super(
+                       sourceContext,
+                       assignedPartitionsWithInitialOffsets,
+                       watermarksPeriodic,
+                       watermarksPunctuated,
+                       processingTimeProvider,
+                       autoWatermarkInterval,
+                       userCodeClassLoader,
+                       consumerMetricGroup,
+                       useMetrics);
+               this.deserializer = new KafkaShuffleElementDeserializer<>();
+               this.serializer = serializer;
+               this.handover = new Handover();
+               this.consumerThread = new KafkaConsumerThread(
+                       LOG,
+                       handover,
+                       kafkaProperties,
+                       unassignedPartitionsQueue,
+                       getFetcherName() + " for " + taskNameWithSubtasks,
+                       pollTimeout,
+                       useMetrics,
+                       consumerMetricGroup,
+                       subtaskMetricGroup);
+               this.watermarkHandler = new 
WatermarkHandler(producerParallelism);
+       }
+
+       // 
------------------------------------------------------------------------
+       //  Fetcher work methods
+       // 
------------------------------------------------------------------------
+
+       @Override
+       public void runFetchLoop() throws Exception {
+               try {
+                       final Handover handover = this.handover;
+
+                       // kick off the actual Kafka consumer
+                       consumerThread.start();
+
+                       while (running) {
+                               // this blocks until we get the next records
+                               // it automatically re-throws exceptions 
encountered in the consumer thread
+                               final ConsumerRecords<byte[], byte[]> records = 
handover.pollNext();
+
+                               // get the records for each topic partition
+                               for (KafkaTopicPartitionState<TopicPartition> 
partition : subscribedPartitionStates()) {
+                                       List<ConsumerRecord<byte[], byte[]>> 
partitionRecords =
+                                               
records.records(partition.getKafkaPartitionHandle());
+
+                                       for (ConsumerRecord<byte[], byte[]> 
record : partitionRecords) {
+                                               final KafkaShuffleElement<T> 
element = deserializer.deserialize(serializer, record);
+
+                                               // TODO: do we need to check 
the end of stream if reaching the end watermark?
+
+                                               if (element.isRecord()) {
+                                                       // timestamp is 
inherent from upstream
+                                                       // If using 
ProcessTime, timestamp is going to be ignored (upstream does not include 
timestamp as well)
+                                                       // If using 
IngestionTime, timestamp is going to be overwritten
+                                                       // If using EventTime, 
timestamp is going to be used
+                                                       synchronized 
(checkpointLock) {
+                                                               
KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+                                                               
sourceContext.collectWithTimestamp(
+                                                                       
elementAsRecord.value,
+                                                                       
elementAsRecord.timestamp == null ? record.timestamp() : 
elementAsRecord.timestamp);
+                                                               
partition.setOffset(record.offset());
+                                                       }
+                                               } else if 
(element.isWatermark()) {
+                                                       final 
KafkaShuffleWatermark watermark = element.asWatermark();
+                                                       Optional<Watermark> 
newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+                                                       
newWatermark.ifPresent(sourceContext::emitWatermark);
+                                               }
+                                       }
+                               }
+                       }
+               }
+               finally {
+                       // this signals the consumer thread that no more work 
is to be done
+                       consumerThread.shutdown();
+               }
+
+               // on a clean exit, wait for the runner thread
+               try {
+                       consumerThread.join();
+               }
+               catch (InterruptedException e) {
+                       // may be the result of a wake-up interruption after an 
exception.
+                       // we ignore this here and only restore the 
interruption state
+                       Thread.currentThread().interrupt();
+               }
+       }
+
+       @Override
+       public void cancel() {
+               // flag the main thread to exit. A thread interrupt will come 
anyways.
+               running = false;
+               handover.close();
+               consumerThread.shutdown();
+       }
+
+       @Override
+       protected TopicPartition createKafkaPartitionHandle(KafkaTopicPartition 
partition) {
+               return new TopicPartition(partition.getTopic(), 
partition.getPartition());
+       }
+
+       @Override
+       protected void doCommitInternalOffsetsToKafka(
+                       Map<KafkaTopicPartition, Long> offsets,
+                       @Nonnull KafkaCommitCallback commitCallback) throws 
Exception {
+               @SuppressWarnings("unchecked")
+               List<KafkaTopicPartitionState<TopicPartition>> partitions = 
subscribedPartitionStates();
+
+               Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new 
HashMap<>(partitions.size());
+
+               for (KafkaTopicPartitionState<TopicPartition> partition : 
partitions) {
+                       Long lastProcessedOffset = 
offsets.get(partition.getKafkaTopicPartition());
+                       if (lastProcessedOffset != null) {
+                               checkState(lastProcessedOffset >= 0, "Illegal 
offset value to commit");
+
+                               // committed offsets through the KafkaConsumer 
need to be 1 more than the last processed offset.
+                               // This does not affect Flink's 
checkpoints/saved state.
+                               long offsetToCommit = lastProcessedOffset + 1;
+
+                               
offsetsToCommit.put(partition.getKafkaPartitionHandle(), new 
OffsetAndMetadata(offsetToCommit));
+                               partition.setCommittedOffset(offsetToCommit);
+                       }
+               }
+
+               // record the work to be committed by the main consumer thread 
and make sure the consumer notices that
+               consumerThread.setOffsetsToCommit(offsetsToCommit, 
commitCallback);
+       }
+
+       private String getFetcherName() {
+               return "Kafka Shuffle Fetcher";
+       }
+
+       /**
+        * An element in a KafkaShuffle. Can be a record or a Watermark.
+        */
+       @VisibleForTesting
+       public abstract static class KafkaShuffleElement<T> {
+
+               public boolean isRecord() {
+                       return getClass() == KafkaShuffleRecord.class;
+               }
+
+               public boolean isWatermark() {
+                       return getClass() == KafkaShuffleWatermark.class;
+               }
+
+               public KafkaShuffleRecord<T> asRecord() {
+                       return (KafkaShuffleRecord<T>) this;
+               }
+
+               public KafkaShuffleWatermark asWatermark() {
+                       return (KafkaShuffleWatermark) this;
+               }
+       }
+
+       /**
+        * A watermark element in a KafkaShuffle. It includes
+        * - subtask index where the watermark is coming from
+        * - watermark timestamp
+        */
+       @VisibleForTesting
+       public static class KafkaShuffleWatermark<T> extends 
KafkaShuffleElement<T> {
+               final int subtask;
+               final long watermark;
+
+               KafkaShuffleWatermark(int subtask, long watermark) {
+                       this.subtask = subtask;
+                       this.watermark = watermark;
+               }
+
+               public int getSubtask() {
+                       return subtask;
+               }
+
+               public long getWatermark() {
+                       return watermark;
+               }
+       }
+
+       /**
+        * One value with Type T in a KafkaShuffle. This stores the value and 
an optional associated timestamp.
+        */
+       @VisibleForTesting
+       public static class KafkaShuffleRecord<T> extends 
KafkaShuffleElement<T> {
+               final T value;
+               final Long timestamp;
+
+               KafkaShuffleRecord(T value) {
+                       this.value = value;
+                       this.timestamp = null;
+               }
+
+               KafkaShuffleRecord(long timestamp, T value) {
+                       this.value = value;
+                       this.timestamp = timestamp;
+               }
+
+               public T getValue() {
+                       return value;
+               }
+
+               public Long getTimestamp() {
+                       return timestamp;
+               }
+       }
+
+       /**
+        * Deserializer for KafkaShuffleElement.
+        */
+       @VisibleForTesting
+       public static class KafkaShuffleElementDeserializer<T> implements 
Serializable {

Review comment:
       Would be good to provide a `serialVersionUID`.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Experimental;
+import org.apache.flink.api.common.operators.Keys;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair 
together.
+ */
+@Experimental
+class FlinkKafkaShuffle {
+       static final String PRODUCER_PARALLELISM = "producer parallelism";
+       static final String PARTITION_NUMBER = "partition number";
+
+       /**
+        * Writes to and reads from a kafka shuffle with the partition decided 
by keys.
+        *
+        * <p>Consumers should read partitions equal to the key group indices 
they are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param producerParallelism Parallelism of producer
+        * @param numberOfPartitions Number of partitions
+        * @param properties Kafka properties
+        * @param fields Key positions from inputStream
+        * @param <T> Input type
+        */
+       public static <T> KeyedStream<T, Tuple> persistentKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       int producerParallelism,
+                       int numberOfPartitions,
+                       Properties properties,
+                       int... fields) {
+               return persistentKeyBy(
+                       inputStream,
+                       topic,
+                       producerParallelism,
+                       numberOfPartitions,
+                       properties,
+                       keySelector(inputStream, fields));
+       }
+
+       /**
+        * Writes to and reads from a kafka shuffle with the partition decided 
by keys.
+        *
+        * <p>Consumers should read partitions equal to the key group indices 
they are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param producerParallelism Parallelism of producer
+        * @param numberOfPartitions Number of partitions
+        * @param properties Kafka properties
+        * @param keySelector key(K) based on inputStream(T)
+        * @param <T> Input type
+        * @param <K> Key type
+        */
+       public static <T, K> KeyedStream<T, K> persistentKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       int producerParallelism,
+                       int numberOfPartitions,
+                       Properties properties,
+                       KeySelector<T, K> keySelector) {
+               // KafkaProducer#propsToMap uses Properties purely as a HashMap 
without considering the default properties
+               // So we have to flatten the default property to first level 
elements.
+               Properties kafkaProperties = PropertiesUtil.flatten(properties);
+               kafkaProperties.setProperty(PRODUCER_PARALLELISM, 
String.valueOf(producerParallelism));
+               kafkaProperties.setProperty(PARTITION_NUMBER, 
String.valueOf(numberOfPartitions));
+
+               StreamExecutionEnvironment env = 
inputStream.getExecutionEnvironment();
+               TypeInformationSerializationSchema<T> schema =
+                       new 
TypeInformationSerializationSchema<>(inputStream.getType(), env.getConfig());
+
+               writeKeyBy(inputStream, topic, kafkaProperties, keySelector);
+               return readKeyBy(topic, env, schema, kafkaProperties, 
keySelector);
+       }
+
+       /**
+        * Writes to a kafka shuffle with the partition decided by keys.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param kafkaProperties Kafka properties
+        * @param fields Key positions from inputStream
+        * @param <T> Input type
+        */
+       public static <T> void writeKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       Properties kafkaProperties,
+                       int... fields) {
+               writeKeyBy(inputStream, topic, kafkaProperties, 
keySelector(inputStream, fields));
+       }
+
+       /**
+        * Writes to a kafka shuffle with the partition decided by keys.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param kafkaProperties Kafka properties
+        * @param keySelector Key(K) based on input(T)
+        * @param <T> Input type
+        * @param <K> Key type
+        */
+       public static <T, K> void writeKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       Properties kafkaProperties,
+                       KeySelector<T, K> keySelector) {
+               StreamExecutionEnvironment env = 
inputStream.getExecutionEnvironment();
+               TypeInformationSerializationSchema<T> schema =
+                       new 
TypeInformationSerializationSchema<>(inputStream.getType(), env.getConfig());
+
+               // write data to Kafka
+               FlinkKafkaShuffleProducer<T, K> kafkaProducer = new 
FlinkKafkaShuffleProducer<>(
+                       topic,
+                       schema.getSerializer(),
+                       kafkaProperties,
+                       env.clean(keySelector),
+                       FlinkKafkaProducer.Semantic.EXACTLY_ONCE,
+                       FlinkKafkaProducer.DEFAULT_KAFKA_PRODUCERS_POOL_SIZE);
+
+               // make sure the sink parallelism is set to producerParallelism
+               Preconditions.checkArgument(
+                       kafkaProperties.getProperty(PRODUCER_PARALLELISM) != 
null,
+                       "Missing producer parallelism for Kafka Shuffle");
+               int producerParallelism = 
PropertiesUtil.getInt(kafkaProperties, PRODUCER_PARALLELISM, Integer.MIN_VALUE);
+
+               addKafkaShuffle(inputStream, kafkaProducer, 
producerParallelism);
+       }
+
+       /**
+        * Reads data from a Kafka Shuffle.

Review comment:
       Reads data from a Kafka Shuffle previously written by `writeKeyBy`.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Experimental;
+import org.apache.flink.api.common.operators.Keys;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair 
together.
+ */
+@Experimental
+class FlinkKafkaShuffle {
+       static final String PRODUCER_PARALLELISM = "producer parallelism";
+       static final String PARTITION_NUMBER = "partition number";
+
+       /**
+        * Writes to and reads from a kafka shuffle with the partition decided 
by keys.
+        *
+        * <p>Consumers should read partitions equal to the key group indices 
they are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param producerParallelism Parallelism of producer
+        * @param numberOfPartitions Number of partitions
+        * @param properties Kafka properties
+        * @param fields Key positions from inputStream
+        * @param <T> Input type
+        */
+       public static <T> KeyedStream<T, Tuple> persistentKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       int producerParallelism,
+                       int numberOfPartitions,
+                       Properties properties,
+                       int... fields) {
+               return persistentKeyBy(
+                       inputStream,
+                       topic,
+                       producerParallelism,
+                       numberOfPartitions,
+                       properties,
+                       keySelector(inputStream, fields));
+       }
+
+       /**
+        * Writes to and reads from a kafka shuffle with the partition decided 
by keys.
+        *
+        * <p>Consumers should read partitions equal to the key group indices 
they are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param producerParallelism Parallelism of producer
+        * @param numberOfPartitions Number of partitions
+        * @param properties Kafka properties
+        * @param keySelector key(K) based on inputStream(T)
+        * @param <T> Input type
+        * @param <K> Key type
+        */
+       public static <T, K> KeyedStream<T, K> persistentKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       int producerParallelism,
+                       int numberOfPartitions,
+                       Properties properties,
+                       KeySelector<T, K> keySelector) {
+               // KafkaProducer#propsToMap uses Properties purely as a HashMap 
without considering the default properties
+               // So we have to flatten the default property to first level 
elements.
+               Properties kafkaProperties = PropertiesUtil.flatten(properties);
+               kafkaProperties.setProperty(PRODUCER_PARALLELISM, 
String.valueOf(producerParallelism));
+               kafkaProperties.setProperty(PARTITION_NUMBER, 
String.valueOf(numberOfPartitions));
+
+               StreamExecutionEnvironment env = 
inputStream.getExecutionEnvironment();
+               TypeInformationSerializationSchema<T> schema =
+                       new 
TypeInformationSerializationSchema<>(inputStream.getType(), env.getConfig());
+
+               writeKeyBy(inputStream, topic, kafkaProperties, keySelector);
+               return readKeyBy(topic, env, schema, kafkaProperties, 
keySelector);
+       }
+
+       /**
+        * Writes to a kafka shuffle with the partition decided by keys.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param kafkaProperties Kafka properties
+        * @param fields Key positions from inputStream
+        * @param <T> Input type
+        */
+       public static <T> void writeKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       Properties kafkaProperties,
+                       int... fields) {
+               writeKeyBy(inputStream, topic, kafkaProperties, 
keySelector(inputStream, fields));
+       }
+
+       /**
+        * Writes to a kafka shuffle with the partition decided by keys.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param kafkaProperties Kafka properties
+        * @param keySelector Key(K) based on input(T)
+        * @param <T> Input type
+        * @param <K> Key type
+        */
+       public static <T, K> void writeKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       Properties kafkaProperties,
+                       KeySelector<T, K> keySelector) {
+               StreamExecutionEnvironment env = 
inputStream.getExecutionEnvironment();
+               TypeInformationSerializationSchema<T> schema =
+                       new 
TypeInformationSerializationSchema<>(inputStream.getType(), env.getConfig());
+
+               // write data to Kafka
+               FlinkKafkaShuffleProducer<T, K> kafkaProducer = new 
FlinkKafkaShuffleProducer<>(
+                       topic,
+                       schema.getSerializer(),
+                       kafkaProperties,
+                       env.clean(keySelector),
+                       FlinkKafkaProducer.Semantic.EXACTLY_ONCE,
+                       FlinkKafkaProducer.DEFAULT_KAFKA_PRODUCERS_POOL_SIZE);
+
+               // make sure the sink parallelism is set to producerParallelism
+               Preconditions.checkArgument(
+                       kafkaProperties.getProperty(PRODUCER_PARALLELISM) != 
null,
+                       "Missing producer parallelism for Kafka Shuffle");
+               int producerParallelism = 
PropertiesUtil.getInt(kafkaProperties, PRODUCER_PARALLELISM, Integer.MIN_VALUE);
+
+               addKafkaShuffle(inputStream, kafkaProducer, 
producerParallelism);
+       }
+
+       /**
+        * Reads data from a Kafka Shuffle.
+        *
+        * <p>Consumers should read partitions equal to the key group indices 
they are assigned.
+        * The number of partitions is equivalent to the key group sizes. Each 
consumer task reads from
+        * one or multiple partitions. Any two consumer tasks can not read from 
the same partition.
+        * Hence, the maximum parallelism of the receiving operator is the 
number of partitions.
+        * This version only supports numberOfPartitions = consumerParallelism
+        *
+        * @param topic Kafka topic
+        * @param env Streaming execution environment. readKeyBy's environment 
can be different from writeKeyBy's
+        * @param schema The record schema to read
+        * @param kafkaProperties Kafka properties
+        * @param keySelector Key(K) based on schema(T)
+        * @param <T> Schema type
+        * @param <K> Key type
+        * @return Keyed data stream
+        */
+       public static <T, K> KeyedStream<T, K> readKeyBy(
+                       String topic,
+                       StreamExecutionEnvironment env,
+                       TypeInformationSerializationSchema<T> schema,

Review comment:
       I'd go with `TypeInformation` here.  
`TypeInformationSerializationSchema` is rather technical and can be easily 
derived from `TypeInformation`.

##########
File path: 
flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleExactlyOnceITCase.java
##########
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.connectors.kafka.testutils.FailingIdentityMapper;
+import 
org.apache.flink.streaming.connectors.kafka.testutils.ValidatingExactlyOnceSink;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Failure Recovery IT Test for KafkaShuffle.
+ */
+public class KafkaShuffleExactlyOnceITCase extends KafkaShuffleTestBase {
+
+       @Rule
+       public final Timeout timeout = Timeout.millis(600000L);
+
+       /**
+        * Failure Recovery after processing 2/3 data with time characteristic: 
ProcessingTime.
+        *
+        * <p>Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        */
+       @Test
+       public void testFailureRecoveryProcessingTime() throws Exception {
+               testKafkaShuffleFailureRecovery("failure_recovery", 1000, 
ProcessingTime);
+       }
+
+       /**
+        * Failure Recovery after processing 2/3 data with time characteristic: 
IngestionTime.
+        *
+        * <p>Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        */
+       @Test
+       public void testFailureRecoveryIngestionTime() throws Exception {
+               testKafkaShuffleFailureRecovery("failure_recovery", 1000, 
IngestionTime);
+       }
+
+       /**
+        * Failure Recovery after processing 2/3 data with time characteristic: 
EventTime.
+        *
+        * <p>Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        */
+       @Test
+       public void testFailureRecoveryEventTime() throws Exception {
+               testKafkaShuffleFailureRecovery("failure_recovery", 1000, 
EventTime);
+       }
+
+       /**
+        * Failure Recovery after data is repartitioned with time 
characteristic: ProcessingTime.
+        *
+        * <p>Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        */
+       @Test
+       public void testAssignedToPartitionFailureRecoveryProcessingTime() 
throws Exception {
+               
testAssignedToPartitionFailureRecovery("partition_failure_recovery", 500, 
ProcessingTime);
+       }
+
+       /**
+        * Failure Recovery after data is repartitioned with time 
characteristic: IngestionTime.
+        *
+        * <p>Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        */
+       @Test
+       public void testAssignedToPartitionFailureRecoveryIngestionTime() 
throws Exception {
+               
testAssignedToPartitionFailureRecovery("partition_failure_recovery", 500, 
IngestionTime);
+       }
+
+       /**
+        * Failure Recovery after data is repartitioned with time 
characteristic: EventTime.
+        *
+        * <p>Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        */
+       @Test
+       public void testAssignedToPartitionFailureRecoveryEventTime() throws 
Exception {
+               
testAssignedToPartitionFailureRecovery("partition_failure_recovery", 500, 
EventTime);
+       }
+
+       /**
+        * To test failure recovery after processing 2/3 data.
+        *
+        * <p>Schema: (key, timestamp, source instance Id).
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1
+        */
+       private void testKafkaShuffleFailureRecovery(
+                       String prefix, int numElementsPerProducer, 
TimeCharacteristic timeCharacteristic) throws Exception {
+               String topic = topic(prefix, timeCharacteristic);
+               final int numberOfPartitions = 1;
+               final int producerParallelism = 1;
+               final int failAfterElements = numElementsPerProducer * 
numberOfPartitions * 2 / 3;
+
+               createTestTopic(topic, numberOfPartitions, 1);
+
+               final StreamExecutionEnvironment env =
+                       createEnvironment(producerParallelism, 
timeCharacteristic).enableCheckpointing(500);
+
+               createKafkaShuffle(
+                       env, topic, numElementsPerProducer, 
producerParallelism, timeCharacteristic, numberOfPartitions)
+                       .map(new 
FailingIdentityMapper<>(failAfterElements)).setParallelism(1)
+                       .map(new 
ToInteger(producerParallelism)).setParallelism(1)
+                       .addSink(new 
ValidatingExactlyOnceSink(numElementsPerProducer * 
producerParallelism)).setParallelism(1);
+
+               FailingIdentityMapper.failedBefore = false;
+
+               tryExecute(env, topic);
+
+               deleteTestTopic(topic);
+       }
+
+       /**
+        * To test failure recovery with partition assignment after processing 
2/3 data.
+        *
+        * <p>Schema: (key, timestamp, source instance Id).
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3
+        */
+       private void testAssignedToPartitionFailureRecovery(
+                       String prefix,
+                       int numElementsPerProducer,
+                       TimeCharacteristic timeCharacteristic) throws Exception 
{
+               String topic = topic(prefix, timeCharacteristic);
+               final int numberOfPartitions = 3;
+               final int producerParallelism = 2;
+               final int failAfterElements = numElementsPerProducer * 
producerParallelism * 2 / 3;
+
+               createTestTopic(topic, numberOfPartitions, 1);
+
+               final StreamExecutionEnvironment env = 
createEnvironment(producerParallelism, timeCharacteristic);
+
+               KeyedStream<Tuple3<Integer, Long, Integer>, Tuple> keyedStream 
= createKafkaShuffle(
+                       env,
+                       topic,
+                       numElementsPerProducer,
+                       producerParallelism,
+                       timeCharacteristic,
+                       numberOfPartitions);
+               keyedStream
+                       .process(new 
PartitionValidator(keyedStream.getKeySelector(), numberOfPartitions, topic))
+                       .setParallelism(numberOfPartitions)
+                       .map(new 
ToInteger(producerParallelism)).setParallelism(numberOfPartitions)
+                       .map(new 
FailingIdentityMapper<>(failAfterElements)).setParallelism(1)
+                       .addSink(new 
ValidatingExactlyOnceSink(numElementsPerProducer * 
producerParallelism)).setParallelism(1);
+
+               FailingIdentityMapper.failedBefore = false;
+
+               tryExecute(env, topic);
+
+               deleteTestTopic(topic);
+       }
+
+       private StreamExecutionEnvironment createEnvironment(
+                       int producerParallelism, TimeCharacteristic 
timeCharacteristic) {
+               final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(producerParallelism);
+               env.setStreamTimeCharacteristic(timeCharacteristic);
+               env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 
0));
+               env.setBufferTimeout(0);
+               env.enableCheckpointing(500);
+
+               return env;
+       }
+
+       private static class ToInteger implements MapFunction<Tuple3<Integer, 
Long, Integer>, Integer> {
+               private final int producerParallelism;
+
+               ToInteger(int producerParallelism) {
+                       this.producerParallelism = producerParallelism;
+               }
+
+               @Override
+               public Integer map(Tuple3<Integer, Long, Integer> element) 
throws Exception {
+                       int addedInteger = element.f0 * producerParallelism + 
element.f2;
+                       System.out.println("<" + element.f0 + "," + element.f2 
+ "> " + addedInteger);

Review comment:
       Remove or translate into LOG.debug

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleConsumer.java
##########
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Internal;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer;
+import org.apache.flink.streaming.connectors.kafka.config.OffsetCommitMode;
+import 
org.apache.flink.streaming.connectors.kafka.internal.KafkaShuffleFetcher;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+import org.apache.flink.util.SerializedValue;
+
+import java.util.Map;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffle.PRODUCER_PARALLELISM;
+
+/**
+ * Flink Kafka Shuffle Consumer Function.
+ */
+@Internal
+public class FlinkKafkaShuffleConsumer<T> extends FlinkKafkaConsumer<T> {
+       private final TypeSerializer<T> serializer;
+       private final int producerParallelism;
+
+       FlinkKafkaShuffleConsumer(String topic, 
TypeInformationSerializationSchema<T> schema, Properties props) {

Review comment:
       If `KafkaShuffleElementDeserializer` would implement 
`KafkaDeserializationSchema` you could directly pass it. That would use more of 
the abstractions that are already there.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Experimental;
+import org.apache.flink.api.common.operators.Keys;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair 
together.
+ */
+@Experimental
+class FlinkKafkaShuffle {
+       static final String PRODUCER_PARALLELISM = "producer parallelism";
+       static final String PARTITION_NUMBER = "partition number";
+
+       /**
+        * Write to and read from a kafka shuffle with the partition decided by 
keys.
+        * Consumers should read partitions equal to the key group indices they 
are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.

Review comment:
       > the producer parallelism does not matter with the max key group size?
   
   If it doesn't matter, why can it be set by the user? Please give users some 
guidance on how to use your function. How should they know how it's implemented 
internally? You could also add it as a class comment to outline the approach.
   
   In general, there are quite a bit of comments and variable names referring 
to producer, but there is no explanation what the producer is in KafkaShuffle. 
I'd probably harmonize source/consumer and sink/producer and only use either 
word.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,379 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> 
{
+       private static final Logger LOG = 
LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+       /** The handler to check and generate watermarks from fetched records. 
**/
+       private final WatermarkHandler watermarkHandler;
+
+       /** The schema to convert between Kafka's byte messages, and Flink's 
objects. */
+       private final KafkaShuffleElementDeserializer<T> deserializer;
+
+       /** Serializer to serialize record. */
+       private final TypeSerializer<T> serializer;
+
+       /** The handover of data and exceptions between the consumer thread and 
the task thread. */
+       private final Handover handover;
+
+       /** The thread that runs the actual KafkaConsumer and hand the record 
batches to this fetcher. */
+       private final KafkaConsumerThread consumerThread;
+
+       /** Flag to mark the main work loop as alive. */
+       private volatile boolean running = true;
+
+       public KafkaShuffleFetcher(
+                       SourceFunction.SourceContext<T> sourceContext,
+                       Map<KafkaTopicPartition, Long> 
assignedPartitionsWithInitialOffsets,
+                       SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
+                       SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
+                       ProcessingTimeService processingTimeProvider,
+                       long autoWatermarkInterval,
+                       ClassLoader userCodeClassLoader,
+                       String taskNameWithSubtasks,
+                       TypeSerializer<T> serializer,
+                       Properties kafkaProperties,
+                       long pollTimeout,
+                       MetricGroup subtaskMetricGroup,
+                       MetricGroup consumerMetricGroup,
+                       boolean useMetrics,
+                       int producerParallelism) throws Exception {
+               super(
+                       sourceContext,
+                       assignedPartitionsWithInitialOffsets,
+                       watermarksPeriodic,
+                       watermarksPunctuated,
+                       processingTimeProvider,
+                       autoWatermarkInterval,
+                       userCodeClassLoader,
+                       consumerMetricGroup,
+                       useMetrics);
+               this.deserializer = new KafkaShuffleElementDeserializer<>();
+               this.serializer = serializer;
+               this.handover = new Handover();
+               this.consumerThread = new KafkaConsumerThread(
+                       LOG,
+                       handover,
+                       kafkaProperties,
+                       unassignedPartitionsQueue,
+                       getFetcherName() + " for " + taskNameWithSubtasks,
+                       pollTimeout,
+                       useMetrics,
+                       consumerMetricGroup,
+                       subtaskMetricGroup);
+               this.watermarkHandler = new 
WatermarkHandler(producerParallelism);
+       }
+
+       // 
------------------------------------------------------------------------
+       //  Fetcher work methods
+       // 
------------------------------------------------------------------------
+
+       @Override
+       public void runFetchLoop() throws Exception {
+               try {
+                       final Handover handover = this.handover;
+
+                       // kick off the actual Kafka consumer
+                       consumerThread.start();
+
+                       while (running) {
+                               // this blocks until we get the next records
+                               // it automatically re-throws exceptions 
encountered in the consumer thread
+                               final ConsumerRecords<byte[], byte[]> records = 
handover.pollNext();
+
+                               // get the records for each topic partition
+                               for (KafkaTopicPartitionState<TopicPartition> 
partition : subscribedPartitionStates()) {
+                                       List<ConsumerRecord<byte[], byte[]>> 
partitionRecords =
+                                               
records.records(partition.getKafkaPartitionHandle());
+
+                                       for (ConsumerRecord<byte[], byte[]> 
record : partitionRecords) {
+                                               final KafkaShuffleElement<T> 
element = deserializer.deserialize(serializer, record);
+
+                                               // TODO: do we need to check 
the end of stream if reaching the end watermark?
+
+                                               if (element.isRecord()) {
+                                                       // timestamp is 
inherent from upstream
+                                                       // If using 
ProcessTime, timestamp is going to be ignored (upstream does not include 
timestamp as well)
+                                                       // If using 
IngestionTime, timestamp is going to be overwritten
+                                                       // If using EventTime, 
timestamp is going to be used
+                                                       synchronized 
(checkpointLock) {
+                                                               
KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+                                                               
sourceContext.collectWithTimestamp(
+                                                                       
elementAsRecord.value,
+                                                                       
elementAsRecord.timestamp == null ? record.timestamp() : 
elementAsRecord.timestamp);
+                                                               
partition.setOffset(record.offset());
+                                                       }
+                                               } else if 
(element.isWatermark()) {
+                                                       final 
KafkaShuffleWatermark watermark = element.asWatermark();
+                                                       Optional<Watermark> 
newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+                                                       
newWatermark.ifPresent(sourceContext::emitWatermark);
+                                               }
+                                       }
+                               }
+                       }
+               }
+               finally {
+                       // this signals the consumer thread that no more work 
is to be done
+                       consumerThread.shutdown();
+               }
+
+               // on a clean exit, wait for the runner thread
+               try {
+                       consumerThread.join();
+               }
+               catch (InterruptedException e) {
+                       // may be the result of a wake-up interruption after an 
exception.
+                       // we ignore this here and only restore the 
interruption state
+                       Thread.currentThread().interrupt();
+               }
+       }
+
+       @Override
+       public void cancel() {
+               // flag the main thread to exit. A thread interrupt will come 
anyways.
+               running = false;
+               handover.close();
+               consumerThread.shutdown();
+       }
+
+       @Override
+       protected TopicPartition createKafkaPartitionHandle(KafkaTopicPartition 
partition) {
+               return new TopicPartition(partition.getTopic(), 
partition.getPartition());
+       }
+
+       @Override
+       protected void doCommitInternalOffsetsToKafka(
+                       Map<KafkaTopicPartition, Long> offsets,
+                       @Nonnull KafkaCommitCallback commitCallback) throws 
Exception {
+               @SuppressWarnings("unchecked")
+               List<KafkaTopicPartitionState<TopicPartition>> partitions = 
subscribedPartitionStates();
+
+               Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new 
HashMap<>(partitions.size());
+
+               for (KafkaTopicPartitionState<TopicPartition> partition : 
partitions) {
+                       Long lastProcessedOffset = 
offsets.get(partition.getKafkaTopicPartition());
+                       if (lastProcessedOffset != null) {
+                               checkState(lastProcessedOffset >= 0, "Illegal 
offset value to commit");
+
+                               // committed offsets through the KafkaConsumer 
need to be 1 more than the last processed offset.
+                               // This does not affect Flink's 
checkpoints/saved state.
+                               long offsetToCommit = lastProcessedOffset + 1;
+
+                               
offsetsToCommit.put(partition.getKafkaPartitionHandle(), new 
OffsetAndMetadata(offsetToCommit));
+                               partition.setCommittedOffset(offsetToCommit);
+                       }
+               }
+
+               // record the work to be committed by the main consumer thread 
and make sure the consumer notices that
+               consumerThread.setOffsetsToCommit(offsetsToCommit, 
commitCallback);
+       }
+
+       private String getFetcherName() {
+               return "Kafka Shuffle Fetcher";
+       }
+
+       /**
+        * An element in a KafkaShuffle. Can be a record or a Watermark.
+        */
+       @VisibleForTesting
+       public abstract static class KafkaShuffleElement<T> {
+
+               public boolean isRecord() {
+                       return getClass() == KafkaShuffleRecord.class;
+               }
+
+               public boolean isWatermark() {
+                       return getClass() == KafkaShuffleWatermark.class;
+               }
+
+               public KafkaShuffleRecord<T> asRecord() {
+                       return (KafkaShuffleRecord<T>) this;
+               }
+
+               public KafkaShuffleWatermark asWatermark() {
+                       return (KafkaShuffleWatermark) this;
+               }
+       }
+
+       /**
+        * A watermark element in a KafkaShuffle. It includes
+        * - subtask index where the watermark is coming from
+        * - watermark timestamp
+        */
+       @VisibleForTesting
+       public static class KafkaShuffleWatermark<T> extends 
KafkaShuffleElement<T> {
+               final int subtask;
+               final long watermark;
+
+               KafkaShuffleWatermark(int subtask, long watermark) {
+                       this.subtask = subtask;
+                       this.watermark = watermark;
+               }
+
+               public int getSubtask() {
+                       return subtask;
+               }
+
+               public long getWatermark() {
+                       return watermark;
+               }
+       }
+
+       /**
+        * One value with Type T in a KafkaShuffle. This stores the value and 
an optional associated timestamp.
+        */
+       @VisibleForTesting
+       public static class KafkaShuffleRecord<T> extends 
KafkaShuffleElement<T> {
+               final T value;
+               final Long timestamp;
+
+               KafkaShuffleRecord(T value) {
+                       this.value = value;
+                       this.timestamp = null;
+               }
+
+               KafkaShuffleRecord(long timestamp, T value) {
+                       this.value = value;
+                       this.timestamp = timestamp;
+               }
+
+               public T getValue() {
+                       return value;
+               }
+
+               public Long getTimestamp() {
+                       return timestamp;
+               }
+       }
+
+       /**
+        * Deserializer for KafkaShuffleElement.
+        */
+       @VisibleForTesting
+       public static class KafkaShuffleElementDeserializer<T> implements 
Serializable {
+               private transient DataInputDeserializer dis;
+
+               @VisibleForTesting
+               public KafkaShuffleElementDeserializer() {
+                       this.dis = new DataInputDeserializer();
+               }
+
+               @VisibleForTesting
+               public KafkaShuffleElement<T> deserialize(TypeSerializer<T> 
serializer, ConsumerRecord<byte[], byte[]> record)

Review comment:
       Pass `serializer` as ctor parameter.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,379 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> 
{
+       private static final Logger LOG = 
LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+       /** The handler to check and generate watermarks from fetched records. 
**/
+       private final WatermarkHandler watermarkHandler;
+
+       /** The schema to convert between Kafka's byte messages, and Flink's 
objects. */
+       private final KafkaShuffleElementDeserializer<T> deserializer;
+
+       /** Serializer to serialize record. */
+       private final TypeSerializer<T> serializer;
+
+       /** The handover of data and exceptions between the consumer thread and 
the task thread. */
+       private final Handover handover;
+
+       /** The thread that runs the actual KafkaConsumer and hand the record 
batches to this fetcher. */
+       private final KafkaConsumerThread consumerThread;
+
+       /** Flag to mark the main work loop as alive. */
+       private volatile boolean running = true;
+
+       public KafkaShuffleFetcher(
+                       SourceFunction.SourceContext<T> sourceContext,
+                       Map<KafkaTopicPartition, Long> 
assignedPartitionsWithInitialOffsets,
+                       SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
+                       SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
+                       ProcessingTimeService processingTimeProvider,
+                       long autoWatermarkInterval,
+                       ClassLoader userCodeClassLoader,
+                       String taskNameWithSubtasks,
+                       TypeSerializer<T> serializer,
+                       Properties kafkaProperties,
+                       long pollTimeout,
+                       MetricGroup subtaskMetricGroup,
+                       MetricGroup consumerMetricGroup,
+                       boolean useMetrics,
+                       int producerParallelism) throws Exception {
+               super(
+                       sourceContext,
+                       assignedPartitionsWithInitialOffsets,
+                       watermarksPeriodic,
+                       watermarksPunctuated,
+                       processingTimeProvider,
+                       autoWatermarkInterval,
+                       userCodeClassLoader,
+                       consumerMetricGroup,
+                       useMetrics);
+               this.deserializer = new KafkaShuffleElementDeserializer<>();
+               this.serializer = serializer;
+               this.handover = new Handover();
+               this.consumerThread = new KafkaConsumerThread(
+                       LOG,
+                       handover,
+                       kafkaProperties,
+                       unassignedPartitionsQueue,
+                       getFetcherName() + " for " + taskNameWithSubtasks,
+                       pollTimeout,
+                       useMetrics,
+                       consumerMetricGroup,
+                       subtaskMetricGroup);
+               this.watermarkHandler = new 
WatermarkHandler(producerParallelism);
+       }
+
+       // 
------------------------------------------------------------------------
+       //  Fetcher work methods
+       // 
------------------------------------------------------------------------
+
+       @Override
+       public void runFetchLoop() throws Exception {
+               try {
+                       final Handover handover = this.handover;
+
+                       // kick off the actual Kafka consumer
+                       consumerThread.start();
+
+                       while (running) {
+                               // this blocks until we get the next records
+                               // it automatically re-throws exceptions 
encountered in the consumer thread
+                               final ConsumerRecords<byte[], byte[]> records = 
handover.pollNext();
+
+                               // get the records for each topic partition
+                               for (KafkaTopicPartitionState<TopicPartition> 
partition : subscribedPartitionStates()) {
+                                       List<ConsumerRecord<byte[], byte[]>> 
partitionRecords =
+                                               
records.records(partition.getKafkaPartitionHandle());
+
+                                       for (ConsumerRecord<byte[], byte[]> 
record : partitionRecords) {
+                                               final KafkaShuffleElement<T> 
element = deserializer.deserialize(serializer, record);
+
+                                               // TODO: do we need to check 
the end of stream if reaching the end watermark?
+
+                                               if (element.isRecord()) {
+                                                       // timestamp is 
inherent from upstream
+                                                       // If using 
ProcessTime, timestamp is going to be ignored (upstream does not include 
timestamp as well)
+                                                       // If using 
IngestionTime, timestamp is going to be overwritten
+                                                       // If using EventTime, 
timestamp is going to be used
+                                                       synchronized 
(checkpointLock) {
+                                                               
KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+                                                               
sourceContext.collectWithTimestamp(
+                                                                       
elementAsRecord.value,
+                                                                       
elementAsRecord.timestamp == null ? record.timestamp() : 
elementAsRecord.timestamp);
+                                                               
partition.setOffset(record.offset());
+                                                       }
+                                               } else if 
(element.isWatermark()) {
+                                                       final 
KafkaShuffleWatermark watermark = element.asWatermark();
+                                                       Optional<Watermark> 
newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+                                                       
newWatermark.ifPresent(sourceContext::emitWatermark);
+                                               }
+                                       }
+                               }
+                       }
+               }
+               finally {
+                       // this signals the consumer thread that no more work 
is to be done
+                       consumerThread.shutdown();
+               }
+
+               // on a clean exit, wait for the runner thread
+               try {
+                       consumerThread.join();
+               }
+               catch (InterruptedException e) {
+                       // may be the result of a wake-up interruption after an 
exception.
+                       // we ignore this here and only restore the 
interruption state
+                       Thread.currentThread().interrupt();
+               }
+       }
+
+       @Override
+       public void cancel() {
+               // flag the main thread to exit. A thread interrupt will come 
anyways.
+               running = false;
+               handover.close();
+               consumerThread.shutdown();
+       }
+
+       @Override
+       protected TopicPartition createKafkaPartitionHandle(KafkaTopicPartition 
partition) {
+               return new TopicPartition(partition.getTopic(), 
partition.getPartition());
+       }
+
+       @Override
+       protected void doCommitInternalOffsetsToKafka(
+                       Map<KafkaTopicPartition, Long> offsets,
+                       @Nonnull KafkaCommitCallback commitCallback) throws 
Exception {
+               @SuppressWarnings("unchecked")
+               List<KafkaTopicPartitionState<TopicPartition>> partitions = 
subscribedPartitionStates();
+
+               Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new 
HashMap<>(partitions.size());
+
+               for (KafkaTopicPartitionState<TopicPartition> partition : 
partitions) {
+                       Long lastProcessedOffset = 
offsets.get(partition.getKafkaTopicPartition());
+                       if (lastProcessedOffset != null) {
+                               checkState(lastProcessedOffset >= 0, "Illegal 
offset value to commit");
+
+                               // committed offsets through the KafkaConsumer 
need to be 1 more than the last processed offset.
+                               // This does not affect Flink's 
checkpoints/saved state.
+                               long offsetToCommit = lastProcessedOffset + 1;
+
+                               
offsetsToCommit.put(partition.getKafkaPartitionHandle(), new 
OffsetAndMetadata(offsetToCommit));
+                               partition.setCommittedOffset(offsetToCommit);
+                       }
+               }
+
+               // record the work to be committed by the main consumer thread 
and make sure the consumer notices that
+               consumerThread.setOffsetsToCommit(offsetsToCommit, 
commitCallback);
+       }
+
+       private String getFetcherName() {
+               return "Kafka Shuffle Fetcher";
+       }
+
+       /**
+        * An element in a KafkaShuffle. Can be a record or a Watermark.
+        */
+       @VisibleForTesting
+       public abstract static class KafkaShuffleElement<T> {
+
+               public boolean isRecord() {
+                       return getClass() == KafkaShuffleRecord.class;
+               }
+
+               public boolean isWatermark() {
+                       return getClass() == KafkaShuffleWatermark.class;
+               }
+
+               public KafkaShuffleRecord<T> asRecord() {
+                       return (KafkaShuffleRecord<T>) this;
+               }
+
+               public KafkaShuffleWatermark asWatermark() {
+                       return (KafkaShuffleWatermark) this;
+               }
+       }
+
+       /**
+        * A watermark element in a KafkaShuffle. It includes
+        * - subtask index where the watermark is coming from
+        * - watermark timestamp
+        */
+       @VisibleForTesting
+       public static class KafkaShuffleWatermark<T> extends 
KafkaShuffleElement<T> {

Review comment:
       `KafkaShuffleWatermark` is always used as a raw type, so please get rid 
of the type parameter. 
   
   I also think that `KafkaShuffleElement<T>` should be without type parameter 
and only `asRecord` has it.
   ```
        public <T> KafkaShuffleRecord<T> asRecord() {
                        return (KafkaShuffleRecord<T>) this;
        }
   ```

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java
##########
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Internal;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaException;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import org.apache.kafka.clients.producer.ProducerRecord;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffle.PARTITION_NUMBER;
+
+/**
+ * Flink Kafka Shuffle Producer Function.
+ * It is different from {@link FlinkKafkaProducer} in the way handling 
elements and watermarks
+ */
+@Internal
+public class FlinkKafkaShuffleProducer<IN, KEY> extends FlinkKafkaProducer<IN> 
{
+       private final KafkaSerializer<IN> kafkaSerializer;
+       private final KeySelector<IN, KEY> keySelector;
+       private final int numberOfPartitions;
+
+       FlinkKafkaShuffleProducer(
+                       String defaultTopicId,
+                       TypeInformationSerializationSchema<IN> schema,
+                       Properties props,
+                       KeySelector<IN, KEY> keySelector,
+                       Semantic semantic,
+                       int kafkaProducersPoolSize) {
+               super(defaultTopicId, (element, timestamp) -> null, props, 
semantic, kafkaProducersPoolSize);
+
+               this.kafkaSerializer = new 
KafkaSerializer<>(schema.getSerializer());
+               this.keySelector = keySelector;
+
+               Preconditions.checkArgument(
+                       props.getProperty(PARTITION_NUMBER) != null,
+                       "Missing partition number for Kafka Shuffle");
+               numberOfPartitions = PropertiesUtil.getInt(props, 
PARTITION_NUMBER, Integer.MIN_VALUE);
+       }
+
+       /**
+        * This is the function invoked to handle each element.
+        * @param transaction transaction state;
+        *                    elements are written to Kafka in transactions to 
guarantee different level of data consistency
+        * @param next element to handle
+        * @param context context needed to handle the element
+        * @throws FlinkKafkaException for kafka error
+        */
+       @Override
+       public void invoke(KafkaTransactionState transaction, IN next, Context 
context) throws FlinkKafkaException {
+               checkErroneous();
+
+               // write timestamp to Kafka if timestamp is available
+               Long timestamp = context.timestamp();
+
+               int[] partitions = getPartitions(transaction);
+               int partitionIndex;
+               try {
+                       partitionIndex = KeyGroupRangeAssignment
+                               
.assignKeyToParallelOperator(keySelector.getKey(next), partitions.length, 
partitions.length);
+               } catch (Exception e) {
+                       throw new RuntimeException("Fail to assign a partition 
number to record");
+               }
+
+               ProducerRecord<byte[], byte[]> record = new ProducerRecord<>(
+                       defaultTopicId, partitionIndex, timestamp, null, 
kafkaSerializer.serializeRecord(next, timestamp));
+               pendingRecords.incrementAndGet();
+               transaction.getProducer().send(record, callback);
+       }
+
+       /**
+        * This is the function invoked to handle each watermark.
+        * @param watermark watermark to handle
+        * @throws FlinkKafkaException for kafka error
+        */
+       public void invoke(Watermark watermark) throws FlinkKafkaException {
+               checkErroneous();
+               KafkaTransactionState transaction = currentTransaction();
+
+               int[] partitions = getPartitions(transaction);
+               int subtask = getRuntimeContext().getIndexOfThisSubtask();
+
+               // broadcast watermark
+               long timestamp = watermark.getTimestamp();
+               for (int partition : partitions) {
+                       ProducerRecord<byte[], byte[]> record = new 
ProducerRecord<>(
+                               defaultTopicId, partition, timestamp, null, 
kafkaSerializer.serializeWatermark(watermark, subtask));
+                       pendingRecords.incrementAndGet();
+                       transaction.getProducer().send(record, callback);
+               }
+       }
+
+       private int[] getPartitions(KafkaTransactionState transaction) {
+               int[] partitions = topicPartitionsMap.get(defaultTopicId);
+               if (partitions == null) {
+                       partitions = getPartitionsByTopic(defaultTopicId, 
transaction.getProducer());
+                       topicPartitionsMap.put(defaultTopicId, partitions);
+               }
+
+               Preconditions.checkArgument(partitions.length == 
numberOfPartitions);
+
+               return partitions;
+       }
+
+       /**
+        * Flink Kafka Shuffle Serializer.
+        */
+       public static final class KafkaSerializer<IN> implements Serializable {
+               public static final int TAG_REC_WITH_TIMESTAMP = 0;
+               public static final int TAG_REC_WITHOUT_TIMESTAMP = 1;
+               public static final int TAG_WATERMARK = 2;
+
+               private final TypeSerializer<IN> serializer;
+
+               private transient DataOutputSerializer dos;
+
+               KafkaSerializer(TypeSerializer<IN> serializer) {
+                       this.serializer = serializer;
+               }
+
+               /**
+                * Format: TAG, (timestamp), record.
+                */
+               byte[] serializeRecord(IN record, Long timestamp) {
+                       if (dos == null) {
+                               dos = new DataOutputSerializer(16);
+                       }
+
+                       try {
+                               if (timestamp == null) {
+                                       dos.writeInt(TAG_REC_WITHOUT_TIMESTAMP);

Review comment:
       `StreamElementSerializer` is using 1 byte only.
   
   > for now
   
   Once this is out, we cannot change the format easily without breaking setups.
   This actually reminds me that it might be good to prepend a version tag as 
well, so that we actually have a way to change it later.
   
   ```
   |version (byte)|tag (byte)|payload|
   ```

##########
File path: 
flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleExactlyOnceITCase.java
##########
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.connectors.kafka.testutils.FailingIdentityMapper;
+import 
org.apache.flink.streaming.connectors.kafka.testutils.ValidatingExactlyOnceSink;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Failure Recovery IT Test for KafkaShuffle.
+ */
+public class KafkaShuffleExactlyOnceITCase extends KafkaShuffleTestBase {
+
+       @Rule
+       public final Timeout timeout = Timeout.millis(600000L);
+
+       /**
+        * Failure Recovery after processing 2/3 data with time characteristic: 
ProcessingTime.
+        *
+        * <p>Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        */
+       @Test
+       public void testFailureRecoveryProcessingTime() throws Exception {
+               testKafkaShuffleFailureRecovery("failure_recovery", 1000, 
ProcessingTime);
+       }
+
+       /**
+        * Failure Recovery after processing 2/3 data with time characteristic: 
IngestionTime.
+        *
+        * <p>Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        */
+       @Test
+       public void testFailureRecoveryIngestionTime() throws Exception {
+               testKafkaShuffleFailureRecovery("failure_recovery", 1000, 
IngestionTime);
+       }
+
+       /**
+        * Failure Recovery after processing 2/3 data with time characteristic: 
EventTime.
+        *
+        * <p>Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        */
+       @Test
+       public void testFailureRecoveryEventTime() throws Exception {
+               testKafkaShuffleFailureRecovery("failure_recovery", 1000, 
EventTime);
+       }
+
+       /**
+        * Failure Recovery after data is repartitioned with time 
characteristic: ProcessingTime.
+        *
+        * <p>Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        */
+       @Test
+       public void testAssignedToPartitionFailureRecoveryProcessingTime() 
throws Exception {
+               
testAssignedToPartitionFailureRecovery("partition_failure_recovery", 500, 
ProcessingTime);
+       }
+
+       /**
+        * Failure Recovery after data is repartitioned with time 
characteristic: IngestionTime.
+        *
+        * <p>Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        */
+       @Test
+       public void testAssignedToPartitionFailureRecoveryIngestionTime() 
throws Exception {
+               
testAssignedToPartitionFailureRecovery("partition_failure_recovery", 500, 
IngestionTime);
+       }
+
+       /**
+        * Failure Recovery after data is repartitioned with time 
characteristic: EventTime.
+        *
+        * <p>Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        */
+       @Test
+       public void testAssignedToPartitionFailureRecoveryEventTime() throws 
Exception {
+               
testAssignedToPartitionFailureRecovery("partition_failure_recovery", 500, 
EventTime);
+       }
+
+       /**
+        * To test failure recovery after processing 2/3 data.
+        *
+        * <p>Schema: (key, timestamp, source instance Id).
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1
+        */
+       private void testKafkaShuffleFailureRecovery(
+                       String prefix, int numElementsPerProducer, 
TimeCharacteristic timeCharacteristic) throws Exception {
+               String topic = topic(prefix, timeCharacteristic);
+               final int numberOfPartitions = 1;
+               final int producerParallelism = 1;
+               final int failAfterElements = numElementsPerProducer * 
numberOfPartitions * 2 / 3;
+
+               createTestTopic(topic, numberOfPartitions, 1);
+
+               final StreamExecutionEnvironment env =
+                       createEnvironment(producerParallelism, 
timeCharacteristic).enableCheckpointing(500);
+
+               createKafkaShuffle(
+                       env, topic, numElementsPerProducer, 
producerParallelism, timeCharacteristic, numberOfPartitions)
+                       .map(new 
FailingIdentityMapper<>(failAfterElements)).setParallelism(1)
+                       .map(new 
ToInteger(producerParallelism)).setParallelism(1)
+                       .addSink(new 
ValidatingExactlyOnceSink(numElementsPerProducer * 
producerParallelism)).setParallelism(1);
+
+               FailingIdentityMapper.failedBefore = false;
+
+               tryExecute(env, topic);
+
+               deleteTestTopic(topic);
+       }
+
+       /**
+        * To test failure recovery with partition assignment after processing 
2/3 data.
+        *
+        * <p>Schema: (key, timestamp, source instance Id).
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3
+        */
+       private void testAssignedToPartitionFailureRecovery(
+                       String prefix,
+                       int numElementsPerProducer,
+                       TimeCharacteristic timeCharacteristic) throws Exception 
{
+               String topic = topic(prefix, timeCharacteristic);
+               final int numberOfPartitions = 3;
+               final int producerParallelism = 2;
+               final int failAfterElements = numElementsPerProducer * 
producerParallelism * 2 / 3;
+
+               createTestTopic(topic, numberOfPartitions, 1);
+
+               final StreamExecutionEnvironment env = 
createEnvironment(producerParallelism, timeCharacteristic);
+
+               KeyedStream<Tuple3<Integer, Long, Integer>, Tuple> keyedStream 
= createKafkaShuffle(
+                       env,
+                       topic,
+                       numElementsPerProducer,
+                       producerParallelism,
+                       timeCharacteristic,
+                       numberOfPartitions);
+               keyedStream
+                       .process(new 
PartitionValidator(keyedStream.getKeySelector(), numberOfPartitions, topic))
+                       .setParallelism(numberOfPartitions)
+                       .map(new 
ToInteger(producerParallelism)).setParallelism(numberOfPartitions)
+                       .map(new 
FailingIdentityMapper<>(failAfterElements)).setParallelism(1)
+                       .addSink(new 
ValidatingExactlyOnceSink(numElementsPerProducer * 
producerParallelism)).setParallelism(1);
+
+               FailingIdentityMapper.failedBefore = false;
+
+               tryExecute(env, topic);
+
+               deleteTestTopic(topic);
+       }
+
+       private StreamExecutionEnvironment createEnvironment(
+                       int producerParallelism, TimeCharacteristic 
timeCharacteristic) {

Review comment:
       chop

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Experimental;
+import org.apache.flink.api.common.operators.Keys;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair 
together.
+ */
+@Experimental
+class FlinkKafkaShuffle {
+       static final String PRODUCER_PARALLELISM = "producer parallelism";
+       static final String PARTITION_NUMBER = "partition number";
+
+       /**
+        * Writes to and reads from a kafka shuffle with the partition decided 
by keys.
+        *
+        * <p>Consumers should read partitions equal to the key group indices 
they are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param producerParallelism Parallelism of producer
+        * @param numberOfPartitions Number of partitions
+        * @param properties Kafka properties
+        * @param fields Key positions from inputStream
+        * @param <T> Input type
+        */
+       public static <T> KeyedStream<T, Tuple> persistentKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       int producerParallelism,
+                       int numberOfPartitions,
+                       Properties properties,
+                       int... fields) {
+               return persistentKeyBy(
+                       inputStream,
+                       topic,
+                       producerParallelism,
+                       numberOfPartitions,
+                       properties,
+                       keySelector(inputStream, fields));
+       }
+
+       /**
+        * Writes to and reads from a kafka shuffle with the partition decided 
by keys.
+        *
+        * <p>Consumers should read partitions equal to the key group indices 
they are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param producerParallelism Parallelism of producer
+        * @param numberOfPartitions Number of partitions
+        * @param properties Kafka properties
+        * @param keySelector key(K) based on inputStream(T)
+        * @param <T> Input type
+        * @param <K> Key type
+        */
+       public static <T, K> KeyedStream<T, K> persistentKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       int producerParallelism,
+                       int numberOfPartitions,
+                       Properties properties,
+                       KeySelector<T, K> keySelector) {
+               // KafkaProducer#propsToMap uses Properties purely as a HashMap 
without considering the default properties
+               // So we have to flatten the default property to first level 
elements.
+               Properties kafkaProperties = PropertiesUtil.flatten(properties);
+               kafkaProperties.setProperty(PRODUCER_PARALLELISM, 
String.valueOf(producerParallelism));
+               kafkaProperties.setProperty(PARTITION_NUMBER, 
String.valueOf(numberOfPartitions));
+
+               StreamExecutionEnvironment env = 
inputStream.getExecutionEnvironment();
+               TypeInformationSerializationSchema<T> schema =
+                       new 
TypeInformationSerializationSchema<>(inputStream.getType(), env.getConfig());
+
+               writeKeyBy(inputStream, topic, kafkaProperties, keySelector);
+               return readKeyBy(topic, env, schema, kafkaProperties, 
keySelector);
+       }
+
+       /**
+        * Writes to a kafka shuffle with the partition decided by keys.

Review comment:
       What does "Writes to a kafka shuffle" mean? I think you need to add more 
explanation.
   This function also is only meaningful with `readKeyBy` so add a `@see`.

##########
File path: 
flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleTestBase.java
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+
+/**
+ * Base Test Class for KafkaShuffle.
+ */
+public class KafkaShuffleTestBase extends KafkaConsumerTestBase {
+       static final long INIT_TIMESTAMP = System.currentTimeMillis();

Review comment:
       Can't we just use a random static number here? I was thinking that tests 
results might be easier to compare if it's not a real time stamp.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Experimental;
+import org.apache.flink.api.common.operators.Keys;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair 
together.
+ */
+@Experimental
+class FlinkKafkaShuffle {
+       static final String PRODUCER_PARALLELISM = "producer parallelism";
+       static final String PARTITION_NUMBER = "partition number";
+
+       /**
+        * Writes to and reads from a kafka shuffle with the partition decided 
by keys.
+        *
+        * <p>Consumers should read partitions equal to the key group indices 
they are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param producerParallelism Parallelism of producer
+        * @param numberOfPartitions Number of partitions
+        * @param properties Kafka properties
+        * @param fields Key positions from inputStream
+        * @param <T> Input type
+        */
+       public static <T> KeyedStream<T, Tuple> persistentKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       int producerParallelism,
+                       int numberOfPartitions,
+                       Properties properties,
+                       int... fields) {
+               return persistentKeyBy(
+                       inputStream,
+                       topic,
+                       producerParallelism,
+                       numberOfPartitions,
+                       properties,
+                       keySelector(inputStream, fields));
+       }
+
+       /**
+        * Writes to and reads from a kafka shuffle with the partition decided 
by keys.
+        *
+        * <p>Consumers should read partitions equal to the key group indices 
they are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param producerParallelism Parallelism of producer
+        * @param numberOfPartitions Number of partitions
+        * @param properties Kafka properties
+        * @param keySelector key(K) based on inputStream(T)
+        * @param <T> Input type
+        * @param <K> Key type
+        */
+       public static <T, K> KeyedStream<T, K> persistentKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       int producerParallelism,
+                       int numberOfPartitions,
+                       Properties properties,
+                       KeySelector<T, K> keySelector) {
+               // KafkaProducer#propsToMap uses Properties purely as a HashMap 
without considering the default properties
+               // So we have to flatten the default property to first level 
elements.
+               Properties kafkaProperties = PropertiesUtil.flatten(properties);
+               kafkaProperties.setProperty(PRODUCER_PARALLELISM, 
String.valueOf(producerParallelism));
+               kafkaProperties.setProperty(PARTITION_NUMBER, 
String.valueOf(numberOfPartitions));
+
+               StreamExecutionEnvironment env = 
inputStream.getExecutionEnvironment();
+               TypeInformationSerializationSchema<T> schema =
+                       new 
TypeInformationSerializationSchema<>(inputStream.getType(), env.getConfig());
+
+               writeKeyBy(inputStream, topic, kafkaProperties, keySelector);
+               return readKeyBy(topic, env, schema, kafkaProperties, 
keySelector);
+       }
+
+       /**
+        * Writes to a kafka shuffle with the partition decided by keys.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param kafkaProperties Kafka properties
+        * @param fields Key positions from inputStream
+        * @param <T> Input type
+        */
+       public static <T> void writeKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       Properties kafkaProperties,
+                       int... fields) {
+               writeKeyBy(inputStream, topic, kafkaProperties, 
keySelector(inputStream, fields));
+       }
+
+       /**
+        * Writes to a kafka shuffle with the partition decided by keys.
+        *
+        * @param inputStream Input stream to the kafka
+        * @param topic Kafka topic
+        * @param kafkaProperties Kafka properties
+        * @param keySelector Key(K) based on input(T)
+        * @param <T> Input type
+        * @param <K> Key type
+        */
+       public static <T, K> void writeKeyBy(
+                       DataStream<T> inputStream,
+                       String topic,
+                       Properties kafkaProperties,
+                       KeySelector<T, K> keySelector) {
+               StreamExecutionEnvironment env = 
inputStream.getExecutionEnvironment();
+               TypeInformationSerializationSchema<T> schema =
+                       new 
TypeInformationSerializationSchema<>(inputStream.getType(), env.getConfig());
+
+               // write data to Kafka
+               FlinkKafkaShuffleProducer<T, K> kafkaProducer = new 
FlinkKafkaShuffleProducer<>(
+                       topic,
+                       schema.getSerializer(),
+                       kafkaProperties,
+                       env.clean(keySelector),
+                       FlinkKafkaProducer.Semantic.EXACTLY_ONCE,
+                       FlinkKafkaProducer.DEFAULT_KAFKA_PRODUCERS_POOL_SIZE);
+
+               // make sure the sink parallelism is set to producerParallelism
+               Preconditions.checkArgument(
+                       kafkaProperties.getProperty(PRODUCER_PARALLELISM) != 
null,
+                       "Missing producer parallelism for Kafka Shuffle");
+               int producerParallelism = 
PropertiesUtil.getInt(kafkaProperties, PRODUCER_PARALLELISM, Integer.MIN_VALUE);
+
+               addKafkaShuffle(inputStream, kafkaProducer, 
producerParallelism);
+       }
+
+       /**
+        * Reads data from a Kafka Shuffle.
+        *
+        * <p>Consumers should read partitions equal to the key group indices 
they are assigned.
+        * The number of partitions is equivalent to the key group sizes. Each 
consumer task reads from
+        * one or multiple partitions. Any two consumer tasks can not read from 
the same partition.
+        * Hence, the maximum parallelism of the receiving operator is the 
number of partitions.
+        * This version only supports numberOfPartitions = consumerParallelism

Review comment:
       This is a good comment. Something like that should exist also for the 
other public API methods.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> 
{
+       private static final Logger LOG = 
LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+       private final WatermarkHandler watermarkHandler;
+       // 
------------------------------------------------------------------------
+
+       /** The schema to convert between Kafka's byte messages, and Flink's 
objects. */
+       private final KafkaShuffleElementDeserializer<T> deserializer;
+
+       /** Serializer to serialize record. */
+       private final TypeSerializer<T> serializer;
+
+       /** The handover of data and exceptions between the consumer thread and 
the task thread. */
+       private final Handover handover;
+
+       /** The thread that runs the actual KafkaConsumer and hand the record 
batches to this fetcher. */
+       private final KafkaConsumerThread consumerThread;
+
+       /** Flag to mark the main work loop as alive. */
+       private volatile boolean running = true;
+
+       public KafkaShuffleFetcher(
+                       SourceFunction.SourceContext<T> sourceContext,
+                       Map<KafkaTopicPartition, Long> 
assignedPartitionsWithInitialOffsets,
+                       SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
+                       SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
+                       ProcessingTimeService processingTimeProvider,
+                       long autoWatermarkInterval,
+                       ClassLoader userCodeClassLoader,
+                       String taskNameWithSubtasks,
+                       TypeSerializer<T> serializer,
+                       Properties kafkaProperties,
+                       long pollTimeout,
+                       MetricGroup subtaskMetricGroup,
+                       MetricGroup consumerMetricGroup,
+                       boolean useMetrics,
+                       int producerParallelism) throws Exception {
+               super(
+                       sourceContext,
+                       assignedPartitionsWithInitialOffsets,
+                       watermarksPeriodic,
+                       watermarksPunctuated,
+                       processingTimeProvider,
+                       autoWatermarkInterval,
+                       userCodeClassLoader,
+                       consumerMetricGroup,
+                       useMetrics);
+               this.deserializer = new KafkaShuffleElementDeserializer<>();
+               this.serializer = serializer;
+               this.handover = new Handover();
+               this.consumerThread = new KafkaConsumerThread(
+                       LOG,
+                       handover,
+                       kafkaProperties,
+                       unassignedPartitionsQueue,
+                       getFetcherName() + " for " + taskNameWithSubtasks,
+                       pollTimeout,
+                       useMetrics,
+                       consumerMetricGroup,
+                       subtaskMetricGroup);
+               this.watermarkHandler = new 
WatermarkHandler(producerParallelism);
+       }
+
+       // 
------------------------------------------------------------------------
+       //  Fetcher work methods
+       // 
------------------------------------------------------------------------
+
+       @Override
+       public void runFetchLoop() throws Exception {
+               try {
+                       final Handover handover = this.handover;
+
+                       // kick off the actual Kafka consumer
+                       consumerThread.start();
+
+                       while (running) {
+                               // this blocks until we get the next records
+                               // it automatically re-throws exceptions 
encountered in the consumer thread
+                               final ConsumerRecords<byte[], byte[]> records = 
handover.pollNext();
+
+                               // get the records for each topic partition
+                               for (KafkaTopicPartitionState<TopicPartition> 
partition : subscribedPartitionStates()) {
+                                       List<ConsumerRecord<byte[], byte[]>> 
partitionRecords =
+                                               
records.records(partition.getKafkaPartitionHandle());
+
+                                       for (ConsumerRecord<byte[], byte[]> 
record : partitionRecords) {
+                                               final KafkaShuffleElement<T> 
element = deserializer.deserialize(serializer, record);
+
+                                               // TODO: do we need to check 
the end of stream if reaching the end watermark?
+
+                                               if (element.isRecord()) {
+                                                       // timestamp is 
inherent from upstream
+                                                       // If using 
ProcessTime, timestamp is going to be ignored (upstream does not include 
timestamp as well)
+                                                       // If using 
IngestionTime, timestamp is going to be overwritten
+                                                       // If using EventTime, 
timestamp is going to be used
+                                                       synchronized 
(checkpointLock) {
+                                                               
KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+                                                               
sourceContext.collectWithTimestamp(
+                                                                       
elementAsRecord.value,
+                                                                       
elementAsRecord.timestamp == null ? record.timestamp() : 
elementAsRecord.timestamp);
+                                                               
partition.setOffset(record.offset());
+                                                       }
+                                               } else if 
(element.isWatermark()) {
+                                                       final 
KafkaShuffleWatermark watermark = element.asWatermark();
+                                                       Optional<Watermark> 
newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+                                                       
newWatermark.ifPresent(sourceContext::emitWatermark);
+                                               }
+                                       }
+                               }
+                       }
+               }
+               finally {
+                       // this signals the consumer thread that no more work 
is to be done
+                       consumerThread.shutdown();
+               }
+
+               // on a clean exit, wait for the runner thread
+               try {
+                       consumerThread.join();
+               }
+               catch (InterruptedException e) {
+                       // may be the result of a wake-up interruption after an 
exception.
+                       // we ignore this here and only restore the 
interruption state
+                       Thread.currentThread().interrupt();
+               }
+       }
+
+       @Override
+       public void cancel() {
+               // flag the main thread to exit. A thread interrupt will come 
anyways.
+               running = false;
+               handover.close();
+               consumerThread.shutdown();
+       }
+
+       @Override
+       protected TopicPartition createKafkaPartitionHandle(KafkaTopicPartition 
partition) {
+               return new TopicPartition(partition.getTopic(), 
partition.getPartition());
+       }
+
+       @Override
+       protected void doCommitInternalOffsetsToKafka(
+                       Map<KafkaTopicPartition, Long> offsets,
+                       @Nonnull KafkaCommitCallback commitCallback) throws 
Exception {
+               @SuppressWarnings("unchecked")
+               List<KafkaTopicPartitionState<TopicPartition>> partitions = 
subscribedPartitionStates();
+
+               Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new 
HashMap<>(partitions.size());
+
+               for (KafkaTopicPartitionState<TopicPartition> partition : 
partitions) {
+                       Long lastProcessedOffset = 
offsets.get(partition.getKafkaTopicPartition());
+                       if (lastProcessedOffset != null) {
+                               checkState(lastProcessedOffset >= 0, "Illegal 
offset value to commit");
+
+                               // committed offsets through the KafkaConsumer 
need to be 1 more than the last processed offset.
+                               // This does not affect Flink's 
checkpoints/saved state.
+                               long offsetToCommit = lastProcessedOffset + 1;
+
+                               
offsetsToCommit.put(partition.getKafkaPartitionHandle(), new 
OffsetAndMetadata(offsetToCommit));
+                               partition.setCommittedOffset(offsetToCommit);
+                       }
+               }
+
+               // record the work to be committed by the main consumer thread 
and make sure the consumer notices that
+               consumerThread.setOffsetsToCommit(offsetsToCommit, 
commitCallback);
+       }
+
+       private String getFetcherName() {
+               return "Kafka Shuffle Fetcher";
+       }
+
+       /**
+        * An element in a KafkaShuffle. Can be a record or a Watermark.
+        */
+       @VisibleForTesting
+       public abstract static class KafkaShuffleElement<T> {
+
+               public boolean isRecord() {
+                       return getClass() == KafkaShuffleRecord.class;
+               }
+
+               public boolean isWatermark() {
+                       return getClass() == KafkaShuffleWatermark.class;
+               }
+
+               public KafkaShuffleRecord<T> asRecord() {
+                       return (KafkaShuffleRecord<T>) this;
+               }
+
+               public KafkaShuffleWatermark asWatermark() {
+                       return (KafkaShuffleWatermark) this;
+               }
+       }
+
+       /**
+        * A watermark element in a KafkaShuffle. It includes
+        * - subtask index where the watermark is coming from
+        * - watermark timestamp
+        */
+       @VisibleForTesting
+       public static class KafkaShuffleWatermark<T> extends 
KafkaShuffleElement<T> {
+               final int subtask;
+               final long watermark;
+
+               KafkaShuffleWatermark(int subtask, long watermark) {
+                       this.subtask = subtask;
+                       this.watermark = watermark;
+               }
+
+               public int getSubtask() {
+                       return subtask;
+               }
+
+               public long getWatermark() {
+                       return watermark;
+               }
+       }
+
+       /**
+        * One value with Type T in a KafkaShuffle. This stores the value and 
an optional associated timestamp.
+        */
+       @VisibleForTesting
+       public static class KafkaShuffleRecord<T> extends 
KafkaShuffleElement<T> {
+               final T value;
+               final Long timestamp;
+
+               KafkaShuffleRecord(T value) {
+                       this.value = value;
+                       this.timestamp = null;
+               }
+
+               KafkaShuffleRecord(long timestamp, T value) {
+                       this.value = value;
+                       this.timestamp = timestamp;
+               }
+
+               public T getValue() {
+                       return value;
+               }
+
+               public Long getTimestamp() {
+                       return timestamp;
+               }
+       }
+
+       /**
+        * Deserializer for KafkaShuffleElement.
+        * @param <T>
+        */
+       @VisibleForTesting
+       public static class KafkaShuffleElementDeserializer<T> implements 
Serializable {
+               private transient DataInputDeserializer dis;
+
+               @VisibleForTesting
+               public KafkaShuffleElementDeserializer() {
+                       this.dis = new DataInputDeserializer();
+               }
+
+               @VisibleForTesting
+               public KafkaShuffleElement<T> deserialize(TypeSerializer<T> 
serializer, ConsumerRecord<byte[], byte[]> record)
+                       throws Exception {
+                       byte[] value = record.value();
+                       dis.setBuffer(value);

Review comment:
       My guess is that in your test code it's actually never deserialized. 
Just provide `readObject`. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to