This is an automated email from the ASF dual-hosted git repository.

dajac pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 08aa33127a4 MINOR: Push logic to resolve the transaction coordinator 
into the AddPartitionsToTxnManager (#14402)
08aa33127a4 is described below

commit 08aa33127a4254497456aa7a0c1646c7c38adf81
Author: David Jacot <[email protected]>
AuthorDate: Mon Sep 25 09:05:30 2023 -0700

    MINOR: Push logic to resolve the transaction coordinator into the 
AddPartitionsToTxnManager (#14402)
    
    This patch refactors the ReplicaManager.appendRecords method and the 
AddPartitionsToTxnManager class in order to move the logic to identify the 
transaction coordinator based on the transaction id from the former to the 
latter. While working on KAFKA-14505, I found pretty annoying that we require 
to pass the transaction state partition to appendRecords because we have to do 
the same from the group coordinator. It seems preferable to delegate that job 
to the AddPartitionsToTxnManager.
    
    Reviewers: Justine Olshan <[email protected]>
---
 .../kafka/server/AddPartitionsToTxnManager.scala   |  89 ++++++-
 .../src/main/scala/kafka/server/BrokerServer.scala |  13 +-
 core/src/main/scala/kafka/server/KafkaApis.scala   |  10 +-
 .../scala/kafka/server/KafkaRequestHandler.scala   |  16 +-
 core/src/main/scala/kafka/server/KafkaServer.scala |  12 +-
 .../main/scala/kafka/server/ReplicaManager.scala   |  69 +----
 .../kafka/server/KafkaRequestHandlerTest.scala     |  25 +-
 .../AbstractCoordinatorConcurrencyTest.scala       |   1 -
 .../group/CoordinatorPartitionWriterTest.scala     |   2 -
 .../coordinator/group/GroupCoordinatorTest.scala   |  12 +-
 .../group/GroupMetadataManagerTest.scala           |  14 +-
 .../transaction/TransactionStateManagerTest.scala  |  10 +-
 .../server/AddPartitionsToTxnManagerTest.scala     | 277 ++++++++++++++++----
 .../scala/unit/kafka/server/KafkaApisTest.scala    |  22 +-
 .../unit/kafka/server/ReplicaManagerTest.scala     | 287 +++++++--------------
 15 files changed, 465 insertions(+), 394 deletions(-)

diff --git a/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala 
b/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala
index 05e40014669..cb9707d14aa 100644
--- a/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala
+++ b/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala
@@ -18,19 +18,22 @@
 package kafka.server
 
 import 
kafka.server.AddPartitionsToTxnManager.{VerificationFailureRateMetricName, 
VerificationTimeMsMetricName}
+import kafka.utils.Implicits.MapExtensionMethods
 import kafka.utils.Logging
 import org.apache.kafka.clients.{ClientResponse, NetworkClient, 
RequestCompletionHandler}
+import org.apache.kafka.common.internals.Topic
 import org.apache.kafka.common.{Node, TopicPartition}
-import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTransaction,
 AddPartitionsToTxnTransactionCollection}
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTopic,
 AddPartitionsToTxnTopicCollection, AddPartitionsToTxnTransaction, 
AddPartitionsToTxnTransactionCollection}
 import org.apache.kafka.common.protocol.Errors
-import org.apache.kafka.common.requests.{AddPartitionsToTxnRequest, 
AddPartitionsToTxnResponse}
+import org.apache.kafka.common.requests.{AddPartitionsToTxnRequest, 
AddPartitionsToTxnResponse, MetadataResponse}
 import org.apache.kafka.common.utils.Time
 import org.apache.kafka.server.metrics.KafkaMetricsGroup
 import org.apache.kafka.server.util.{InterBrokerSendThread, 
RequestAndCompletionHandler}
 
 import java.util
 import java.util.concurrent.TimeUnit
