This is an automated email from the ASF dual-hosted git repository.

jolshan pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 3a3159b01e1 KAFKA-18953: [1/N] Add broker side handling for 2 PC 
(KIP-939) (#19193)
3a3159b01e1 is described below

commit 3a3159b01e10b7d3ac1733feb70e013fdec93b79
Author: Ritika Reddy <[email protected]>
AuthorDate: Wed Mar 19 09:22:00 2025 -0700

    KAFKA-18953: [1/N] Add broker side handling for 2 PC (KIP-939) (#19193)
    
    This patch adds logic to enable and handle two phase commit (2PC)
    transactions following KIP-939.
    The changes made are as follows:
    1) Add a new broker config called
    **transaction.two.phase.commit.enable** which is set to false by default
    2) Add new flags **enableTwoPCFlag** and **keepPreparedTxn** to
    handleInitProducerId
    3) Return an error if keepPreparedTxn is set to true (for now)
    
    Reviewers: Artem Livshits <[email protected]>, Justine Olshan
    <[email protected]>
---
 .../common/requests/InitProducerIdRequest.java     |   7 +
 .../transaction/TransactionCoordinator.scala       |  19 +-
 .../transaction/TransactionMetadata.scala          |   6 +
 .../transaction/TransactionStateManager.scala      |  21 +-
 core/src/main/scala/kafka/server/KafkaApis.scala   |  15 +-
 .../TransactionCoordinatorConcurrencyTest.scala    |  12 +-
 .../transaction/TransactionCoordinatorTest.scala   | 301 ++++++++++++++++-----
 .../transaction/TransactionStateManagerTest.scala  |  15 +-
 .../scala/unit/kafka/server/KafkaApisTest.scala    |   4 +
 .../transaction/TransactionStateManagerConfig.java |  15 +-
 .../TransactionStateManagerConfigTest.java         |   8 +-
 11 files changed, 338 insertions(+), 85 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/requests/InitProducerIdRequest.java
 
b/clients/src/main/java/org/apache/kafka/common/requests/InitProducerIdRequest.java
index 9d92f0e5351..e02f6757a48 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/requests/InitProducerIdRequest.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/requests/InitProducerIdRequest.java
@@ -77,4 +77,11 @@ public class InitProducerIdRequest extends AbstractRequest {
         return data;
     }
 
