This is an automated email from the ASF dual-hosted git repository.

lucasbru pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 4817eb9227b KAFKA-15344: Streams task should cache consumer 
nextOffsets (#17091)
4817eb9227b is described below

commit 4817eb9227bc54225f9d24dcdc496f545538a67e
Author: Alieh Saeedi <[email protected]>
AuthorDate: Tue Oct 29 09:30:11 2024 +0100

    KAFKA-15344: Streams task should cache consumer nextOffsets (#17091)
    
    This PR augments Streams messages with leader epoch. In case of empty 
buffer queues, the last offset and leader epoch are retrieved from the streams 
task 's cache of nextOffsets.
    
    Co-authored-by: Lucas Brutschy <[email protected]>
    Reviewers: Lucas Brutschy <[email protected]>, Matthias J. Sax 
<[email protected]>
---
 .../internals/AbstractPartitionGroup.java          |   3 +
 .../processor/internals/PartitionGroup.java        |  13 +++
 .../streams/processor/internals/ReadOnlyTask.java  |   5 +
 .../processor/internals/RecordDeserializer.java    |   3 +-
 .../streams/processor/internals/RecordQueue.java   |  12 +++
 .../streams/processor/internals/StampedRecord.java |   6 ++
 .../streams/processor/internals/StreamTask.java    |  42 ++++----
 .../streams/processor/internals/StreamThread.java  |   3 +
 .../internals/SynchronizedPartitionGroup.java      |   6 ++
 .../kafka/streams/processor/internals/Task.java    |   3 +
 .../streams/processor/internals/TaskManager.java   |  30 ++++--
 .../processor/internals/PartitionGroupTest.java    |  63 +++++++++--
 .../processor/internals/ReadOnlyTaskTest.java      |   5 +
 .../internals/RecordDeserializerTest.java          |   3 +-
 .../processor/internals/RecordQueueTest.java       |  33 ++++--
 .../processor/internals/StreamTaskTest.java        | 117 +++++++++++++++------
 .../apache/kafka/streams/TopologyTestDriver.java   |   4 +-
 17 files changed, 268 insertions(+), 83 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractPartitionGroup.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractPartitionGroup.java
index 8d98eaf4979..1dadbf496a2 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractPartitionGroup.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractPartitionGroup.java
@@ -19,6 +19,7 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
 
+import java.util.Optional;
 import java.util.Set;
 import java.util.function.Function;
 
@@ -56,6 +57,8 @@ abstract class AbstractPartitionGroup {
 
     abstract Long headRecordOffset(final TopicPartition partition);
 
+    abstract Optional<Integer> headRecordLeaderEpoch(final TopicPartition 
partition);
+
     abstract int numBuffered();
 
     abstract int numBuffered(TopicPartition tp);
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
index ac85a17ca0e..5e57efb9628 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
@@ -30,6 +30,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.Map;
+import java.util.Optional;
 import java.util.OptionalLong;
 import java.util.PriorityQueue;
 import java.util.Set;
@@ -322,6 +323,18 @@ class PartitionGroup extends AbstractPartitionGroup {
         return recordQueue.headRecordOffset();
     }
 
+    @Override
+    Optional<Integer> headRecordLeaderEpoch(final TopicPartition partition) {
+        final RecordQueue recordQueue = partitionQueues.get(partition);
+
+        if (recordQueue == null) {
+            throw new IllegalStateException("Partition " + partition + " not 
found.");
+        }
+
+        return recordQueue.headRecordLeaderEpoch();
+    }
+
+
     /**
      * @throws IllegalStateException if the record's partition does not belong 
to this partition group
      */
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ReadOnlyTask.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ReadOnlyTask.java
index 6402f7a98a4..a895b71e4e9 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ReadOnlyTask.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ReadOnlyTask.java
@@ -149,6 +149,11 @@ public class ReadOnlyTask implements Task {
         throw new UnsupportedOperationException("This task is read-only");
     }
 
+    @Override
+    public void updateNextOffsets(final TopicPartition partition, final 
OffsetAndMetadata offsetAndMetadata) {
+        throw new UnsupportedOperationException("This task is read-only");
+    }
+
     @Override
     public boolean process(final long wallClockTime) {
         throw new UnsupportedOperationException("This task is read-only");
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java
index 1ef72a1714c..5ddafe654e9 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java
@@ -29,7 +29,6 @@ import 
org.apache.kafka.streams.processor.api.ProcessorContext;
 import org.slf4j.Logger;
 
 import java.util.Objects;
-import java.util.Optional;
 
 import static 
org.apache.kafka.streams.StreamsConfig.DESERIALIZATION_EXCEPTION_HANDLER_CLASS_CONFIG;
 
@@ -69,7 +68,7 @@ public class RecordDeserializer {
                 sourceNode.deserializeKey(rawRecord.topic(), 
rawRecord.headers(), rawRecord.key()),
                 sourceNode.deserializeValue(rawRecord.topic(), 
rawRecord.headers(), rawRecord.value()),
                 rawRecord.headers(),
-                Optional.empty()
+                rawRecord.leaderEpoch()
             );
         } catch (final RuntimeException deserializationException) {
             handleDeserializationFailure(deserializationExceptionHandler, 
processorContext, deserializationException, rawRecord, log, 
droppedRecordsSensor, sourceNode().name());
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java
index a6b30a07ef9..ea03b2f8a0e 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java
@@ -30,6 +30,7 @@ import 
org.apache.kafka.streams.processor.internals.metrics.TopicMetrics;
 import org.slf4j.Logger;
 
 import java.util.ArrayDeque;
+import java.util.Optional;
 
 import static 
org.apache.kafka.streams.processor.internals.ClientUtils.consumerRecordSizeInBytes;
 
@@ -181,6 +182,17 @@ public class RecordQueue {
         return headRecord == null ? null : headRecord.offset();
     }
 
+    /**
+     * Returns the leader epoch of the head record if it exists
+     *
+     * @return An Optional containing the leader epoch of the head record, or 
null if the queue is empty. The Optional.empty()
+     * is reserved for the case  when the leader epoch is not set for head 
record of the queue.
+     */
+    @SuppressWarnings("OptionalAssignedToNull")
+    public Optional<Integer> headRecordLeaderEpoch() {
+        return headRecord == null ? null : headRecord.leaderEpoch();
+    }
+
     /**
      * Clear the fifo queue of its elements
      */
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java
index 71e3ca2e3ce..c8ed35a9a8f 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java
@@ -19,6 +19,8 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.header.Headers;
 
+import java.util.Optional;
+
 public class StampedRecord extends Stamped<ConsumerRecord<?, ?>> {
 
     public StampedRecord(final ConsumerRecord<?, ?> record, final long 
timestamp) {
@@ -45,6 +47,10 @@ public class StampedRecord extends Stamped<ConsumerRecord<?, 
?>> {
         return value.offset();
     }
 
+    public Optional<Integer> leaderEpoch() {
+        return value.leaderEpoch();
+    }
+
     public Headers headers() {
         return value.headers();
     }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index 771f8bcd5f0..d0a73a1de95 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -88,6 +88,7 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
     private final RecordCollector recordCollector;
     private final AbstractPartitionGroup.RecordInfo recordInfo;
     private final Map<TopicPartition, Long> consumedOffsets;
+    private final Map<TopicPartition, OffsetAndMetadata> 
nextOffsetsAndMetadataToBeConsumed = new HashMap<>();
     private final Map<TopicPartition, Long> committedOffsets;
     private final Map<TopicPartition, Long> highWatermark;
     private final Set<TopicPartition> resetOffsetsForPartitions;
@@ -462,23 +463,27 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
         }
     }
 
-    private Long findOffset(final TopicPartition partition) {
+    private OffsetAndMetadata findOffsetAndMetadata(final TopicPartition 
partition) {
         Long offset = partitionGroup.headRecordOffset(partition);
+        Optional<Integer> leaderEpoch = 
partitionGroup.headRecordLeaderEpoch(partition);
+        final long partitionTime = 
partitionGroup.partitionTimestamp(partition);
         if (offset == null) {
             try {
-                offset = mainConsumer.position(partition);
-            } catch (final TimeoutException error) {
-                // the `consumer.position()` call should never block, because 
we know that we did process data
-                // for the requested partition and thus the consumer should 
have a valid local position
-                // that it can return immediately
-
-                // hence, a `TimeoutException` indicates a bug and thus we 
rethrow it as fatal `IllegalStateException`
-                throw new IllegalStateException(error);
+                if (nextOffsetsAndMetadataToBeConsumed.containsKey(partition)) 
{
+                    final OffsetAndMetadata offsetAndMetadata = 
nextOffsetsAndMetadataToBeConsumed.get(partition);
+                    offset = offsetAndMetadata.offset();
+                    leaderEpoch = offsetAndMetadata.leaderEpoch();
+                } else {
+                    // This indicates a bug and thus we rethrow it as fatal 
`IllegalStateException`
+                    throw new IllegalStateException("Stream task " + id + " 
does not know the partition: " + partition);
+                }
             } catch (final KafkaException fatal) {
                 throw new StreamsException(fatal);
             }
         }
-        return offset;
+        return new OffsetAndMetadata(offset,
+                leaderEpoch,
+                new TopicPartitionMetadata(partitionTime, 
processorContext.processorMetadata()).encode());
     }
 
     private Map<TopicPartition, OffsetAndMetadata> 
committableOffsetsAndMetadata() {
@@ -493,7 +498,6 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
 
             case RUNNING:
             case SUSPENDED:
-                final Map<TopicPartition, Long> partitionTimes = 
extractPartitionTimes();
 
                 // If there's processor metadata to be committed. We need to 
commit them to all
                 // input partitions
@@ -502,10 +506,7 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
                 committableOffsets = new 
HashMap<>(partitionsNeedCommit.size());
 
                 for (final TopicPartition partition : partitionsNeedCommit) {
-                    final Long offset = findOffset(partition);
-                    final long partitionTime = partitionTimes.get(partition);
-                    committableOffsets.put(partition, new 
OffsetAndMetadata(offset,
-                        new TopicPartitionMetadata(partitionTime, 
processorContext.processorMetadata()).encode()));
+                    committableOffsets.put(partition, 
findOffsetAndMetadata(partition));
                 }
                 break;
 
@@ -561,13 +562,6 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
         processorContext.processorMetadata().setNeedsCommit(false);
     }
 
-    private Map<TopicPartition, Long> extractPartitionTimes() {
-        final Map<TopicPartition, Long> partitionTimes = new HashMap<>();
-        for (final TopicPartition partition : partitionGroup.partitions()) {
-            partitionTimes.put(partition, 
partitionGroup.partitionTimestamp(partition));
-        }
-        return partitionTimes;
-    }
 
     @Override
     public void closeClean() {
@@ -1125,6 +1119,10 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
         }
     }
 
+    public void updateNextOffsets(final TopicPartition partition, final 
OffsetAndMetadata offsetAndMetadata) {
+        nextOffsetsAndMetadataToBeConsumed.put(partition, offsetAndMetadata);
+    }
+
     /**
      * Schedules a punctuation for the processor
      *
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 28815c91e28..e45021f25c3 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -1240,6 +1240,9 @@ public class StreamThread extends Thread implements 
ProcessingThread {
             pollRecordsSensor.record(numRecords, now);
             taskManager.addRecordsToTasks(records);
         }
+        if (!records.nextOffsets().isEmpty()) {
+            taskManager.updateNextOffsets(records.nextOffsets());
+        }
 
         while (!nonFatalExceptionsToHandle.isEmpty()) {
             
streamsUncaughtExceptionHandler.accept(nonFatalExceptionsToHandle.poll(), true);
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroup.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroup.java
index 48e425ff1de..cee6442b663 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroup.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroup.java
@@ -19,6 +19,7 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
 
+import java.util.Optional;
 import java.util.Set;
 import java.util.function.Function;
 
@@ -70,6 +71,11 @@ class SynchronizedPartitionGroup extends 
AbstractPartitionGroup {
         return wrapped.headRecordOffset(partition);
     }
 
+    @Override
+    Optional<Integer> headRecordLeaderEpoch(final TopicPartition partition) {
+        return Optional.empty();
+    }
+
     @Override
     synchronized int numBuffered() {
         return wrapped.numBuffered();
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
index f4bf7c3db33..484c1ca574b 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
@@ -177,6 +177,9 @@ public interface Task {
 
     void addRecords(TopicPartition partition, Iterable<ConsumerRecord<byte[], 
byte[]>> records);
 
+    default void updateNextOffsets(final TopicPartition partition, final 
OffsetAndMetadata offsetAndMetadata) {
+    }
+
     default boolean process(final long wallClockTime) {
         return false;
     }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 0d222b61600..5384b6a72ba 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -1805,16 +1805,32 @@ public class TaskManager {
      */
     void addRecordsToTasks(final ConsumerRecords<byte[], byte[]> records) {
         for (final TopicPartition partition : records.partitions()) {
-            final Task activeTask = 
tasks.activeTasksForInputPartition(partition);
+            final Task activeTask = getActiveTask(partition);
+            activeTask.addRecords(partition, records.records(partition));
+        }
+    }
 
-            if (activeTask == null) {
-                log.error("Unable to locate active task for received-record 
partition {}. Current tasks: {}",
-                    partition, toString(">"));
-                throw new NullPointerException("Task was unexpectedly missing 
for partition " + partition);
-            }
+    /**
+     * Update the next offsets for each task
+     *
+     * @param nextOffsets A map of offsets keyed by partition
+     */
+    void updateNextOffsets(final Map<TopicPartition, OffsetAndMetadata> 
nextOffsets) {
+        for (final Map.Entry<TopicPartition, OffsetAndMetadata> entry : 
nextOffsets.entrySet()) {
+            final Task activeTask = getActiveTask(entry.getKey());
+            activeTask.updateNextOffsets(entry.getKey(), entry.getValue());
+        }
+    }
 
-            activeTask.addRecords(partition, records.records(partition));
+    private Task getActiveTask(final TopicPartition partition) {
+        final Task activeTask = tasks.activeTasksForInputPartition(partition);
+
+        if (activeTask == null) {
+            log.error("Unable to locate active task for received-record 
partition {}. Current tasks: {}",
+                partition, toString(">"));
+            throw new NullPointerException("Task was unexpectedly missing for 
partition " + partition);
         }
+        return activeTask;
     }
 
     private void maybeLockTasks(final Set<TaskId> ids) {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
index 704c20ff98e..1f4fec19484 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
@@ -19,9 +19,11 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.metrics.stats.Value;
+import org.apache.kafka.common.record.TimestampType;
 import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.IntegerDeserializer;
 import org.apache.kafka.common.serialization.IntegerSerializer;
@@ -45,6 +47,7 @@ import org.junit.jupiter.api.Test;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Optional;
 import java.util.OptionalLong;
 import java.util.Set;
 import java.util.UUID;
@@ -133,24 +136,38 @@ public class PartitionGroupTest {
         return new TopicPartition(topics[0], 2);
     }
 
+    private ConsumerRecord<byte[], byte[]> createConsumerRecord(
+            final int partition,
+            final long offset,
+            final byte[] key,
+            final byte[] value,
+            final int leaderEpoch
+    ) {
+        return new ConsumerRecord<>(topics[0], partition, offset, 
ConsumerRecord.NO_TIMESTAMP,
+                TimestampType.NO_TIMESTAMP_TYPE, ConsumerRecord.NULL_SIZE, 
ConsumerRecord.NULL_SIZE,
+                key, value, new RecordHeaders(), Optional.of(leaderEpoch));
+    }
+
     private void testFirstBatch(final PartitionGroup group) {
         StampedRecord record;
         final PartitionGroup.RecordInfo info = new RecordInfo();
         assertThat(group.numBuffered(), is(0));
+        assertNull(group.headRecordLeaderEpoch(partition1));
+        assertNull(group.headRecordLeaderEpoch(partition2));
 
         // add three 3 records with timestamp 1, 3, 5 to partition-1
         final List<ConsumerRecord<byte[], byte[]>> list1 = Arrays.asList(
-                new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue),
-                new ConsumerRecord<>("topic", 1, 3L, recordKey, recordValue),
-                new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue));
+                createConsumerRecord(1, 1L, recordKey, recordValue, 0),
+                createConsumerRecord(1, 3L, recordKey, recordValue, 0),
+                createConsumerRecord(1, 5L, recordKey, recordValue, 2));
 
         group.addRawRecords(partition1, list1);
 
         // add three 3 records with timestamp 2, 4, 6 to partition-2
         final List<ConsumerRecord<byte[], byte[]>> list2 = Arrays.asList(
-                new ConsumerRecord<>("topic", 2, 2L, recordKey, recordValue),
-                new ConsumerRecord<>("topic", 2, 4L, recordKey, recordValue),
-                new ConsumerRecord<>("topic", 2, 6L, recordKey, recordValue));
+                createConsumerRecord(2, 2L, recordKey, recordValue, 1),
+                createConsumerRecord(2, 4L, recordKey, recordValue, 4),
+                createConsumerRecord(2, 6L, recordKey, recordValue, 4));
 
         group.addRawRecords(partition2, list2);
         // 1:[1, 3, 5]
@@ -162,6 +179,8 @@ public class PartitionGroupTest {
         assertThat(group.partitionTimestamp(partition2), 
is(RecordQueue.UNKNOWN));
         assertThat(group.headRecordOffset(partition1), is(1L));
         assertThat(group.headRecordOffset(partition2), is(2L));
+        assertThat(group.headRecordLeaderEpoch(partition1), 
is(Optional.of(0)));
+        assertThat(group.headRecordLeaderEpoch(partition2), 
is(Optional.of(1)));
         assertThat(group.streamTime(), is(RecordQueue.UNKNOWN));
         assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0));
 
@@ -175,6 +194,8 @@ public class PartitionGroupTest {
         assertThat(group.partitionTimestamp(partition2), 
is(RecordQueue.UNKNOWN));
         assertThat(group.headRecordOffset(partition1), is(3L));
         assertThat(group.headRecordOffset(partition2), is(2L));
+        assertThat(group.headRecordLeaderEpoch(partition1), 
is(Optional.of(0)));
+        assertThat(group.headRecordLeaderEpoch(partition2), 
is(Optional.of(1)));
         verifyTimes(record, 1L, 1L, group);
         assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0));
 
@@ -188,6 +209,8 @@ public class PartitionGroupTest {
         assertThat(group.partitionTimestamp(partition2), is(2L));
         assertThat(group.headRecordOffset(partition1), is(3L));
         assertThat(group.headRecordOffset(partition2), is(4L));
+        assertThat(group.headRecordLeaderEpoch(partition1), 
is(Optional.of(0)));
+        assertThat(group.headRecordLeaderEpoch(partition2), 
is(Optional.of(4)));
         verifyTimes(record, 2L, 2L, group);
         verifyBuffered(4, 2, 2, group);
         assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue());
@@ -199,8 +222,8 @@ public class PartitionGroupTest {
 
         // add 2 more records with timestamp 2, 4 to partition-1
         final List<ConsumerRecord<byte[], byte[]>> list3 = Arrays.asList(
-                new ConsumerRecord<>("topic", 1, 2L, recordKey, recordValue),
-                new ConsumerRecord<>("topic", 1, 4L, recordKey, recordValue));
+                createConsumerRecord(1, 2L, recordKey, recordValue, 5),
+                createConsumerRecord(1, 4L, recordKey, recordValue, 6));
 
         group.addRawRecords(partition1, list3);
         // 1:[3, 5, 2, 4]
@@ -211,6 +234,8 @@ public class PartitionGroupTest {
         assertThat(group.partitionTimestamp(partition2), is(2L));
         assertThat(group.headRecordOffset(partition1), is(3L));
         assertThat(group.headRecordOffset(partition2), is(4L));
+        assertThat(group.headRecordLeaderEpoch(partition1), 
is(Optional.of(0)));
+        assertThat(group.headRecordLeaderEpoch(partition2), 
is(Optional.of(4)));
         assertThat(group.streamTime(), is(2L));
         assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0));
 
@@ -224,6 +249,8 @@ public class PartitionGroupTest {
         assertThat(group.partitionTimestamp(partition2), is(2L));
         assertThat(group.headRecordOffset(partition1), is(5L));
         assertThat(group.headRecordOffset(partition2), is(4L));
+        assertThat(group.headRecordLeaderEpoch(partition1), 
is(Optional.of(2)));
+        assertThat(group.headRecordLeaderEpoch(partition2), 
is(Optional.of(4)));
         verifyTimes(record, 3L, 3L, group);
         verifyBuffered(5, 3, 2, group);
         assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0));
@@ -238,6 +265,8 @@ public class PartitionGroupTest {
         assertThat(group.partitionTimestamp(partition2), is(4L));
         assertThat(group.headRecordOffset(partition1), is(5L));
         assertThat(group.headRecordOffset(partition2), is(6L));
+        assertThat(group.headRecordLeaderEpoch(partition1), 
is(Optional.of(2)));
+        assertThat(group.headRecordLeaderEpoch(partition2), 
is(Optional.of(4)));
         verifyTimes(record, 4L, 4L, group);
         verifyBuffered(4, 3, 1, group);
         assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0));
@@ -252,6 +281,8 @@ public class PartitionGroupTest {
         assertThat(group.partitionTimestamp(partition2), is(4L));
         assertThat(group.headRecordOffset(partition1), is(2L));
         assertThat(group.headRecordOffset(partition2), is(6L));
+        assertThat(group.headRecordLeaderEpoch(partition1), 
is(Optional.of(5)));
+        assertThat(group.headRecordLeaderEpoch(partition2), 
is(Optional.of(4)));
         verifyTimes(record, 5L, 5L, group);
         verifyBuffered(3, 2, 1, group);
         assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0));