-import scala.collection.mutable
+import scala.collection.{Set, Seq, mutable}
+import scala.jdk.CollectionConverters._
 
 object AddPartitionsToTxnManager {
   type AppendCallback = Map[TopicPartition, Errors] => Unit
@@ -39,7 +42,6 @@ object AddPartitionsToTxnManager {
   val VerificationTimeMsMetricName = "VerificationTimeMs"
 }
 
-
 /*
  * Data structure to hold the transactional data to send to a node. Note -- at 
most one request per transactional ID
  * will exist at a time in the map. If a given transactional ID exists in the 
map, and a new request with the same ID
@@ -49,10 +51,18 @@ class TransactionDataAndCallbacks(val transactionData: 
AddPartitionsToTxnTransac
                                   val callbacks: mutable.Map[String, 
AddPartitionsToTxnManager.AppendCallback],
                                   val startTimeMs: mutable.Map[String, Long])
 
-
-class AddPartitionsToTxnManager(config: KafkaConfig, client: NetworkClient, 
time: Time)
-  extends InterBrokerSendThread("AddPartitionsToTxnSenderThread-" + 
config.brokerId, client, config.requestTimeoutMs, time)
-  with Logging {
+class AddPartitionsToTxnManager(
+  config: KafkaConfig,
+  client: NetworkClient,
+  metadataCache: MetadataCache,
+  partitionFor: String => Int,
+  time: Time
+) extends InterBrokerSendThread(
+  "AddPartitionsToTxnSenderThread-" + config.brokerId,
+  client,
+  config.requestTimeoutMs,
+  time
+) with Logging {
 
   this.logIdent = logPrefix
 
@@ -63,7 +73,41 @@ class AddPartitionsToTxnManager(config: KafkaConfig, client: 
NetworkClient, time
   val verificationFailureRate = 
metricsGroup.newMeter(VerificationFailureRateMetricName, "failures", 
TimeUnit.SECONDS)
   val verificationTimeMs = 
metricsGroup.newHistogram(VerificationTimeMsMetricName)
 
-  def addTxnData(node: Node, transactionData: AddPartitionsToTxnTransaction, 
callback: AddPartitionsToTxnManager.AppendCallback): Unit = {
+  def verifyTransaction(
+    transactionalId: String,
+    producerId: Long,
+    producerEpoch: Short,
+    topicPartitions: Seq[TopicPartition],
+    callback: AddPartitionsToTxnManager.AppendCallback
+  ): Unit = {
+    val (error, node) = 
getTransactionCoordinator(partitionFor(transactionalId))
+
+    if (error != Errors.NONE) {
+      callback(topicPartitions.map(tp => tp -> error).toMap)
+    } else {
+      val topicCollection = new AddPartitionsToTxnTopicCollection()
+      topicPartitions.groupBy(_.topic).forKeyValue { (topic, tps) =>
+        topicCollection.add(new AddPartitionsToTxnTopic()
+          .setName(topic)
+          .setPartitions(tps.map(tp => Int.box(tp.partition)).toList.asJava))
+      }
+
+      val transactionData = new AddPartitionsToTxnTransaction()
+        .setTransactionalId(transactionalId)
+        .setProducerId(producerId)
+        .setProducerEpoch(producerEpoch)
+        .setVerifyOnly(true)
+        .setTopics(topicCollection)
+
+      addTxnData(node, transactionData, callback)
+    }
+  }
+
+  private def addTxnData(
+    node: Node,
+    transactionData: AddPartitionsToTxnTransaction,
+    callback: AddPartitionsToTxnManager.AppendCallback
+  ): Unit = {
     nodesToTransactions.synchronized {
       val curTime = time.milliseconds()
       // Check if we have already have either node or individual transaction. 
Add the Node if it isn't there.
@@ -102,6 +146,33 @@ class AddPartitionsToTxnManager(config: KafkaConfig, 
client: NetworkClient, time
     }
   }
 
+  private def getTransactionCoordinator(partition: Int): (Errors, Node) = {
+    val listenerName = config.interBrokerListenerName
+
+    val topicMetadata = 
metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
listenerName)
+
+    if (topicMetadata.headOption.isEmpty) {
+      // If topic is not created, then the transaction is definitely not 
started.
+      (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode)
+    } else {
+      if (topicMetadata.head.errorCode != Errors.NONE.code) {
+        (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode)
+      } else {
+        val coordinatorEndpoint = topicMetadata.head.partitions.asScala
+          .find(_.partitionIndex == partition)
+          .filter(_.leaderId != MetadataResponse.NO_LEADER_ID)
+          .flatMap(metadata => 
metadataCache.getAliveBrokerNode(metadata.leaderId, listenerName))
+
+        coordinatorEndpoint match {
+          case Some(endpoint) =>
+            (Errors.NONE, endpoint)
+          case _ =>
+            (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode)
+        }
+      }
+    }
+  }
+
   private def topicPartitionsToError(transactionData: 
AddPartitionsToTxnTransaction, error: Errors): Map[TopicPartition, Errors] = {
     val topicPartitionsToError = mutable.Map[TopicPartition, Errors]()
     transactionData.topics.forEach { topic =>
diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala 
b/core/src/main/scala/kafka/server/BrokerServer.scala
index 99554a90bbe..96e8dc8a5a1 100644
--- a/core/src/main/scala/kafka/server/BrokerServer.scala
+++ b/core/src/main/scala/kafka/server/BrokerServer.scala
@@ -27,7 +27,6 @@ import kafka.raft.KafkaRaftManager
 import kafka.security.CredentialProvider
 import kafka.server.metadata.{AclPublisher, BrokerMetadataPublisher, 
ClientQuotaMetadataManager, DelegationTokenPublisher, 
DynamicClientQuotaPublisher, DynamicConfigPublisher, KRaftMetadataCache, 
ScramPublisher}
 import kafka.utils.CoreUtils
-import org.apache.kafka.clients.NetworkClient
 import org.apache.kafka.common.config.ConfigException
 import org.apache.kafka.common.feature.SupportedVersionRange
 import org.apache.kafka.common.message.ApiMessageType.ListenerType
@@ -258,8 +257,16 @@ class BrokerServer(
       alterPartitionManager.start()
 
       val addPartitionsLogContext = new 
LogContext(s"[AddPartitionsToTxnManager broker=${config.brokerId}]")
-      val addPartitionsToTxnNetworkClient: NetworkClient = 
NetworkUtils.buildNetworkClient("AddPartitionsManager", config, metrics, time, 
addPartitionsLogContext)
-      val addPartitionsToTxnManager: AddPartitionsToTxnManager = new 
AddPartitionsToTxnManager(config, addPartitionsToTxnNetworkClient, time)
+      val addPartitionsToTxnNetworkClient = 
NetworkUtils.buildNetworkClient("AddPartitionsManager", config, metrics, time, 
addPartitionsLogContext)
+      val addPartitionsToTxnManager = new AddPartitionsToTxnManager(
+        config,
+        addPartitionsToTxnNetworkClient,
+        metadataCache,
+        // The transaction coordinator is not created at this point so we must
+        // use a lambda here.
+        transactionalId => 
transactionCoordinator.partitionFor(transactionalId),
+        time
+      )
 
       this._replicaManager = new ReplicaManager(
         config = config,
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala 
b/core/src/main/scala/kafka/server/KafkaApis.scala
index 5c8091dccfc..8bef9daab5e 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -677,12 +677,6 @@ class KafkaApis(val requestChannel: RequestChannel,
     else {
       val internalTopicsAllowed = request.header.clientId == 
AdminUtils.ADMIN_CLIENT_ID
 
-      val transactionStatePartition =
-        if (produceRequest.transactionalId() == null)
-          None
-        else
-          Some(txnCoordinator.partitionFor(produceRequest.transactionalId()))
-
       // call the replica manager to append messages to the replicas
       replicaManager.appendRecords(
         timeout = produceRequest.timeout.toLong,
@@ -693,8 +687,8 @@ class KafkaApis(val requestChannel: RequestChannel,
         requestLocal = requestLocal,
         responseCallback = sendResponseCallback,
         recordConversionStatsCallback = processingStatsCallback,
-        transactionalId = produceRequest.transactionalId(),
-        transactionStatePartition = transactionStatePartition)
+        transactionalId = produceRequest.transactionalId()
+      )
 
       // if the request is put into the purgatory, it will have a held 
reference and hence cannot be garbage collected;
       // hence we clear its data here in order to let GC reclaim its memory 
since it is already appended to log
diff --git a/core/src/main/scala/kafka/server/KafkaRequestHandler.scala 
b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
index 316cf92ca5a..5bfe500f13a 100755
--- a/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
+++ b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
@@ -48,10 +48,6 @@ object KafkaRequestHandler {
   def setBypassThreadCheck(bypassCheck: Boolean): Unit = {
     bypassThreadCheck = bypassCheck
   }
-  
-  def currentRequestOnThread(): RequestChannel.Request = {
-    threadCurrentRequest.get()
-  }
 
   /**
    * Wrap callback to schedule it on a request thread.
@@ -68,9 +64,15 @@ object KafkaRequestHandler {
       T => fun(T)
     } else {
       T => {
-        // The requestChannel and request are captured in this lambda, so when 
it's executed on the callback thread
-        // we can re-schedule the original callback on a request thread and 
update the metrics accordingly.
-        requestChannel.sendCallbackRequest(RequestChannel.CallbackRequest(() 
=> fun(T), currentRequest))
+        if (threadCurrentRequest.get() != null) {
+          // If the callback is actually executed on a request thread, we can 
directly execute
+          // it without re-scheduling it.
+          fun(T)
+        } else {
+          // The requestChannel and request are captured in this lambda, so 
when it's executed on the callback thread
+          // we can re-schedule the original callback on a request thread and 
update the metrics accordingly.
+          requestChannel.sendCallbackRequest(RequestChannel.CallbackRequest(() 
=> fun(T), currentRequest))
+        }
       }
     }
   }
diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala 
b/core/src/main/scala/kafka/server/KafkaServer.scala
index 7ca3373549d..cb1fee8778a 100755
--- a/core/src/main/scala/kafka/server/KafkaServer.scala
+++ b/core/src/main/scala/kafka/server/KafkaServer.scala
@@ -629,8 +629,16 @@ class KafkaServer(
 
   protected def createReplicaManager(isShuttingDown: AtomicBoolean): 
ReplicaManager = {
     val addPartitionsLogContext = new LogContext(s"[AddPartitionsToTxnManager 
broker=${config.brokerId}]")
-    val addPartitionsToTxnNetworkClient: NetworkClient = 
NetworkUtils.buildNetworkClient("AddPartitionsManager", config, metrics, time, 
addPartitionsLogContext)
-    val addPartitionsToTxnManager: AddPartitionsToTxnManager = new 
AddPartitionsToTxnManager(config, addPartitionsToTxnNetworkClient, time)
+    val addPartitionsToTxnNetworkClient = 
NetworkUtils.buildNetworkClient("AddPartitionsManager", config, metrics, time, 
addPartitionsLogContext)
+    val addPartitionsToTxnManager = new AddPartitionsToTxnManager(
+      config,
+      addPartitionsToTxnNetworkClient,
+      metadataCache,
+      // The transaction coordinator is not created at this point so we must
+      // use a lambda here.
+      transactionalId => transactionCoordinator.partitionFor(transactionalId),
+      time
+    )
 
     new ReplicaManager(
       metrics = metrics,
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala 
b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 73dd527ed16..07e7418776a 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -33,7 +33,6 @@ import kafka.utils._
 import kafka.zk.KafkaZkClient
 import org.apache.kafka.common.errors._
 import org.apache.kafka.common.internals.Topic
-import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTopic,
 AddPartitionsToTxnTopicCollection, AddPartitionsToTxnTransaction}
 import 
org.apache.kafka.common.message.DeleteRecordsResponseData.DeleteRecordsPartitionResult
 import 
org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState
 import 
org.apache.kafka.common.message.LeaderAndIsrResponseData.{LeaderAndIsrPartitionError,
 LeaderAndIsrTopicError}
@@ -711,7 +710,6 @@ class ReplicaManager(val config: KafkaConfig,
    * @param recordConversionStatsCallback callback for updating stats on 
record conversions
    * @param requestLocal                  container for the stateful instances 
scoped to this request
    * @param transactionalId               transactional ID if the request is 
from a producer and the producer is transactional
-   * @param transactionStatePartition     partition that holds the 
transactional state if transactionalId is present
    * @param actionQueue                   the action queue to use. 
ReplicaManager#actionQueue is used by default.
    */
   def appendRecords(timeout: Long,
@@ -724,14 +722,13 @@ class ReplicaManager(val config: KafkaConfig,
                     recordConversionStatsCallback: Map[TopicPartition, 
RecordConversionStats] => Unit = _ => (),
                     requestLocal: RequestLocal = RequestLocal.NoCaching,
                     transactionalId: String = null,
-                    transactionStatePartition: Option[Int] = None,
                     actionQueue: ActionQueue = this.actionQueue): Unit = {
     if (isValidRequiredAcks(requiredAcks)) {
       val sTime = time.milliseconds
 
       val verificationGuards: mutable.Map[TopicPartition, Object] = 
mutable.Map[TopicPartition, Object]()
       val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition, 
errorsPerPartition) =
-        if (transactionStatePartition.isEmpty || 
!config.transactionPartitionVerificationEnable)
+        if (transactionalId == null || 
!config.transactionPartitionVerificationEnable)
           (entriesPerPartition, Map.empty[TopicPartition, MemoryRecords], 
Map.empty[TopicPartition, Errors])
         else {
           val verifiedEntries = mutable.Map[TopicPartition, MemoryRecords]()
@@ -846,33 +843,15 @@ class ReplicaManager(val config: KafkaConfig,
         appendEntries(verifiedEntriesPerPartition)(Map.empty)
       } else {
         // For unverified entries, send a request to verify. When verified, 
the append process will proceed via the callback.
-        val (error, node) = 
getTransactionCoordinator(transactionStatePartition.get)
-
-        if (error != Errors.NONE) {
-          
appendEntries(entriesPerPartition)(notYetVerifiedEntriesPerPartition.map {
-            case (tp, _) => (tp, error)
-          }.toMap)
-        } else {
-          val topicGrouping = 
notYetVerifiedEntriesPerPartition.keySet.groupBy(tp => tp.topic())
-          val topicCollection = new AddPartitionsToTxnTopicCollection()
-          topicGrouping.foreach { case (topic, tps) =>
-            topicCollection.add(new AddPartitionsToTxnTopic()
-              .setName(topic)
-              .setPartitions(tps.map(tp => 
Integer.valueOf(tp.partition())).toList.asJava))
-          }
-
-          // Map not yet verified partitions to a request object.
-          // We verify above that all partitions use the same producer ID.
-          val batchInfo = 
notYetVerifiedEntriesPerPartition.head._2.firstBatch()
-          val notYetVerifiedTransaction = new AddPartitionsToTxnTransaction()
-            .setTransactionalId(transactionalId)
-            .setProducerId(batchInfo.producerId())
-            .setProducerEpoch(batchInfo.producerEpoch())
-            .setVerifyOnly(true)
-            .setTopics(topicCollection)
-
-          addPartitionsToTxnManager.foreach(_.addTxnData(node, 
notYetVerifiedTransaction, 
KafkaRequestHandler.wrap(appendEntries(entriesPerPartition)(_))))
-        }
+        // We verify above that all partitions use the same producer ID.
+        val batchInfo = notYetVerifiedEntriesPerPartition.head._2.firstBatch()
+        addPartitionsToTxnManager.foreach(_.verifyTransaction(
+          transactionalId = transactionalId,
+          producerId = batchInfo.producerId,
+          producerEpoch = batchInfo.producerEpoch,
+          topicPartitions = notYetVerifiedEntriesPerPartition.keySet.toSeq,
+          callback = 
KafkaRequestHandler.wrap(appendEntries(entriesPerPartition)(_))
+        ))
       }
     } else {
       // If required.acks is outside accepted range, something is wrong with 
the client
@@ -2644,32 +2623,4 @@ class ReplicaManager(val config: KafkaConfig,
       }
     }
   }
-
-  private[server] def getTransactionCoordinator(partition: Int): (Errors, 
Node) = {
-    val listenerName = config.interBrokerListenerName
-
-    val topicMetadata = 
metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
listenerName)
-
-    if (topicMetadata.headOption.isEmpty) {
-      // If topic is not created, then the transaction is definitely not 
started.
-      (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode)
-    } else {
-      if (topicMetadata.head.errorCode != Errors.NONE.code) {
-        (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode)
-      } else {
-        val coordinatorEndpoint = topicMetadata.head.partitions.asScala
-          .find(_.partitionIndex == partition)
-          .filter(_.leaderId != MetadataResponse.NO_LEADER_ID)
-          .flatMap(metadata => metadataCache.
-            getAliveBrokerNode(metadata.leaderId, listenerName))
-
-        coordinatorEndpoint match {
-          case Some(endpoint) =>
-            (Errors.NONE, endpoint)
-          case _ =>
-            (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode)
-        }
-      }
-    }
-  }
 }