+    public boolean enable2Pc() {
+        return data.enable2Pc();
+    }
+
+    public boolean keepPreparedTxn() {
+        return data.keepPreparedTxn();
+    }
 }
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index 0b9f4b5ab64..697d6a30144 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -55,6 +55,7 @@ object TransactionCoordinator {
       config.transactionLogConfig.transactionTopicMinISR,
       
config.transactionStateManagerConfig.transactionAbortTimedOutTransactionCleanupIntervalMs,
       
config.transactionStateManagerConfig.transactionRemoveExpiredTransactionalIdCleanupIntervalMs,
+      config.transactionStateManagerConfig.transaction2PCEnabled,
       config.requestTimeoutMs)
 
     val txnStateManager = new TransactionStateManager(config.brokerId, 
scheduler, replicaManager, metadataCache, txnConfig,
@@ -109,6 +110,8 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
 
   def handleInitProducerId(transactionalId: String,
                            transactionTimeoutMs: Int,
+                           enableTwoPCFlag: Boolean,
+                           keepPreparedTxn: Boolean,
                            expectedProducerIdAndEpoch: 
Option[ProducerIdAndEpoch],
                            responseCallback: InitProducerIdCallback,
                            requestLocal: RequestLocal = 
RequestLocal.noCaching): Unit = {
@@ -125,10 +128,20 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
       // if transactional id is empty then return error as invalid request. 
This is
       // to make TransactionCoordinator's behavior consistent with producer 
client
       responseCallback(initTransactionError(Errors.INVALID_REQUEST))
-    } else if (!txnManager.validateTransactionTimeoutMs(transactionTimeoutMs)) 
{
+    } else if (enableTwoPCFlag && !txnManager.isTransaction2pcEnabled()) {
+      // if the request is to enable two-phase commit but the broker 2PC 
config is set to false,
+      // 2PC functionality is disabled, clients that attempt to use this 
functionality
+      // would receive an authorization failed error.
+      
responseCallback(initTransactionError(Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED))
+    } else if (keepPreparedTxn) {
+      // if the request is to keep the prepared transaction, then return an
+      // unsupported version error since the feature hasn't been implemented 
yet.
+      responseCallback(initTransactionError(Errors.UNSUPPORTED_VERSION))
+    } else if (!txnManager.validateTransactionTimeoutMs(enableTwoPCFlag, 
transactionTimeoutMs)) {
       // check transactionTimeoutMs is not larger than the broker configured 
maximum allowed value
       
responseCallback(initTransactionError(Errors.INVALID_TRANSACTION_TIMEOUT))
     } else {
+      val resolvedTxnTimeoutMs = if (enableTwoPCFlag) Int.MaxValue else 
transactionTimeoutMs
       val coordinatorEpochAndMetadata = 
txnManager.getTransactionState(transactionalId).flatMap {
         case None =>
           try {
@@ -138,7 +151,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
               nextProducerId = RecordBatch.NO_PRODUCER_ID,
               producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
               lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
-              txnTimeoutMs = transactionTimeoutMs,
+              txnTimeoutMs = resolvedTxnTimeoutMs,
               state = Empty,
               topicPartitions = collection.mutable.Set.empty[TopicPartition],
               txnLastUpdateTimestamp = time.milliseconds(),
@@ -157,7 +170,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
           val txnMetadata = existingEpochAndMetadata.transactionMetadata
 
           txnMetadata.inLock {
-            prepareInitProducerIdTransit(transactionalId, 
transactionTimeoutMs, coordinatorEpoch, txnMetadata,
+            prepareInitProducerIdTransit(transactionalId, 
resolvedTxnTimeoutMs, coordinatorEpoch, txnMetadata,
               expectedProducerIdAndEpoch)
           }
       }
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
index 99ef4711171..aff68749513 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
@@ -419,6 +419,12 @@ private[transaction] class TransactionMetadata(val 
transactionalId: String,
    */
   def isProducerEpochExhausted: Boolean = 
TransactionMetadata.isEpochExhausted(producerEpoch)
 
+  /**
+   * Check if this is a distributed two phase commit transaction.
+   * Such transactions have no timeout (identified by maximum value for 
timeout).
+   */
+  def isDistributedTwoPhaseCommitTxn: Boolean = txnTimeoutMs == Int.MaxValue
+
   private def hasPendingTransaction: Boolean = {
     state match {
       case Ongoing | PrepareAbort | PrepareCommit => true
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index b61fd4622f8..0c6391af6db 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -108,6 +108,8 @@ class TransactionStateManager(brokerId: Int,
     version
   }
 
+  private[transaction] def isTransaction2pcEnabled(): Boolean = { 
config.transaction2PCEnable }
+
   // visible for testing only
   private[transaction] def addLoadingPartition(partitionId: Int, 
coordinatorEpoch: Int): Unit = {
     val partitionAndLeaderEpoch = 
TransactionPartitionAndLeaderEpoch(partitionId, coordinatorEpoch)
@@ -130,7 +132,9 @@ class TransactionStateManager(brokerId: Int,
           } else {
             txnMetadata.state match {
               case Ongoing =>
-                txnMetadata.txnStartTimestamp + txnMetadata.txnTimeoutMs < now
+                // Do not apply timeout to distributed two phase commit 
transactions.
+                (!txnMetadata.isDistributedTwoPhaseCommitTxn) &&
+                (txnMetadata.txnStartTimestamp + txnMetadata.txnTimeoutMs < 
now)
               case _ => false
             }
           }
@@ -396,10 +400,18 @@ class TransactionStateManager(brokerId: Int,
   }
 
   /**
-   * Validate the given transaction timeout value
+   * Validates the provided transaction timeout.
+   * - If 2PC is enabled, the timeout is always valid (set to Int.MAX by 
default).
+   * - Otherwise, the timeout must be a positive value and not exceed the
+   *   configured transaction max timeout.
+   *
+   * @param enableTwoPC       Whether Two-Phase Commit (2PC) is enabled.
+   * @param txnTimeoutMs      The requested transaction timeout in 
milliseconds.
+   * @return `true` if the timeout is valid, `false` otherwise.
    */
-  def validateTransactionTimeoutMs(txnTimeoutMs: Int): Boolean =
-    txnTimeoutMs <= config.transactionMaxTimeoutMs && txnTimeoutMs > 0
+  def validateTransactionTimeoutMs(enableTwoPC: Boolean, txnTimeoutMs: Int): 
Boolean = {
+    enableTwoPC || (txnTimeoutMs <= config.transactionMaxTimeoutMs && 
txnTimeoutMs > 0)
+  }
 
   def transactionTopicConfigs: Properties = {
     val props = new Properties
@@ -826,6 +838,7 @@ private[transaction] case class 
TransactionConfig(transactionalIdExpirationMs: I
                                                   
transactionLogMinInsyncReplicas: Int = 
TransactionLogConfig.TRANSACTIONS_TOPIC_MIN_ISR_DEFAULT,
                                                   
abortTimedOutTransactionsIntervalMs: Int = 
TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT,
                                                   
removeExpiredTransactionalIdsIntervalMs: Int = 
TransactionStateManagerConfig.TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONAL_ID_CLEANUP_INTERVAL_MS_DEFAULT,
+                                                  transaction2PCEnable: 
Boolean = TransactionStateManagerConfig.TRANSACTIONS_2PC_ENABLED_DEFAULT,
                                                   requestTimeoutMs: Int = 
ServerConfigs.REQUEST_TIMEOUT_MS_DEFAULT)
 
 case class TransactionalIdAndProducerIdEpoch(transactionalId: String, 
producerId: Long, producerEpoch: Short) {
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala 
b/core/src/main/scala/kafka/server/KafkaApis.scala
index d48947b2017..adf5c5a6e53 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -1554,8 +1554,19 @@ class KafkaApis(val requestChannel: RequestChannel,
     }
 
     producerIdAndEpoch match {
-      case Right(producerIdAndEpoch) => 
txnCoordinator.handleInitProducerId(transactionalId, 
initProducerIdRequest.data.transactionTimeoutMs,
-        producerIdAndEpoch, sendResponseCallback, requestLocal)
+      case Right(producerIdAndEpoch) =>
+        val enableTwoPC = initProducerIdRequest.enable2Pc()
+        val keepPreparedTxn = initProducerIdRequest.keepPreparedTxn()
+
+        txnCoordinator.handleInitProducerId(
+            transactionalId,
+            initProducerIdRequest.data.transactionTimeoutMs,
+            enableTwoPC,
+            keepPreparedTxn,
+            producerIdAndEpoch,
+            sendResponseCallback,
+            requestLocal
+        )
       case Left(error) => 
requestHelper.sendErrorResponseMaybeThrottle(request, error.exception)
     }
   }
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
index bf69c29dfc2..cfe89a41649 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -523,10 +523,18 @@ class TransactionCoordinatorConcurrencyTest extends 
AbstractCoordinatorConcurren
 
   class InitProducerIdOperation(val producerIdAndEpoch: 
Option[ProducerIdAndEpoch] = None) extends TxnOperation[InitProducerIdResult] {
     override def run(txn: Transaction): Unit = {
-      transactionCoordinator.handleInitProducerId(txn.transactionalId, 60000, 
producerIdAndEpoch, resultCallback,
-        RequestLocal.withThreadConfinedCaching)
+      transactionCoordinator.handleInitProducerId(
+        txn.transactionalId,
+        60000,
+        enableTwoPCFlag = false,
+        keepPreparedTxn = false,
+        producerIdAndEpoch,
+        resultCallback,
+        RequestLocal.withThreadConfinedCaching
+      )
       replicaManager.tryCompleteActions()
     }
+
     override def awaitAndVerify(txn: Transaction): Unit = {
       val initPidResult = result.getOrElse(throw new 
IllegalStateException("InitProducerId has not completed"))
       assertEquals(Errors.NONE, initPidResult.error)
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index 7cfe5acb728..94ccd6dc03d 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -29,8 +29,8 @@ import org.apache.kafka.server.util.MockScheduler
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
 import org.junit.jupiter.params.ParameterizedTest
-import org.junit.jupiter.params.provider.ValueSource
-import org.mockito.ArgumentMatchers.{any, anyInt}
+import org.junit.jupiter.params.provider.{CsvSource, ValueSource}
+import org.mockito.ArgumentMatchers.{any, anyBoolean, anyInt}
 import org.mockito.Mockito._
 import org.mockito.{ArgumentCaptor, ArgumentMatchers}
 
@@ -82,7 +82,7 @@ class TransactionCoordinatorTest {
 
   private def initPidGenericMocks(transactionalId: String): Unit = {
     mockPidGenerator()
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
   }
 
@@ -90,9 +90,29 @@ class TransactionCoordinatorTest {
   def shouldReturnInvalidRequestWhenTransactionalIdIsEmpty(): Unit = {
     mockPidGenerator()
 
-    coordinator.handleInitProducerId("", txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId("", txnTimeoutMs, enableTwoPCFlag = false,
+      keepPreparedTxn = false, None, initProducerIdMockCallback)
     assertEquals(InitProducerIdResult(-1L, -1, Errors.INVALID_REQUEST), result)
-    coordinator.handleInitProducerId("", txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId("", txnTimeoutMs, enableTwoPCFlag = false,
+      keepPreparedTxn = false, None, initProducerIdMockCallback)
+    assertEquals(InitProducerIdResult(-1L, -1, Errors.INVALID_REQUEST), result)
+  }
+
+  @Test
+  def shouldReturnInvalidRequestWhenKeepPreparedIsTrue(): Unit = {
+    mockPidGenerator()
+
+    coordinator.handleInitProducerId("", txnTimeoutMs, enableTwoPCFlag = false,
+      keepPreparedTxn = true, None, initProducerIdMockCallback)
+    assertEquals(InitProducerIdResult(-1L, -1, Errors.INVALID_REQUEST), result)
+  }
+
+  @Test
+  def shouldReturnInvalidRequestWhen2PCEnabledButBroker2PCConfigFalse(): Unit 
= {
+    mockPidGenerator()
+
+    coordinator.handleInitProducerId("", txnTimeoutMs, enableTwoPCFlag = true,
+      keepPreparedTxn = false, None, initProducerIdMockCallback)
     assertEquals(InitProducerIdResult(-1L, -1, Errors.INVALID_REQUEST), result)
   }
 
@@ -100,9 +120,11 @@ class TransactionCoordinatorTest {
   def shouldAcceptInitPidAndReturnNextPidWhenTransactionalIdIsNull(): Unit = {
     mockPidGenerator()
 
-    coordinator.handleInitProducerId(null, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(null, txnTimeoutMs, enableTwoPCFlag = 
false,
+      keepPreparedTxn = false, None, initProducerIdMockCallback)
     assertEquals(InitProducerIdResult(0L, 0, Errors.NONE), result)
-    coordinator.handleInitProducerId(null, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(null, txnTimeoutMs, enableTwoPCFlag = 
false,
+      keepPreparedTxn = false, None, initProducerIdMockCallback)
     assertEquals(InitProducerIdResult(1L, 0, Errors.NONE), result)
   }
 
@@ -127,7 +149,14 @@ class TransactionCoordinatorTest {
       any())
     ).thenAnswer(_ => capturedErrorsCallback.getValue.apply(Errors.NONE))
 
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(nextPid - 1, 0, Errors.NONE), result)
   }
 
@@ -152,8 +181,14 @@ class TransactionCoordinatorTest {
       any())
     ).thenAnswer(_ => capturedErrorsCallback.getValue.apply(Errors.NONE))
 
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new 
ProducerIdAndEpoch(producerId, producerEpoch)),
-      initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, producerEpoch)),
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(nextPid - 1, 0, Errors.NONE), result)
   }
 
@@ -176,7 +211,14 @@ class TransactionCoordinatorTest {
       any()
     )).thenAnswer(_ => capturedErrorsCallback.getValue.apply(Errors.NONE))
 
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
     assertNotEquals(producerId, result.producerId)
     assertEquals(0, result.producerEpoch)
     assertEquals(Errors.NONE, result.error)
@@ -218,23 +260,37 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRespondWithNotCoordinatorOnInitPidWhenNotCoordinator(): Unit = {
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
       .thenReturn(Left(Errors.NOT_COORDINATOR))
 
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(-1, -1, Errors.NOT_COORDINATOR), result)
   }
 
   @Test
   def 
shouldRespondWithCoordinatorLoadInProgressOnInitPidWhenCoordinatorLoading(): 
Unit = {
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
       .thenReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS))
 
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(-1, -1, 
Errors.COORDINATOR_LOAD_IN_PROGRESS), result)
   }
 
@@ -471,7 +527,7 @@ class TransactionCoordinatorTest {
 
   @ParameterizedTest
   @ValueSource(shorts = Array(0, 2))
-  def 
shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(transactionVersion:
 Short): Unit = {
+  def 
shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDoesntMatchMapped(transactionVersion:
 Short): Unit = {
     val clientTransactionVersion = 
TransactionVersion.fromFeatureLevel(transactionVersion)
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
       .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
@@ -967,12 +1023,16 @@ class TransactionCoordinatorTest {
     
validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareAbort)
   }
 
-  @Test
-  def 
shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(): 
Unit = {
+  @ParameterizedTest(name = "enableTwoPCFlag={0}, keepPreparedTxn={1}")
+  @CsvSource(Array("false, false"))
+  def 
shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(
+    enableTwoPCFlag: Boolean,
+    keepPreparedTxn:  Boolean
+  ): Unit = {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, 
producerId, RecordBatch.NO_PRODUCER_ID,
       producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, 
partitions, time.milliseconds(), time.milliseconds(), TV_0)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
 
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
@@ -991,10 +1051,17 @@ class TransactionCoordinatorTest {
       any())
     ).thenAnswer(_ => capturedErrorsCallback.getValue.apply(Errors.NONE))
 
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag,
+      keepPreparedTxn,
+      None,
+      initProducerIdMockCallback
+    )
 
     assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), 
result)
-    verify(transactionManager).validateTransactionTimeoutMs(anyInt())
+    verify(transactionManager).validateTransactionTimeoutMs(anyBoolean(), 
anyInt())
     verify(transactionManager, 
times(3)).getTransactionState(ArgumentMatchers.eq(transactionalId))
     verify(transactionManager).appendTransactionToLog(
       ArgumentMatchers.eq(transactionalId),
@@ -1010,7 +1077,7 @@ class TransactionCoordinatorTest {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, 
producerId, RecordBatch.NO_PRODUCER_ID,
       producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, 
partitions, time.milliseconds(), time.milliseconds(), TV_0)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
 
     val bumpedTxnMetadata = new TransactionMetadata(transactionalId, 
producerId, producerId, RecordBatch.NO_PRODUCER_ID,
@@ -1021,11 +1088,18 @@ class TransactionCoordinatorTest {
       .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
bumpedTxnMetadata))))
 
     when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
 
     assertEquals(InitProducerIdResult(-1, -1, Errors.PRODUCER_FENCED), result)
 
-    verify(transactionManager).validateTransactionTimeoutMs(anyInt())
+    verify(transactionManager).validateTransactionTimeoutMs(anyBoolean(), 
anyInt())
     verify(transactionManager, 
times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId))
   }
 