@@ -266,6 +297,8 @@ public class PartitionGroupTest {
         assertThat(group.partitionTimestamp(partition2), is(4L));
         assertThat(group.headRecordOffset(partition1), is(4L));
         assertThat(group.headRecordOffset(partition2), is(6L));
+        assertThat(group.headRecordLeaderEpoch(partition1), 
is(Optional.of(6)));
+        assertThat(group.headRecordLeaderEpoch(partition2), 
is(Optional.of(4)));
         verifyTimes(record, 2L, 5L, group);
         verifyBuffered(2, 1, 1, group);
         assertThat(metrics.metric(lastLatenessValue).metricValue(), is(3.0));
@@ -280,6 +313,8 @@ public class PartitionGroupTest {
         assertThat(group.partitionTimestamp(partition2), is(4L));
         assertNull(group.headRecordOffset(partition1));
         assertThat(group.headRecordOffset(partition2), is(6L));
+        assertThat(group.headRecordLeaderEpoch(partition1), is(nullValue()));
+        assertThat(group.headRecordLeaderEpoch(partition2), 
is(Optional.of(4)));
         verifyTimes(record, 4L, 5L, group);
         verifyBuffered(1, 0, 1, group);
         assertThat(metrics.metric(lastLatenessValue).metricValue(), is(1.0));
@@ -294,6 +329,8 @@ public class PartitionGroupTest {
         assertThat(group.partitionTimestamp(partition2), is(6L));
         assertNull(group.headRecordOffset(partition1));
         assertNull(group.headRecordOffset(partition2));
+        assertNull(group.headRecordLeaderEpoch(partition1));
+        assertNull(group.headRecordLeaderEpoch(partition2));
         verifyTimes(record, 6L, 6L, group);
         verifyBuffered(0, 0, 0, group);
         assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0));
@@ -428,6 +465,16 @@ public class PartitionGroupTest {
         assertThat(errMessage, equalTo(exception.getMessage()));
     }
 
+    @Test
+    public void 
shouldThrowIllegalStateExceptionUponHeadRecordLeaderEpochIfPartitionUnknown() {
+        final PartitionGroup group = getBasicGroup();
+
+        final IllegalStateException exception = assertThrows(
+                IllegalStateException.class,
+                () -> group.headRecordLeaderEpoch(unknownPartition));
+        assertThat(errMessage, equalTo(exception.getMessage()));
+    }
+
     @Test
     public void shouldEmptyPartitionsOnClear() {
         final PartitionGroup group =
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ReadOnlyTaskTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ReadOnlyTaskTest.java
index 6d9984b380c..896e84baf5b 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ReadOnlyTaskTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ReadOnlyTaskTest.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.processor.TaskId;
@@ -27,6 +28,7 @@ import java.lang.reflect.Method;
 import java.util.Collections;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Optional;
 import java.util.Set;
 import java.util.function.Consumer;
 
@@ -204,6 +206,9 @@ class ReadOnlyTaskTest {
                 case "org.apache.kafka.common.TopicPartition":
                     parameters[i] = new TopicPartition("topic", 0);
                     break;
+                case "org.apache.kafka.clients.consumer.OffsetAndMetadata":
+                    parameters[i] = new OffsetAndMetadata(0, Optional.empty(), 
"");
+                    break;
                 case "java.lang.Exception":
                     parameters[i] = new IllegalStateException();
                     break;
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordDeserializerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordDeserializerTest.java
index 906422bcfeb..aa3bb57c7a6 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordDeserializerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordDeserializerTest.java
@@ -58,7 +58,7 @@ public class RecordDeserializerTest {
         new byte[0],
         new byte[0],
         headers,
-        Optional.empty());
+        Optional.of(5));
 
     private final InternalProcessorContext<Void, Void> context = new 
InternalMockProcessorContext<>();
 
@@ -86,6 +86,7 @@ public class RecordDeserializerTest {
             assertEquals(rawRecord.timestamp(), record.timestamp());
             assertEquals(TimestampType.CREATE_TIME, record.timestampType());
             assertEquals(rawRecord.headers(), record.headers());
+            assertEquals(rawRecord.leaderEpoch(), record.leaderEpoch());
         }
     }
 
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
index 076aeef7939..a55fe098608 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
@@ -171,79 +171,90 @@ public class RecordQueueTest {
 
         // add three 3 out-of-order records with timestamp 2, 1, 3
         final List<ConsumerRecord<byte[], byte[]>> list1 = Arrays.asList(
-            new ConsumerRecord<>("topic", 1, 2, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()),
-            new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()),
-            new ConsumerRecord<>("topic", 1, 3, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()));
+            new ConsumerRecord<>("topic", 1, 2, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.of(1)),
+            new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.of(1)),
+            new ConsumerRecord<>("topic", 1, 3, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.of(2)));
+
 
         queue.addRawRecords(list1);
 
         assertEquals(3, queue.size());
         assertEquals(2L, queue.headRecordTimestamp());
         assertEquals(2L, queue.headRecordOffset().longValue());