diff --git a/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala 
b/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
index 822dd5263d6..6016d7c99d3 100644
--- a/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
+++ b/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
@@ -36,6 +36,7 @@ import org.mockito.Mockito.{mock, when}
 
 import java.net.InetAddress
 import java.nio.ByteBuffer
+import java.util.concurrent.CompletableFuture
 import java.util.concurrent.atomic.AtomicInteger
 
 class KafkaRequestHandlerTest {
@@ -53,14 +54,15 @@ class KafkaRequestHandlerTest {
       val request = makeRequest(time, metrics)
       requestChannel.sendRequest(request)
 
-      def callback(ms: Int): Unit = {
-        time.sleep(ms)
-        handler.stop()
-      }
-
       when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer 
{ _ =>
         time.sleep(2)
-        KafkaRequestHandler.wrap(callback(_: Int))(1)
+        // Prepare the callback.
+        val callback = KafkaRequestHandler.wrap((ms: Int) => {
+          time.sleep(ms)
+          handler.stop()
+        })
+        // Execute the callback asynchronously.
+        CompletableFuture.runAsync(() => callback(1))
         request.apiLocalCompleteTimeNanos = time.nanoseconds
       }
 
@@ -86,16 +88,17 @@ class KafkaRequestHandlerTest {
     var handledCount = 0
     var tryCompleteActionCount = 0
 
-    def callback(x: Int): Unit = {
-      handler.stop()
-    }
-
     val request = makeRequest(time, metrics)
     requestChannel.sendRequest(request)
 
     when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { 
_ =>
       handledCount = handledCount + 1
-      KafkaRequestHandler.wrap(callback(_: Int))(1)
+      // Prepare the callback.
+      val callback = KafkaRequestHandler.wrap((ms: Int) => {
+        handler.stop()
+      })
+      // Execute the callback asynchronously.
+      CompletableFuture.runAsync(() => callback(1))
     }
 
     when(apiHandler.tryCompleteActions()).thenAnswer { _ =>
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
index 255b8dbb866..9b8c02249aa 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
@@ -180,7 +180,6 @@ object AbstractCoordinatorConcurrencyTest {
                                processingStatsCallback: Map[TopicPartition, 
RecordConversionStats] => Unit = _ => (),
                                requestLocal: RequestLocal = 
RequestLocal.NoCaching,
                                transactionalId: String = null,
-                               transactionStatePartition: Option[Int],
                                actionQueue: ActionQueue = null): Unit = {
 
       if (entriesPerPartition.isEmpty)
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/group/CoordinatorPartitionWriterTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/group/CoordinatorPartitionWriterTest.scala
index 436458ccc48..badcb6f8cba 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/group/CoordinatorPartitionWriterTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/group/CoordinatorPartitionWriterTest.scala
@@ -112,7 +112,6 @@ class CoordinatorPartitionWriterTest {
       ArgumentMatchers.any(),
       ArgumentMatchers.any(),
       ArgumentMatchers.any(),
-      ArgumentMatchers.any(),
       ArgumentMatchers.any()
     )).thenAnswer( _ => {
       callbackCapture.getValue.apply(Map(
@@ -183,7 +182,6 @@ class CoordinatorPartitionWriterTest {
       ArgumentMatchers.any(),
       ArgumentMatchers.any(),
       ArgumentMatchers.any(),
-      ArgumentMatchers.any(),
       ArgumentMatchers.any()
     )).thenAnswer(_ => {
       callbackCapture.getValue.apply(Map(
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala 
b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala
index 787e76d6aef..ac9c16dde6c 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala
@@ -3865,7 +3865,6 @@ class GroupCoordinatorTest {
       any(),
       any(),
       any(),
-      any(),
       any()
     )).thenAnswer(_ => {
       capturedArgument.getValue.apply(
@@ -3902,7 +3901,6 @@ class GroupCoordinatorTest {
       any(),
       any(), 
       any(),
-      any(),
       any())).thenAnswer(_ => {
         capturedArgument.getValue.apply(
           Map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 
groupPartitionId) ->
@@ -4049,9 +4047,8 @@ class GroupCoordinatorTest {
       any(),
       any(),
       any(),
-      any(),
-      any())
-    ).thenAnswer(_ => {
+      any()
+    )).thenAnswer(_ => {
       capturedArgument.getValue.apply(
         Map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 
groupPartitionId) ->
           new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L)
@@ -4085,9 +4082,8 @@ class GroupCoordinatorTest {
       any(),
       any(),
       any(),
-      any(),
-      any())
-    ).thenAnswer(_ => {
+      any()
+    )).thenAnswer(_ => {
       capturedArgument.getValue.apply(
         Map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 
groupCoordinator.partitionFor(groupId)) ->
           new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L)
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala
index 4304841453b..578317eeb88 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala
@@ -1185,7 +1185,6 @@ class GroupMetadataManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any())
     verify(replicaManager).getMagic(any())
   }
@@ -1224,7 +1223,6 @@ class GroupMetadataManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any())
     verify(replicaManager).getMagic(any())
   }
@@ -1301,7 +1299,6 @@ class GroupMetadataManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any())
     // Will update sensor after commit
     assertEquals(1, TestUtils.totalMetricValue(metrics, "offset-commit-count"))
@@ -1344,7 +1341,6 @@ class GroupMetadataManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any())
     verify(replicaManager).getMagic(any())
     capturedResponseCallback.getValue.apply(Map(groupTopicPartition ->
@@ -1405,7 +1401,6 @@ class GroupMetadataManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any())
     verify(replicaManager).getMagic(any())
   }
@@ -1456,7 +1451,6 @@ class GroupMetadataManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any())
     verify(replicaManager).getMagic(any())
   }
@@ -1609,7 +1603,6 @@ class GroupMetadataManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any())
     verify(replicaManager).getMagic(any())
     assertEquals(1, TestUtils.totalMetricValue(metrics, "offset-commit-count"))
@@ -1717,7 +1710,6 @@ class GroupMetadataManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any())
     verify(replicaManager, times(2)).getMagic(any())
   }
@@ -2825,7 +2817,6 @@ class GroupMetadataManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any())
     capturedArgument
   }
@@ -2843,9 +2834,8 @@ class GroupMetadataManagerTest {
       any(),
       any(),
       any(),
-      any(),
-      any())
-    ).thenAnswer(_ => {
+      any()
+    )).thenAnswer(_ => {
       capturedCallback.getValue.apply(
         Map(groupTopicPartition ->
           new PartitionResponse(error, 0L, RecordBatch.NO_TIMESTAMP, 0L)
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index e0a4a4470cc..94ffe7a9795 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -658,7 +658,6 @@ class TransactionStateManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any()
     )
 
@@ -704,7 +703,6 @@ class TransactionStateManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any()
     )
 
@@ -747,7 +745,6 @@ class TransactionStateManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any())
 
     assertEquals(Set.empty, listExpirableTransactionalIds())
@@ -806,7 +803,6 @@ class TransactionStateManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any()
     )
 
@@ -957,7 +953,6 @@ class TransactionStateManagerTest {
       any(),
       any(),
       any(),
-      any(),
       any()
     )).thenAnswer(_ => callbackCapture.getValue.apply(
       recordsCapture.getValue.map { case (topicPartition, records) =>
@@ -1110,9 +1105,8 @@ class TransactionStateManagerTest {
       any(),
       any(),
       any(),
-      any(),
-      any())
-    ).thenAnswer(_ => capturedArgument.getValue.apply(
+      any()
+    )).thenAnswer(_ => capturedArgument.getValue.apply(
       Map(new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId) ->
         new PartitionResponse(error, 0L, RecordBatch.NO_TIMESTAMP, 0L)))
     )
diff --git 
a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala 
b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala
index 9231fdc124f..9e34322ec96 100644
--- a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala
@@ -15,15 +15,16 @@
  * limitations under the License.
  */
 
-package unit.kafka.server
+package kafka.server
 
 import com.yammer.metrics.core.{Histogram, Meter}
-import kafka.server.{AddPartitionsToTxnManager, KafkaConfig}
+import kafka.utils.Implicits.MapExtensionMethods
 import kafka.utils.TestUtils
 import org.apache.kafka.clients.{ClientResponse, NetworkClient}
 import org.apache.kafka.common.errors.{AuthenticationException, 
SaslAuthenticationException, UnsupportedVersionException}
+import org.apache.kafka.common.internals.Topic
 import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTopic,
 AddPartitionsToTxnTopicCollection, AddPartitionsToTxnTransaction, 
AddPartitionsToTxnTransactionCollection}
-import org.apache.kafka.common.message.AddPartitionsToTxnResponseData
+import org.apache.kafka.common.message.{AddPartitionsToTxnResponseData, 
MetadataResponseData}
 import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResultCollection
 import org.apache.kafka.common.{Node, TopicPartition}
 import org.apache.kafka.common.protocol.Errors
@@ -44,13 +45,15 @@ import scala.jdk.CollectionConverters._
 
 class AddPartitionsToTxnManagerTest {
   private val networkClient: NetworkClient = mock(classOf[NetworkClient])
+  private val metadataCache: MetadataCache = mock(classOf[MetadataCache])
+  private val partitionFor: String => Int = mock(classOf[String => Int])
 
   private val time = new MockTime
 
   private var addPartitionsToTxnManager: AddPartitionsToTxnManager = _
 
-  val topic = "foo"
-  val topicPartitions = List(new TopicPartition(topic, 1), new 
TopicPartition(topic, 2), new TopicPartition(topic, 3))
+  private val topic = "foo"
+  private val topicPartitions = List(new TopicPartition(topic, 1), new 
TopicPartition(topic, 2), new TopicPartition(topic, 3))
 
   private val node0 = new Node(0, "host1", 0)
   private val node1 = new Node(1, "host2", 1)
@@ -68,12 +71,17 @@ class AddPartitionsToTxnManagerTest {
   private val versionMismatchResponse = clientResponse(null, mismatchException 
= new UnsupportedVersionException(""))
   private val disconnectedResponse = clientResponse(null, disconnected = true)
 
+  private val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, 
"localhost:2181"))
+
   @BeforeEach
   def setup(): Unit = {
     addPartitionsToTxnManager = new AddPartitionsToTxnManager(
-      KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:2181")),
+      config,
       networkClient,
-      time)
+      metadataCache,
+      partitionFor,
+      time
+    )
   }
 
   @AfterEach