@@ -1034,7 +1108,7 @@ class TransactionCoordinatorTest {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, 
producerId, RecordBatch.NO_PRODUCER_ID,
       producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, 
partitions, time.milliseconds(), time.milliseconds(), TV_0)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
 
     
when(transactionManager.putTransactionStateIfNotExists(any[TransactionMetadata]()))
@@ -1070,26 +1144,47 @@ class TransactionCoordinatorTest {
     })
 
     // For the first two calls, verify that the epoch was only bumped once
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(-1, -1, Errors.NOT_ENOUGH_REPLICAS), 
result)
 
     assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch)
     assertTrue(txnMetadata.hasFailedEpochFence)
 
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(-1, -1, Errors.NOT_ENOUGH_REPLICAS), 
result)
 
     assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch)
     assertTrue(txnMetadata.hasFailedEpochFence)
 
     // For the last, successful call, verify that the epoch was not bumped 
further
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), 
result)
 
     assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch)
     assertFalse(txnMetadata.hasFailedEpochFence)
 
-    verify(transactionManager, times(3)).validateTransactionTimeoutMs(anyInt())
+    verify(transactionManager, 
times(3)).validateTransactionTimeoutMs(anyBoolean(), anyInt())
     verify(transactionManager, 
times(9)).getTransactionState(ArgumentMatchers.eq(transactionalId))
     verify(transactionManager, times(3)).appendTransactionToLog(
       ArgumentMatchers.eq(transactionalId),
@@ -1106,7 +1201,7 @@ class TransactionCoordinatorTest {
       (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, 
txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), 
TV_0)
     assertTrue(txnMetadata.isProducerEpochExhausted)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
 
     val postFenceTxnMetadata = new TransactionMetadata(transactionalId, 
producerId, producerId, RecordBatch.NO_PRODUCER_ID,
@@ -1139,11 +1234,18 @@ class TransactionCoordinatorTest {
       any())
     ).thenAnswer(_ => capturedErrorsCallback.getValue.apply(Errors.NONE))
 
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
     assertEquals(Short.MaxValue, txnMetadata.producerEpoch)
 
     assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), 
