This is an automated email from the ASF dual-hosted git repository.

jolshan pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 0b2e410d619 KAFKA-19367: Fix InitProducerId with TV2 double-increments 
epoch if ongoing transaction is aborted (#19910)
0b2e410d619 is described below

commit 0b2e410d61970e66c6f73a18c75028df0a871777
Author: Ritika Reddy <[email protected]>
AuthorDate: Thu Jun 12 09:37:07 2025 -0700

    KAFKA-19367: Fix InitProducerId with TV2 double-increments epoch if ongoing 
transaction is aborted (#19910)
    
    When InitProducerId is handled on the transaction coordinator, the
    producer epoch is incremented (so that we fence stale requests), then if
    a transaction was ongoing during this time, it's aborted.  With
    transaction version 2 (a.k.a. KIP-890 part 2), abort increments the
    producer epoch again (it's the part of the new abort / commit protocol),
    so the epoch ends up incremented twice.
    
    In most cases, this is benign, but in the case where the epoch of the
    ongoing transaction is 32766, it's incremented to 32767, which is the
    maximum value for short. Then, when it's incremented for the second
    time, it goes negative, causing an illegal argument exception.
    
    To fix this we just avoid bumping the epoch a second time.
    
    Reviewers: Justine Olshan <[email protected]>, Artem Livshits
     <[email protected]>
---
 .../transaction/TransactionCoordinator.scala       |   6 +-
 .../transaction/TransactionCoordinatorTest.scala   | 136 +++++++++++++++++++++
 2 files changed, 138 insertions(+), 4 deletions(-)

diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index 30f7fb6cf86..2764de5cd6c 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -819,11 +819,9 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
                 }
 
               if (nextState == TransactionState.PREPARE_ABORT && isEpochFence) 
{
-                // We should clear the pending state to make way for the 
transition to PrepareAbort and also bump
-                // the epoch in the transaction metadata we are about to 
append.
+                // We should clear the pending state to make way for the 
transition to PrepareAbort
                 txnMetadata.pendingState = None
-                txnMetadata.producerEpoch = producerEpoch
-                txnMetadata.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH
+                // For TV2+, don't manually set the epoch - let 
prepareAbortOrCommit handle it naturally.
               }
 
               nextProducerIdOrErrors.flatMap {
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index eea5db86bc6..12c36f61761 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -1267,6 +1267,142 @@ class TransactionCoordinatorTest {
       any())
   }
 
+  @Test
+  def shouldNotCauseEpochOverflowWhenInitPidDuringOngoingTxnV2(): Unit = {
+    // When InitProducerId is called with an ongoing transaction at epoch 
32766 (Short.MaxValue - 1),
+    // it should not cause an epoch overflow by incrementing twice.
+    // The only true increment happens in prepareAbortOrCommit
+    val txnMetadata = new TransactionMetadata(transactionalId, producerId, 
producerId, RecordBatch.NO_PRODUCER_ID,
+      (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, 
txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), 
time.milliseconds(), TV_2)
+
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
+      .thenReturn(true)
+    
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
+      .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
+    when(transactionManager.transactionVersionLevel()).thenReturn(TV_2)
+
+    // Capture the transition metadata to verify epoch increments
+    val capturedTxnTransitMetadata: ArgumentCaptor[TxnTransitMetadata] = 
ArgumentCaptor.forClass(classOf[TxnTransitMetadata])
+    when(transactionManager.appendTransactionToLog(
+      ArgumentMatchers.eq(transactionalId),
+      ArgumentMatchers.eq(coordinatorEpoch),
+      capturedTxnTransitMetadata.capture(),
+      capturedErrorsCallback.capture(),
+      any(),
+      any())
+    ).thenAnswer(invocation => {
+      val transitMetadata = invocation.getArgument[TxnTransitMetadata](2)
+      // Simulate the metadata update that would happen in the real 
appendTransactionToLog
+      txnMetadata.completeTransitionTo(transitMetadata)
+      capturedErrorsCallback.getValue.apply(Errors.NONE)
+    })
+
+    // Handle InitProducerId with ongoing transaction at epoch 32766
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
+
+    // Verify that the epoch did not overflow (should be Short.MaxValue = 
32767, not negative)
+    assertEquals(Short.MaxValue, txnMetadata.producerEpoch)
+    assertEquals(TransactionState.PREPARE_ABORT, txnMetadata.state)
+    
+    verify(transactionManager).validateTransactionTimeoutMs(anyBoolean(), 
anyInt())
+    verify(transactionManager, 
times(3)).getTransactionState(ArgumentMatchers.eq(transactionalId))
+    verify(transactionManager).appendTransactionToLog(
+      ArgumentMatchers.eq(transactionalId),
+      ArgumentMatchers.eq(coordinatorEpoch),
+      any[TxnTransitMetadata],
+      any(),
+      any(),
+      any())
+  }
+
+  @Test
+  def shouldHandleTimeoutAtEpochOverflowBoundaryCorrectlyTV2(): Unit = {
+    // Test the scenario where we have an ongoing transaction at epoch 32766 
(Short.MaxValue - 1)
+    // and the producer crashes/times out. This test verifies that the timeout 
handling
+    // correctly manages the epoch overflow scenario without causing failures.
+
+    val epochAtMaxBoundary = (Short.MaxValue - 1).toShort // 32766
+    val now = time.milliseconds()
+
+    // Create transaction metadata at the epoch boundary that would cause 
overflow IFF double-incremented
+    val txnMetadata = new TransactionMetadata(
+      transactionalId = transactionalId,
+      producerId = producerId,
+      prevProducerId = RecordBatch.NO_PRODUCER_ID,
+      nextProducerId = RecordBatch.NO_PRODUCER_ID,
+      producerEpoch = epochAtMaxBoundary,
+      lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
+      txnTimeoutMs = txnTimeoutMs,
+      state = TransactionState.ONGOING,
+      topicPartitions = partitions,
+      txnStartTimestamp = now,
+      txnLastUpdateTimestamp = now,
+      clientTransactionVersion = TV_2
+    )
+    assertTrue(txnMetadata.isProducerEpochExhausted)
+
+    // Mock the transaction manager to return our test transaction as timed out
+    when(transactionManager.timedOutTransactions())
+      .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, 
producerId, epochAtMaxBoundary)))
+    
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
+      .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
+    when(transactionManager.transactionVersionLevel()).thenReturn(TV_2)
+
+    // Mock the append operation to simulate successful write and update the 
metadata
+    when(transactionManager.appendTransactionToLog(
+      ArgumentMatchers.eq(transactionalId),
+      ArgumentMatchers.eq(coordinatorEpoch),
+      any[TxnTransitMetadata],
+      capturedErrorsCallback.capture(),
+      any(),
+      any())
+    ).thenAnswer(invocation => {
+      val transitMetadata = invocation.getArgument[TxnTransitMetadata](2)
+      // Simulate the metadata update that would happen in the real 
appendTransactionToLog
+      txnMetadata.completeTransitionTo(transitMetadata)
+      capturedErrorsCallback.getValue.apply(Errors.NONE)
+    })
+
+    // Track the actual behavior
+    var callbackInvoked = false
+    var resultError: Errors = null
+    var resultProducerId: Long = -1
+    var resultEpoch: Short = -1
+
+    def checkOnEndTransactionComplete(txnIdAndPidEpoch: 
TransactionalIdAndProducerIdEpoch)
+      (error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = {
+        callbackInvoked = true
+        resultError = error
+        resultProducerId = newProducerId
+        resultEpoch = newProducerEpoch
+      }
+
+    // Execute the timeout abort process
+    coordinator.abortTimedOutTransactions(checkOnEndTransactionComplete)
+
+    assertTrue(callbackInvoked, "Callback should have been invoked")
+    assertEquals(Errors.NONE, resultError, "Expected no errors in the 
callback")
+    assertEquals(producerId, resultProducerId, "Expected producer ID to match")
+    assertEquals(Short.MaxValue, resultEpoch, "Expected producer epoch to be 
Short.MaxValue (32767) single epoch bump")
+    
+    // Verify the transaction metadata was correctly updated to the final epoch
+    assertEquals(Short.MaxValue, txnMetadata.producerEpoch, 
+      s"Expected transaction metadata producer epoch to be ${Short.MaxValue} " 
+
+        s"after timeout handling, but was ${txnMetadata.producerEpoch}"
+    )
+
+    // Verify the basic flow was attempted
+    verify(transactionManager).timedOutTransactions()
+    verify(transactionManager, 
atLeast(1)).getTransactionState(ArgumentMatchers.eq(transactionalId))
+  }
+
   @Test
   def testInitProducerIdWithNoLastProducerData(): Unit = {
     // If the metadata doesn't include the previous producer data (for 
example, if it was written to the log by a broker

Reply via email to