@@ -81,21 +89,40 @@ class AddPartitionsToTxnManagerTest {
     addPartitionsToTxnManager.shutdown()
   }
 
-  def setErrors(errors: mutable.Map[TopicPartition, Errors])(callbackErrors: 
Map[TopicPartition, Errors]): Unit = {
-    callbackErrors.foreach {
-      case (tp, error) => errors.put(tp, error)
-    }
+  private def setErrors(errors: mutable.Map[TopicPartition, 
Errors])(callbackErrors: Map[TopicPartition, Errors]): Unit = {
+    callbackErrors.forKeyValue(errors.put)
   }
 
   @Test
   def testAddTxnData(): Unit = {
+    when(partitionFor.apply(transactionalId1)).thenReturn(0)
+    when(partitionFor.apply(transactionalId2)).thenReturn(1)
+    when(partitionFor.apply(transactionalId3)).thenReturn(0)
+    
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName))
+      .thenReturn(Seq(
+        new MetadataResponseData.MetadataResponseTopic()
+          .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
+          .setPartitions(List(
+            new MetadataResponseData.MetadataResponsePartition()
+              .setPartitionIndex(0)
+              .setLeaderId(0),
+            new MetadataResponseData.MetadataResponsePartition()
+              .setPartitionIndex(1)
+              .setLeaderId(1)
+          ).asJava)
+      ))
+    when(metadataCache.getAliveBrokerNode(0, config.interBrokerListenerName))
+      .thenReturn(Some(node0))
+    when(metadataCache.getAliveBrokerNode(1, config.interBrokerListenerName))
+      .thenReturn(Some(node1))
+
     val transaction1Errors = mutable.Map[TopicPartition, Errors]()
     val transaction2Errors = mutable.Map[TopicPartition, Errors]()
     val transaction3Errors = mutable.Map[TopicPartition, Errors]()
 
-    addPartitionsToTxnManager.addTxnData(node0, 
transactionData(transactionalId1, producerId1), setErrors(transaction1Errors))
-    addPartitionsToTxnManager.addTxnData(node1, 
transactionData(transactionalId2, producerId2), setErrors(transaction2Errors))
-    addPartitionsToTxnManager.addTxnData(node0, 
transactionData(transactionalId3, producerId3), setErrors(transaction3Errors))
+    addPartitionsToTxnManager.verifyTransaction(transactionalId1, producerId1, 
producerEpoch = 0, topicPartitions, setErrors(transaction1Errors))
+    addPartitionsToTxnManager.verifyTransaction(transactionalId2, producerId2, 
producerEpoch = 0, topicPartitions, setErrors(transaction2Errors))
+    addPartitionsToTxnManager.verifyTransaction(transactionalId3, producerId3, 
producerEpoch = 0, topicPartitions, setErrors(transaction3Errors))
 
     // We will try to add transaction1 3 more times (retries). One will have 
the same epoch, one will have a newer epoch, and one will have an older epoch 
than the new one we just added.
     val transaction1RetryWithSameEpochErrors = mutable.Map[TopicPartition, 
Errors]()
@@ -103,26 +130,32 @@ class AddPartitionsToTxnManagerTest {
     val transaction1RetryWithOldEpochErrors = mutable.Map[TopicPartition, 
Errors]()
 
     // Trying to add more transactional data for the same transactional ID, 
producer ID, and epoch should simply replace the old data and send a retriable 
response.
-    addPartitionsToTxnManager.addTxnData(node0, 
transactionData(transactionalId1, producerId1), 
setErrors(transaction1RetryWithSameEpochErrors))
+    addPartitionsToTxnManager.verifyTransaction(transactionalId1, producerId1, 
producerEpoch = 0, topicPartitions, 
setErrors(transaction1RetryWithSameEpochErrors))
     val expectedNetworkErrors = topicPartitions.map(_ -> 
Errors.NETWORK_EXCEPTION).toMap
     assertEquals(expectedNetworkErrors, transaction1Errors)
 
     // Trying to add more transactional data for the same transactional ID and 
producer ID, but new epoch should replace the old data and send an error 
response for it.
-    addPartitionsToTxnManager.addTxnData(node0, 
transactionData(transactionalId1, producerId1, producerEpoch = 1), 
setErrors(transaction1RetryWithNewerEpochErrors))
+    addPartitionsToTxnManager.verifyTransaction(transactionalId1, producerId1, 
producerEpoch = 1, topicPartitions, 
setErrors(transaction1RetryWithNewerEpochErrors))
     val expectedEpochErrors = topicPartitions.map(_ -> 
Errors.INVALID_PRODUCER_EPOCH).toMap
     assertEquals(expectedEpochErrors, transaction1RetryWithSameEpochErrors)
 
     // Trying to add more transactional data for the same transactional ID and 
producer ID, but an older epoch should immediately return with error and keep 
the old data queued to send.
-    addPartitionsToTxnManager.addTxnData(node0, 
transactionData(transactionalId1, producerId1, producerEpoch = 0), 
setErrors(transaction1RetryWithOldEpochErrors))
+    addPartitionsToTxnManager.verifyTransaction(transactionalId1, producerId1, 
producerEpoch = 0, topicPartitions, 
setErrors(transaction1RetryWithOldEpochErrors))
     assertEquals(expectedEpochErrors, transaction1RetryWithOldEpochErrors)
 
     val requestsAndHandlers = 
addPartitionsToTxnManager.generateRequests().asScala
     requestsAndHandlers.foreach { requestAndHandler =>
       if (requestAndHandler.destination == node0) {
         assertEquals(time.milliseconds(), requestAndHandler.creationTimeMs)
-        assertEquals(AddPartitionsToTxnRequest.Builder.forBroker(
-          new 
AddPartitionsToTxnTransactionCollection(Seq(transactionData(transactionalId3, 
producerId3), transactionData(transactionalId1, producerId1, producerEpoch = 
1)).iterator.asJava)).data,
-          
requestAndHandler.request.asInstanceOf[AddPartitionsToTxnRequest.Builder].data) 
// insertion order
+        assertEquals(
+          AddPartitionsToTxnRequest.Builder.forBroker(
+            new AddPartitionsToTxnTransactionCollection(Seq(
+              transactionData(transactionalId3, producerId3),
+              transactionData(transactionalId1, producerId1, producerEpoch = 1)
+            ).iterator.asJava)
+          ).data,
+          
requestAndHandler.request.asInstanceOf[AddPartitionsToTxnRequest.Builder].data 
// insertion order
+        )
       } else {
         verifyRequest(node1, transactionalId2, producerId2, requestAndHandler)
       }
@@ -131,15 +164,41 @@ class AddPartitionsToTxnManagerTest {
 
   @Test
   def testGenerateRequests(): Unit = {
+    when(partitionFor.apply(transactionalId1)).thenReturn(0)
+    when(partitionFor.apply(transactionalId2)).thenReturn(1)
+    when(partitionFor.apply(transactionalId3)).thenReturn(2)
+    
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName))
+      .thenReturn(Seq(
+        new MetadataResponseData.MetadataResponseTopic()
+          .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
+          .setPartitions(List(
+            new MetadataResponseData.MetadataResponsePartition()
+              .setPartitionIndex(0)
+              .setLeaderId(0),
+            new MetadataResponseData.MetadataResponsePartition()
+              .setPartitionIndex(1)
+              .setLeaderId(1),
+            new MetadataResponseData.MetadataResponsePartition()
+              .setPartitionIndex(2)
+              .setLeaderId(2)
+          ).asJava)
+      ))
+    when(metadataCache.getAliveBrokerNode(0, config.interBrokerListenerName))
+      .thenReturn(Some(node0))
+    when(metadataCache.getAliveBrokerNode(1, config.interBrokerListenerName))
+      .thenReturn(Some(node1))
+    when(metadataCache.getAliveBrokerNode(2, config.interBrokerListenerName))
+      .thenReturn(Some(node2))
+
     val transactionErrors = mutable.Map[TopicPartition, Errors]()
 
-    addPartitionsToTxnManager.addTxnData(node0, 
transactionData(transactionalId1, producerId1), setErrors(transactionErrors))
-    addPartitionsToTxnManager.addTxnData(node1, 
transactionData(transactionalId2, producerId2), setErrors(transactionErrors))
+    addPartitionsToTxnManager.verifyTransaction(transactionalId1, producerId1, 
producerEpoch = 0, topicPartitions, setErrors(transactionErrors))
+    addPartitionsToTxnManager.verifyTransaction(transactionalId2, producerId2, 
producerEpoch = 0, topicPartitions, setErrors(transactionErrors))
 
     val requestsAndHandlers = 
addPartitionsToTxnManager.generateRequests().asScala
     assertEquals(2, requestsAndHandlers.size)
     // Note: handlers are tested in testAddPartitionsToTxnHandlerErrorHandling
-    requestsAndHandlers.foreach{ requestAndHandler =>
+    requestsAndHandlers.foreach { requestAndHandler =>
       if (requestAndHandler.destination == node0) {
         verifyRequest(node0, transactionalId1, producerId1, requestAndHandler)
       } else {
@@ -147,8 +206,8 @@ class AddPartitionsToTxnManagerTest {
       }
     }
 
-    addPartitionsToTxnManager.addTxnData(node1, 
transactionData(transactionalId2, producerId2), setErrors(transactionErrors))
-    addPartitionsToTxnManager.addTxnData(node2, 
transactionData(transactionalId3, producerId3), setErrors(transactionErrors))
+    addPartitionsToTxnManager.verifyTransaction(transactionalId2, producerId2, 
producerEpoch = 0, topicPartitions, setErrors(transactionErrors))
+    addPartitionsToTxnManager.verifyTransaction(transactionalId3, producerId3, 
producerEpoch = 0, topicPartitions, setErrors(transactionErrors))
 
     // Test creationTimeMs increases too.
     time.sleep(10)
@@ -169,8 +228,90 @@ class AddPartitionsToTxnManagerTest {
     }
   }
 
+  @Test
+  def testTransactionCoordinatorResolution(): Unit = {
+    when(partitionFor.apply(transactionalId1)).thenReturn(0)
+
+    def checkError(): Unit = {
+      val errors = mutable.Map[TopicPartition, Errors]()
+
+      addPartitionsToTxnManager.verifyTransaction(
+        transactionalId1,
+        producerId1,
+        producerEpoch = 0,
+        topicPartitions,
+        setErrors(errors)
+      )
+
+      assertEquals(topicPartitions.map(tp => tp -> 
Errors.COORDINATOR_NOT_AVAILABLE).toMap, errors)
+    }
+
+    // The transaction state topic does not exist.
+    
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName))
+      .thenReturn(Seq())
+    checkError()
+
+    // The metadata of the transaction state topic returns an error.
+    
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName))
+      .thenReturn(Seq(
+        new MetadataResponseData.MetadataResponseTopic()
+          .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
+          .setErrorCode(Errors.BROKER_NOT_AVAILABLE.code)
+      ))
+    checkError()
+
+    // The partition does not exist.
+    
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName))
+      .thenReturn(Seq(
+        new MetadataResponseData.MetadataResponseTopic()
+          .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
+      ))
+    checkError()
+
+    // The partition has no leader.
+    
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName))
+      .thenReturn(Seq(
+        new MetadataResponseData.MetadataResponseTopic()
+          .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
+          .setPartitions(List(
+            new MetadataResponseData.MetadataResponsePartition()
+              .setPartitionIndex(0)
+              .setLeaderId(-1)
+          ).asJava)
+      ))
+    checkError()
+
+    // The leader is not available.
+    
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName))
+      .thenReturn(Seq(
+        new MetadataResponseData.MetadataResponseTopic()
+          .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
+          .setPartitions(List(
+            new MetadataResponseData.MetadataResponsePartition()
+              .setPartitionIndex(0)
+              .setLeaderId(0)
+          ).asJava)
+      ))
+    checkError()
+  }
+
   @Test
   def testAddPartitionsToTxnHandlerErrorHandling(): Unit = {
+    when(partitionFor.apply(transactionalId1)).thenReturn(0)
+    when(partitionFor.apply(transactionalId2)).thenReturn(0)
+    
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName))
+      .thenReturn(Seq(
+        new MetadataResponseData.MetadataResponseTopic()
+          .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
+          .setPartitions(List(
+            new MetadataResponseData.MetadataResponsePartition()
+              .setPartitionIndex(0)
+              .setLeaderId(0)
+          ).asJava)
+      ))
+    when(metadataCache.getAliveBrokerNode(0, config.interBrokerListenerName))
+      .thenReturn(Some(node0))
+
     val transaction1Errors = mutable.Map[TopicPartition, Errors]()
     val transaction2Errors = mutable.Map[TopicPartition, Errors]()
 
