This is an automated email from the ASF dual-hosted git repository.

jolshan pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 783e5013778 KAFKA-20090: rotate producerId when epoch exhausted during 
transaction completion. (#21506)
783e5013778 is described below

commit 783e50137787d195ae6551149d2df8603d8e8f99
Author: ChickenchickenLove <[email protected]>
AuthorDate: Thu Mar 12 02:55:39 2026 +0900

    KAFKA-20090: rotate producerId when epoch exhausted during transaction 
completion. (#21506)
    
    ### Motivation
    In TV2, there is an edge-case race at the epoch exhaustion boundary
    (`Short.MaxValue - 1`) that can leave a transaction stuck in `ONGOING`.
    
    When a transaction times out at epoch `Short.MaxValue - 1`, the
    coordinator fences and aborts it. During this flow, the epoch is bumped
    to `Short.MaxValue` and the transaction proceeds toward abort
    completion.  If a delayed client `EndTxn(ABORT)` arrives with the old
    epoch, the request may be treated as a retry path.
    
    Before this change, `generateTxnTransitMetadataForTxnCompletion` skipped
    producer ID rotation for epoch-fence aborts due to the `!isEpochFence`
    guard.  That allowed the exhausted epoch path to continue without
    allocating `nextProducerId`, which could prevent definitive
    fencing/rotation at the overflow boundary.
    
    ### Changes
    - Removed the `!isEpochFence` guard in the TV2 `endTransaction`
    completion path when checking `txnMetadata.isProducerEpochExhausted`.
    
    This ensures that abort flows passing through epoch fencing (for
    example, timeout-driven aborts and fencing triggered during
    `InitProducerId`) can still allocate `nextProducerId`  when the epoch is
    exhausted.
    
    Metadata transition behavior:
    - In `PREPARE_ABORT`, `TxnTransitMetadata` carries `nextProducerId`
    (with the boundary bump behavior, e.g. epoch reaching `Short.MaxValue`).
    - In-memory `TransactionMetadata` keeps the original producer ID until
    marker completion.
    - On `COMPLETE_ABORT`, `TransactionMetadata` is rotated to the new
    producer ID with epoch `0`, fully fencing the old producer session.
    
    ### Testing
    Added TV2 boundary tests in `TransactionCoordinatorTest.scala`:
    -
    
    
`shouldHandleTimeoutAtEpochOverflowBoundaryCorrectlyAndLateClientAbortRequestTV2`
    -
    
    
`shouldRotateProducerIdWhenInitPidFencesOngoingTxnAtEpochOverflowBoundaryTV2`
    
    Reviewers: Justine Olshan <[email protected]>, Artem Livshits
     <[email protected]>
---
 .../transaction/TransactionCoordinator.scala       |   5 +-
 .../transaction/TransactionCoordinatorTest.scala   | 488 ++++++++++++++++++++-
 2 files changed, 482 insertions(+), 11 deletions(-)

diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index 82a2bd7706b..fab67b3fcd0 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -813,9 +813,10 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
             val isRetry = retryOnEpochBump || retryOnOverflow
 
             def generateTxnTransitMetadataForTxnCompletion(nextState: 
TransactionState, noPartitionAdded: Boolean): ApiResult[(Int, 
TxnTransitMetadata)] = {
-              // Maybe allocate new producer ID if we are bumping epoch and 
epoch is exhausted
+              // EndTxn completion on TV2 bumps epoch, so rotate producer ID 
whenever the current epoch is exhausted.
+              // This must also apply to the epoch-fence path.
               val nextProducerIdOrErrors =
-                if (!isEpochFence && txnMetadata.isProducerEpochExhausted) {
+                if (txnMetadata.isProducerEpochExhausted) {
                   try {
                     Right(producerIdManager.generateProducerId())
                   } catch {
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index 8393382f8a9..3f0cd44989f 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -23,7 +23,7 @@ import org.apache.kafka.common.record.internal.RecordBatch
 import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, 
TransactionResult}
 import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch}
 import org.apache.kafka.coordinator.transaction.{ProducerIdManager, 
TransactionMetadata, TransactionState, TransactionStateManagerConfig, 
TxnTransitMetadata}
-import org.apache.kafka.server.common.TransactionVersion
+import org.apache.kafka.server.common.{RequestLocal, TransactionVersion}
 import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
 import org.apache.kafka.server.util.MockScheduler
 import org.junit.jupiter.api.Assertions._
@@ -35,7 +35,6 @@ import org.mockito.Mockito._
 import org.mockito.{ArgumentCaptor, ArgumentMatchers}
 
 import java.util
-
 import scala.jdk.CollectionConverters._
 
 class TransactionCoordinatorTest {
@@ -1333,6 +1332,11 @@ class TransactionCoordinatorTest {
 
     val epochAtMaxBoundary = (Short.MaxValue - 1).toShort // 32766
     val now = time.milliseconds()
+    
+    val bumpedUpProducerId = producerId + 1L
+    val bumpedUpEpoch = 0.toShort
+    when(pidGenerator.generateProducerId())
+      .thenReturn(bumpedUpProducerId)
 
     // Create transaction metadata at the epoch boundary that would cause 
overflow IFF double-incremented
     val txnMetadata = new TransactionMetadata(
@@ -1362,15 +1366,15 @@ class TransactionCoordinatorTest {
     when(transactionManager.appendTransactionToLog(
       ArgumentMatchers.eq(transactionalId),
       ArgumentMatchers.eq(coordinatorEpoch),
-      any[TxnTransitMetadata],
-      capturedErrorsCallback.capture(),
+      capturedTxnTransitMetadata.capture(),
+      any[Errors => Unit](),
       any(),
       any())
     ).thenAnswer(invocation => {
       val transitMetadata = invocation.getArgument[TxnTransitMetadata](2)
-      // Simulate the metadata update that would happen in the real 
appendTransactionToLog
+      val callback = invocation.getArgument[Errors => Unit](3)
       txnMetadata.completeTransitionTo(transitMetadata)
-      capturedErrorsCallback.getValue.apply(Errors.NONE)
+      callback.apply(Errors.NONE)
     })
 
     // Track the actual behavior
@@ -1392,11 +1396,13 @@ class TransactionCoordinatorTest {
 
     assertTrue(callbackInvoked, "Callback should have been invoked")
     assertEquals(Errors.NONE, resultError, "Expected no errors in the 
callback")
-    assertEquals(producerId, resultProducerId, "Expected producer ID to match")
-    assertEquals(Short.MaxValue, resultEpoch, "Expected producer epoch to be 
Short.MaxValue (32767) single epoch bump")
+    assertEquals(bumpedUpProducerId, resultProducerId, "Expected producer ID 
should be rotated because of epoch exhausted.")
+    assertEquals(bumpedUpEpoch, resultEpoch, "Expected producer epoch to be 0 
as a result of ProducerId rotation.")
     
     // Verify the transaction metadata was correctly updated to the final epoch
-    assertEquals(Short.MaxValue, txnMetadata.producerEpoch, 
+    assertEquals(TransactionState.PREPARE_ABORT, txnMetadata.state())
+    assertEquals(producerId, txnMetadata.producerId(), "Expected producer ID 
should not be rotated because txnMarker is not written yet.")
+    assertEquals(Short.MaxValue, txnMetadata.producerEpoch,
       s"Expected transaction metadata producer epoch to be ${Short.MaxValue} " 
+
         s"after timeout handling, but was ${txnMetadata.producerEpoch}"
     )
@@ -1404,6 +1410,470 @@ class TransactionCoordinatorTest {
     // Verify the basic flow was attempted
     verify(transactionManager).timedOutTransactions()
     verify(transactionManager, 
atLeast(1)).getTransactionState(ArgumentMatchers.eq(transactionalId))
+    verify(pidGenerator, times(1)).generateProducerId()
+  }
+
+  @Test
+  def 
shouldHandleTimeoutAtEpochOverflowBoundaryCorrectlyAndLateClientAbortRequestTV2():
 Unit = {
+    // 1. The transaction coordinator aborts the transaction due to a timeout 
at epoch 32766 
+    //    (timeout -> fenced -> prepare abort -> complete abort) 
+    // 2. The client sends an abort request later.
+
+    val epochAtMaxBoundary = (Short.MaxValue - 1).toShort // 32766
+    val now = time.milliseconds()
+
+    val rotatedProducerId = producerId + 1L
+    val rotatedEpoch = 0.toShort
+    when(pidGenerator.generateProducerId())
+      .thenReturn(rotatedProducerId)
+
+    // Create transaction metadata at the epoch boundary that would cause 
overflow IFF double-incremented
+    val txnMetadata = new TransactionMetadata(
+      transactionalId,
+      producerId,
+      RecordBatch.NO_PRODUCER_ID,
+      RecordBatch.NO_PRODUCER_ID,
+      epochAtMaxBoundary,
+      RecordBatch.NO_PRODUCER_EPOCH,
+      txnTimeoutMs,
+      TransactionState.ONGOING,
+      partitions,
+      now,
+      now,
+      TV_2
+    )
+    assertTrue(txnMetadata.isProducerEpochExhausted)
+
+    // Mock the transaction manager to return our test transaction as timed out
+    when(transactionManager.timedOutTransactions())
+      .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, 
producerId, epochAtMaxBoundary)))
+    
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
+      .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
+    when(transactionManager.transactionVersionLevel()).thenReturn(TV_2)
+
+    // Mock the append operation to simulate successful write and update the 
metadata
+    when(transactionManager.appendTransactionToLog(
+      ArgumentMatchers.eq(transactionalId),
+      ArgumentMatchers.eq(coordinatorEpoch),
+      capturedTxnTransitMetadata.capture(),
+      any[Errors => Unit](),
+      any(),
+      any())
+    ).thenAnswer(invocation => {
+      val transitMetadata = invocation.getArgument[TxnTransitMetadata](2)
+      val callback = invocation.getArgument[Errors => Unit](3)
+      txnMetadata.completeTransitionTo(transitMetadata)
+      callback.apply(Errors.NONE)
+    })
+
+    // Simulate marker write completion by appending COMPLETE_ABORT.
+    doAnswer(invocation => {
+      val markerCoordinatorEpoch = invocation.getArgument[Int](0)
+      val markerTxnMetadata = invocation.getArgument[TransactionMetadata](2)
+      val newTxnMetadata = invocation.getArgument[TxnTransitMetadata](3)
+      transactionManager.appendTransactionToLog(
+        markerTxnMetadata.transactionalId(),
+        markerCoordinatorEpoch,
+        newTxnMetadata,
+        _ => (),
+        _ == Errors.COORDINATOR_NOT_AVAILABLE,
+        RequestLocal.noCaching
+      )
+      null
+    }).when(transactionMarkerChannelManager).addTxnMarkersToSend(
+      ArgumentMatchers.eq(coordinatorEpoch),
+      ArgumentMatchers.eq(TransactionResult.ABORT),
+      ArgumentMatchers.eq(txnMetadata),
+      any[TxnTransitMetadata]()
+    )
+    
+    // Track the actual behavior
+    var callbackInvoked = false
+    var resultError: Errors = null
+    var resultProducerId: Long = -1
+    var resultEpoch: Short = -1
+
+    def checkOnEndTransactionComplete(txnIdAndPidEpoch: 
TransactionalIdAndProducerIdEpoch)
+       (error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = {
+      callbackInvoked = true
+      resultError = error
+      resultProducerId = newProducerId
+      resultEpoch = newProducerEpoch
+
+      // TransitMetadata should be rotated.
+      assertEquals(Errors.NONE, resultError, "Expected no errors in the 
callback")
+      assertEquals(rotatedProducerId, resultProducerId, "Expected producer ID 
should be rotated because of epoch exhausted.")
+      assertEquals(rotatedEpoch, resultEpoch, "Expected producer epoch to be 0 
as a result of ProducerId rotation.")
+
+      // The local transaction state is not updated yet.
+      assertEquals(TransactionState.PREPARE_ABORT, txnMetadata.state())
+      assertEquals(producerId, txnMetadata.producerId(), "Expected producer ID 
should not be rotated because txnMarker is not written yet.")
+      assertEquals(Short.MaxValue, txnMetadata.producerEpoch,
+        s"Expected transaction metadata producer epoch to be ${Short.MaxValue} 
" +
+          s"after timeout handling, but was ${txnMetadata.producerEpoch}"
+      )
+    }
+
+    // Execute the timeout abort process
+    coordinator.abortTimedOutTransactions(checkOnEndTransactionComplete)
+
+    // the transaction completion callback was invoked.
+    assertTrue(callbackInvoked, "Callback should have been invoked")
+
+    val capturedTransitions = 
capturedTxnTransitMetadata.getAllValues.asScala.toList
+    val prepareAbortTransition = capturedTransitions.head
+    assertEquals(TransactionState.PREPARE_ABORT, 
prepareAbortTransition.txnState)
+    assertEquals(rotatedProducerId, prepareAbortTransition.nextProducerId)
+    assertTrue(capturedTransitions.exists(_.txnState == 
TransactionState.COMPLETE_ABORT))
+
+    // Verify the transaction metadata was correctly updated to the final 
epoch as a result of sendMarkerTxn.
+    assertEquals(TransactionState.COMPLETE_ABORT, txnMetadata.state())
+    assertEquals(rotatedProducerId, txnMetadata.producerId())
+    assertEquals(rotatedEpoch, txnMetadata.producerEpoch)
+
+    // Verify the basic flow was attempted
+    verify(transactionManager).timedOutTransactions()
+    verify(transactionManager, times(2)).appendTransactionToLog(
+      ArgumentMatchers.eq(transactionalId),
+      ArgumentMatchers.eq(coordinatorEpoch),
+      any(),
+      any(),
+      any(),
+      any()
+    )
+    verify(transactionManager, 
atLeast(1)).getTransactionState(ArgumentMatchers.eq(transactionalId))
+    verify(pidGenerator, times(1)).generateProducerId()
+    verify(transactionMarkerChannelManager, times(1)).addTxnMarkersToSend(
+      ArgumentMatchers.eq(coordinatorEpoch),
+      ArgumentMatchers.eq(TransactionResult.ABORT),
+      ArgumentMatchers.eq(txnMetadata),
+      any[TxnTransitMetadata]()
+    )
+
+    // Simulate that client send abort request lately. 
+    val clientPid = producerId
+    val clientEpoch = epochAtMaxBoundary
+
+    var clientCallbackInvoked = false
+    var clientErr: Errors = null
+    var clientReturnedPid: Long = -1L
+    var clientReturnedEpoch: Short = -1
+
+    def onClientEndTxn(error: Errors, newProducerId: Long, newProducerEpoch: 
Short): Unit = {
+      clientCallbackInvoked = true
+      clientErr = error
+      clientReturnedPid = newProducerId
+      clientReturnedEpoch = newProducerEpoch
+    }
+
+    // WHEN : Client tries to abort transaction after server abort transaction 
because of timeout.
+    coordinator.handleEndTransaction(
+      transactionalId,
+      clientPid,
+      clientEpoch,
+      TransactionResult.ABORT,
+      TV_2,
+      onClientEndTxn,
+      RequestLocal.noCaching
+    )
+
+    // THEN : It should be treated as a retry.
+    assertTrue(clientCallbackInvoked)
+    assertEquals(Errors.NONE, clientErr)
+    assertEquals(rotatedProducerId, clientReturnedPid)
+    assertEquals(rotatedEpoch, clientReturnedEpoch)
+  }
+
+  @Test
+  def 
shouldRotateProducerIdWhenInitPidFencesOngoingTxnAtEpochOverflowBoundaryTV2(): 
Unit = {
+    // 1. The transaction coordinator aborts the transaction because a new 
InitProducerId fences an ongoing transaction at epoch 32766. 
+    //    (InitProducerId -> fenced -> prepare abort -> complete abort) 
+    // 2. The client sends an abort request later.
+    
+    val epochAtMaxBoundary = (Short.MaxValue - 1).toShort // 32766
+    val now = time.milliseconds()
+
+    val rotatedProducerId = producerId + 1L
+    val rotatedEpoch = 0.toShort
+    when(pidGenerator.generateProducerId())
+      .thenReturn(rotatedProducerId)
+
+    val txnMetadata = new TransactionMetadata(
+      transactionalId,
+      producerId,
+      RecordBatch.NO_PRODUCER_ID,
+      RecordBatch.NO_PRODUCER_ID,
+      epochAtMaxBoundary,
+      RecordBatch.NO_PRODUCER_EPOCH,
+      txnTimeoutMs,
+      TransactionState.ONGOING,
+      partitions,
+      now,
+      now,
+      TV_2
+    )
+    assertTrue(txnMetadata.isProducerEpochExhausted)
+
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
+      .thenReturn(true)
+    
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
+      .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
+    when(transactionManager.transactionVersionLevel()).thenReturn(TV_2)
+
+    when(transactionManager.appendTransactionToLog(
+      ArgumentMatchers.eq(transactionalId),
+      ArgumentMatchers.eq(coordinatorEpoch),
+      capturedTxnTransitMetadata.capture(),
+      any[Errors => Unit](),
+      any(),
+      any())
+    ).thenAnswer(invocation => {
+      val transitMetadata = invocation.getArgument[TxnTransitMetadata](2)
+      val callback = invocation.getArgument[Errors => Unit](3)
+      txnMetadata.completeTransitionTo(transitMetadata)
+      callback.apply(Errors.NONE)
+    })
+
+    // Simulate marker write completion by appending COMPLETE_ABORT.
+    doAnswer(invocation => {
+      val markerCoordinatorEpoch = invocation.getArgument[Int](0)
+      val markerTxnMetadata = invocation.getArgument[TransactionMetadata](2)
+      val newTxnMetadata = invocation.getArgument[TxnTransitMetadata](3)
+      transactionManager.appendTransactionToLog(
+        markerTxnMetadata.transactionalId(),
+        markerCoordinatorEpoch,
+        newTxnMetadata,
+        _ => (),
+        _ == Errors.COORDINATOR_NOT_AVAILABLE,
+        RequestLocal.noCaching
+      )
+      null
+    }).when(transactionMarkerChannelManager).addTxnMarkersToSend(
+      ArgumentMatchers.eq(coordinatorEpoch),
+      ArgumentMatchers.eq(TransactionResult.ABORT),
+      ArgumentMatchers.eq(txnMetadata),
+      any[TxnTransitMetadata]()
+    )
+
+    // WHEN1: Trigger fencing of the ongoing transaction.
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
+
+    // THEN1
+    assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), 
result)
+
+    val capturedTransitions = 
capturedTxnTransitMetadata.getAllValues.asScala.toList
+    val firstAbortTransition = capturedTransitions.head
+    assertEquals(TransactionState.PREPARE_ABORT, firstAbortTransition.txnState)
+    assertEquals(producerId, firstAbortTransition.producerId)
+    assertEquals(Short.MaxValue, firstAbortTransition.producerEpoch)
+    assertEquals(rotatedProducerId, firstAbortTransition.nextProducerId)
+    assertTrue(capturedTransitions.exists(_.txnState == 
TransactionState.COMPLETE_ABORT))
+
+    // Marker completion should rotate producer ID on COMPLETE_ABORT.
+    assertEquals(TransactionState.COMPLETE_ABORT, txnMetadata.state())
+    assertEquals(rotatedProducerId, txnMetadata.producerId())
+    assertEquals(rotatedEpoch, txnMetadata.producerEpoch)
+
+    // Client retries ABORT with old pid/epoch and should be treated as 
retryOnOverflow.
+    var clientCallbackInvoked = false
+    var clientErr: Errors = null
+    var clientReturnedPid: Long = -1L
+    var clientReturnedEpoch: Short = -1
+
+    def onClientEndTxn(error: Errors, newProducerId: Long, newProducerEpoch: 
Short): Unit = {
+      clientCallbackInvoked = true
+      clientErr = error
+      clientReturnedPid = newProducerId
+      clientReturnedEpoch = newProducerEpoch
+    }
+
+    // WHEN2 : The client tries to abort the transaction after the coordinator 
has already aborted it. 
+    coordinator.handleEndTransaction(
+      transactionalId,
+      producerId,
+      epochAtMaxBoundary,
+      TransactionResult.ABORT,
+      TV_2,
+      onClientEndTxn,
+      RequestLocal.noCaching
+    )
+
+    // THEN2 : It should be treated as a retry.
+    assertTrue(clientCallbackInvoked)
+    assertEquals(Errors.NONE, clientErr)
+    assertEquals(rotatedProducerId, clientReturnedPid)
+    assertEquals(rotatedEpoch, clientReturnedEpoch)
+    verify(transactionManager, times(2)).appendTransactionToLog(
+      ArgumentMatchers.eq(transactionalId),
+      ArgumentMatchers.eq(coordinatorEpoch),
+      any(),
+      any(),
+      any(),
+      any()
+    )
+    verify(pidGenerator, times(1)).generateProducerId()
+    verify(transactionMarkerChannelManager, times(1)).addTxnMarkersToSend(
+      ArgumentMatchers.eq(coordinatorEpoch),
+      ArgumentMatchers.eq(TransactionResult.ABORT),
+      ArgumentMatchers.eq(txnMetadata),
+      any[TxnTransitMetadata]()
+    )
+  }
+
+  @Test
+  def 
shouldHandleTimeoutAtEpochOverflowBoundaryCorrectlyAndRetryInitProducerIdTV2(): 
Unit = {
+    // 1. The transaction coordinator aborts the transaction due to a timeout 
at epoch 32766
+    //    (timeout -> prepare abort -> complete abort with producerId rotation)
+    // 2. The original client retries InitProducerId with the old 
producerId/epoch.
+
+    val epochAtMaxBoundary = (Short.MaxValue - 1).toShort // 32766
+    val now = time.milliseconds()
+
+    val rotatedProducerId = producerId + 1L
+    val rotatedEpoch = 0.toShort
+    when(pidGenerator.generateProducerId())
+      .thenReturn(rotatedProducerId)
+
+    // Create transaction metadata at the epoch boundary that would cause 
overflow IFF double-incremented
+    val txnMetadata = new TransactionMetadata(
+      transactionalId,
+      producerId,
+      RecordBatch.NO_PRODUCER_ID,
+      RecordBatch.NO_PRODUCER_ID,
+      epochAtMaxBoundary,
+      RecordBatch.NO_PRODUCER_EPOCH,
+      txnTimeoutMs,
+      TransactionState.ONGOING,
+      partitions,
+      now,
+      now,
+      TV_2
+    )
+    assertTrue(txnMetadata.isProducerEpochExhausted)
+
+    // Mock the transaction manager to return our test transaction as timed out
+    when(transactionManager.timedOutTransactions())
+      .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, 
producerId, epochAtMaxBoundary)))
+    
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
+      .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
+    when(transactionManager.transactionVersionLevel()).thenReturn(TV_2)
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
+      .thenReturn(true)
+
+    // Mock the append operation to simulate successful write and update the 
metadata
+    when(transactionManager.appendTransactionToLog(
+      ArgumentMatchers.eq(transactionalId),
+      ArgumentMatchers.eq(coordinatorEpoch),
+      capturedTxnTransitMetadata.capture(),
+      any[Errors => Unit](),
+      any(),
+      any())
+    ).thenAnswer(invocation => {
+      val transitMetadata = invocation.getArgument[TxnTransitMetadata](2)
+      val callback = invocation.getArgument[Errors => Unit](3)
+      txnMetadata.completeTransitionTo(transitMetadata)
+      callback.apply(Errors.NONE)
+    })
+
+    // Simulate marker write completion by appending COMPLETE_ABORT.
+    doAnswer(invocation => {
+      val markerCoordinatorEpoch = invocation.getArgument[Int](0)
+      val markerTxnMetadata = invocation.getArgument[TransactionMetadata](2)
+      val newTxnMetadata = invocation.getArgument[TxnTransitMetadata](3)
+      transactionManager.appendTransactionToLog(
+        markerTxnMetadata.transactionalId(),
+        markerCoordinatorEpoch,
+        newTxnMetadata,
+        _ => (),
+        _ == Errors.COORDINATOR_NOT_AVAILABLE,
+        RequestLocal.noCaching
+      )
+      null
+    }).when(transactionMarkerChannelManager).addTxnMarkersToSend(
+      ArgumentMatchers.eq(coordinatorEpoch),
+      ArgumentMatchers.eq(TransactionResult.ABORT),
+      ArgumentMatchers.eq(txnMetadata),
+      any[TxnTransitMetadata]()
+    )
+
+    // Track the actual behavior
+    var callbackInvoked = false
+
+    def checkOnEndTransactionComplete(txnIdAndPidEpoch: 
TransactionalIdAndProducerIdEpoch)
+                                     (error: Errors, newProducerId: Long, 
newProducerEpoch: Short): Unit = {
+      callbackInvoked = true
+
+      // TransitMetadata should be rotated.
+      assertEquals(Errors.NONE, error, "Expected no errors in the callback")
+      assertEquals(rotatedProducerId, newProducerId, "Expected producer ID 
should be rotated because of epoch exhausted.")
+      assertEquals(rotatedEpoch, newProducerEpoch, "Expected producer epoch to 
be 0 as a result of ProducerId rotation.")
+
+      // The local transaction state is not updated yet.
+      assertEquals(TransactionState.PREPARE_ABORT, txnMetadata.state())
+      assertEquals(producerId, txnMetadata.producerId(), "Expected producer ID 
should not be rotated because txnMarker is not written yet.")
+      assertEquals(Short.MaxValue, txnMetadata.producerEpoch,
+        s"Expected transaction metadata producer epoch to be ${Short.MaxValue} 
" +
+          s"after timeout handling, but was ${txnMetadata.producerEpoch}"
+      )
+    }
+
+    // Execute the timeout abort process
+    coordinator.abortTimedOutTransactions(checkOnEndTransactionComplete)
+
+    // the transaction completion callback was invoked.
+    assertTrue(callbackInvoked, "Callback should have been invoked")
+
+    val capturedTransitions = 
capturedTxnTransitMetadata.getAllValues.asScala.toList
+    val prepareAbortTransition = capturedTransitions.head
+    assertEquals(TransactionState.PREPARE_ABORT, 
prepareAbortTransition.txnState)
+    assertEquals(rotatedProducerId, prepareAbortTransition.nextProducerId)
+    assertTrue(capturedTransitions.exists(_.txnState == 
TransactionState.COMPLETE_ABORT))
+
+    // Verify the transaction metadata was correctly updated to the final 
epoch as a result of sendMarkerTxn.
+    assertEquals(TransactionState.COMPLETE_ABORT, txnMetadata.state())
+    assertEquals(rotatedProducerId, txnMetadata.producerId())
+    assertEquals(rotatedEpoch, txnMetadata.producerEpoch)
+
+    // Verify the basic flow was attempted
+    verify(transactionManager).timedOutTransactions()
+    verify(transactionManager, times(2)).appendTransactionToLog(
+      ArgumentMatchers.eq(transactionalId),
+      ArgumentMatchers.eq(coordinatorEpoch),
+      any(),
+      any(),
+      any(),
+      any()
+    )
+    verify(transactionManager, 
atLeast(1)).getTransactionState(ArgumentMatchers.eq(transactionalId))
+    verify(pidGenerator, times(1)).generateProducerId()
+    verify(transactionMarkerChannelManager, times(1)).addTxnMarkersToSend(
+      ArgumentMatchers.eq(coordinatorEpoch),
+      ArgumentMatchers.eq(TransactionResult.ABORT),
+      ArgumentMatchers.eq(txnMetadata),
+      any[TxnTransitMetadata]()
+    )
+
+    // WHEN: The original client retries InitProducerId after the coordinator 
has already
+    // completed the timeout-driven abort and rotated the producerId.
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, epochAtMaxBoundary)),
+      initProducerIdMockCallback
+    )
+    
+    // THEN
+    val expectedResult = InitProducerIdResult(rotatedProducerId, rotatedEpoch, 
Errors.NONE) 
+    assertEquals(expectedResult, result)
   }
 
   @Test

Reply via email to