+        assertEquals(Optional.of(1), queue.headRecordLeaderEpoch());
 
         // poll the first record, now with 1, 3
         assertEquals(2L, queue.poll(0).timestamp);
         assertEquals(2, queue.size());
         assertEquals(1L, queue.headRecordTimestamp());
         assertEquals(1L, queue.headRecordOffset().longValue());
+        assertEquals(Optional.of(1), queue.headRecordLeaderEpoch());
 
         // poll the second record, now with 3
         assertEquals(1L, queue.poll(0).timestamp);
         assertEquals(1, queue.size());
         assertEquals(3L, queue.headRecordTimestamp());
         assertEquals(3L, queue.headRecordOffset().longValue());
+        assertEquals(Optional.of(2), queue.headRecordLeaderEpoch());
 
         // add three 3 out-of-order records with timestamp 4, 1, 2
         // now with 3, 4, 1, 2
         final List<ConsumerRecord<byte[], byte[]>> list2 = Arrays.asList(
-            new ConsumerRecord<>("topic", 1, 4, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()),
-            new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()),
-            new ConsumerRecord<>("topic", 1, 2, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()));
+            new ConsumerRecord<>("topic", 1, 4, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.of(2)),
+            new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.of(1)),
+            new ConsumerRecord<>("topic", 1, 2, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.of(1)));
 
         queue.addRawRecords(list2);
 
         assertEquals(4, queue.size());
         assertEquals(3L, queue.headRecordTimestamp());
         assertEquals(3L, queue.headRecordOffset().longValue());