result)
-    verify(transactionManager).validateTransactionTimeoutMs(anyInt())
+    verify(transactionManager).validateTransactionTimeoutMs(anyBoolean(), 
anyInt())
     verify(transactionManager, 
times(3)).getTransactionState(ArgumentMatchers.eq(transactionalId))
     verify(transactionManager).appendTransactionToLog(
       ArgumentMatchers.eq(transactionalId),
@@ -1172,14 +1274,20 @@ class TransactionCoordinatorTest {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, 
RecordBatch.NO_PRODUCER_ID,
       RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, 
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, 
time.milliseconds, time.milliseconds, TV_0)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
       .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
 
     // Simulate producer trying to continue after new producer has already 
been initialized
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new 
ProducerIdAndEpoch(producerId, producerEpoch)),
-      initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, producerEpoch)),
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, 
RecordBatch.NO_PRODUCER_EPOCH, Errors.PRODUCER_FENCED), result)
   }
 
@@ -1189,14 +1297,20 @@ class TransactionCoordinatorTest {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, 
producerId,
       RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, (producerEpoch - 
1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, 
time.milliseconds, TV_0)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
       .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
 
     // Simulate producer trying to continue after new producer has already 
been initialized
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new 
ProducerIdAndEpoch(producerId, producerEpoch)),
-      initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, producerEpoch)),
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, 
RecordBatch.NO_PRODUCER_EPOCH, Errors.PRODUCER_FENCED), result)
   }
 