@@ -178,8 +319,8 @@ class AddPartitionsToTxnManagerTest {
       transaction1Errors.clear()
       transaction2Errors.clear()
 
-      addPartitionsToTxnManager.addTxnData(node0, 
transactionData(transactionalId1, producerId1), setErrors(transaction1Errors))
-      addPartitionsToTxnManager.addTxnData(node0, 
transactionData(transactionalId2, producerId2), setErrors(transaction2Errors))
+      addPartitionsToTxnManager.verifyTransaction(transactionalId1, 
producerId1, producerEpoch = 0, topicPartitions, setErrors(transaction1Errors))
+      addPartitionsToTxnManager.verifyTransaction(transactionalId2, 
producerId2, producerEpoch = 0, topicPartitions, setErrors(transaction2Errors))
     }
 
     val expectedAuthErrors = topicPartitions.map(_ -> 
Errors.SASL_AUTHENTICATION_FAILED).toMap
@@ -237,29 +378,49 @@ class AddPartitionsToTxnManagerTest {
     val mockVerificationFailureMeter = mock(classOf[Meter])
     val mockVerificationTime = mock(classOf[Histogram])
 
+    when(partitionFor.apply(transactionalId1)).thenReturn(0)
+    when(partitionFor.apply(transactionalId2)).thenReturn(1)
+    
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName))
+      .thenReturn(Seq(
+        new MetadataResponseData.MetadataResponseTopic()
+          .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
+          .setPartitions(List(
+            new MetadataResponseData.MetadataResponsePartition()
+              .setPartitionIndex(0)
+              .setLeaderId(0),
+            new MetadataResponseData.MetadataResponsePartition()
+              .setPartitionIndex(1)
+              .setLeaderId(1)
+          ).asJava)
+      ))
+    when(metadataCache.getAliveBrokerNode(0, config.interBrokerListenerName))
+      .thenReturn(Some(node0))
+    when(metadataCache.getAliveBrokerNode(1, config.interBrokerListenerName))
+      .thenReturn(Some(node1))
+
     // Update max verification time when we see a higher verification time.
-    when(mockVerificationTime.update(anyLong())).thenAnswer(
-      {
-        invocation =>
-          val newTime = invocation.getArgument(0).asInstanceOf[Long]
-          if (newTime > maxVerificationTime)
-            maxVerificationTime = newTime
-      }
-    )
+    when(mockVerificationTime.update(anyLong())).thenAnswer { invocation =>
+      val newTime = invocation.getArgument(0).asInstanceOf[Long]
+      if (newTime > maxVerificationTime)
+        maxVerificationTime = newTime
+    }
 
     val mockMetricsGroupCtor = mockConstruction(classOf[KafkaMetricsGroup], 
(mock: KafkaMetricsGroup, context: Context) => {
-        
when(mock.newMeter(ArgumentMatchers.eq(AddPartitionsToTxnManager.VerificationFailureRateMetricName),
 anyString(), any(classOf[TimeUnit]))).thenReturn(mockVerificationFailureMeter)
-        
when(mock.newHistogram(ArgumentMatchers.eq(AddPartitionsToTxnManager.VerificationTimeMsMetricName))).thenReturn(mockVerificationTime)
-      })
+      
when(mock.newMeter(ArgumentMatchers.eq(AddPartitionsToTxnManager.VerificationFailureRateMetricName),
 anyString(), any(classOf[TimeUnit]))).thenReturn(mockVerificationFailureMeter)
+      
when(mock.newHistogram(ArgumentMatchers.eq(AddPartitionsToTxnManager.VerificationTimeMsMetricName))).thenReturn(mockVerificationTime)
+    })
 
     val addPartitionsManagerWithMockedMetrics = new AddPartitionsToTxnManager(
-      KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:2181")),
+      config,
       networkClient,
-      time)
+      metadataCache,
+      partitionFor,
+      time
+    )
 
     try {
-      addPartitionsManagerWithMockedMetrics.addTxnData(node0, 
transactionData(transactionalId1, producerId1), setErrors(transactionErrors))
-      addPartitionsManagerWithMockedMetrics.addTxnData(node1, 
transactionData(transactionalId2, producerId2), setErrors(transactionErrors))
+      
addPartitionsManagerWithMockedMetrics.verifyTransaction(transactionalId1, 
producerId1, producerEpoch = 0, topicPartitions, setErrors(transactionErrors))
+      
addPartitionsManagerWithMockedMetrics.verifyTransaction(transactionalId2, 
producerId2, producerEpoch = 0, topicPartitions, setErrors(transactionErrors))
 
       time.sleep(100)
 
@@ -297,15 +458,25 @@ class AddPartitionsToTxnManagerTest {
     }
   }
 
-  private def clientResponse(response: AbstractResponse, authException: 
AuthenticationException = null, mismatchException: UnsupportedVersionException 
= null, disconnected: Boolean = false): ClientResponse = {
+  private def clientResponse(
+    response: AbstractResponse,
+    authException: AuthenticationException = null,
+    mismatchException: UnsupportedVersionException = null,
+    disconnected: Boolean = false
+  ): ClientResponse = {
     new ClientResponse(null, null, null, 0, 0, disconnected, 
mismatchException, authException, response)
   }
 
-  private def transactionData(transactionalId: String, producerId: Long, 
producerEpoch: Short = 0): AddPartitionsToTxnTransaction = {
+  private def transactionData(
+    transactionalId: String,
+    producerId: Long,
+    producerEpoch: Short = 0
+  ): AddPartitionsToTxnTransaction = {
     new AddPartitionsToTxnTransaction()
       .setTransactionalId(transactionalId)
       .setProducerId(producerId)
       .setProducerEpoch(producerEpoch)
+      .setVerifyOnly(true)
       .setTopics(new AddPartitionsToTxnTopicCollection(
         Seq(new AddPartitionsToTxnTopic()
           .setName(topic)
@@ -316,11 +487,21 @@ class AddPartitionsToTxnManagerTest {
     
addPartitionsToTxnManager.generateRequests().asScala.head.handler.onComplete(response)
   }
 
-  private def verifyRequest(expectedDestination: Node, transactionalId: 
String, producerId: Long, requestAndHandler: RequestAndCompletionHandler): Unit 
= {
+  private def verifyRequest(
+    expectedDestination: Node,
+    transactionalId: String,
+    producerId: Long,
+    requestAndHandler: RequestAndCompletionHandler
+  ): Unit = {
     assertEquals(time.milliseconds(), requestAndHandler.creationTimeMs)
     assertEquals(expectedDestination, requestAndHandler.destination)
-    assertEquals(AddPartitionsToTxnRequest.Builder.forBroker(
-      new 
AddPartitionsToTxnTransactionCollection(Seq(transactionData(transactionalId, 
producerId)).iterator.asJava)).data,
-      
requestAndHandler.request.asInstanceOf[AddPartitionsToTxnRequest.Builder].data)
+    assertEquals(
+      AddPartitionsToTxnRequest.Builder.forBroker(
+        new AddPartitionsToTxnTransactionCollection(
+          Seq(transactionData(transactionalId, producerId)).iterator.asJava
+        )
+      ).data,
+      
requestAndHandler.request.asInstanceOf[AddPartitionsToTxnRequest.Builder].data
+    )
   }
 }
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala 
b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index 7b398622cc3..13fc0412fd3 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -2326,9 +2326,8 @@ class KafkaApisTest {
         any(),
         any(),
         any(),
-        any(),
-        any())
-      ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new 
PartitionResponse(Errors.INVALID_PRODUCER_EPOCH))))
+        any()
+      )).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new 
PartitionResponse(Errors.INVALID_PRODUCER_EPOCH))))
 
       
when(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(any[RequestChannel.Request](),
         any[Long])).thenReturn(0)