+        assertEquals(Optional.of(2), queue.headRecordLeaderEpoch());
 
         // poll the third record, now with 4, 1, 2
         assertEquals(3L, queue.poll(0).timestamp);
         assertEquals(3, queue.size());
         assertEquals(4L, queue.headRecordTimestamp());
         assertEquals(4L, queue.headRecordOffset().longValue());
+        assertEquals(Optional.of(2), queue.headRecordLeaderEpoch());
 
         // poll the rest records
         assertEquals(4L, queue.poll(0).timestamp);
         assertEquals(1L, queue.headRecordTimestamp());
         assertEquals(1L, queue.headRecordOffset().longValue());
+        assertEquals(Optional.of(1), queue.headRecordLeaderEpoch());
 
         assertEquals(1L, queue.poll(0).timestamp);
         assertEquals(2L, queue.headRecordTimestamp());
         assertEquals(2L, queue.headRecordOffset().longValue());
+        assertEquals(Optional.of(1), queue.headRecordLeaderEpoch());
+
 
         assertEquals(2L, queue.poll(0).timestamp);
         assertTrue(queue.isEmpty());
         assertEquals(0, queue.size());
         assertEquals(RecordQueue.UNKNOWN, queue.headRecordTimestamp());
         assertNull(queue.headRecordOffset());