@@ -1207,7 +1321,7 @@ class TransactionCoordinatorTest {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, 
producerId,
       RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, 
time.milliseconds, time.milliseconds, TV_0)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
       .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
@@ -1225,13 +1339,25 @@ class TransactionCoordinatorTest {
     })
 
     // Re-initialization should succeed and bump the producer epoch
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new 
ProducerIdAndEpoch(producerId, 10)),
-      initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, 10)),
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(producerId, 11, Errors.NONE), result)
 
     // Simulate producer retrying after successfully re-initializing but 
failing to receive the response
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new 
ProducerIdAndEpoch(producerId, 10)),
-      initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, 10)),
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(producerId, 11, Errors.NONE), result)
   }
 
@@ -1242,7 +1368,7 @@ class TransactionCoordinatorTest {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, 
producerId,
       RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, 
time.milliseconds, time.milliseconds, TV_0)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
       .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
@@ -1263,12 +1389,25 @@ class TransactionCoordinatorTest {
     })
 
     // With producer epoch at 10, new producer calls InitProducerId and should 
get epoch 11
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      None,
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(producerId, 11, Errors.NONE), result)
 
     // Simulate old producer trying to continue from epoch 10
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new 
ProducerIdAndEpoch(producerId, 10)),
-      initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, 10)),
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, 
RecordBatch.NO_PRODUCER_EPOCH, Errors.PRODUCER_FENCED), result)
   }
 
@@ -1281,7 +1420,7 @@ class TransactionCoordinatorTest {
     when(pidGenerator.generateProducerId())
       .thenReturn(producerId + 1)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
       .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
@@ -1303,13 +1442,25 @@ class TransactionCoordinatorTest {
     })
 
     // Bump epoch and cause producer ID to be rotated
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new 
ProducerIdAndEpoch(producerId,
-      (Short.MaxValue - 1).toShort)), initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, (Short.MaxValue - 1).toShort)),
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(producerId + 1, 0, Errors.NONE), result)
 
     // Simulate producer retrying old request after producer bump
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new 
ProducerIdAndEpoch(producerId,
-      (Short.MaxValue - 1).toShort)), initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, (Short.MaxValue - 1).toShort)),
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(producerId + 1, 0, Errors.NONE), result)
   }
 