@@ -2351,7 +2350,6 @@ class KafkaApisTest {
   def testTransactionalParametersSetCorrectly(): Unit = {
     val topic = "topic"
     val transactionalId = "txn1"
-    val transactionCoordinatorPartition = 35
 
     addTopicToMetadataCache(topic, numPartitions = 2)
 
@@ -2379,10 +2377,6 @@ class KafkaApisTest {
 
       val kafkaApis = createKafkaApis()
       
-      when(txnCoordinator.partitionFor(
-        ArgumentMatchers.eq(transactionalId))
-      ).thenReturn(transactionCoordinatorPartition)
-      
       kafkaApis.handleProduceRequest(request, 
RequestLocal.withThreadConfinedCaching)
       
       verify(replicaManager).appendRecords(anyLong,
@@ -2395,7 +2389,6 @@ class KafkaApisTest {
         any(),
         any(),
         ArgumentMatchers.eq(transactionalId),
-        ArgumentMatchers.eq(Some(transactionCoordinatorPartition)),
         any())
     }
   }
@@ -2523,9 +2516,8 @@ class KafkaApisTest {
       any(),
       ArgumentMatchers.eq(requestLocal),
       any(),
-      any(),
-      any())
-    ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp2 -> new 
PartitionResponse(Errors.NONE))))
+      any()
+    )).thenAnswer(_ => responseCallback.getValue.apply(Map(tp2 -> new 
PartitionResponse(Errors.NONE))))
 
     createKafkaApis().handleWriteTxnMarkersRequest(request, requestLocal)
 
@@ -2656,9 +2648,8 @@ class KafkaApisTest {
       any(),
       ArgumentMatchers.eq(requestLocal),
       any(),
-      any(),
-      any())
-    ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp2 -> new 
PartitionResponse(Errors.NONE))))
+      any()
+    )).thenAnswer(_ => responseCallback.getValue.apply(Map(tp2 -> new 
PartitionResponse(Errors.NONE))))
 
     createKafkaApis().handleWriteTxnMarkersRequest(request, requestLocal)
     verify(requestChannel).sendResponse(
@@ -2691,7 +2682,6 @@ class KafkaApisTest {
       any(),
       ArgumentMatchers.eq(requestLocal),
       any(),
-      any(),
       any())
   }
 
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala 
b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 3663e023dfc..9028e323564 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -68,9 +68,6 @@ import org.junit.jupiter.params.provider.{EnumSource, 
ValueSource}
 import com.yammer.metrics.core.Gauge
 import kafka.log.remote.RemoteLogManager
 import org.apache.kafka.common.config.AbstractConfig
-import org.apache.kafka.common.internals.Topic
-import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTopic,
 AddPartitionsToTxnTopicCollection, AddPartitionsToTxnTransaction}
-import 
org.apache.kafka.common.message.MetadataResponseData.{MetadataResponsePartition,
 MetadataResponseTopic}
 import 
org.apache.kafka.server.log.remote.storage.{NoOpRemoteLogMetadataManager, 
NoOpRemoteStorageManager, RemoteLogManagerConfig}
 import org.apache.kafka.server.util.timer.MockTimer
 import org.mockito.invocation.InvocationOnMock
@@ -117,10 +114,9 @@ class ReplicaManagerTest {
     addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
     // Anytime we try to verify, just automatically run the callback as though 
the transaction was verified.
-    when(addPartitionsToTxnManager.addTxnData(any(), any(), any())).thenAnswer 
{
-      invocationOnMock =>
-        val callback = invocationOnMock.getArgument(2, 
classOf[AddPartitionsToTxnManager.AppendCallback])
-        callback(Map.empty[TopicPartition, Errors].toMap)
+    when(addPartitionsToTxnManager.verifyTransaction(any(), any(), any(), 
any(), any())).thenAnswer { invocationOnMock =>
+      val callback = invocationOnMock.getArgument(4, 
classOf[AddPartitionsToTxnManager.AppendCallback])
+      callback(Map.empty[TopicPartition, Errors].toMap)
     }
   }
 
@@ -657,7 +653,7 @@ class ReplicaManagerTest {
       val sequence = 9
       val records = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, 
sequence,
         new SimpleRecord(time.milliseconds(), s"message $sequence".getBytes))
-      appendRecords(replicaManager, new TopicPartition(topic, 0), records, 
transactionalId = transactionalId, transactionStatePartition = Some(0)).onFire 
{ response =>
+      appendRecords(replicaManager, new TopicPartition(topic, 0), records, 
transactionalId = transactionalId).onFire { response =>
         assertEquals(Errors.NONE, response.error)
       }
       assertLateTransactionCount(Some(0))
@@ -721,7 +717,7 @@ class ReplicaManagerTest {
       for (sequence <- 0 until numRecords) {
         val records = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, 
sequence,
           new SimpleRecord(s"message $sequence".getBytes))
-        appendRecords(replicaManager, new TopicPartition(topic, 0), records, 
transactionalId = transactionalId, transactionStatePartition = Some(0)).onFire 
{ response =>
+        appendRecords(replicaManager, new TopicPartition(topic, 0), records, 
transactionalId = transactionalId).onFire { response =>
           assertEquals(Errors.NONE, response.error)
         }
       }
@@ -842,7 +838,7 @@ class ReplicaManagerTest {
       for (sequence <- 0 until numRecords) {
         val records = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, 
sequence,
           new SimpleRecord(s"message $sequence".getBytes))
-        appendRecords(replicaManager, new TopicPartition(topic, 0), records, 
transactionalId = transactionalId, transactionStatePartition = Some(0)).onFire 
{ response =>
+        appendRecords(replicaManager, new TopicPartition(topic, 0), records, 
transactionalId = transactionalId).onFire { response =>
           assertEquals(Errors.NONE, response.error)
         }
       }
@@ -2150,10 +2146,9 @@ class ReplicaManagerTest {
     val producerId = 24L
     val producerEpoch = 0.toShort
     val sequence = 0
-    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0, tp1), node)
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0, tp1))
     try {
       replicaManager.becomeLeaderOrFollower(1,
         makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
@@ -2167,26 +2162,23 @@ class ReplicaManagerTest {
       val idempotentRecords = 
MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
         new SimpleRecord("message".getBytes))
       appendRecords(replicaManager, tp0, idempotentRecords)
-      verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), 
any[AddPartitionsToTxnManager.AppendCallback]())
+      verify(addPartitionsToTxnManager, times(0)).verifyTransaction(any(), 
any(), any(), any(), any[AddPartitionsToTxnManager.AppendCallback]())
       assertNull(getVerificationGuard(replicaManager, tp0, producerId))
 
       // If we supply a transactional ID and some transactional and some 
idempotent records, we should only verify the topic partition with 
transactional records.
       val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence + 1,
         new SimpleRecord("message".getBytes))
 
-      val transactionToAdd = new AddPartitionsToTxnTransaction()
-        .setTransactionalId(transactionalId)
-        .setProducerId(producerId)
-        .setProducerEpoch(producerEpoch)
-        .setVerifyOnly(true)
-        .setTopics(new AddPartitionsToTxnTopicCollection(
-          Seq(new 
AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
-        ))
-
       val idempotentRecords2 = 
MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
         new SimpleRecord("message".getBytes))