+        assertNull(queue.headRecordLeaderEpoch());
 
         // add three more records with 4, 5, 6
         final List<ConsumerRecord<byte[], byte[]>> list3 = Arrays.asList(
-            new ConsumerRecord<>("topic", 1, 4, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()),
-            new ConsumerRecord<>("topic", 1, 5, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()),
-            new ConsumerRecord<>("topic", 1, 6, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()));
+            new ConsumerRecord<>("topic", 1, 4, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.of(2)),
+            new ConsumerRecord<>("topic", 1, 5, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.of(3)),
+            new ConsumerRecord<>("topic", 1, 6, 0L, TimestampType.CREATE_TIME, 
0, 0, recordKey, recordValue, new RecordHeaders(), Optional.of(3)));
 
         queue.addRawRecords(list3);
 
         assertEquals(3, queue.size());
         assertEquals(4L, queue.headRecordTimestamp());
         assertEquals(4L, queue.headRecordOffset().longValue());
+        assertEquals(Optional.of(2), queue.headRecordLeaderEpoch());
 
         // poll one record again, the timestamp should advance now
         assertEquals(4L, queue.poll(0).timestamp);
         assertEquals(2, queue.size());
         assertEquals(5L, queue.headRecordTimestamp());
-        assertEquals(5L, queue.headRecordOffset().longValue());
+        assertEquals(Optional.of(3), queue.headRecordLeaderEpoch());
 
         // clear the queue
         queue.clear();