@@ -1322,7 +1473,7 @@ class TransactionCoordinatorTest {
     when(pidGenerator.generateProducerId())
       .thenReturn(producerId + 1)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
       .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
@@ -1344,13 +1495,25 @@ class TransactionCoordinatorTest {
     })
 
     // Bump epoch and cause producer ID to be rotated
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new 
ProducerIdAndEpoch(producerId,
-      (Short.MaxValue - 1).toShort)), initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, (Short.MaxValue - 1).toShort)),
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(producerId + 1, 0, Errors.NONE), result)
 
     // Validate that producer with old producer ID and stale epoch is fenced
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new 
ProducerIdAndEpoch(producerId,
-      (Short.MaxValue - 2).toShort)), initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, (Short.MaxValue - 2).toShort)),
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, 
RecordBatch.NO_PRODUCER_EPOCH, Errors.PRODUCER_FENCED), result)
   }
 
@@ -1500,16 +1663,22 @@ class TransactionCoordinatorTest {
       RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, 
time.milliseconds(), time.milliseconds(), TV_0)
     txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, 
RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
       .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
txnMetadata))))
 
-    coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new 
ProducerIdAndEpoch(producerId, 10)),
-      initProducerIdMockCallback)
+    coordinator.handleInitProducerId(
+      transactionalId,
+      txnTimeoutMs,
+      enableTwoPCFlag = false,
+      keepPreparedTxn = false,
+      Some(new ProducerIdAndEpoch(producerId, 10)),
+      initProducerIdMockCallback
+    )
     assertEquals(InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, 
RecordBatch.NO_PRODUCER_EPOCH, Errors.CONCURRENT_TRANSACTIONS), result)
 
-    verify(transactionManager).validateTransactionTimeoutMs(anyInt())
+    verify(transactionManager).validateTransactionTimeoutMs(anyBoolean(), 
anyInt())
     
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
   }
 
@@ -1575,7 +1744,7 @@ class TransactionCoordinatorTest {
   }
 
   private def 
validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(state: 
TransactionState): Unit = {
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
 
     // Since the clientTransactionVersion doesn't matter, use 2 since the 
states are PrepareCommit and PrepareAbort.
@@ -1584,7 +1753,8 @@ class TransactionCoordinatorTest {
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
       .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, 
metadata))))
 
-    coordinator.handleInitProducerId(transactionalId, 10, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(transactionalId, 10, enableTwoPCFlag = 
false,
+      keepPreparedTxn = false, None, initProducerIdMockCallback)
 
     assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), 
result)
   }
@@ -1594,7 +1764,7 @@ class TransactionCoordinatorTest {
     when(pidGenerator.generateProducerId())
       .thenReturn(producerId)
 
-    when(transactionManager.validateTransactionTimeoutMs(anyInt()))
+    when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), 
anyInt()))
       .thenReturn(true)
 
     val metadata = new TransactionMetadata(transactionalId, producerId, 
producerId, RecordBatch.NO_PRODUCER_EPOCH,
@@ -1616,7 +1786,8 @@ class TransactionCoordinatorTest {
     })
 
     val newTxnTimeoutMs = 10
