This is an automated email from the ASF dual-hosted git repository.

jolshan pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 2f71708955b KAFKA-15028: AddPartitionsToTxnManager metrics (#13798)
2f71708955b is described below

commit 2f71708955b293658cec3b27e9a5588d39c38d7e
Author: Justine Olshan <[email protected]>
AuthorDate: Wed Jun 28 09:00:37 2023 -0700

    KAFKA-15028: AddPartitionsToTxnManager metrics (#13798)
    
    Adding the following metrics as per kip-890:
    
    VerificationTimeMs – number of milliseconds from adding partition info to 
the manager to the time the response is sent. This will include the round trip 
to the transaction coordinator if it is called. This will also account for 
verifications that fail before the coordinator is called.
    
    VerificationFailureRate – rate of verifications that returned in failure 
either from the AddPartitionsToTxn response or through errors in the manager.
    
    AddPartitionsToTxnVerification metrics – separating the verification 
request metrics from the typical add partitions ones similar to how fetch 
replication and fetch consumer metrics are separated.
    
    Reviewers: Divij Vaidya <[email protected]>
---
 .../common/requests/AddPartitionsToTxnRequest.java |  19 ++-
 .../main/scala/kafka/network/RequestChannel.scala  |  21 ++--
 .../kafka/server/AddPartitionsToTxnManager.scala   |  64 +++++++---
 .../server/AddPartitionsToTxnManagerTest.scala     |  79 +++++++++++-
 .../scala/unit/kafka/server/KafkaApisTest.scala    | 134 +++++++++++++--------
 5 files changed, 240 insertions(+), 77 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java
 
b/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java
index c91374fc507..b83e997b199 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java
@@ -40,6 +40,10 @@ import java.util.Map;
 
 public class AddPartitionsToTxnRequest extends AbstractRequest {
 
+    private static final short LAST_CLIENT_VERSION = (short) 3;
+    // Note: earliest broker version is also the first version to support 
verification requests.
+    private static final short EARLIEST_BROKER_VERSION = (short) 4;
+
     private final AddPartitionsToTxnRequestData data;
 
     public static class Builder extends 
AbstractRequest.Builder<AddPartitionsToTxnRequest> {
@@ -52,7 +56,7 @@ public class AddPartitionsToTxnRequest extends 
AbstractRequest {
 
             AddPartitionsToTxnTopicCollection topics = 
buildTxnTopicCollection(partitions);
             
-            return new Builder(ApiKeys.ADD_PARTITIONS_TO_TXN.oldestVersion(), 
(short) 3,
+            return new Builder(ApiKeys.ADD_PARTITIONS_TO_TXN.oldestVersion(), 
LAST_CLIENT_VERSION,
                 new AddPartitionsToTxnRequestData()
                     .setV3AndBelowTransactionalId(transactionalId)
                     .setV3AndBelowProducerId(producerId)
@@ -61,7 +65,7 @@ public class AddPartitionsToTxnRequest extends 
AbstractRequest {
         }
         
         public static Builder 
forBroker(AddPartitionsToTxnTransactionCollection transactions) {
-            return new Builder((short) 4, 
ApiKeys.ADD_PARTITIONS_TO_TXN.latestVersion(),
+            return new Builder(EARLIEST_BROKER_VERSION, 
ApiKeys.ADD_PARTITIONS_TO_TXN.latestVersion(),
                 new AddPartitionsToTxnRequestData()
                     .setTransactions(transactions));
         }
@@ -120,7 +124,7 @@ public class AddPartitionsToTxnRequest extends 
AbstractRequest {
     public AddPartitionsToTxnResponse getErrorResponse(int throttleTimeMs, 
Throwable e) {
         Errors error = Errors.forException(e);
         AddPartitionsToTxnResponseData response = new 
AddPartitionsToTxnResponseData();
-        if (version() < 4) {
+        if (version() < EARLIEST_BROKER_VERSION) {
             
response.setResultsByTopicV3AndBelow(errorResponseForTopics(data.v3AndBelowTopics(),
 error));
         } else {
             response.setErrorCode(error.code());
@@ -149,11 +153,18 @@ public class AddPartitionsToTxnRequest extends 
AbstractRequest {
         return partitionsByTransaction;
     }
 
-    // Takes a version 3 or below request and returns a v4+ singleton (one 
transaction ID) request.
+    // Takes a version 3 or below request (client request) and returns a v4+ 
singleton (one transaction ID) request.
     public AddPartitionsToTxnRequest normalizeRequest() {
         return new AddPartitionsToTxnRequest(new 
AddPartitionsToTxnRequestData().setTransactions(singletonTransaction()), 
version());
     }
 
+    // This method returns true if all the transactions in it are verify only. 
One reason to distinguish is to separate
+    // requests that will need to write to log in the non error case (adding 
partitions) from ones that will not (verify only).
+    public boolean allVerifyOnlyRequest() {
+        return version() > LAST_CLIENT_VERSION &&
+            
data.transactions().stream().allMatch(AddPartitionsToTxnTransaction::verifyOnly);
+    }
+
     private AddPartitionsToTxnTransactionCollection singletonTransaction() {
         AddPartitionsToTxnTransactionCollection singleTxn = new 
AddPartitionsToTxnTransactionCollection();
         singleTxn.add(new AddPartitionsToTxnTransaction()
diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala 
b/core/src/main/scala/kafka/network/RequestChannel.scala
index 34a860a2098..477f02a9c98 100644
--- a/core/src/main/scala/kafka/network/RequestChannel.scala
+++ b/core/src/main/scala/kafka/network/RequestChannel.scala
@@ -69,7 +69,7 @@ object RequestChannel extends Logging {
     private val metricsMap = mutable.Map[String, RequestMetrics]()
 
     (enabledApis.map(_.name) ++
-      Seq(RequestMetrics.consumerFetchMetricName, 
RequestMetrics.followFetchMetricName)).foreach { name =>
+      Seq(RequestMetrics.consumerFetchMetricName, 
RequestMetrics.followFetchMetricName, 
RequestMetrics.verifyPartitionsInTxnMetricName)).foreach { name =>
       metricsMap.put(name, new RequestMetrics(name))
     }
 
@@ -240,17 +240,18 @@ object RequestChannel extends Logging {
       val responseSendTimeMs = nanosToMs(endTimeNanos - 
responseDequeueTimeNanos)
       val messageConversionsTimeMs = nanosToMs(messageConversionsTimeNanos)
       val totalTimeMs = nanosToMs(endTimeNanos - startTimeNanos)
-      val fetchMetricNames =
+      val overrideMetricNames =
         if (header.apiKey == ApiKeys.FETCH) {
-          val isFromFollower = body[FetchRequest].isFromFollower
-          Seq(
-            if (isFromFollower) RequestMetrics.followFetchMetricName
+          val specifiedMetricName =
+            if (body[FetchRequest].isFromFollower) 
RequestMetrics.followFetchMetricName
             else RequestMetrics.consumerFetchMetricName
-          )
+          Seq(specifiedMetricName, header.apiKey.name)
+        } else if (header.apiKey == ApiKeys.ADD_PARTITIONS_TO_TXN && 
body[AddPartitionsToTxnRequest].allVerifyOnlyRequest) {
+            Seq(RequestMetrics.verifyPartitionsInTxnMetricName)
+        } else {
+          Seq(header.apiKey.name)
         }
-        else Seq.empty
-      val metricNames = fetchMetricNames :+ header.apiKey.name
-      metricNames.foreach { metricName =>
+      overrideMetricNames.foreach { metricName =>
         val m = metrics(metricName)
         m.requestRate(header.apiVersion).mark()
         m.requestQueueTimeHist.update(Math.round(requestQueueTimeMs))
@@ -517,6 +518,8 @@ object RequestMetrics {
   val consumerFetchMetricName = ApiKeys.FETCH.name + "Consumer"
   val followFetchMetricName = ApiKeys.FETCH.name + "Follower"
 
+  val verifyPartitionsInTxnMetricName = ApiKeys.ADD_PARTITIONS_TO_TXN.name + 
"Verification"
+
   val RequestsPerSec = "RequestsPerSec"
   val RequestQueueTimeMs = "RequestQueueTimeMs"
   val LocalTimeMs = "LocalTimeMs"
diff --git a/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala 
b/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala
index cbf981a76dd..fc5705042fe 100644
--- a/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala
+++ b/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala
@@ -17,6 +17,7 @@
 
 package kafka.server
 
+import 
kafka.server.AddPartitionsToTxnManager.{VerificationFailureRateMetricName, 
VerificationTimeMsMetricName}
 import kafka.utils.Logging
 import org.apache.kafka.clients.{ClientResponse, NetworkClient, 
RequestCompletionHandler}
 import org.apache.kafka.common.{Node, TopicPartition}
@@ -24,18 +25,29 @@ import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartiti
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.requests.{AddPartitionsToTxnRequest, 
AddPartitionsToTxnResponse}
 import org.apache.kafka.common.utils.Time
+import org.apache.kafka.server.metrics.KafkaMetricsGroup
 import org.apache.kafka.server.util.{InterBrokerSendThread, 
RequestAndCompletionHandler}
 
 import java.util
+import java.util.concurrent.TimeUnit
 import scala.collection.mutable
 
 object AddPartitionsToTxnManager {
   type AppendCallback = Map[TopicPartition, Errors] => Unit
+
+  val VerificationFailureRateMetricName = "VerificationFailureRate"
+  val VerificationTimeMsMetricName = "VerificationTimeMs"
 }
 
 
+/*
+ * Data structure to hold the transactional data to send to a node. Note -- at 
most one request per transactional ID
+ * will exist at a time in the map. If a given transactional ID exists in the 
map, and a new request with the same ID
+ * comes in, one request will be in the map and one will return to the 
producer with a response depending on the epoch.
+ */
 class TransactionDataAndCallbacks(val transactionData: 
AddPartitionsToTxnTransactionCollection,
-                                  val callbacks: mutable.Map[String, 
AddPartitionsToTxnManager.AppendCallback])
+                                  val callbacks: mutable.Map[String, 
AddPartitionsToTxnManager.AppendCallback],
+                                  val startTimeMs: mutable.Map[String, Long])
 
 
 class AddPartitionsToTxnManager(config: KafkaConfig, client: NetworkClient, 
time: Time)
@@ -47,13 +59,19 @@ class AddPartitionsToTxnManager(config: KafkaConfig, 
client: NetworkClient, time
   private val inflightNodes = mutable.HashSet[Node]()
   private val nodesToTransactions = mutable.Map[Node, 
TransactionDataAndCallbacks]()
 
+  private val metricsGroup = new KafkaMetricsGroup(this.getClass)
+  val verificationFailureRate = 
metricsGroup.newMeter(VerificationFailureRateMetricName, "failures", 
TimeUnit.SECONDS)
+  val verificationTimeMs = 
metricsGroup.newHistogram(VerificationTimeMsMetricName)
+
   def addTxnData(node: Node, transactionData: AddPartitionsToTxnTransaction, 
callback: AddPartitionsToTxnManager.AppendCallback): Unit = {
     nodesToTransactions.synchronized {
+      val curTime = time.milliseconds()
       // Check if we have already have either node or individual transaction. 
Add the Node if it isn't there.
       val existingNodeAndTransactionData = 
nodesToTransactions.getOrElseUpdate(node,
         new TransactionDataAndCallbacks(
           new AddPartitionsToTxnTransactionCollection(1),
-          mutable.Map[String, AddPartitionsToTxnManager.AppendCallback]()))
+          mutable.Map[String, AddPartitionsToTxnManager.AppendCallback](),
+          mutable.Map[String, Long]()))
 
       val existingTransactionData = 
existingNodeAndTransactionData.transactionData.find(transactionData.transactionalId)
 
@@ -69,16 +87,17 @@ class AddPartitionsToTxnManager(config: KafkaConfig, 
client: NetworkClient, time
             Errors.NETWORK_EXCEPTION
           val oldCallback = 
existingNodeAndTransactionData.callbacks(transactionData.transactionalId)
           
existingNodeAndTransactionData.transactionData.remove(transactionData)
-          oldCallback(topicPartitionsToError(existingTransactionData, error))
+          sendCallback(oldCallback, 
topicPartitionsToError(existingTransactionData, error), 
existingNodeAndTransactionData.startTimeMs(transactionData.transactionalId))
         } else {
           // If the incoming transactionData's epoch is lower, we can return 
with INVALID_PRODUCER_EPOCH immediately.
-          callback(topicPartitionsToError(transactionData, 
Errors.INVALID_PRODUCER_EPOCH))
+          sendCallback(callback, topicPartitionsToError(transactionData, 
Errors.INVALID_PRODUCER_EPOCH), curTime)
           return
         }
       }
 
       existingNodeAndTransactionData.transactionData.add(transactionData)
       
existingNodeAndTransactionData.callbacks.put(transactionData.transactionalId, 
callback)
+      
existingNodeAndTransactionData.startTimeMs.put(transactionData.transactionalId, 
curTime)
       wakeup()
     }
   }
@@ -90,9 +109,15 @@ class AddPartitionsToTxnManager(config: KafkaConfig, 
client: NetworkClient, time
         topicPartitionsToError.put(new TopicPartition(topic.name, partition), 
error)
       }
     }
+    verificationFailureRate.mark(topicPartitionsToError.size)
     topicPartitionsToError.toMap
   }
 
+  private def sendCallback(callback: AddPartitionsToTxnManager.AppendCallback, 
errorMap: Map[TopicPartition, Errors], startTimeMs: Long): Unit = {
+    verificationTimeMs.update(time.milliseconds() - startTimeMs)
+    callback(errorMap)
+  }
+
   private class AddPartitionsToTxnHandler(node: Node, 
transactionDataAndCallbacks: TransactionDataAndCallbacks) extends 
RequestCompletionHandler {
     override def onComplete(response: ClientResponse): Unit = {
       // Note: Synchronization is not needed on inflightNodes since it is 
always accessed from this thread.
@@ -100,20 +125,18 @@ class AddPartitionsToTxnManager(config: KafkaConfig, 
client: NetworkClient, time
       if (response.authenticationException != null) {
         error(s"AddPartitionsToTxnRequest failed for node 
${response.destination} with an " +
           "authentication exception.", response.authenticationException)
-        transactionDataAndCallbacks.callbacks.foreach { case (txnId, callback) 
=>
-          callback(buildErrorMap(txnId, 
Errors.forException(response.authenticationException).code))
-        }
+        
sendCallbacksToAll(Errors.forException(response.authenticationException).code)
       } else if (response.versionMismatch != null) {
         // We may see unsupported version exception if we try to send a verify 
only request to a broker that can't handle it.
         // In this case, skip verification.
         warn(s"AddPartitionsToTxnRequest failed for node 
${response.destination} with invalid version exception. This suggests 
verification is not supported." +
           s"Continuing handling the produce request.")
-        transactionDataAndCallbacks.callbacks.values.foreach(_(Map.empty))
-      } else if (response.wasDisconnected || response.wasTimedOut) {
-        warn(s"AddPartitionsToTxnRequest failed for node 
${response.destination} with a network exception.")
         transactionDataAndCallbacks.callbacks.foreach { case (txnId, callback) 
=>
-          callback(buildErrorMap(txnId, Errors.NETWORK_EXCEPTION.code))
+          sendCallback(callback, Map.empty, 
transactionDataAndCallbacks.startTimeMs(txnId))
         }
+      } else if (response.wasDisconnected || response.wasTimedOut) {
+        warn(s"AddPartitionsToTxnRequest failed for node 
${response.destination} with a network exception.")
+        sendCallbacksToAll(Errors.NETWORK_EXCEPTION.code)
       } else {
         val addPartitionsToTxnResponseData = 
response.responseBody.asInstanceOf[AddPartitionsToTxnResponse].data
         if (addPartitionsToTxnResponseData.errorCode != 0) {
@@ -125,9 +148,7 @@ class AddPartitionsToTxnManager(config: KafkaConfig, 
client: NetworkClient, time
           else
             addPartitionsToTxnResponseData.errorCode
 
-          transactionDataAndCallbacks.callbacks.foreach { case (txnId, 
callback) =>
-            callback(buildErrorMap(txnId, finalError))
-          }
+          sendCallbacksToAll(finalError)
         } else {
           addPartitionsToTxnResponseData.resultsByTransaction.forEach { 
transactionResult =>
             val unverified = mutable.Map[TopicPartition, Errors]()
@@ -148,8 +169,9 @@ class AddPartitionsToTxnManager(config: KafkaConfig, 
client: NetworkClient, time
                 }
               }
             }
+            verificationFailureRate.mark(unverified.size)
             val callback = 
transactionDataAndCallbacks.callbacks(transactionResult.transactionalId)
-            callback(unverified.toMap)
+            sendCallback(callback, unverified.toMap, 
transactionDataAndCallbacks.startTimeMs(transactionResult.transactionalId))
           }
         }
       }
@@ -160,6 +182,12 @@ class AddPartitionsToTxnManager(config: KafkaConfig, 
client: NetworkClient, time
       val transactionData = 
transactionDataAndCallbacks.transactionData.find(transactionalId)
       topicPartitionsToError(transactionData, Errors.forCode(errorCode))
     }
+
+    private def sendCallbacksToAll(errorCode: Short): Unit = {
+      transactionDataAndCallbacks.callbacks.foreach { case (txnId, callback) =>
+        sendCallback(callback, buildErrorMap(txnId, errorCode), 
transactionDataAndCallbacks.startTimeMs(txnId))
+      }
+    }
   }
 
   override def generateRequests(): 
util.Collection[RequestAndCompletionHandler] = {
@@ -188,4 +216,10 @@ class AddPartitionsToTxnManager(config: KafkaConfig, 
client: NetworkClient, time
     list
   }
 
+  override def shutdown(): Unit = {
+    super.shutdown()
+    metricsGroup.removeMetric(VerificationFailureRateMetricName)
+    metricsGroup.removeMetric(VerificationTimeMsMetricName)
+  }
+
 }
diff --git 
a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala 
b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala
index 01ced6ab5d4..232e9d012d9 100644
--- a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala
@@ -17,6 +17,7 @@
 
 package unit.kafka.server
 
+import com.yammer.metrics.core.{Histogram, Meter}
 import kafka.server.{AddPartitionsToTxnManager, KafkaConfig}
 import kafka.utils.TestUtils
 import org.apache.kafka.clients.{ClientResponse, NetworkClient}
@@ -28,11 +29,16 @@ import org.apache.kafka.common.{Node, TopicPartition}
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.requests.{AbstractResponse, 
AddPartitionsToTxnRequest, AddPartitionsToTxnResponse}
 import org.apache.kafka.common.utils.MockTime
+import org.apache.kafka.server.metrics.KafkaMetricsGroup
 import org.apache.kafka.server.util.RequestAndCompletionHandler
 import org.junit.jupiter.api.Assertions.assertEquals
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
-import org.mockito.Mockito.mock
+import org.mockito.ArgumentMatchers
+import org.mockito.ArgumentMatchers.{any, anyLong, anyString}
+import org.mockito.MockedConstruction.Context
+import org.mockito.Mockito.{mock, mockConstruction, times, verify, 
verifyNoMoreInteractions, when}
 
+import java.util.concurrent.TimeUnit
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 
@@ -145,7 +151,7 @@ class AddPartitionsToTxnManagerTest {
     addPartitionsToTxnManager.addTxnData(node2, 
transactionData(transactionalId3, producerId3), setErrors(transactionErrors))
 
     // Test creationTimeMs increases too.
-    time.sleep(1000)
+    time.sleep(10)
 
     val requestsAndHandlers2 = 
addPartitionsToTxnManager.generateRequests().asScala
     // The request for node1 should not be added because one request is 
already inflight.
@@ -222,6 +228,75 @@ class AddPartitionsToTxnManagerTest {
     assertEquals(expectedTransaction2Errors, transaction2Errors)
   }
 
+  @Test
+  def testAddPartitionsToTxnManagerMetrics(): Unit = {
+    val startTime = time.milliseconds()
+    val transactionErrors = mutable.Map[TopicPartition, Errors]()
+
+    var maxVerificationTime: Long = 0
+    val mockVerificationFailureMeter = mock(classOf[Meter])
+    val mockVerificationTime = mock(classOf[Histogram])
+
+    // Update max verification time when we see a higher verification time.
+    when(mockVerificationTime.update(anyLong())).thenAnswer(
+      {
+        invocation =>
+          val newTime = invocation.getArgument(0).asInstanceOf[Long]
+          if (newTime > maxVerificationTime)
+            maxVerificationTime = newTime
+      }
+    )
+
+    val mockMetricsGroupCtor = mockConstruction(classOf[KafkaMetricsGroup], 
(mock: KafkaMetricsGroup, context: Context) => {
+        
when(mock.newMeter(ArgumentMatchers.eq(AddPartitionsToTxnManager.VerificationFailureRateMetricName),
 anyString(), any(classOf[TimeUnit]))).thenReturn(mockVerificationFailureMeter)
+        
when(mock.newHistogram(ArgumentMatchers.eq(AddPartitionsToTxnManager.VerificationTimeMsMetricName))).thenReturn(mockVerificationTime)
+      })
+
+    val addPartitionsManagerWithMockedMetrics = new AddPartitionsToTxnManager(
+      KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:2181")),
+      networkClient,
+      time)
+
+    try {
+      addPartitionsManagerWithMockedMetrics.addTxnData(node0, 
transactionData(transactionalId1, producerId1), setErrors(transactionErrors))
+      addPartitionsManagerWithMockedMetrics.addTxnData(node1, 
transactionData(transactionalId2, producerId2), setErrors(transactionErrors))
+
+      time.sleep(100)
+
+      val requestsAndHandlers = 
addPartitionsManagerWithMockedMetrics.generateRequests()
+      var requestsHandled = 0
+
+      requestsAndHandlers.forEach { requestAndCompletionHandler =>
+        time.sleep(100)
+        
requestAndCompletionHandler.handler.onComplete(authenticationErrorResponse)
+        requestsHandled += 1
+        verify(mockVerificationTime, times(requestsHandled)).update(anyLong())
+        assertEquals(maxVerificationTime, time.milliseconds() - startTime)
+        verify(mockVerificationFailureMeter, times(requestsHandled)).mark(3) 
// since there are 3 partitions
+      }
+
+      // shutdown the manager so that metrics are removed.
+      addPartitionsManagerWithMockedMetrics.shutdown()
+
+      val mockMetricsGroup = mockMetricsGroupCtor.constructed.get(0)
+
+      
verify(mockMetricsGroup).newMeter(ArgumentMatchers.eq(AddPartitionsToTxnManager.VerificationFailureRateMetricName),
 anyString(), any(classOf[TimeUnit]))
+      
verify(mockMetricsGroup).newHistogram(ArgumentMatchers.eq(AddPartitionsToTxnManager.VerificationTimeMsMetricName))
+      
verify(mockMetricsGroup).removeMetric(AddPartitionsToTxnManager.VerificationFailureRateMetricName)
+      
verify(mockMetricsGroup).removeMetric(AddPartitionsToTxnManager.VerificationTimeMsMetricName)
+
+      // assert that we have verified all invocations on the metrics group.
+      verifyNoMoreInteractions(mockMetricsGroup)
+    } finally {
+      if (mockMetricsGroupCtor != null) {
+        mockMetricsGroupCtor.close()
+      }
+      if (addPartitionsManagerWithMockedMetrics.isRunning) {
+        addPartitionsManagerWithMockedMetrics.shutdown()
+      }
+    }
+  }
+
   private def clientResponse(response: AbstractResponse, authException: 
AuthenticationException = null, mismatchException: UnsupportedVersionException 
= null, disconnected: Boolean = false): ClientResponse = {
     new ClientResponse(null, null, null, 0, 0, disconnected, 
mismatchException, authException, response)
   }
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala 
b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index fd95b74ad9f..6f135af363a 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -27,7 +27,7 @@ import kafka.api.LeaderAndIsr
 import kafka.cluster.Broker
 import kafka.controller.{ControllerContext, KafkaController}
 import kafka.coordinator.transaction.{InitProducerIdResult, 
TransactionCoordinator}
-import kafka.network.RequestChannel
+import kafka.network.{RequestChannel, RequestMetrics}
 import kafka.server.QuotaFactory.QuotaManagers
 import kafka.server.metadata.{ConfigRepository, KRaftMetadataCache, 
MockConfigRepository, ZkMetadataCache}
 import kafka.utils.{Log4jController, TestUtils}
@@ -2122,56 +2122,64 @@ class KafkaApisTest {
 
   @ParameterizedTest
   @ApiKeyVersionsSource(apiKey = ApiKeys.ADD_PARTITIONS_TO_TXN)
-  def testHandleAddPartitionsToTxnAuthorizationFailed(version: Short): Unit = {
-    val topic = "topic"
+  def testHandleAddPartitionsToTxnAuthorizationFailedAndMetrics(version: 
Short): Unit = {
+    val requestMetrics = new 
RequestChannel.Metrics(Seq(ApiKeys.ADD_PARTITIONS_TO_TXN))
+    try {
+      val topic = "topic"
 
-    val transactionalId = "txnId1"
-    val producerId = 15L
-    val epoch = 0.toShort
+      val transactionalId = "txnId1"
+      val producerId = 15L
+      val epoch = 0.toShort
 
-    val tp = new TopicPartition(topic, 0)
+      val tp = new TopicPartition(topic, 0)
 
-    val addPartitionsToTxnRequest = 
-      if (version < 4) 
-        AddPartitionsToTxnRequest.Builder.forClient(
-          transactionalId,
-          producerId,
-          epoch,
-          Collections.singletonList(tp)).build(version)
-      else
-        AddPartitionsToTxnRequest.Builder.forBroker(
-          new AddPartitionsToTxnTransactionCollection(
-            List(new AddPartitionsToTxnTransaction()
-              .setTransactionalId(transactionalId)
-              .setProducerId(producerId)
-              .setProducerEpoch(epoch)
-              .setVerifyOnly(true)
-              .setTopics(new AddPartitionsToTxnTopicCollection(
-                Collections.singletonList(new AddPartitionsToTxnTopic()
-                  .setName(tp.topic)
-                  .setPartitions(Collections.singletonList(tp.partition))
-                ).iterator()))
-            ).asJava.iterator())).build(version)
-
-    val requestChannelRequest = buildRequest(addPartitionsToTxnRequest)
+      val addPartitionsToTxnRequest =
+        if (version < 4)
+          AddPartitionsToTxnRequest.Builder.forClient(
+            transactionalId,
+            producerId,
+            epoch,
+            Collections.singletonList(tp)).build(version)
+        else
+          AddPartitionsToTxnRequest.Builder.forBroker(
+            new AddPartitionsToTxnTransactionCollection(
+              List(new AddPartitionsToTxnTransaction()
+                .setTransactionalId(transactionalId)
+                .setProducerId(producerId)
+                .setProducerEpoch(epoch)
+                .setVerifyOnly(true)
+                .setTopics(new AddPartitionsToTxnTopicCollection(
+                  Collections.singletonList(new AddPartitionsToTxnTopic()
+                    .setName(tp.topic)
+                    .setPartitions(Collections.singletonList(tp.partition))
+                  ).iterator()))
+              ).asJava.iterator())).build(version)
+
+      val requestChannelRequest = buildRequest(addPartitionsToTxnRequest, 
requestMetrics = requestMetrics)
+
+      val authorizer: Authorizer = mock(classOf[Authorizer])
+      when(authorizer.authorize(any[RequestContext], any[util.List[Action]]))
+        .thenReturn(Seq(AuthorizationResult.DENIED).asJava)
+
+      createKafkaApis(authorizer = Some(authorizer)).handle(
+        requestChannelRequest,
+        RequestLocal.NoCaching
+      )
 
-    val authorizer: Authorizer = mock(classOf[Authorizer])
-    when(authorizer.authorize(any[RequestContext], any[util.List[Action]]))
-      .thenReturn(Seq(AuthorizationResult.DENIED).asJava)
+      val response = 
verifyNoThrottlingAndUpdateMetrics[AddPartitionsToTxnResponse](requestChannelRequest)
+      val error = if (version < 4)
+        
response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID).get(tp)
+      else
+        Errors.forCode(response.data().errorCode)
 
-    createKafkaApis(authorizer = Some(authorizer)).handle(
-      requestChannelRequest,
-      RequestLocal.NoCaching
-    )
+      val expectedError = if (version < 4) 
Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED else 
Errors.CLUSTER_AUTHORIZATION_FAILED
+      assertEquals(expectedError, error)
 
-    val response = 
verifyNoThrottling[AddPartitionsToTxnResponse](requestChannelRequest)
-    val error = if (version < 4) 
-      
response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID).get(tp) 
-    else
-      Errors.forCode(response.data().errorCode)
-      
-    val expectedError = if (version < 4) 
Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED else 
Errors.CLUSTER_AUTHORIZATION_FAILED
-    assertEquals(expectedError, error)
+      val metricName = if (version < 4) ApiKeys.ADD_PARTITIONS_TO_TXN.name 
else RequestMetrics.verifyPartitionsInTxnMetricName
+      assertEquals(8, TestUtils.metersCount(metricName))
+    } finally {
+      requestMetrics.close()
+    }
   }
 
   @ParameterizedTest
@@ -5275,7 +5283,8 @@ class KafkaApisTest {
   private def buildRequest(request: AbstractRequest,
                            listenerName: ListenerName = 
ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT),
                            fromPrivilegedListener: Boolean = false,
-                           requestHeader: Option[RequestHeader] = None): 
RequestChannel.Request = {
+                           requestHeader: Option[RequestHeader] = None,
+                           requestMetrics: RequestChannel.Metrics = 
requestChannelMetrics): RequestChannel.Request = {
     val buffer = request.serializeWithHeader(
       requestHeader.getOrElse(new RequestHeader(request.apiKey, 
request.version, clientId, 0)))
 
@@ -5285,7 +5294,7 @@ class KafkaApisTest {
       listenerName, SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, 
fromPrivilegedListener,
       Optional.of(kafkaPrincipalSerde))
     new RequestChannel.Request(processor = 1, context = context, 
startTimeNanos = 0, MemoryPool.NONE, buffer,
-      requestChannelMetrics, envelope = None)
+      requestMetrics, envelope = None)
   }
 
   private def verifyNoThrottling[T <: AbstractResponse](
@@ -5309,6 +5318,37 @@ class KafkaApisTest {
     ).asInstanceOf[T]
   }
 
+  private def verifyNoThrottlingAndUpdateMetrics[T <: AbstractResponse](
+    request: RequestChannel.Request
+  ): T = {
+    val capturedResponse: ArgumentCaptor[AbstractResponse] = 
ArgumentCaptor.forClass(classOf[AbstractResponse])
+    verify(requestChannel).sendResponse(
+      ArgumentMatchers.eq(request),
+      capturedResponse.capture(),
+      any()
+    )
+    val response = capturedResponse.getValue
+    val buffer = MessageUtil.toByteBuffer(
+      response.data,
+      request.context.header.apiVersion
+    )
+
+    // Create the RequestChannel.Response that is created when sendResponse is 
called in order to update the metrics.
+    val sendResponse = new RequestChannel.SendResponse(
+      request,
+      request.buildResponseSend(response),
+      request.responseNode(response),
+      None
+    )
+    request.updateRequestMetrics(time.milliseconds(), sendResponse)
+
+    AbstractResponse.parseResponse(
+      request.context.header.apiKey,
+      buffer,
+      request.context.header.apiVersion,
+    ).asInstanceOf[T]
+  }
+
   private def createBasicMetadataRequest(topic: String,
                                          numPartitions: Int,
                                          brokerEpoch: Long,

Reply via email to