@@ -252,6 +263,7 @@ public class RecordQueueTest {
         assertEquals(RecordQueue.UNKNOWN, queue.headRecordTimestamp());
         assertEquals(RecordQueue.UNKNOWN, queue.partitionTime());
         assertNull(queue.headRecordOffset());
+        assertNull(queue.headRecordLeaderEpoch());
 
         // re-insert the three records with 4, 5, 6
         queue.addRawRecords(list3);
@@ -259,6 +271,7 @@ public class RecordQueueTest {
         assertEquals(3, queue.size());
         assertEquals(4L, queue.headRecordTimestamp());
         assertEquals(4L, queue.headRecordOffset().longValue());
+        assertEquals(Optional.of(2), queue.headRecordLeaderEpoch());
     }
 
     @Test
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index f531d9fbaa2..9cda2b9ee7c 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -1213,7 +1213,11 @@ public class StreamTaskTest {
 
         assertFalse(task.commitNeeded());
 
-        task.addRecords(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 0)));
+        final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> record 
= mkMap(
+                mkEntry(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 0)))
+        );
+        task.addRecords(partition1, record.get(partition1));
+        task.updateNextOffsets(partition1, new OffsetAndMetadata(1, 
Optional.empty(), ""));
         assertTrue(task.process(0L));
         assertTrue(task.commitNeeded());
 