-      appendRecordsToMultipleTopics(replicaManager, Map(tp0 -> 
transactionalRecords, tp1 -> idempotentRecords2), transactionalId, Some(0))
-      verify(addPartitionsToTxnManager, 
times(1)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), 
any[AddPartitionsToTxnManager.AppendCallback]())
+      appendRecordsToMultipleTopics(replicaManager, Map(tp0 -> 
transactionalRecords, tp1 -> idempotentRecords2), transactionalId)
+      verify(addPartitionsToTxnManager, times(1)).verifyTransaction(
+        ArgumentMatchers.eq(transactionalId),
+        ArgumentMatchers.eq(producerId),
+        ArgumentMatchers.eq(producerEpoch),
+        ArgumentMatchers.eq(Seq(tp0)),
+        any[AddPartitionsToTxnManager.AppendCallback]()
+      )
       assertNotNull(getVerificationGuard(replicaManager, tp0, producerId))
       assertNull(getVerificationGuard(replicaManager, tp1, producerId))
     } finally {
@@ -2200,10 +2192,9 @@ class ReplicaManagerTest {
     val producerId = 24L
     val producerEpoch = 0.toShort
     val sequence = 6
-    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0), node)
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0))
     try {
       replicaManager.becomeLeaderOrFollower(1,
         makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
@@ -2213,19 +2204,16 @@ class ReplicaManagerTest {
       val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
         new SimpleRecord("message".getBytes))
 
-      val transactionToAdd = new AddPartitionsToTxnTransaction()
-        .setTransactionalId(transactionalId)
-        .setProducerId(producerId)
-        .setProducerEpoch(producerEpoch)
-        .setVerifyOnly(true)
-        .setTopics(new AddPartitionsToTxnTopicCollection(
-          Seq(new 
AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
-        ))
-
       // We should add these partitions to the manager to verify.
-      val result = appendRecords(replicaManager, tp0, transactionalRecords, 
transactionalId = transactionalId, transactionStatePartition = Some(0))
+      val result = appendRecords(replicaManager, tp0, transactionalRecords, 
transactionalId = transactionalId)
       val appendCallback = 
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
-      verify(addPartitionsToTxnManager, 
times(1)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), appendCallback.capture())
+      verify(addPartitionsToTxnManager, times(1)).verifyTransaction(
+        ArgumentMatchers.eq(transactionalId),
+        ArgumentMatchers.eq(producerId),
+        ArgumentMatchers.eq(producerEpoch),
+        ArgumentMatchers.eq(Seq(tp0)),
+        appendCallback.capture()
+      )
       val verificationGuard = getVerificationGuard(replicaManager, tp0, 
producerId)
       assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
 
@@ -2236,9 +2224,15 @@ class ReplicaManagerTest {
       assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
 
       // This time verification is successful.
-      appendRecords(replicaManager, tp0, transactionalRecords, transactionalId 
= transactionalId, transactionStatePartition = Some(0))
+      appendRecords(replicaManager, tp0, transactionalRecords, transactionalId 
= transactionalId)
       val appendCallback2 = 
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
-      verify(addPartitionsToTxnManager, 
times(2)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), appendCallback2.capture())
+      verify(addPartitionsToTxnManager, times(2)).verifyTransaction(
+        ArgumentMatchers.eq(transactionalId),
+        ArgumentMatchers.eq(producerId),
+        ArgumentMatchers.eq(producerEpoch),
+        ArgumentMatchers.eq(Seq(tp0)),
+        appendCallback2.capture()
+      )
       assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
 
       val callback2: AddPartitionsToTxnManager.AppendCallback = 
appendCallback2.getValue()
@@ -2256,10 +2250,9 @@ class ReplicaManagerTest {
     val producerId = 24L
     val producerEpoch = 0.toShort
     val sequence = 6
-    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0), node)
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0))
     try {
       replicaManager.becomeLeaderOrFollower(1,
         makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
@@ -2269,19 +2262,16 @@ class ReplicaManagerTest {
       val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
         new SimpleRecord("message".getBytes))
 
-      val transactionToAdd = new AddPartitionsToTxnTransaction()
-        .setTransactionalId(transactionalId)
-        .setProducerId(producerId)
-        .setProducerEpoch(producerEpoch)
-        .setVerifyOnly(true)
-        .setTopics(new AddPartitionsToTxnTopicCollection(
-          Seq(new 
AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
-        ))
-
       // We should add these partitions to the manager to verify.
-      val result = appendRecords(replicaManager, tp0, transactionalRecords, 
transactionalId = transactionalId, transactionStatePartition = Some(0))
+      val result = appendRecords(replicaManager, tp0, transactionalRecords, 
transactionalId = transactionalId)
       val appendCallback = 
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
-      verify(addPartitionsToTxnManager, 
times(1)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), appendCallback.capture())
+      verify(addPartitionsToTxnManager, times(1)).verifyTransaction(
+        ArgumentMatchers.eq(transactionalId),
+        ArgumentMatchers.eq(producerId),
+        ArgumentMatchers.eq(producerEpoch),
+        ArgumentMatchers.eq(Seq(tp0)),
+        appendCallback.capture()
+      )
       val verificationGuard = getVerificationGuard(replicaManager, tp0, 
producerId)
       assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
 
@@ -2295,9 +2285,15 @@ class ReplicaManagerTest {
       val transactionalRecords2 = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence + 1,
         new SimpleRecord("message".getBytes))
 
-      val result2 = appendRecords(replicaManager, tp0, transactionalRecords2, 
transactionalId = transactionalId, transactionStatePartition = Some(0))
+      val result2 = appendRecords(replicaManager, tp0, transactionalRecords2, 
transactionalId = transactionalId)
       val appendCallback2 = 
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
-      verify(addPartitionsToTxnManager, 
times(2)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), appendCallback2.capture())
+      verify(addPartitionsToTxnManager, times(2)).verifyTransaction(
+        ArgumentMatchers.eq(transactionalId),
+        ArgumentMatchers.eq(producerId),
+        ArgumentMatchers.eq(producerEpoch),
+        ArgumentMatchers.eq(Seq(tp0)),
+        appendCallback2.capture()
+      )
       assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
 
       // Verification should succeed, but we expect to fail with 
OutOfOrderSequence and for the verification guard to remain.
@@ -2332,9 +2328,9 @@ class ReplicaManagerTest {
       val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
         new SimpleRecord(s"message $sequence".getBytes))
 
-      appendRecordsToMultipleTopics(replicaManager, Map(tp0 -> 
transactionalRecords, tp1 -> transactionalRecords), transactionalId, 
Some(0)).onFire { responses =>
+      appendRecordsToMultipleTopics(replicaManager, Map(tp0 -> 
transactionalRecords, tp1 -> transactionalRecords), transactionalId).onFire { 
responses =>
         responses.foreach {
-          entry => assertEquals(Errors.NONE, entry._2)
+          entry => assertEquals(Errors.NONE, entry._2.error)
         }
       }
     } finally {
@@ -2351,10 +2347,9 @@ class ReplicaManagerTest {
     val producerEpoch = 0.toShort
     val sequence = 0
 
-    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0, tp1), node)
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0, tp1))
 
     try {
       replicaManager.becomeLeaderOrFollower(1,
@@ -2373,9 +2368,9 @@ class ReplicaManagerTest {
         new SimpleRecord(s"message $sequence".getBytes)))
 
       assertThrows(classOf[InvalidPidMappingException],
-        () => appendRecordsToMultipleTopics(replicaManager, 
transactionalRecords, transactionalId = transactionalId, 
transactionStatePartition = Some(0)))
+        () => appendRecordsToMultipleTopics(replicaManager, 
transactionalRecords, transactionalId = transactionalId))
       // We should not add these partitions to the manager to verify.
-      verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), 
any())
+      verify(addPartitionsToTxnManager, times(0)).verifyTransaction(any(), 
any(), any(), any(), any())
     } finally {
       replicaManager.shutdown(checkpointHW = false)
     }
@@ -2387,29 +2382,19 @@ class ReplicaManagerTest {
     val producerId = 24L
     val producerEpoch = 0.toShort
     val sequence = 6
-    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0), node)
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0))
     try {
       // Append some transactional records.
       val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
         new SimpleRecord("message".getBytes))
 
-      val transactionToAdd = new AddPartitionsToTxnTransaction()
-        .setTransactionalId(transactionalId)
-        .setProducerId(producerId)
-        .setProducerEpoch(producerEpoch)
-        .setVerifyOnly(true)
-        .setTopics(new AddPartitionsToTxnTopicCollection(
-          Seq(new 
AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
-        ))
-
       // We should not add these partitions to the manager to verify, but 
instead throw an error.
-      appendRecords(replicaManager, tp0, transactionalRecords, transactionalId 
= transactionalId, transactionStatePartition = Some(0)).onFire { response =>
+      appendRecords(replicaManager, tp0, transactionalRecords, transactionalId 
= transactionalId).onFire { response =>
         assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, response.error)
       }
-      verify(addPartitionsToTxnManager, 
times(0)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), 
any[AddPartitionsToTxnManager.AppendCallback]())
+      verify(addPartitionsToTxnManager, times(0)).verifyTransaction(any(), 
any(), any(), any(), any())
     } finally {
       replicaManager.shutdown(checkpointHW = false)
     }
@@ -2427,10 +2412,9 @@ class ReplicaManagerTest {
     val producerEpoch = 0.toShort
     val sequence = 0
 
-    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp), node, config = config)
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp), config = config)
 
     try {
       val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp.topic), 
tp, Seq(0, 1), LeaderAndIsr(0, List(0, 1)))
@@ -2438,11 +2422,11 @@ class ReplicaManagerTest {
 
       val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
         new SimpleRecord(s"message $sequence".getBytes))
-      appendRecords(replicaManager, tp, transactionalRecords, transactionalId 
= transactionalId, transactionStatePartition = Some(0))
+      appendRecords(replicaManager, tp, transactionalRecords, transactionalId 
= transactionalId)
       assertNull(getVerificationGuard(replicaManager, tp, producerId))
 
       // We should not add these partitions to the manager to verify.
-      verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), 
any())
+      verify(addPartitionsToTxnManager, times(0)).verifyTransaction(any(), 
any(), any(), any(), any())
 
       // Dynamically enable verification.
       config.dynamicConfig.initialize(None)
@@ -2455,8 +2439,8 @@ class ReplicaManagerTest {
       val moreTransactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence + 1,
         new SimpleRecord("message".getBytes))
 
-      appendRecords(replicaManager, tp, moreTransactionalRecords, 
transactionalId = transactionalId, transactionStatePartition = Some(0))
-      verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), 
any())
+      appendRecords(replicaManager, tp, moreTransactionalRecords, 
transactionalId = transactionalId)
+      verify(addPartitionsToTxnManager, times(0)).verifyTransaction(any(), 
any(), any(), any(), any())
       assertEquals(null, getVerificationGuard(replicaManager, tp, producerId))
       
