This is an automated email from the ASF dual-hosted git repository.
jolshan pushed a commit to branch 3.6
in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/3.6 by this push:
new 814de813eaa cherrypick KAFKA-15653: Pass requestLocal as argument to
callback so we use the correct one for the thread (#14712)
814de813eaa is described below
commit 814de813eaa35ff47f526983f3ec63db558864d5
Author: Justine Olshan <[email protected]>
AuthorDate: Thu Nov 9 21:14:25 2023 -0800
cherrypick KAFKA-15653: Pass requestLocal as argument to callback so we use
the correct one for the thread (#14712)
With the new callback mechanism we were accidentally passing context with
the wrong request local. Now include a RequestLocal as an explicit argument to
the callback.
Also make the arguments passed through the callback clearer by separating
the method out.
Added a test to ensure we use the request handler's request local and not
the one passed in when the callback is executed via the request handler.
Reviewers: Ismael Juma [email protected], Divij Vaidya [email protected],
David Jacot [email protected], Jason Gustafson [email protected], Artem
Livshits [email protected], Jun Rao [email protected],
Conflicts:
core/src/main/scala/kafka/server/KafkaRequestHandler.scala
core/src/main/scala/kafka/server/ReplicaManager.scala
core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
Conflicts around verification guard, running the callback on the same
thread, and checking the coordinator node before AddPartitionsToTxnManager.
Remove test that is not applicable since we don't have
https://github.com/apache/kafka/commit/08aa33127a4254497456aa7a0c1646c7c38adf81
---
.../main/scala/kafka/network/RequestChannel.scala | 4 +-
.../scala/kafka/server/KafkaRequestHandler.scala | 21 +-
.../main/scala/kafka/server/ReplicaManager.scala | 245 ++++++++++++---------
.../kafka/server/KafkaRequestHandlerTest.scala | 64 +++++-
4 files changed, 205 insertions(+), 129 deletions(-)
diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala
b/core/src/main/scala/kafka/network/RequestChannel.scala
index 477f02a9c98..88686082998 100644
--- a/core/src/main/scala/kafka/network/RequestChannel.scala
+++ b/core/src/main/scala/kafka/network/RequestChannel.scala
@@ -24,7 +24,7 @@ import com.fasterxml.jackson.databind.JsonNode
import com.typesafe.scalalogging.Logger
import com.yammer.metrics.core.Meter
import kafka.network
-import kafka.server.KafkaConfig
+import kafka.server.{KafkaConfig, RequestLocal}
import kafka.utils.{Logging, NotNothing, Pool}
import kafka.utils.Implicits._
import org.apache.kafka.common.config.ConfigResource
@@ -80,7 +80,7 @@ object RequestChannel extends Logging {
}
}
- case class CallbackRequest(fun: () => Unit,
+ case class CallbackRequest(fun: RequestLocal => Unit,
originalRequest: Request) extends BaseRequest
class Request(val processor: Int,
diff --git a/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
index 316cf92ca5a..7885b436c59 100755
--- a/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
+++ b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
@@ -54,23 +54,27 @@ object KafkaRequestHandler {
}
/**
- * Wrap callback to schedule it on a request thread.
- * NOTE: this function must be called on a request thread.
- * @param fun Callback function to execute
- * @return Wrapped callback that would execute `fun` on a request thread
+ * Creates a wrapped callback to be executed asynchronously on an arbitrary
request thread.
+ * NOTE: this function must be originally called from a request thread.
+ * @param asyncCompletionCallback A callback method that we intend to call
from another thread after an asynchronous
+ * action completes. The RequestLocal passed
in must belong to the request handler
+ * thread that is executing the callback.
+ * @param requestLocal The RequestLocal for the current request handler
thread in case we need to execute the callback
+ * function synchronously from the calling thread (used
for testing)
+ * @return Wrapped callback will schedule `asyncCompletionCallback` on an
arbitrary request thread
*/
- def wrap[T](fun: T => Unit): T => Unit = {
+ def wrapAsyncCallback[T](asyncCompletionCallback: (RequestLocal, T) => Unit,
requestLocal: RequestLocal): T => Unit = {
val requestChannel = threadRequestChannel.get()
val currentRequest = threadCurrentRequest.get()
if (requestChannel == null || currentRequest == null) {
if (!bypassThreadCheck)
throw new IllegalStateException("Attempted to reschedule to request
handler thread from non-request handler thread.")
- T => fun(T)
+ T => asyncCompletionCallback(requestLocal, T)
} else {
T => {
// The requestChannel and request are captured in this lambda, so when
it's executed on the callback thread
// we can re-schedule the original callback on a request thread and
update the metrics accordingly.
- requestChannel.sendCallbackRequest(RequestChannel.CallbackRequest(()
=> fun(T), currentRequest))
+
requestChannel.sendCallbackRequest(RequestChannel.CallbackRequest(newRequestLocal
=> asyncCompletionCallback(newRequestLocal, T), currentRequest))
}
}
}
@@ -127,7 +131,7 @@ class KafkaRequestHandler(id: Int,
}
threadCurrentRequest.set(originalRequest)
- callback.fun()
+ callback.fun(requestLocal)
} catch {
case e: FatalExitError =>
completeShutdown()
@@ -169,6 +173,7 @@ class KafkaRequestHandler(id: Int,
private def completeShutdown(): Unit = {
requestLocal.close()
+ threadRequestChannel.remove()
shutdownComplete.countDown()
}
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala
b/core/src/main/scala/kafka/server/ReplicaManager.scala
index a574ba37716..53a9cd12efe 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -727,7 +727,6 @@ class ReplicaManager(val config: KafkaConfig,
transactionStatePartition: Option[Int] = None,
actionQueue: ActionQueue = this.actionQueue): Unit = {
if (isValidRequiredAcks(requiredAcks)) {
- val sTime = time.milliseconds
val verificationGuards: mutable.Map[TopicPartition, Object] =
mutable.Map[TopicPartition, Object]()
val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition,
errorsPerPartition) =
@@ -741,117 +740,19 @@ class ReplicaManager(val config: KafkaConfig,
(verifiedEntries.toMap, unverifiedEntries.toMap, errorEntries.toMap)
}
- def appendEntries(allEntries: Map[TopicPartition,
MemoryRecords])(unverifiedEntries: Map[TopicPartition, Errors]): Unit = {
- val verifiedEntries =
- if (unverifiedEntries.isEmpty)
- allEntries
- else
- allEntries.filter { case (tp, _) =>
- !unverifiedEntries.contains(tp)
- }
-
- val localProduceResults = appendToLocalLog(internalTopicsAllowed =
internalTopicsAllowed,
- origin, verifiedEntries, requiredAcks, requestLocal,
verificationGuards.toMap)
- debug("Produce to local log in %d ms".format(time.milliseconds -
sTime))
-
- def produceStatusResult(appendResult: Map[TopicPartition,
LogAppendResult],
- useCustomMessage: Boolean):
Map[TopicPartition, ProducePartitionStatus] = {
- appendResult.map { case (topicPartition, result) =>
- topicPartition -> ProducePartitionStatus(
- result.info.lastOffset + 1, // required offset
- new PartitionResponse(
- result.error,
- result.info.firstOffset.map[Long](_.messageOffset).orElse(-1L),
- result.info.lastOffset,
- result.info.logAppendTime,
- result.info.logStartOffset,
- result.info.recordErrors,
- if (useCustomMessage) result.exception.get.getMessage else
result.info.errorMessage
- )
- ) // response status
- }
- }
-
- val unverifiedResults = unverifiedEntries.map {
- case (topicPartition, error) =>
- val finalException =
- error match {
- case Errors.INVALID_TXN_STATE => error.exception("Partition
was not added to the transaction")
- case Errors.CONCURRENT_TRANSACTIONS |
- Errors.COORDINATOR_LOAD_IN_PROGRESS |
- Errors.COORDINATOR_NOT_AVAILABLE |
- Errors.NOT_COORDINATOR => new NotEnoughReplicasException(
- s"Unable to verify the partition has been added to
the transaction. Underlying error: ${error.toString}")
- case _ => error.exception()
- }
- topicPartition -> LogAppendResult(
- LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
- Some(finalException)
- )
- }
-
- val errorResults = errorsPerPartition.map {
- case (topicPartition, error) =>
- topicPartition -> LogAppendResult(
- LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
- Some(error.exception())
- )
- }
-
- val produceStatus = Set((localProduceResults, false),
(unverifiedResults, true), (errorResults, false)).flatMap {
- case (results, useCustomError) => produceStatusResult(results,
useCustomError)
- }.toMap
- val allResults = localProduceResults ++ unverifiedResults ++
errorResults
-
- actionQueue.add {
- () => allResults.foreach { case (topicPartition, result) =>
- val requestKey = TopicPartitionOperationKey(topicPartition)
- result.info.leaderHwChange match {
- case LeaderHwChange.INCREASED =>
- // some delayed operations may be unblocked after HW changed
- delayedProducePurgatory.checkAndComplete(requestKey)
- delayedFetchPurgatory.checkAndComplete(requestKey)
- delayedDeleteRecordsPurgatory.checkAndComplete(requestKey)
- case LeaderHwChange.SAME =>
- // probably unblock some follower fetch requests since log end
offset has been updated
- delayedFetchPurgatory.checkAndComplete(requestKey)
- case LeaderHwChange.NONE =>
- // nothing
- }
- }
- }
-
- recordConversionStatsCallback(localProduceResults.map { case (k, v) =>
k -> v.info.recordConversionStats })
-
- if (delayedProduceRequestRequired(requiredAcks, allEntries,
allResults)) {
- // create delayed produce operation
- val produceMetadata = ProduceMetadata(requiredAcks, produceStatus)
- val delayedProduce = new DelayedProduce(timeout, produceMetadata,
this, responseCallback, delayedProduceLock)
-
- // create a list of (topic, partition) pairs to use as keys for this
delayed produce operation
- val producerRequestKeys =
allEntries.keys.map(TopicPartitionOperationKey(_)).toSeq
-
- // try to complete the request immediately, otherwise put it into
the purgatory
- // this is because while the delayed produce operation is being
created, new
- // requests may arrive and hence make this operation completable.
- delayedProducePurgatory.tryCompleteElseWatch(delayedProduce,
producerRequestKeys)
- } else {
- // we can respond immediately
- val produceResponseStatus = produceStatus.map { case (k, status) =>
k -> status.responseStatus }
- responseCallback(produceResponseStatus)
- }
- }
-
if (notYetVerifiedEntriesPerPartition.isEmpty ||
addPartitionsToTxnManager.isEmpty) {
- appendEntries(verifiedEntriesPerPartition)(Map.empty)
+ appendEntries(verifiedEntriesPerPartition, internalTopicsAllowed,
origin, requiredAcks, verificationGuards.toMap,
+ errorsPerPartition, recordConversionStatsCallback, timeout,
responseCallback, delayedProduceLock)(requestLocal, Map.empty)
} else {
// For unverified entries, send a request to verify. When verified,
the append process will proceed via the callback.
val (error, node) =
getTransactionCoordinator(transactionStatePartition.get)
if (error != Errors.NONE) {
-
appendEntries(entriesPerPartition)(notYetVerifiedEntriesPerPartition.map {
- case (tp, _) => (tp, error)
- }.toMap)
+ appendEntries(verifiedEntriesPerPartition, internalTopicsAllowed,
origin, requiredAcks, verificationGuards.toMap,
+ errorsPerPartition, recordConversionStatsCallback, timeout,
responseCallback, delayedProduceLock)(requestLocal,
+ notYetVerifiedEntriesPerPartition.map {
+ case (tp, _) => (tp, error)
+ }.toMap)
} else {
val topicGrouping =
notYetVerifiedEntriesPerPartition.keySet.groupBy(tp => tp.topic())
val topicCollection = new AddPartitionsToTxnTopicCollection()
@@ -871,7 +772,21 @@ class ReplicaManager(val config: KafkaConfig,
.setVerifyOnly(true)
.setTopics(topicCollection)
- addPartitionsToTxnManager.foreach(_.addTxnData(node,
notYetVerifiedTransaction,
KafkaRequestHandler.wrap(appendEntries(entriesPerPartition)(_))))
+ addPartitionsToTxnManager.foreach(_.addTxnData(node,
notYetVerifiedTransaction, KafkaRequestHandler.wrapAsyncCallback(
+ appendEntries(
+ entriesPerPartition,
+ internalTopicsAllowed,
+ origin,
+ requiredAcks,
+ verificationGuards.toMap,
+ errorsPerPartition,
+ recordConversionStatsCallback,
+ timeout,
+ responseCallback,
+ delayedProduceLock
+ ),
+ requestLocal)
+ ))
}
}
} else {
@@ -889,6 +804,122 @@ class ReplicaManager(val config: KafkaConfig,
}
}
+ /*
+ * Note: This method can be used as a callback in a different request
thread. Ensure that correct RequestLocal
+ * is passed when executing this method. Accessing non-thread-safe data
structures should be avoided if possible.
+ */
+ private def appendEntries(allEntries: Map[TopicPartition, MemoryRecords],
+ internalTopicsAllowed: Boolean,
+ origin: AppendOrigin,
+ requiredAcks: Short,
+ verificationGuards: Map[TopicPartition, Object],
+ errorsPerPartition: Map[TopicPartition, Errors],
+ recordConversionStatsCallback: Map[TopicPartition,
RecordConversionStats] => Unit,
+ timeout: Long,
+ responseCallback: Map[TopicPartition,
PartitionResponse] => Unit,
+ delayedProduceLock: Option[Lock])
+ (requestLocal: RequestLocal, unverifiedEntries:
Map[TopicPartition, Errors]): Unit = {
+ val sTime = time.milliseconds
+ val verifiedEntries =
+ if (unverifiedEntries.isEmpty)
+ allEntries
+ else
+ allEntries.filter { case (tp, _) =>
+ !unverifiedEntries.contains(tp)
+ }
+
+ val localProduceResults = appendToLocalLog(internalTopicsAllowed =
internalTopicsAllowed,
+ origin, verifiedEntries, requiredAcks, requestLocal,
verificationGuards.toMap)
+ debug("Produce to local log in %d ms".format(time.milliseconds - sTime))
+
+ def produceStatusResult(appendResult: Map[TopicPartition, LogAppendResult],
+ useCustomMessage: Boolean): Map[TopicPartition,
ProducePartitionStatus] = {
+ appendResult.map { case (topicPartition, result) =>
+ topicPartition -> ProducePartitionStatus(
+ result.info.lastOffset + 1, // required offset
+ new PartitionResponse(
+ result.error,
+ result.info.firstOffset.map[Long](_.messageOffset).orElse(-1L),
+ result.info.lastOffset,
+ result.info.logAppendTime,
+ result.info.logStartOffset,
+ result.info.recordErrors,
+ if (useCustomMessage) result.exception.get.getMessage else
result.info.errorMessage
+ )
+ ) // response status
+ }
+ }
+
+ val unverifiedResults = unverifiedEntries.map {
+ case (topicPartition, error) =>
+ val finalException =
+ error match {
+ case Errors.INVALID_TXN_STATE => error.exception("Partition was
not added to the transaction")
+ case Errors.CONCURRENT_TRANSACTIONS |
+ Errors.COORDINATOR_LOAD_IN_PROGRESS |
+ Errors.COORDINATOR_NOT_AVAILABLE |
+ Errors.NOT_COORDINATOR => new NotEnoughReplicasException(
+ s"Unable to verify the partition has been added to the
transaction. Underlying error: ${error.toString}")
+ case _ => error.exception()
+ }
+ topicPartition -> LogAppendResult(
+ LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
+ Some(finalException)
+ )
+ }
+
+ val errorResults = errorsPerPartition.map {
+ case (topicPartition, error) =>
+ topicPartition -> LogAppendResult(
+ LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
+ Some(error.exception())
+ )
+ }
+
+ val produceStatus = Set((localProduceResults, false), (unverifiedResults,
true), (errorResults, false)).flatMap {
+ case (results, useCustomError) => produceStatusResult(results,
useCustomError)
+ }.toMap
+ val allResults = localProduceResults ++ unverifiedResults ++ errorResults
+
+ actionQueue.add {
+ () => allResults.foreach { case (topicPartition, result) =>
+ val requestKey = TopicPartitionOperationKey(topicPartition)
+ result.info.leaderHwChange match {
+ case LeaderHwChange.INCREASED =>
+ // some delayed operations may be unblocked after HW changed
+ delayedProducePurgatory.checkAndComplete(requestKey)
+ delayedFetchPurgatory.checkAndComplete(requestKey)
+ delayedDeleteRecordsPurgatory.checkAndComplete(requestKey)
+ case LeaderHwChange.SAME =>
+ // probably unblock some follower fetch requests since log end
offset has been updated
+ delayedFetchPurgatory.checkAndComplete(requestKey)
+ case LeaderHwChange.NONE =>
+ // nothing
+ }
+ }
+ }
+
+ recordConversionStatsCallback(localProduceResults.map { case (k, v) => k
-> v.info.recordConversionStats })
+
+ if (delayedProduceRequestRequired(requiredAcks, allEntries, allResults)) {
+ // create delayed produce operation
+ val produceMetadata = ProduceMetadata(requiredAcks, produceStatus)
+ val delayedProduce = new DelayedProduce(timeout, produceMetadata, this,
responseCallback, delayedProduceLock)
+
+ // create a list of (topic, partition) pairs to use as keys for this
delayed produce operation
+ val producerRequestKeys =
allEntries.keys.map(TopicPartitionOperationKey(_)).toSeq
+
+ // try to complete the request immediately, otherwise put it into the
purgatory
+ // this is because while the delayed produce operation is being created,
new
+ // requests may arrive and hence make this operation completable.
+ delayedProducePurgatory.tryCompleteElseWatch(delayedProduce,
producerRequestKeys)
+ } else {
+ // we can respond immediately
+ val produceResponseStatus = produceStatus.map { case (k, status) => k ->
status.responseStatus }
+ responseCallback(produceResponseStatus)
+ }
+ }
+
private def partitionEntriesForVerification(verificationGuards:
mutable.Map[TopicPartition, Object],
entriesPerPartition:
Map[TopicPartition, MemoryRecords],
verifiedEntries:
mutable.Map[TopicPartition, MemoryRecords],
diff --git a/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
b/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
index 822dd5263d6..dd3f7996999 100644
--- a/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
+++ b/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
@@ -32,10 +32,11 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import org.mockito.ArgumentMatchers
import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito.{mock, when}
+import org.mockito.Mockito.{mock, times, verify, when}
import java.net.InetAddress
import java.nio.ByteBuffer
+import java.util.concurrent.CompletableFuture
import java.util.concurrent.atomic.AtomicInteger
class KafkaRequestHandlerTest {
@@ -53,14 +54,17 @@ class KafkaRequestHandlerTest {
val request = makeRequest(time, metrics)
requestChannel.sendRequest(request)
- def callback(ms: Int): Unit = {
- time.sleep(ms)
- handler.stop()
- }
-
when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer
{ _ =>
time.sleep(2)
- KafkaRequestHandler.wrap(callback(_: Int))(1)
+ // Prepare the callback.
+ val callback = KafkaRequestHandler.wrapAsyncCallback(
+ (reqLocal: RequestLocal, ms: Int) => {
+ time.sleep(ms)
+ handler.stop()
+ },
+ RequestLocal.NoCaching)
+ // Execute the callback asynchronously.
+ CompletableFuture.runAsync(() => callback(1))
request.apiLocalCompleteTimeNanos = time.nanoseconds
}
@@ -86,16 +90,19 @@ class KafkaRequestHandlerTest {
var handledCount = 0
var tryCompleteActionCount = 0
- def callback(x: Int): Unit = {
- handler.stop()
- }
-
val request = makeRequest(time, metrics)
requestChannel.sendRequest(request)
when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer {
_ =>
handledCount = handledCount + 1
- KafkaRequestHandler.wrap(callback(_: Int))(1)
+ // Prepare the callback.
+ val callback = KafkaRequestHandler.wrapAsyncCallback(
+ (reqLocal: RequestLocal, ms: Int) => {
+ handler.stop()
+ },
+ RequestLocal.NoCaching)
+ // Execute the callback asynchronously.
+ CompletableFuture.runAsync(() => callback(1))
}
when(apiHandler.tryCompleteActions()).thenAnswer { _ =>
@@ -108,6 +115,39 @@ class KafkaRequestHandlerTest {
assertEquals(1, tryCompleteActionCount)
}
+ @Test
+ def testHandlingCallbackOnNewThread(): Unit = {
+ val time = new MockTime()
+ val metrics = mock(classOf[RequestChannel.Metrics])
+ val apiHandler = mock(classOf[ApiRequestHandler])
+ val requestChannel = new RequestChannel(10, "", time, metrics)
+ val handler = new KafkaRequestHandler(0, 0, mock(classOf[Meter]), new
AtomicInteger(1), requestChannel, apiHandler, time)
+
+ val originalRequestLocal = mock(classOf[RequestLocal])
+
+ var handledCount = 0
+
+ val request = makeRequest(time, metrics)
+ requestChannel.sendRequest(request)
+
+ when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer {
_ =>
+ // Prepare the callback.
+ val callback = KafkaRequestHandler.wrapAsyncCallback(
+ (reqLocal: RequestLocal, ms: Int) => {
+ reqLocal.bufferSupplier.close()
+ handledCount = handledCount + 1
+ handler.stop()
+ },
+ originalRequestLocal)
+ // Execute the callback asynchronously.
+ CompletableFuture.runAsync(() => callback(1))
+ }
+
+ handler.run()
+ // Verify that we don't use the request local that we passed in.
+ verify(originalRequestLocal, times(0)).bufferSupplier
+ assertEquals(1, handledCount)
+ }
@ParameterizedTest
@ValueSource(booleans = Array(true, false))