@@ -1254,9 +1258,9 @@ public class StreamTaskTest {
         task.completeRestoration(noOpResetter -> { });
 
         task.addRecords(partition1, asList(
-            getConsumerRecordWithOffsetAsTimestamp(partition1, 0L),
-            getConsumerRecordWithOffsetAsTimestamp(partition1, 3L),
-            getConsumerRecordWithOffsetAsTimestamp(partition1, 5L)));
+            getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition1, 
0L, 1),
+            getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition1, 
3L, 1),
+            getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition1, 
5L, 2)));
 
         task.process(0L);
         processorStreamTime.mockProcessor.addProcessorMetadata("key1", 100L);
@@ -1273,37 +1277,41 @@ public class StreamTaskTest {
             )
         );
 
-        assertThat(offsetsAndMetadata, equalTo(mkMap(mkEntry(partition1, new 
OffsetAndMetadata(5L, expected.encode())))));
+        assertThat(offsetsAndMetadata, equalTo(mkMap(mkEntry(partition1, new 
OffsetAndMetadata(5L, Optional.of(2), expected.encode())))));
     }
 
     @Test
-    public void shouldCommitConsumerPositionIfRecordQueueIsEmpty() {
+    public void shouldCommitFetchedNextOffsetIfRecordQueueIsEmpty() {
         when(stateManager.taskId()).thenReturn(taskId);
         when(stateManager.taskType()).thenReturn(TaskType.ACTIVE);
         task = createStatelessTask(createConfig());
         task.initializeIfNeeded();
         task.completeRestoration(noOpResetter -> { });
 
-        consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition1, 
0L));
-        consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition1, 
1L));
-        consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition1, 
2L));
-        consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition2, 
0L));
+        
consumer.addRecord(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition1,
 0L, 0));
+        
consumer.addRecord(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition1,
 1L, 1));
+        
consumer.addRecord(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition1,
 2L, 2));
+        
consumer.addRecord(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition2,
 0L, 0));
         consumer.poll(Duration.ZERO);
 
-        task.addRecords(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 0L)));
-        task.addRecords(partition2, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition2, 0L)));
+        task.addRecords(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition1, 
0L, 0)));
+        task.addRecords(partition2, 
singletonList(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition2, 
0L, 0)));
+
+        task.updateNextOffsets(partition1, new OffsetAndMetadata(3, 
Optional.of(2), ""));
+        task.updateNextOffsets(partition2, new OffsetAndMetadata(1, 
Optional.of(0), ""));
+
         task.process(0L);
 
+
         final TopicPartitionMetadata metadata = new TopicPartitionMetadata(0, 
new ProcessorMetadata());
 
         assertTrue(task.commitNeeded());
         assertThat(task.prepareCommit(), equalTo(
-            mkMap(
-                mkEntry(partition1,
-                    new OffsetAndMetadata(3L, metadata.encode())
+                mkMap(
+                        mkEntry(partition1, new OffsetAndMetadata(3L, 
Optional.of(2), metadata.encode()))
                 )
-            )
         ));
+
         task.postCommit(false);
 
         // the task should still be committed since the processed records have 
not reached the consumer position
@@ -1317,8 +1325,8 @@ public class StreamTaskTest {
         assertTrue(task.commitNeeded());
         assertThat(task.prepareCommit(), equalTo(
             mkMap(
-                mkEntry(partition1, new OffsetAndMetadata(3L, 
metadata.encode())),
-                mkEntry(partition2, new OffsetAndMetadata(1L, 
metadata.encode()))
+                mkEntry(partition1, new OffsetAndMetadata(3L, Optional.of(2), 
metadata.encode())),
+                mkEntry(partition2, new OffsetAndMetadata(1L, Optional.of(0), 
metadata.encode()))
             )
         ));
         task.postCommit(false);
@@ -1336,14 +1344,17 @@ public class StreamTaskTest {
 
         task.resumePollingForPartitionsWithAvailableSpace();
 
-        consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition1, 
0L));
-        consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition1, 
1L));
-        consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition2, 
0L));
-        consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition2, 
1L));
+        
consumer.addRecord(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition1,
 0L, 0));
+        
consumer.addRecord(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition1,
 1L, 1));
+        
consumer.addRecord(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition2,
 0L, 0));
+        
consumer.addRecord(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition2,
 1L, 1));
         consumer.poll(Duration.ZERO);
 
-        task.addRecords(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 0L)));
-        task.addRecords(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 1L)));
+        task.addRecords(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition1, 
0L, 0)));
+        task.addRecords(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(partition1, 
1L, 1)));
+
+        task.updateNextOffsets(partition1, new OffsetAndMetadata(2, 
Optional.of(1), ""));
+        task.updateNextOffsets(partition2, new OffsetAndMetadata(2, 
Optional.of(1), ""));
 
         task.updateLags();
 