assertTrue(replicaManager.localLog(tp).get.hasOngoingTransaction(producerId))
     } finally {
@@ -2470,10 +2454,9 @@ class ReplicaManagerTest {
     val producerId = 24L
     val producerEpoch = 0.toShort
     val sequence = 6
-    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0), node)
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0))
     try {
       replicaManager.becomeLeaderOrFollower(1,
         makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
@@ -2483,19 +2466,16 @@ class ReplicaManagerTest {
       val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
         new SimpleRecord("message".getBytes))
 
-      val transactionToAdd = new AddPartitionsToTxnTransaction()
-        .setTransactionalId(transactionalId)
-        .setProducerId(producerId)
-        .setProducerEpoch(producerEpoch)
-        .setVerifyOnly(true)
-        .setTopics(new AddPartitionsToTxnTopicCollection(
-          Seq(new 
AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
-        ))
-
       // We should add these partitions to the manager to verify.
-      val result = appendRecords(replicaManager, tp0, transactionalRecords, 
transactionalId = transactionalId, transactionStatePartition = Some(0))
+      val result = appendRecords(replicaManager, tp0, transactionalRecords, 
transactionalId = transactionalId)
       val appendCallback = 
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
-      verify(addPartitionsToTxnManager, 
times(1)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), appendCallback.capture())
+      verify(addPartitionsToTxnManager, times(1)).verifyTransaction(
+        ArgumentMatchers.eq(transactionalId),
+        ArgumentMatchers.eq(producerId),
+        ArgumentMatchers.eq(producerEpoch),
+        ArgumentMatchers.eq(Seq(tp0)),
+        appendCallback.capture()
+      )
       val verificationGuard = getVerificationGuard(replicaManager, tp0, 
producerId)
       assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
 
@@ -2513,8 +2493,8 @@ class ReplicaManagerTest {
       assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
 
       // This time we do not verify
-      appendRecords(replicaManager, tp0, transactionalRecords, transactionalId 
= transactionalId, transactionStatePartition = Some(0))
-      verify(addPartitionsToTxnManager, times(1)).addTxnData(any(), any(), 
any())
+      appendRecords(replicaManager, tp0, transactionalRecords, transactionalId 
= transactionalId)
+      verify(addPartitionsToTxnManager, times(1)).verifyTransaction(any(), 
any(), any(), any(), any())
       assertEquals(null, getVerificationGuard(replicaManager, tp0, producerId))
       
assertTrue(replicaManager.localLog(tp0).get.hasOngoingTransaction(producerId))
     } finally {
@@ -2522,75 +2502,6 @@ class ReplicaManagerTest {
     }
   }
 
-  @Test
-  def testGetTransactionCoordinator(): Unit = {
-    val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new 
File(_)))
-
-    val metadataCache = mock(classOf[MetadataCache])
-
-    val replicaManager = new ReplicaManager(
-      metrics = metrics,
-      config = config,
-      time = time,
-      scheduler = new MockScheduler(time),
-      logManager = mockLogMgr,
-      quotaManagers = quotaManager,
-      metadataCache = metadataCache,
-      logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
-      alterPartitionManager = alterPartitionManager,
-      addPartitionsToTxnManager = Some(addPartitionsToTxnManager)
-    )
-
-    try {
-      val txnCoordinatorPartition0 = 0
-      val txnCoordinatorPartition1 = 1
-
-      // Before we set up the metadata cache, return nothing for the topic.
-      
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName)).thenReturn(Seq())
-      assertEquals((Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode), 
replicaManager.getTransactionCoordinator(txnCoordinatorPartition0))
-
-      // Return an error response.
-      
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName)).
-        thenReturn(Seq(new 
MetadataResponseTopic().setErrorCode(Errors.UNSUPPORTED_VERSION.code)))
-      assertEquals((Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode), 
replicaManager.getTransactionCoordinator(txnCoordinatorPartition0))
-
-      val metadataResponseTopic = Seq(new MetadataResponseTopic()
-        .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
-        .setPartitions(Seq(
-          new MetadataResponsePartition()
-            .setPartitionIndex(0)
-            .setLeaderId(0),
-          new MetadataResponsePartition()
-            .setPartitionIndex(1)
-            .setLeaderId(1)).asJava))
-      val node0 = new Node(0, "host1", 0)
-
-      
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
-      when(metadataCache.getAliveBrokerNode(0, 
config.interBrokerListenerName)).thenReturn(Some(node0))
-      when(metadataCache.getAliveBrokerNode(1, 
config.interBrokerListenerName)).thenReturn(None)
-
-      assertEquals((Errors.NONE, node0), 
replicaManager.getTransactionCoordinator(txnCoordinatorPartition0))
-      assertEquals((Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode), 
replicaManager.getTransactionCoordinator(txnCoordinatorPartition1))
-
-      // Test we convert the error correctly when trying to append and 
coordinator is not available
-      val tp0 = new TopicPartition(topic, 0)
-      val producerId = 24L
-      val producerEpoch = 0.toShort
-      val sequence = 0
-      replicaManager.becomeLeaderOrFollower(1,
-        makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
-        (_, _) => ())
-      val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
-        new SimpleRecord("message".getBytes))
-      val result = appendRecords(replicaManager, tp0, transactionalRecords, 
transactionalId = transactionalId, transactionStatePartition = 
Some(txnCoordinatorPartition1))
-      val expectedError = s"Unable to verify the partition has been added to 
the transaction. Underlying error: ${Errors.COORDINATOR_NOT_AVAILABLE.toString}"
-      assertEquals(Errors.NOT_ENOUGH_REPLICAS, result.assertFired.error)
-      assertEquals(expectedError, result.assertFired.errorMessage)
-    } finally {
-      replicaManager.shutdown(checkpointHW = false)
-    }
-  }
-
   @ParameterizedTest
   @EnumSource(value = classOf[Errors], names = Array("NOT_COORDINATOR", 
"CONCURRENT_TRANSACTIONS", "COORDINATOR_LOAD_IN_PROGRESS", 
"COORDINATOR_NOT_AVAILABLE"))
   def testVerificationErrorConversions(error: Errors): Unit = {
@@ -2598,10 +2509,9 @@ class ReplicaManagerTest {
     val producerId = 24L
     val producerEpoch = 0.toShort
     val sequence = 0
-    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0), node)
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0))
     try {
       replicaManager.becomeLeaderOrFollower(1,
         makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
@@ -2610,20 +2520,17 @@ class ReplicaManagerTest {
       val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
         new SimpleRecord("message".getBytes))
 
-      val transactionToAdd = new AddPartitionsToTxnTransaction()
-        .setTransactionalId(transactionalId)
-        .setProducerId(producerId)
-        .setProducerEpoch(producerEpoch)
-        .setVerifyOnly(true)
-        .setTopics(new AddPartitionsToTxnTopicCollection(
-          Seq(new 
AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
-        ))
-
       // Start verification and return the coordinator related errors.
       val expectedMessage = s"Unable to verify the partition has been added to 
the transaction. Underlying error: ${error.toString}"
-      val result = appendRecords(replicaManager, tp0, transactionalRecords, 
transactionalId = transactionalId, transactionStatePartition = Some(0))
+      val result = appendRecords(replicaManager, tp0, transactionalRecords, 
transactionalId = transactionalId)
       val appendCallback = 
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
-      verify(addPartitionsToTxnManager, 
times(1)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), appendCallback.capture())
+      verify(addPartitionsToTxnManager, times(1)).verifyTransaction(
+        ArgumentMatchers.eq(transactionalId),
+        ArgumentMatchers.eq(producerId),
+        ArgumentMatchers.eq(producerEpoch),
+        ArgumentMatchers.eq(Seq(tp0)),
+        appendCallback.capture()
+      )
 
       // Confirm we did not write to the log and instead returned the 
converted error with the correct error message.
       val callback: AddPartitionsToTxnManager.AppendCallback = 
appendCallback.getValue()
@@ -2880,8 +2787,7 @@ class ReplicaManagerTest {
                             records: MemoryRecords,
                             origin: AppendOrigin = AppendOrigin.CLIENT,
                             requiredAcks: Short = -1,
-                            transactionalId: String = null,
-                            transactionStatePartition: Option[Int] = None): 
CallbackResult[PartitionResponse] = {
+                            transactionalId: String = null): 
CallbackResult[PartitionResponse] = {
     val result = new CallbackResult[PartitionResponse]()
     def appendCallback(responses: Map[TopicPartition, PartitionResponse]): 
Unit = {
       val response = responses.get(partition)
@@ -2897,7 +2803,7 @@ class ReplicaManagerTest {
       entriesPerPartition = Map(partition -> records),
       responseCallback = appendCallback,
       transactionalId = transactionalId,
-      transactionStatePartition = transactionStatePartition)
+    )
 
     result
   }
@@ -2905,7 +2811,6 @@ class ReplicaManagerTest {
   private def appendRecordsToMultipleTopics(replicaManager: ReplicaManager,
                                             entriesToAppend: 
Map[TopicPartition, MemoryRecords],
                                             transactionalId: String,
-                                            transactionStatePartition: 
Option[Int],
                                             origin: AppendOrigin = 
AppendOrigin.CLIENT,
                                             requiredAcks: Short = -1): 
CallbackResult[Map[TopicPartition, PartitionResponse]] = {
     val result = new CallbackResult[Map[TopicPartition, PartitionResponse]]()
@@ -2922,7 +2827,7 @@ class ReplicaManagerTest {
       entriesPerPartition = entriesToAppend,
       responseCallback = appendCallback,
       transactionalId = transactionalId,
-      transactionStatePartition = transactionStatePartition)
+    )
 
     result
   }
@@ -3051,7 +2956,6 @@ class ReplicaManagerTest {
 
   private def 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager:
 AddPartitionsToTxnManager,
                                                                      
transactionalTopicPartitions: List[TopicPartition],
-                                                                     node: 
Node,
                                                                      config: 
KafkaConfig = config): ReplicaManager = {
     val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new 
File(_)))
     val metadataCache = mock(classOf[MetadataCache])
@@ -3068,17 +2972,7 @@ class ReplicaManagerTest {
       alterPartitionManager = alterPartitionManager,
       addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
 
-    val metadataResponseTopic = Seq(new MetadataResponseTopic()
-      .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
-      .setPartitions(Seq(
-        new MetadataResponsePartition()
-          .setPartitionIndex(0)
-          .setLeaderId(0)).asJava))
-
     transactionalTopicPartitions.foreach(tp => 
when(metadataCache.contains(tp)).thenReturn(true))
-    
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
-    when(metadataCache.getAliveBrokerNode(0, 
config.interBrokerListenerName)).thenReturn(Some(node))
-    when(metadataCache.getAliveBrokerNode(1, 
config.interBrokerListenerName)).thenReturn(None)
 
     // We will attempt to schedule to the request handler thread using a non 
request handler thread. Set this to avoid error.
     KafkaRequestHandler.setBypassThreadCheck(true)
@@ -3125,15 +3019,8 @@ class ReplicaManagerTest {
     val mockDelayedRemoteFetchPurgatory = new 
DelayedOperationPurgatory[DelayedRemoteFetch](
       purgatoryName = "DelayedRemoteFetch", timer, reaperEnabled = false)
 
-    // Set up transactions
-    val metadataResponseTopic = Seq(new MetadataResponseTopic()
-      .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
-      .setPartitions(Seq(
-        new MetadataResponsePartition()
-          .setPartitionIndex(0)
-          .setLeaderId(0)).asJava))
     when(metadataCache.contains(new TopicPartition(topic, 0))).thenReturn(true)
-    
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
+
     // Transactional appends attempt to schedule to the request handler thread 
using a non request handler thread. Set this to avoid error.
     KafkaRequestHandler.setBypassThreadCheck(true)
 


Reply via email to