frankvicky commented on code in PR #20292:
URL: https://github.com/apache/kafka/pull/20292#discussion_r2610811787


##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java:
##########
@@ -113,6 +115,22 @@ public static Sensor processLatencySensor(final String 
threadId,
         );
     }
 
+    public static Sensor totalInputBufferBytesSensor(final String threadId,
+        final String taskId,
+        final StreamsMetricsImpl streamsMetrics) {

Review Comment:
   style consistent nit:
   ```suggestion
       public static Sensor totalInputBufferBytesSensor(final String threadId,
                                                        final String taskId,
                                                        final 
StreamsMetricsImpl streamsMetrics) {
   ```



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java:
##########
@@ -1130,7 +1138,9 @@ public void addRecords(final TopicPartition partition, 
final Iterable<ConsumerRe
 
         // if after adding these records, its partition queue's buffered size 
has been
         // increased beyond the threshold, we can then pause the consumption 
for this partition
-        if (newQueueSize > maxBufferedSize) {
+        // We do this only if the deprecated config 
buffered.records.per.partition is set
+        if (maxBufferedSize != -1 && newQueueSize > maxBufferedSize) {

Review Comment:
   `-1` is a magic number. It's hard to know the meaning of it.
   Could we have a static variable for it?



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java:
##########
@@ -912,6 +925,81 @@ public void shouldIdleAsSpecifiedWhenLagIsZero() {
         }
     }
 
+    @Test
+    public void 
shouldUpdateTotalBytesBufferedOnRecordsAdditionAndConsumption() {
+        final PartitionGroup group = getBasicGroup();
+
+        assertEquals(0, group.numBuffered());
+        assertEquals(0L, group.totalBytesBuffered());
+
+        // add three 3 records with timestamp 1, 5, 3 to partition-1
+        final List<ConsumerRecord<byte[], byte[]>> list1 = Arrays.asList(
+            new ConsumerRecord<>("topic", 1, 1L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 5L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 3L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()));
+
+        long partition1TotalBytes = getBytesBufferedForRawRecords(list1);
+        group.addRawRecords(partition1, list1);
+
+        verifyBuffered(3, 3, 0, group);
+        assertEquals(group.totalBytesBuffered(), partition1TotalBytes);
+        assertEquals(-1L, group.streamTime());
+        assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue());
+        assertThat(metrics.metric(totalBytesValue).metricValue(), is((double) 
partition1TotalBytes));
+
+        StampedRecord record;
+        final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo();
+
+        // get first two records from partition 1
+        record = group.nextRecord(info, time.milliseconds());
+        assertEquals(record.timestamp, 1L);

Review Comment:
   
   
   ```suggestion
           assertEquals(1L, record.timestamp);
   ```



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java:
##########
@@ -912,6 +925,81 @@ public void shouldIdleAsSpecifiedWhenLagIsZero() {
         }
     }
 
+    @Test
+    public void 
shouldUpdateTotalBytesBufferedOnRecordsAdditionAndConsumption() {
+        final PartitionGroup group = getBasicGroup();
+
+        assertEquals(0, group.numBuffered());
+        assertEquals(0L, group.totalBytesBuffered());
+
+        // add three 3 records with timestamp 1, 5, 3 to partition-1
+        final List<ConsumerRecord<byte[], byte[]>> list1 = Arrays.asList(
+            new ConsumerRecord<>("topic", 1, 1L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 5L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 3L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()));
+
+        long partition1TotalBytes = getBytesBufferedForRawRecords(list1);
+        group.addRawRecords(partition1, list1);
+
+        verifyBuffered(3, 3, 0, group);
+        assertEquals(group.totalBytesBuffered(), partition1TotalBytes);
+        assertEquals(-1L, group.streamTime());
+        assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue());
+        assertThat(metrics.metric(totalBytesValue).metricValue(), is((double) 
partition1TotalBytes));
+
+        StampedRecord record;
+        final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo();
+
+        // get first two records from partition 1
+        record = group.nextRecord(info, time.milliseconds());
+        assertEquals(record.timestamp, 1L);
+        record = group.nextRecord(info, time.milliseconds());
+        assertEquals(record.timestamp, 5L);

Review Comment:
   ```suggestion
           assertEquals(5L, record.timestamp);
   ```



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java:
##########
@@ -912,6 +925,81 @@ public void shouldIdleAsSpecifiedWhenLagIsZero() {
         }
     }
 
+    @Test
+    public void 
shouldUpdateTotalBytesBufferedOnRecordsAdditionAndConsumption() {
+        final PartitionGroup group = getBasicGroup();
+
+        assertEquals(0, group.numBuffered());
+        assertEquals(0L, group.totalBytesBuffered());
+
+        // add three 3 records with timestamp 1, 5, 3 to partition-1
+        final List<ConsumerRecord<byte[], byte[]>> list1 = Arrays.asList(

Review Comment:
   nit: `List.of`



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java:
##########
@@ -4237,6 +4258,281 @@ t2p1, new PartitionInfo(t2p1.topic(), t2p1.partition(), 
null, new Node[0], new N
         );
     }
 
+    @Test
+    public void 
shouldPauseNonEmptyPartitionsWhenTotalBufferSizeExceedsMaxBufferSize() {
+        // Set up consumer mock
+        @SuppressWarnings("unchecked")
+        final Consumer<byte[], byte[]> consumer = mock(Consumer.class);
+        final ConsumerGroupMetadata consumerGroupMetadata = 
mock(ConsumerGroupMetadata.class);
+        when(consumer.groupMetadata()).thenReturn(consumerGroupMetadata);
+        
when(consumerGroupMetadata.groupInstanceId()).thenReturn(Optional.empty());
+
+        // Create records for polling
+        final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> 
polledRecords = new HashMap<>();
+        final List<ConsumerRecord<byte[], byte[]>> t1p1Records = new 
ArrayList<>();
+
+        t1p1Records.add(new ConsumerRecord<>(
+            t1p1.topic(),
+            t1p1.partition(),
+            0,
+            mockTime.milliseconds(),
+            TimestampType.CREATE_TIME,
+            2,
+            6,
+            new byte[2],
+            new byte[6],
+            new RecordHeaders(),
+            Optional.empty()));
+
+        t1p1Records.add(new ConsumerRecord<>(
+            t1p1.topic(),
+            t1p1.partition(),
+            1,
+            mockTime.milliseconds(),
+            TimestampType.CREATE_TIME,
+            2,
+            6,
+            new byte[2],
+            new byte[6],
+            new RecordHeaders(),
+            Optional.empty()));
+
+        final List<ConsumerRecord<byte[], byte[]>> t2p1Records = 
Collections.singletonList(
+            new ConsumerRecord<>(
+                t2p1.topic(),
+                t2p1.partition(),
+                0,
+                mockTime.milliseconds(),
+                TimestampType.CREATE_TIME,
+                2,
+                6,
+                new byte[2],
+                new byte[6],
+                new RecordHeaders(),
+                Optional.empty()));
+
+        polledRecords.put(t1p1, t1p1Records);
+        polledRecords.put(t2p1, t2p1Records);
+
+        // Set up consumer behavior
+        final Set<TopicPartition> partitionSet = new 
HashSet<>(Arrays.asList(t1p1, t2p1));
+
+        // First poll returns records
+        when(consumer.poll(any())).thenReturn(new 
ConsumerRecords<>(polledRecords, Map.of()))
+            // Second and third polls return empty
+            .thenReturn(new ConsumerRecords<>(Map.of(), Map.of()))
+            .thenReturn(new ConsumerRecords<>(Map.of(), Map.of()));
+
+        // Mock paused partitions behavior
+        when(consumer.paused()).thenReturn(partitionSet) // After pause
+            .thenReturn(partitionSet) // Before resume
+            .thenReturn(Collections.emptySet()); // After resume
+
+        // Set up task mock
+        final Task task1 = mock(Task.class);
+        when(task1.inputPartitions()).thenReturn(Set.of(t1p1));
+        when(task1.committedOffsets()).thenReturn(new HashMap<>());
+        when(task1.highWaterMark()).thenReturn(new HashMap<>());
+        when(task1.timeCurrentIdlingStarted()).thenReturn(Optional.empty());
+
+        // Set up TaskManager mock
+        final TaskManager taskManager = mock(TaskManager.class);
+        when(taskManager.activeTaskMap()).thenReturn(mkMap(mkEntry(new 
TaskId(0, 0), task1)));
+        when(taskManager.standbyTaskMap()).thenReturn(new HashMap<>());
+        when(taskManager.producerClientIds()).thenReturn("producerClientId");
+
+        // Mock buffer size behavior
+        when(taskManager.getInputBufferSizeInBytes())
+            .thenReturn(18L)  // After first poll
+            .thenReturn(12L)  // After first process
+            .thenReturn(6L)   // After second process
+            .thenReturn(0L);  // After third process
+
+        when(taskManager.nonEmptyPartitions()).thenReturn(partitionSet);
+        when(taskManager.process(anyInt(), any()))
+            .thenReturn(1)
+            .thenReturn(1)
+            .thenReturn(1)
+            .thenReturn(0);
+
+        // Create configuration and thread
+        final Properties props = configProps(false, false, false);
+        final StreamsConfig config = new StreamsConfig(props);
+        final TopologyMetadata topologyMetadata = new 
TopologyMetadata(internalTopologyBuilder, config);
+        topologyMetadata.buildAndRewriteTopology();
+
+        final StreamsMetricsImpl streamsMetrics =
+            new StreamsMetricsImpl(metrics, CLIENT_ID, 
StreamsConfig.METRICS_LATEST, mockTime);
+
+        thread = new StreamThread(
+            mockTime,
+            config,
+            null,
+            consumer,
+            consumer,
+            changelogReader,
+            null,
+            taskManager,
+            null,
+            streamsMetrics,
+            topologyMetadata,
+            PROCESS_ID,
+            CLIENT_ID,
+            new LogContext(""),
+            new AtomicInteger(),
+            new AtomicLong(Long.MAX_VALUE),
+            new LinkedList<>(),
+            null,
+            HANDLER,
+            null,
+            Optional.empty(),
+            null,
+            10L // maxBufferSize set to 10
+        ).updateThreadMetadata(adminClientId(CLIENT_ID));
+
+        thread.setState(State.STARTING);
+        thread.setState(State.PARTITIONS_ASSIGNED);
+        thread.setState(State.RUNNING);
+
+        // Run the test
+        thread.runOnceWithoutProcessingThreads();
+        thread.runOnceWithoutProcessingThreads();
+        thread.runOnceWithoutProcessingThreads();
+
+        // Verify behavior
+        verify(consumer).pause(partitionSet);
+        verify(consumer).resume(partitionSet);
+        verify(taskManager, times(1)).addRecordsToTasks(any());
+    }
+
+    @Test
+    public void shouldNotPausePartitionsWhenMaxBufferSizeIsSetToNegative() {
+        // Set up consumer mock
+        @SuppressWarnings("unchecked")
+        final Consumer<byte[], byte[]> consumer = mock(Consumer.class);
+        final ConsumerGroupMetadata consumerGroupMetadata = 
mock(ConsumerGroupMetadata.class);
+        when(consumer.groupMetadata()).thenReturn(consumerGroupMetadata);
+        
when(consumerGroupMetadata.groupInstanceId()).thenReturn(Optional.empty());
+
+        // Create records for polling
+        final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> 
polledRecords = new HashMap<>();
+        final List<ConsumerRecord<byte[], byte[]>> t1p1Records = new 
ArrayList<>();
+        t1p1Records.add(new ConsumerRecord<>(
+            t1p1.topic(),
+            t1p1.partition(),
+            0,
+            mockTime.milliseconds(),
+            TimestampType.CREATE_TIME,
+            2,
+            6,
+            new byte[2],
+            new byte[6],
+            new RecordHeaders(),
+            Optional.empty()));
+        t1p1Records.add(new ConsumerRecord<>(
+            t1p1.topic(),
+            t1p1.partition(),
+            1,
+            mockTime.milliseconds(),
+            TimestampType.CREATE_TIME,
+            2,
+            6,
+            new byte[2],
+            new byte[6],
+            new RecordHeaders(),
+            Optional.empty()));
+        final List<ConsumerRecord<byte[], byte[]>> t2p1Records = 
Collections.singletonList(

Review Comment:
   nit: `List.of`



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java:
##########
@@ -249,12 +255,18 @@ StampedRecord nextRecord(final RecordInfo info, final 
long wallClockTime) {
         info.queue = queue;
 
         if (queue != null) {
+

Review Comment:
   nit: redundant line



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java:
##########
@@ -113,6 +115,22 @@ public static Sensor processLatencySensor(final String 
threadId,
         );
     }
 
+    public static Sensor totalInputBufferBytesSensor(final String threadId,
+        final String taskId,
+        final StreamsMetricsImpl streamsMetrics) {
+        final String name = INPUT_BUFFER_BYTES_TOTAL;
+        final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, 
name, RecordingLevel.INFO);

Review Comment:
   IIRC, the KIP said this should be `DEBUG` level?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java:
##########
@@ -790,7 +792,13 @@ record = partitionGroup.nextRecord(recordInfo, 
wallClockTime);
 
             // after processing this record, if its partition queue's buffered 
size has been
             // decreased to the threshold, we can then resume the consumption 
on this partition
-            if (recordInfo.queue().size() <= maxBufferedSize) {
+            // TODO the second part of OR condition would be removed once

Review Comment:
   It would be nice to add a KIP number here.



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java:
##########
@@ -912,6 +925,81 @@ public void shouldIdleAsSpecifiedWhenLagIsZero() {
         }
     }
 
+    @Test
+    public void 
shouldUpdateTotalBytesBufferedOnRecordsAdditionAndConsumption() {
+        final PartitionGroup group = getBasicGroup();
+
+        assertEquals(0, group.numBuffered());
+        assertEquals(0L, group.totalBytesBuffered());
+
+        // add three 3 records with timestamp 1, 5, 3 to partition-1
+        final List<ConsumerRecord<byte[], byte[]>> list1 = Arrays.asList(
+            new ConsumerRecord<>("topic", 1, 1L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 5L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 3L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()));
+
+        long partition1TotalBytes = getBytesBufferedForRawRecords(list1);
+        group.addRawRecords(partition1, list1);
+
+        verifyBuffered(3, 3, 0, group);
+        assertEquals(group.totalBytesBuffered(), partition1TotalBytes);
+        assertEquals(-1L, group.streamTime());
+        assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue());
+        assertThat(metrics.metric(totalBytesValue).metricValue(), is((double) 
partition1TotalBytes));
+
+        StampedRecord record;
+        final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo();
+
+        // get first two records from partition 1
+        record = group.nextRecord(info, time.milliseconds());
+        assertEquals(record.timestamp, 1L);
+        record = group.nextRecord(info, time.milliseconds());
+        assertEquals(record.timestamp, 5L);
+
+        partition1TotalBytes -= 
getBytesBufferedForRawRecords(Arrays.asList(list1.get(0), list1.get(1)));
+        assertEquals(group.totalBytesBuffered(), partition1TotalBytes);
+        assertThat(metrics.metric(totalBytesValue).metricValue(), is((double) 
partition1TotalBytes));
+
+        // add three 3 records with timestamp 2, 4, 6 to partition-2
+        final List<ConsumerRecord<byte[], byte[]>> list2 = Arrays.asList(
+            new ConsumerRecord<>("topic", 2, 2L, record.timestamp, 
TimestampType.CREATE_TIME, recordKey.length, recordValue.length, recordKey, 
recordValue, new RecordHeaders(), Optional.empty()),
+            new ConsumerRecord<>("topic", 2, 4L, record.timestamp, 
TimestampType.CREATE_TIME, recordKey.length, recordValue.length, recordKey, 
recordValue, new RecordHeaders(), Optional.empty()),
+            new ConsumerRecord<>("topic", 2, 6L, record.timestamp, 
TimestampType.CREATE_TIME, recordKey.length, recordValue.length, recordKey, 
recordValue, new RecordHeaders(), Optional.empty()));
+
+        long partition2TotalBytes = getBytesBufferedForRawRecords(list2);
+        group.addRawRecords(partition2, list2);
+        // 1:[3]
+        // 2:[2, 4, 6]
+        assertEquals(group.totalBytesBuffered(), partition2TotalBytes + 
partition1TotalBytes);
+        assertThat(metrics.metric(totalBytesValue).metricValue(), is((double) 
partition2TotalBytes + partition1TotalBytes));
+
+        // get one record, next record should be ts=2 from partition 2
+        record = group.nextRecord(info, time.milliseconds());
+        // 1:[3]
+        // 2:[4, 6]
+        partition2TotalBytes -= 
getBytesBufferedForRawRecords(Collections.singletonList(list2.get(0)));

Review Comment:
   nit: `List.of`



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java:
##########
@@ -912,6 +925,81 @@ public void shouldIdleAsSpecifiedWhenLagIsZero() {
         }
     }
 
+    @Test
+    public void 
shouldUpdateTotalBytesBufferedOnRecordsAdditionAndConsumption() {
+        final PartitionGroup group = getBasicGroup();
+
+        assertEquals(0, group.numBuffered());
+        assertEquals(0L, group.totalBytesBuffered());
+
+        // add three 3 records with timestamp 1, 5, 3 to partition-1
+        final List<ConsumerRecord<byte[], byte[]>> list1 = Arrays.asList(
+            new ConsumerRecord<>("topic", 1, 1L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 5L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 3L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()));
+
+        long partition1TotalBytes = getBytesBufferedForRawRecords(list1);
+        group.addRawRecords(partition1, list1);
+
+        verifyBuffered(3, 3, 0, group);
+        assertEquals(group.totalBytesBuffered(), partition1TotalBytes);
+        assertEquals(-1L, group.streamTime());
+        assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue());
+        assertThat(metrics.metric(totalBytesValue).metricValue(), is((double) 
partition1TotalBytes));
+
+        StampedRecord record;
+        final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo();
+
+        // get first two records from partition 1
+        record = group.nextRecord(info, time.milliseconds());
+        assertEquals(record.timestamp, 1L);
+        record = group.nextRecord(info, time.milliseconds());
+        assertEquals(record.timestamp, 5L);
+
+        partition1TotalBytes -= 
getBytesBufferedForRawRecords(Arrays.asList(list1.get(0), list1.get(1)));

Review Comment:
   nit: List.of



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java:
##########
@@ -912,6 +925,81 @@ public void shouldIdleAsSpecifiedWhenLagIsZero() {
         }
     }
 
+    @Test
+    public void 
shouldUpdateTotalBytesBufferedOnRecordsAdditionAndConsumption() {
+        final PartitionGroup group = getBasicGroup();
+
+        assertEquals(0, group.numBuffered());
+        assertEquals(0L, group.totalBytesBuffered());
+
+        // add three 3 records with timestamp 1, 5, 3 to partition-1
+        final List<ConsumerRecord<byte[], byte[]>> list1 = Arrays.asList(
+            new ConsumerRecord<>("topic", 1, 1L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 5L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 3L, new 
MockTime().milliseconds(), TimestampType.CREATE_TIME, recordKey.length, 
recordValue.length, recordKey, recordValue, new RecordHeaders(), 
Optional.empty()));
+
+        long partition1TotalBytes = getBytesBufferedForRawRecords(list1);
+        group.addRawRecords(partition1, list1);
+
+        verifyBuffered(3, 3, 0, group);
+        assertEquals(group.totalBytesBuffered(), partition1TotalBytes);
+        assertEquals(-1L, group.streamTime());
+        assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue());
+        assertThat(metrics.metric(totalBytesValue).metricValue(), is((double) 
partition1TotalBytes));
+
+        StampedRecord record;
+        final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo();
+
+        // get first two records from partition 1
+        record = group.nextRecord(info, time.milliseconds());
+        assertEquals(record.timestamp, 1L);
+        record = group.nextRecord(info, time.milliseconds());
+        assertEquals(record.timestamp, 5L);
+
+        partition1TotalBytes -= 
getBytesBufferedForRawRecords(Arrays.asList(list1.get(0), list1.get(1)));
+        assertEquals(group.totalBytesBuffered(), partition1TotalBytes);
+        assertThat(metrics.metric(totalBytesValue).metricValue(), is((double) 
partition1TotalBytes));
+
+        // add three 3 records with timestamp 2, 4, 6 to partition-2
+        final List<ConsumerRecord<byte[], byte[]>> list2 = Arrays.asList(
+            new ConsumerRecord<>("topic", 2, 2L, record.timestamp, 
TimestampType.CREATE_TIME, recordKey.length, recordValue.length, recordKey, 
recordValue, new RecordHeaders(), Optional.empty()),
+            new ConsumerRecord<>("topic", 2, 4L, record.timestamp, 
TimestampType.CREATE_TIME, recordKey.length, recordValue.length, recordKey, 
recordValue, new RecordHeaders(), Optional.empty()),
+            new ConsumerRecord<>("topic", 2, 6L, record.timestamp, 
TimestampType.CREATE_TIME, recordKey.length, recordValue.length, recordKey, 
recordValue, new RecordHeaders(), Optional.empty()));
+
+        long partition2TotalBytes = getBytesBufferedForRawRecords(list2);
+        group.addRawRecords(partition2, list2);
+        // 1:[3]
+        // 2:[2, 4, 6]
+        assertEquals(group.totalBytesBuffered(), partition2TotalBytes + 
partition1TotalBytes);
+        assertThat(metrics.metric(totalBytesValue).metricValue(), is((double) 
partition2TotalBytes + partition1TotalBytes));
+
+        // get one record, next record should be ts=2 from partition 2
+        record = group.nextRecord(info, time.milliseconds());
+        // 1:[3]
+        // 2:[4, 6]
+        partition2TotalBytes -= 
getBytesBufferedForRawRecords(Collections.singletonList(list2.get(0)));
+        assertEquals(group.totalBytesBuffered(), partition2TotalBytes + 
partition1TotalBytes);
+        assertThat(metrics.metric(totalBytesValue).metricValue(), is((double) 
partition2TotalBytes + partition1TotalBytes));
+        assertEquals(record.timestamp, 2L);
+
+        // get one record, next up should have ts=3 from partition 1 (even 
though it has seen a larger max timestamp =5)
+        record = group.nextRecord(info, time.milliseconds());
+        // 1:[]
+        // 2:[4, 6]
+        partition1TotalBytes -= 
getBytesBufferedForRawRecords(Collections.singletonList(list1.get(2)));

Review Comment:
   nit: `List.of`



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java:
##########
@@ -4237,6 +4258,281 @@ t2p1, new PartitionInfo(t2p1.topic(), t2p1.partition(), 
null, new Node[0], new N
         );
     }
 
+    @Test
+    public void 
shouldPauseNonEmptyPartitionsWhenTotalBufferSizeExceedsMaxBufferSize() {
+        // Set up consumer mock
+        @SuppressWarnings("unchecked")
+        final Consumer<byte[], byte[]> consumer = mock(Consumer.class);
+        final ConsumerGroupMetadata consumerGroupMetadata = 
mock(ConsumerGroupMetadata.class);
+        when(consumer.groupMetadata()).thenReturn(consumerGroupMetadata);
+        
when(consumerGroupMetadata.groupInstanceId()).thenReturn(Optional.empty());
+
+        // Create records for polling
+        final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> 
polledRecords = new HashMap<>();
+        final List<ConsumerRecord<byte[], byte[]>> t1p1Records = new 
ArrayList<>();
+
+        t1p1Records.add(new ConsumerRecord<>(
+            t1p1.topic(),
+            t1p1.partition(),
+            0,
+            mockTime.milliseconds(),
+            TimestampType.CREATE_TIME,
+            2,
+            6,
+            new byte[2],
+            new byte[6],
+            new RecordHeaders(),
+            Optional.empty()));
+
+        t1p1Records.add(new ConsumerRecord<>(
+            t1p1.topic(),
+            t1p1.partition(),
+            1,
+            mockTime.milliseconds(),
+            TimestampType.CREATE_TIME,
+            2,
+            6,
+            new byte[2],
+            new byte[6],
+            new RecordHeaders(),
+            Optional.empty()));
+
+        final List<ConsumerRecord<byte[], byte[]>> t2p1Records = 
Collections.singletonList(
+            new ConsumerRecord<>(
+                t2p1.topic(),
+                t2p1.partition(),
+                0,
+                mockTime.milliseconds(),
+                TimestampType.CREATE_TIME,
+                2,
+                6,
+                new byte[2],
+                new byte[6],
+                new RecordHeaders(),
+                Optional.empty()));
+
+        polledRecords.put(t1p1, t1p1Records);
+        polledRecords.put(t2p1, t2p1Records);
+
+        // Set up consumer behavior
+        final Set<TopicPartition> partitionSet = new 
HashSet<>(Arrays.asList(t1p1, t2p1));

Review Comment:
   nit: `List.of`



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java:
##########
@@ -4237,6 +4258,281 @@ t2p1, new PartitionInfo(t2p1.topic(), t2p1.partition(), 
null, new Node[0], new N
         );
     }
 
+    @Test
+    public void 
shouldPauseNonEmptyPartitionsWhenTotalBufferSizeExceedsMaxBufferSize() {
+        // Set up consumer mock
+        @SuppressWarnings("unchecked")
+        final Consumer<byte[], byte[]> consumer = mock(Consumer.class);
+        final ConsumerGroupMetadata consumerGroupMetadata = 
mock(ConsumerGroupMetadata.class);
+        when(consumer.groupMetadata()).thenReturn(consumerGroupMetadata);
+        
when(consumerGroupMetadata.groupInstanceId()).thenReturn(Optional.empty());
+
+        // Create records for polling
+        final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> 
polledRecords = new HashMap<>();
+        final List<ConsumerRecord<byte[], byte[]>> t1p1Records = new 
ArrayList<>();
+
+        t1p1Records.add(new ConsumerRecord<>(
+            t1p1.topic(),
+            t1p1.partition(),
+            0,
+            mockTime.milliseconds(),
+            TimestampType.CREATE_TIME,
+            2,
+            6,
+            new byte[2],
+            new byte[6],
+            new RecordHeaders(),
+            Optional.empty()));
+
+        t1p1Records.add(new ConsumerRecord<>(
+            t1p1.topic(),
+            t1p1.partition(),
+            1,
+            mockTime.milliseconds(),
+            TimestampType.CREATE_TIME,
+            2,
+            6,
+            new byte[2],
+            new byte[6],
+            new RecordHeaders(),
+            Optional.empty()));
+
+        final List<ConsumerRecord<byte[], byte[]>> t2p1Records = 
Collections.singletonList(

Review Comment:
   nit: `List.of`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to