@@ -1370,8 +1381,8 @@ public class StreamTaskTest {
 
         assertThat(task.prepareCommit(), equalTo(
             mkMap(
-                mkEntry(partition1, new OffsetAndMetadata(1L, 
expectedMetadata1.encode())),
-                mkEntry(partition2, new OffsetAndMetadata(2L, 
expectedMetadata2.encode()))
+                mkEntry(partition1, new OffsetAndMetadata(1L,  Optional.of(1), 
expectedMetadata1.encode())),
+                mkEntry(partition2, new OffsetAndMetadata(2L, Optional.of(1), 
expectedMetadata2.encode()))
             )));
         task.postCommit(false);
 
@@ -1392,7 +1403,7 @@ public class StreamTaskTest {
 
         // Processor metadata not updated, we just need to commit to 
partition1 again with new offset
         assertThat(task.prepareCommit(), equalTo(
-            mkMap(mkEntry(partition1, new OffsetAndMetadata(2L, 
expectedMetadata3.encode())))
+                mkMap(mkEntry(partition1, new OffsetAndMetadata(2L, 
Optional.of(1), expectedMetadata3.encode())))
         ));
         task.postCommit(false);
 
@@ -1877,8 +1888,17 @@ public class StreamTaskTest {
         task.initializeIfNeeded();
         task.completeRestoration(noOpResetter -> { });
 
-        task.addRecords(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 5L)));
-        task.addRecords(repartition, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(repartition, 10L)));
+
+        final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> 
records = mkMap(
+                mkEntry(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 5L))),
+                mkEntry(repartition, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(repartition, 10L)))
+        );
+
+        task.addRecords(partition1, records.get(partition1));
+        task.addRecords(repartition, records.get(repartition));
+
+        task.updateNextOffsets(partition1, new OffsetAndMetadata(6, 
Optional.empty(), ""));
+        task.updateNextOffsets(repartition, new OffsetAndMetadata(11, 
Optional.empty(), ""));
 
         task.resumePollingForPartitionsWithAvailableSpace();
         task.updateLags();
@@ -2137,7 +2157,11 @@ public class StreamTaskTest {
         task.initializeIfNeeded();
         task.completeRestoration(noOpResetter -> { });
 
-        task.addRecords(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 
consumedOffset)));
+        final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> record 
= mkMap(
+                mkEntry(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 5L)))
+        );
+        task.addRecords(partition1, record.get(partition1));
+        task.updateNextOffsets(partition1, new OffsetAndMetadata(6, 
Optional.empty(), ""));
         task.process(100L);
         assertTrue(task.commitNeeded());
 
@@ -2164,7 +2188,10 @@ public class StreamTaskTest {
         task.initializeIfNeeded();
         task.completeRestoration(noOpResetter -> { }); // should checkpoint
 
-        task.addRecords(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, offset)));
+        final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> record 
= mkMap(
+                mkEntry(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, offset))));
+        task.addRecords(partition1, record.get(partition1));
+        task.updateNextOffsets(partition1, new OffsetAndMetadata(offset + 1, 
Optional.empty(), ""));
         task.process(100L);
         assertTrue(task.commitNeeded());
 
@@ -2222,7 +2249,11 @@ public class StreamTaskTest {
         task = createOptimizedStatefulTask(createConfig("100"), consumer);
         task.initializeIfNeeded();
 
-        task.addRecords(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, offset)));
+        final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> record 
= mkMap(
+                mkEntry(partition1, 
singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, offset))));
+        task.addRecords(partition1, record.get(partition1));
+        task.updateNextOffsets(partition1, new OffsetAndMetadata(offset + 1, 
Optional.empty(), ""));
+
         task.process(100L);
         assertTrue(task.commitNeeded());
 
@@ -2526,6 +2557,7 @@ public class StreamTaskTest {
         task.resumePollingForPartitionsWithAvailableSpace();
         consumer.poll(Duration.ZERO);
         task.addRecords(partition1, records);
+        task.updateNextOffsets(partition1, new OffsetAndMetadata(offset + 1, 
Optional.empty(), ""));
         task.updateLags();
 
         assertTrue(task.process(offset));
@@ -2560,6 +2592,7 @@ public class StreamTaskTest {
         task.resumePollingForPartitionsWithAvailableSpace();
         consumer.poll(Duration.ZERO);
         task.addRecords(partition1, records);
+        task.updateNextOffsets(partition1, new OffsetAndMetadata(offset + 1, 
Optional.empty(), ""));
         task.updateLags();
 
         assertTrue(task.process(offset));
@@ -2591,6 +2624,8 @@ public class StreamTaskTest {
         task.resumePollingForPartitionsWithAvailableSpace();
         consumer.poll(Duration.ZERO);
         task.addRecords(partition1, records);
+        task.updateNextOffsets(partition1, new OffsetAndMetadata(offset + 1, 
Optional.empty(), ""));
+
         task.updateLags();
 
         assertTrue(task.process(offset));
@@ -3167,6 +3202,24 @@ public class StreamTaskTest {
         );
     }
 
+    private ConsumerRecord<byte[], byte[]> 
getConsumerRecordWithOffsetAsTimestampWithLeaderEpoch(final TopicPartition 
topicPartition,
+                                                                               
                  final long offset,
+                                                                               
                  final int leaderEpoch) {
+        return new ConsumerRecord<>(
+                topicPartition.topic(),
+                topicPartition.partition(),
+                offset,
+                offset, // use the offset as the timestamp
+                TimestampType.CREATE_TIME,
+                0,
+                0,
+                recordKey,
+                recordValue,
+                new RecordHeaders(),
+                Optional.of(leaderEpoch)
+        );
+    }
+
     private ConsumerRecord<byte[], byte[]> 
getCorruptedConsumerRecordWithOffsetAsTimestamp(final long offset) {
         return new ConsumerRecord<>(
             topic1,
diff --git 
a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
 
b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
index 6c797a0a280..5f9f140d088 100644
--- 
a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
+++ 
b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
@@ -511,7 +511,9 @@ public class TopologyTestDriver implements Closeable {
             task.initializeIfNeeded();
             task.completeRestoration(noOpResetter -> { });
             task.processorContext().setRecordContext(null);
-
+            for (final TopicPartition tp: task.inputPartitions()) {
+                task.updateNextOffsets(tp, new OffsetAndMetadata(0, 
Optional.empty(), ""));
+            }
         } else {
             task = null;
         }

Reply via email to