-    coordinator.handleInitProducerId(transactionalId, newTxnTimeoutMs, None, 
initProducerIdMockCallback)
+    coordinator.handleInitProducerId(transactionalId, newTxnTimeoutMs, 
enableTwoPCFlag = false,
+      keepPreparedTxn = false, None, initProducerIdMockCallback)
 
     assertEquals(InitProducerIdResult(producerId, (producerEpoch + 1).toShort, 
Errors.NONE), result)
     assertEquals(newTxnTimeoutMs, metadata.txnTimeoutMs)
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index eb45273945d..5bca8653347 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -104,11 +104,16 @@ class TransactionStateManagerTest {
 
   @Test
   def testValidateTransactionTimeout(): Unit = {
-    assertTrue(transactionManager.validateTransactionTimeoutMs(1))
-    assertFalse(transactionManager.validateTransactionTimeoutMs(-1))
-    assertFalse(transactionManager.validateTransactionTimeoutMs(0))
-    
assertTrue(transactionManager.validateTransactionTimeoutMs(txnConfig.transactionMaxTimeoutMs))
-    
assertFalse(transactionManager.validateTransactionTimeoutMs(txnConfig.transactionMaxTimeoutMs
 + 1))
+    assertTrue(transactionManager.validateTransactionTimeoutMs(enableTwoPC = 
false, 1))
+    assertFalse(transactionManager.validateTransactionTimeoutMs(enableTwoPC = 
false, -1))
+    assertFalse(transactionManager.validateTransactionTimeoutMs(enableTwoPC = 
false, 0))
+    assertTrue(transactionManager.validateTransactionTimeoutMs(enableTwoPC = 
false, txnConfig.transactionMaxTimeoutMs))
+    assertFalse(transactionManager.validateTransactionTimeoutMs(enableTwoPC = 
false, txnConfig.transactionMaxTimeoutMs + 1))
+    // KIP-939 Always return true when two phase commit is enabled on 
transaction. Two phase commit is enabled in case of
+    // externally coordinated distributed transactions.
+    assertTrue(transactionManager.validateTransactionTimeoutMs(enableTwoPC = 
true, -1))
+    assertTrue(transactionManager.validateTransactionTimeoutMs(enableTwoPC = 
true, 10))
+    assertTrue(transactionManager.validateTransactionTimeoutMs(enableTwoPC = 
true, txnConfig.transactionMaxTimeoutMs + 1))
   }
 
   @Test
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala 
b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index 4ed3ee5fa77..70a64f47160 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -1582,6 +1582,8 @@ class KafkaApisTest extends Logging {
         new InitProducerIdRequestData()
           .setTransactionalId(transactionalId)
           .setTransactionTimeoutMs(txnTimeoutMs)
+          .setEnable2Pc(false)
+          .setKeepPreparedTxn(false)
           .setProducerId(producerId)
           .setProducerEpoch(epoch)
       ).build(version.toShort)
@@ -1597,6 +1599,8 @@ class KafkaApisTest extends Logging {
       when(txnCoordinator.handleInitProducerId(
         ArgumentMatchers.eq(transactionalId),
         ArgumentMatchers.eq(txnTimeoutMs),
+        ArgumentMatchers.eq(false),
+        ArgumentMatchers.eq(false),
         ArgumentMatchers.eq(expectedProducerIdAndEpoch),
         responseCallback.capture(),
         ArgumentMatchers.eq(requestLocal)
diff --git 
a/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TransactionStateManagerConfig.java
 
b/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TransactionStateManagerConfig.java
index 46dfb46d129..f75f496edcf 100644
--- 
a/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TransactionStateManagerConfig.java
+++ 
b/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TransactionStateManagerConfig.java
@@ -24,6 +24,7 @@ import java.util.concurrent.TimeUnit;
 import static org.apache.kafka.common.config.ConfigDef.Importance.HIGH;
 import static org.apache.kafka.common.config.ConfigDef.Importance.LOW;
 import static org.apache.kafka.common.config.ConfigDef.Range.atLeast;
+import static org.apache.kafka.common.config.ConfigDef.Type.BOOLEAN;
 import static org.apache.kafka.common.config.ConfigDef.Type.INT;
 
 public final class TransactionStateManagerConfig {
@@ -47,25 +48,33 @@ public final class TransactionStateManagerConfig {
     public static final int 
TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONAL_ID_CLEANUP_INTERVAL_MS_DEFAULT = 
(int) TimeUnit.HOURS.toMillis(1);
     public static final String 
TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONS_INTERVAL_MS_DOC = "The interval at 
which to remove transactions that have expired due to 
<code>transactional.id.expiration.ms</code> passing";
 
+    public static final String TRANSACTIONS_2PC_ENABLED_CONFIG = 
"transaction.two.phase.commit.enable";
+    public static final boolean TRANSACTIONS_2PC_ENABLED_DEFAULT = false;
+    public static final String TRANSACTIONS_2PC_ENABLED_DOC = "Allow 
participation in Two-Phase Commit (2PC) transactions with an external 
transaction coordinator";
+
     public static final String METRICS_GROUP = 
"transaction-coordinator-metrics";
     public static final String LOAD_TIME_SENSOR = 
"TransactionsPartitionLoadTime";
     public static final ConfigDef CONFIG_DEF =  new ConfigDef()
             .define(TRANSACTIONAL_ID_EXPIRATION_MS_CONFIG, INT, 
TRANSACTIONAL_ID_EXPIRATION_MS_DEFAULT, atLeast(1), HIGH, 
TRANSACTIONAL_ID_EXPIRATION_MS_DOC)
             .define(TRANSACTIONS_MAX_TIMEOUT_MS_CONFIG, INT, 
TRANSACTIONS_MAX_TIMEOUT_MS_DEFAULT, atLeast(1), HIGH, 
TRANSACTIONS_MAX_TIMEOUT_MS_DOC)
             
.define(TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_CONFIG, 
INT, TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, 
atLeast(1), LOW, TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTIONS_INTERVAL_MS_DOC)
-            
.define(TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONAL_ID_CLEANUP_INTERVAL_MS_CONFIG,
 INT, TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONAL_ID_CLEANUP_INTERVAL_MS_DEFAULT, 
atLeast(1), LOW, TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONS_INTERVAL_MS_DOC);
+            
.define(TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONAL_ID_CLEANUP_INTERVAL_MS_CONFIG,
 INT, TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONAL_ID_CLEANUP_INTERVAL_MS_DEFAULT, 
atLeast(1), LOW, TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONS_INTERVAL_MS_DOC)
+            .define(TRANSACTIONS_2PC_ENABLED_CONFIG, BOOLEAN, 
TRANSACTIONS_2PC_ENABLED_DEFAULT, LOW, TRANSACTIONS_2PC_ENABLED_DOC);
 
     private final int transactionalIdExpirationMs;
     private final int transactionMaxTimeoutMs;
     private final int transactionAbortTimedOutTransactionCleanupIntervalMs;
     private final int transactionRemoveExpiredTransactionalIdCleanupIntervalMs;
+    private final boolean transaction2PCEnabled;
 
     public TransactionStateManagerConfig(AbstractConfig config) {
         transactionalIdExpirationMs = 
config.getInt(TransactionStateManagerConfig.TRANSACTIONAL_ID_EXPIRATION_MS_CONFIG);
         transactionMaxTimeoutMs = 
config.getInt(TransactionStateManagerConfig.TRANSACTIONS_MAX_TIMEOUT_MS_CONFIG);
         transactionAbortTimedOutTransactionCleanupIntervalMs = 
config.getInt(TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_CONFIG);
         transactionRemoveExpiredTransactionalIdCleanupIntervalMs = 
config.getInt(TransactionStateManagerConfig.TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONAL_ID_CLEANUP_INTERVAL_MS_CONFIG);
+        transaction2PCEnabled = 
config.getBoolean(TransactionStateManagerConfig.TRANSACTIONS_2PC_ENABLED_CONFIG);
     }
+
     public int transactionalIdExpirationMs() {
         return transactionalIdExpirationMs;
     }
@@ -81,4 +90,8 @@ public final class TransactionStateManagerConfig {
     public int transactionRemoveExpiredTransactionalIdCleanupIntervalMs() {
         return transactionRemoveExpiredTransactionalIdCleanupIntervalMs;
     }
+
+    public boolean transaction2PCEnabled() {
+        return transaction2PCEnabled;
+    }
 }
diff --git 
a/transaction-coordinator/src/test/java/org/apache/kafka/coordinator/transaction/TransactionStateManagerConfigTest.java
 
b/transaction-coordinator/src/test/java/org/apache/kafka/coordinator/transaction/TransactionStateManagerConfigTest.java
index 7b68f692ef8..ddcf5ff9463 100644
--- 
a/transaction-coordinator/src/test/java/org/apache/kafka/coordinator/transaction/TransactionStateManagerConfigTest.java
+++ 
b/transaction-coordinator/src/test/java/org/apache/kafka/coordinator/transaction/TransactionStateManagerConfigTest.java
@@ -50,6 +50,7 @@ class TransactionStateManagerConfigTest {
         
doReturn(2).when(config).getInt(TransactionStateManagerConfig.TRANSACTIONAL_ID_EXPIRATION_MS_CONFIG);
         
doReturn(3).when(config).getInt(TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_CONFIG);
         
doReturn(4).when(config).getInt(TransactionStateManagerConfig.TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONAL_ID_CLEANUP_INTERVAL_MS_CONFIG);
+        
doReturn(true).when(config).getBoolean(TransactionStateManagerConfig.TRANSACTIONS_2PC_ENABLED_CONFIG);
 
         TransactionStateManagerConfig transactionStateManagerConfig = new 
TransactionStateManagerConfig(config);
 
@@ -57,7 +58,7 @@ class TransactionStateManagerConfigTest {
         assertEquals(2, 
transactionStateManagerConfig.transactionalIdExpirationMs());
         assertEquals(3, 
transactionStateManagerConfig.transactionAbortTimedOutTransactionCleanupIntervalMs());
         assertEquals(4, 
transactionStateManagerConfig.transactionRemoveExpiredTransactionalIdCleanupIntervalMs());
-
+        assertEquals(true, 
transactionStateManagerConfig.transaction2PCEnabled());
 
         // If the following calls are missing, we won’t be able to distinguish 
whether the value is set in the constructor or if
         // it fetches the latest value from AbstractConfig with each call.
@@ -65,11 +66,12 @@ class TransactionStateManagerConfigTest {
         transactionStateManagerConfig.transactionalIdExpirationMs();
         
transactionStateManagerConfig.transactionAbortTimedOutTransactionCleanupIntervalMs();
         
transactionStateManagerConfig.transactionRemoveExpiredTransactionalIdCleanupIntervalMs();
+        transactionStateManagerConfig.transaction2PCEnabled();
 
         verify(config, 
times(1)).getInt(TransactionStateManagerConfig.TRANSACTIONS_MAX_TIMEOUT_MS_CONFIG);
         verify(config, 
times(1)).getInt(TransactionStateManagerConfig.TRANSACTIONAL_ID_EXPIRATION_MS_CONFIG);
         verify(config, 
times(1)).getInt(TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_CONFIG);
         verify(config, 
times(1)).getInt(TransactionStateManagerConfig.TRANSACTIONS_REMOVE_EXPIRED_TRANSACTIONAL_ID_CLEANUP_INTERVAL_MS_CONFIG);
+        verify(config, 
times(1)).getBoolean(TransactionStateManagerConfig.TRANSACTIONS_2PC_ENABLED_CONFIG);
     }
-
-}
\ No newline at end of file
+}

Reply via email to