This is an automated email from the ASF dual-hosted git repository. junrao pushed a commit to branch trunk in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/trunk by this push: new c2273ad KAFKA-8334 Make sure the thread which tries to complete delayed reque… (#8657) c2273ad is described below commit c2273adc25b2bab0a3ac95bf7844fedf2860b40b Author: Chia-Ping Tsai <chia7...@gmail.com> AuthorDate: Thu Sep 10 05:42:37 2020 +0800 KAFKA-8334 Make sure the thread which tries to complete delayed reque… (#8657) The main changes of this PR are shown below. 1. replace tryLock by lock for DelayedOperation#maybeTryComplete 2. complete the delayed requests without holding group lock Reviewers: Ismael Juma <ism...@juma.me.uk>, Jun Rao <jun...@gmail.com> --- core/src/main/scala/kafka/cluster/Partition.scala | 22 +--- .../kafka/coordinator/group/DelayedJoin.scala | 11 +- .../coordinator/group/GroupMetadataManager.scala | 2 +- core/src/main/scala/kafka/log/Log.scala | 13 ++- core/src/main/scala/kafka/server/ActionQueue.scala | 56 ++++++++++ .../main/scala/kafka/server/DelayedOperation.scala | 118 +++++++++------------ core/src/main/scala/kafka/server/KafkaApis.scala | 4 + .../main/scala/kafka/server/ReplicaManager.scala | 31 ++++++ .../AbstractCoordinatorConcurrencyTest.scala | 10 +- .../group/GroupCoordinatorConcurrencyTest.scala | 65 ++++++++---- .../TransactionCoordinatorConcurrencyTest.scala | 5 +- .../unit/kafka/server/DelayedOperationTest.scala | 96 +++++++---------- 12 files changed, 257 insertions(+), 176 deletions(-) diff --git a/core/src/main/scala/kafka/cluster/Partition.scala b/core/src/main/scala/kafka/cluster/Partition.scala index fb0576e..fc0852f 100755 --- a/core/src/main/scala/kafka/cluster/Partition.scala +++ b/core/src/main/scala/kafka/cluster/Partition.scala @@ -103,19 +103,7 @@ class DelayedOperations(topicPartition: TopicPartition, fetch.checkAndComplete(TopicPartitionOperationKey(topicPartition)) } - def checkAndCompleteProduce(): Unit = { - produce.checkAndComplete(TopicPartitionOperationKey(topicPartition)) - } - - def checkAndCompleteDeleteRecords(): Unit = { - deleteRecords.checkAndComplete(TopicPartitionOperationKey(topicPartition)) - } - def numDelayedDelete: Int = deleteRecords.numDelayed - - def numDelayedFetch: Int = fetch.numDelayed - - def numDelayedProduce: Int = produce.numDelayed } object Partition extends KafkaMetricsGroup { @@ -1010,15 +998,7 @@ class Partition(val topicPartition: TopicPartition, } } - // some delayed operations may be unblocked after HW changed - if (leaderHWIncremented) - tryCompleteDelayedRequests() - else { - // probably unblock some follower fetch requests since log end offset has been updated - delayedOperations.checkAndCompleteFetch() - } - - info + info.copy(leaderHwChange = if (leaderHWIncremented) LeaderHwChange.Increased else LeaderHwChange.Same) } def readRecords(fetchOffset: Long, diff --git a/core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala b/core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala index dad2b1e..92e8835 100644 --- a/core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala +++ b/core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala @@ -36,8 +36,15 @@ private[group] class DelayedJoin(coordinator: GroupCoordinator, rebalanceTimeout: Long) extends DelayedOperation(rebalanceTimeout, Some(group.lock)) { override def tryComplete(): Boolean = coordinator.tryCompleteJoin(group, forceComplete _) - override def onExpiration() = coordinator.onExpireJoin() - override def onComplete() = coordinator.onCompleteJoin(group) + override def onExpiration(): Unit = { + coordinator.onExpireJoin() + // try to complete delayed actions introduced by coordinator.onCompleteJoin + tryToCompleteDelayedAction() + } + override def onComplete(): Unit = coordinator.onCompleteJoin(group) + + // TODO: remove this ugly chain after we move the action queue to handler thread + private def tryToCompleteDelayedAction(): Unit = coordinator.groupManager.replicaManager.tryCompleteActions() } /** diff --git a/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala index 1fcdd91..9d58b2d 100644 --- a/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala +++ b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala @@ -56,7 +56,7 @@ import scala.jdk.CollectionConverters._ class GroupMetadataManager(brokerId: Int, interBrokerProtocolVersion: ApiVersion, config: OffsetConfig, - replicaManager: ReplicaManager, + val replicaManager: ReplicaManager, zkClient: KafkaZkClient, time: Time, metrics: Metrics) extends Logging with KafkaMetricsGroup { diff --git a/core/src/main/scala/kafka/log/Log.scala b/core/src/main/scala/kafka/log/Log.scala index fa78666..8953156 100644 --- a/core/src/main/scala/kafka/log/Log.scala +++ b/core/src/main/scala/kafka/log/Log.scala @@ -68,6 +68,13 @@ object LogAppendInfo { offsetsMonotonic = false, -1L, recordErrors, errorMessage) } +sealed trait LeaderHwChange +object LeaderHwChange { + case object Increased extends LeaderHwChange + case object Same extends LeaderHwChange + case object None extends LeaderHwChange +} + /** * Struct to hold various quantities we compute about each message set before appending to the log * @@ -85,6 +92,9 @@ object LogAppendInfo { * @param validBytes The number of valid bytes * @param offsetsMonotonic Are the offsets in this message set monotonically increasing * @param lastOffsetOfFirstBatch The last offset of the first batch + * @param leaderHwChange Incremental if the high watermark needs to be increased after appending record. + * Same if high watermark is not changed. None is the default value and it means append failed + * */ case class LogAppendInfo(var firstOffset: Option[Long], var lastOffset: Long, @@ -100,7 +110,8 @@ case class LogAppendInfo(var firstOffset: Option[Long], offsetsMonotonic: Boolean, lastOffsetOfFirstBatch: Long, recordErrors: Seq[RecordError] = List(), - errorMessage: String = null) { + errorMessage: String = null, + leaderHwChange: LeaderHwChange = LeaderHwChange.None) { /** * Get the first offset if it exists, else get the last offset of the first batch * For magic versions 2 and newer, this method will return first offset. For magic versions diff --git a/core/src/main/scala/kafka/server/ActionQueue.scala b/core/src/main/scala/kafka/server/ActionQueue.scala new file mode 100644 index 0000000..1b6b832 --- /dev/null +++ b/core/src/main/scala/kafka/server/ActionQueue.scala @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.concurrent.ConcurrentLinkedQueue + +import kafka.utils.Logging + +/** + * This queue is used to collect actions which need to be executed later. One use case is that ReplicaManager#appendRecords + * produces record changes so we need to check and complete delayed requests. In order to avoid conflicting locking, + * we add those actions to this queue and then complete them at the end of KafkaApis.handle() or DelayedJoin.onExpiration. + */ +class ActionQueue extends Logging { + private val queue = new ConcurrentLinkedQueue[() => Unit]() + + /** + * add action to this queue. + * @param action action + */ + def add(action: () => Unit): Unit = queue.add(action) + + /** + * try to complete all delayed actions + */ + def tryCompleteActions(): Unit = { + val maxToComplete = queue.size() + var count = 0 + var done = false + while (!done && count < maxToComplete) { + try { + val action = queue.poll() + if (action == null) done = true + else action() + } catch { + case e: Throwable => + error("failed to complete delayed actions", e) + } finally count += 1 + } + } +} diff --git a/core/src/main/scala/kafka/server/DelayedOperation.scala b/core/src/main/scala/kafka/server/DelayedOperation.scala index 2756b4f..09fd337 100644 --- a/core/src/main/scala/kafka/server/DelayedOperation.scala +++ b/core/src/main/scala/kafka/server/DelayedOperation.scala @@ -41,13 +41,15 @@ import scala.collection.mutable.ListBuffer * forceComplete(). * * A subclass of DelayedOperation needs to provide an implementation of both onComplete() and tryComplete(). + * + * Noted that if you add a future delayed operation that calls ReplicaManager.appendRecords() in onComplete() + * like DelayedJoin, you must be aware that this operation's onExpiration() needs to call actionQueue.tryCompleteAction(). */ abstract class DelayedOperation(override val delayMs: Long, lockOpt: Option[Lock] = None) extends TimerTask with Logging { private val completed = new AtomicBoolean(false) - private val tryCompletePending = new AtomicBoolean(false) // Visible for testing private[server] val lock: Lock = lockOpt.getOrElse(new ReentrantLock) @@ -100,42 +102,24 @@ abstract class DelayedOperation(override val delayMs: Long, def tryComplete(): Boolean /** - * Thread-safe variant of tryComplete() that attempts completion only if the lock can be acquired - * without blocking. - * - * If threadA acquires the lock and performs the check for completion before completion criteria is met - * and threadB satisfies the completion criteria, but fails to acquire the lock because threadA has not - * yet released the lock, we need to ensure that completion is attempted again without blocking threadA - * or threadB. `tryCompletePending` is set by threadB when it fails to acquire the lock and at least one - * of threadA or threadB will attempt completion of the operation if this flag is set. This ensures that - * every invocation of `maybeTryComplete` is followed by at least one invocation of `tryComplete` until - * the operation is actually completed. + * Thread-safe variant of tryComplete() and call extra function if first tryComplete returns false + * @param f else function to be executed after first tryComplete returns false + * @return result of tryComplete */ - private[server] def maybeTryComplete(): Boolean = { - var retry = false - var done = false - do { - if (lock.tryLock()) { - try { - tryCompletePending.set(false) - done = tryComplete() - } finally { - lock.unlock() - } - // While we were holding the lock, another thread may have invoked `maybeTryComplete` and set - // `tryCompletePending`. In this case we should retry. - retry = tryCompletePending.get() - } else { - // Another thread is holding the lock. If `tryCompletePending` is already set and this thread failed to - // acquire the lock, then the thread that is holding the lock is guaranteed to see the flag and retry. - // Otherwise, we should set the flag and retry on this thread since the thread holding the lock may have - // released the lock and returned by the time the flag is set. - retry = !tryCompletePending.getAndSet(true) - } - } while (!isCompleted && retry) - done + private[server] def safeTryCompleteOrElse(f: => Unit): Boolean = inLock(lock) { + if (tryComplete()) true + else { + f + // last completion check + tryComplete() + } } + /** + * Thread-safe variant of tryComplete() + */ + private[server] def safeTryComplete(): Boolean = inLock(lock)(tryComplete()) + /* * run() method defines a task that is executed on timeout */ @@ -219,38 +203,38 @@ final class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: Stri def tryCompleteElseWatch(operation: T, watchKeys: Seq[Any]): Boolean = { assert(watchKeys.nonEmpty, "The watch key list can't be empty") - // The cost of tryComplete() is typically proportional to the number of keys. Calling - // tryComplete() for each key is going to be expensive if there are many keys. Instead, - // we do the check in the following way. Call tryComplete(). If the operation is not completed, - // we just add the operation to all keys. Then we call tryComplete() again. At this time, if - // the operation is still not completed, we are guaranteed that it won't miss any future triggering - // event since the operation is already on the watcher list for all keys. This does mean that - // if the operation is completed (by another thread) between the two tryComplete() calls, the - // operation is unnecessarily added for watch. However, this is a less severe issue since the - // expire reaper will clean it up periodically. - - // At this point the only thread that can attempt this operation is this current thread - // Hence it is safe to tryComplete() without a lock - var isCompletedByMe = operation.tryComplete() - if (isCompletedByMe) - return true - - var watchCreated = false - for(key <- watchKeys) { - // If the operation is already completed, stop adding it to the rest of the watcher list. - if (operation.isCompleted) - return false - watchForOperation(key, operation) - - if (!watchCreated) { - watchCreated = true - estimatedTotalOperations.incrementAndGet() - } - } - - isCompletedByMe = operation.maybeTryComplete() - if (isCompletedByMe) - return true + // The cost of tryComplete() is typically proportional to the number of keys. Calling tryComplete() for each key is + // going to be expensive if there are many keys. Instead, we do the check in the following way through safeTryCompleteOrElse(). + // If the operation is not completed, we just add the operation to all keys. Then we call tryComplete() again. At + // this time, if the operation is still not completed, we are guaranteed that it won't miss any future triggering + // event since the operation is already on the watcher list for all keys. + // + // ==============[story about lock]============== + // Through safeTryCompleteOrElse(), we hold the operation's lock while adding the operation to watch list and doing + // the tryComplete() check. This is to avoid a potential deadlock between the callers to tryCompleteElseWatch() and + // checkAndComplete(). For example, the following deadlock can happen if the lock is only held for the final tryComplete() + // 1) thread_a holds readlock of stateLock from TransactionStateManager + // 2) thread_a is executing tryCompleteElseWatch() + // 3) thread_a adds op to watch list + // 4) thread_b requires writelock of stateLock from TransactionStateManager (blocked by thread_a) + // 5) thread_c calls checkAndComplete() and holds lock of op + // 6) thread_c is waiting readlock of stateLock to complete op (blocked by thread_b) + // 7) thread_a is waiting lock of op to call the final tryComplete() (blocked by thread_c) + // + // Note that even with the current approach, deadlocks could still be introduced. For example, + // 1) thread_a calls tryCompleteElseWatch() and gets lock of op + // 2) thread_a adds op to watch list + // 3) thread_a calls op#tryComplete and tries to require lock_b + // 4) thread_b holds lock_b and calls checkAndComplete() + // 5) thread_b sees op from watch list + // 6) thread_b needs lock of op + // To avoid the above scenario, we recommend DelayedOperationPurgatory.checkAndComplete() be called without holding + // any exclusive lock. Since DelayedOperationPurgatory.checkAndComplete() completes delayed operations asynchronously, + // holding a exclusive lock to make the call is often unnecessary. + if (operation.safeTryCompleteOrElse { + watchKeys.foreach(key => watchForOperation(key, operation)) + if (watchKeys.nonEmpty) estimatedTotalOperations.incrementAndGet() + }) return true // if it cannot be completed by now and hence is watched, add to the expire queue also if (!operation.isCompleted) { @@ -375,7 +359,7 @@ final class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: Stri if (curr.isCompleted) { // another thread has completed this operation, just remove it iter.remove() - } else if (curr.maybeTryComplete()) { + } else if (curr.safeTryComplete()) { iter.remove() completed += 1 } diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 867ff6a..ef2df97 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -186,6 +186,10 @@ class KafkaApis(val requestChannel: RequestChannel, case e: FatalExitError => throw e case e: Throwable => handleError(request, e) } finally { + // try to complete delayed action. In order to avoid conflicting locking, the actions to complete delayed requests + // are kept in a queue. We add the logic to check the ReplicaManager queue at the end of KafkaApis.handle() and the + // expiration thread for certain delayed operations (e.g. DelayedJoin) + replicaManager.tryCompleteActions() // The local completion time may be set while processing the request. Only record it if it's unset. if (request.apiLocalCompleteTimeNanos < 0) request.apiLocalCompleteTimeNanos = time.nanoseconds diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index 32c9aa4..1ad37ef 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -559,9 +559,20 @@ class ReplicaManager(val config: KafkaConfig, } /** + * TODO: move this action queue to handle thread so we can simplify concurrency handling + */ + private val actionQueue = new ActionQueue + + def tryCompleteActions(): Unit = actionQueue.tryCompleteActions() + + /** * Append messages to leader replicas of the partition, and wait for them to be replicated to other replicas; * the callback function will be triggered either when timeout or the required acks are satisfied; * if the callback function itself is already synchronized on some object then pass this object to avoid deadlock. + * + * Noted that all pending delayed check operations are stored in a queue. All callers to ReplicaManager.appendRecords() + * are expected to call ActionQueue.tryCompleteActions for all affected partitions, without holding any conflicting + * locks. */ def appendRecords(timeout: Long, requiredAcks: Short, @@ -585,6 +596,26 @@ class ReplicaManager(val config: KafkaConfig, result.info.logStartOffset, result.info.recordErrors.asJava, result.info.errorMessage)) // response status } + actionQueue.add { + () => + localProduceResults.foreach { + case (topicPartition, result) => + val requestKey = TopicPartitionOperationKey(topicPartition) + result.info.leaderHwChange match { + case LeaderHwChange.Increased => + // some delayed operations may be unblocked after HW changed + delayedProducePurgatory.checkAndComplete(requestKey) + delayedFetchPurgatory.checkAndComplete(requestKey) + delayedDeleteRecordsPurgatory.checkAndComplete(requestKey) + case LeaderHwChange.Same => + // probably unblock some follower fetch requests since log end offset has been updated + delayedFetchPurgatory.checkAndComplete(requestKey) + case LeaderHwChange.None => + // nothing + } + } + } + recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats }) if (delayedProduceRequestRequired(requiredAcks, entriesPerPartition, localProduceResults)) { diff --git a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala index c7145ce..62ee85d 100644 --- a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala @@ -17,8 +17,8 @@ package kafka.coordinator -import java.util.{Collections, Random} import java.util.concurrent.{ConcurrentHashMap, Executors} +import java.util.{Collections, Random} import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.locks.Lock @@ -97,7 +97,7 @@ abstract class AbstractCoordinatorConcurrencyTest[M <: CoordinatorMember] { } def enableCompletion(): Unit = { - replicaManager.tryCompleteDelayedRequests() + replicaManager.tryCompleteActions() scheduler.tick() } @@ -166,9 +166,8 @@ object AbstractCoordinatorConcurrencyTest { producePurgatory = new DelayedOperationPurgatory[DelayedProduce]("Produce", timer, 1, reaperEnabled = false) watchKeys = Collections.newSetFromMap(new ConcurrentHashMap[TopicPartitionOperationKey, java.lang.Boolean]()).asScala } - def tryCompleteDelayedRequests(): Unit = { - watchKeys.map(producePurgatory.checkAndComplete) - } + + override def tryCompleteActions(): Unit = watchKeys.map(producePurgatory.checkAndComplete) override def appendRecords(timeout: Long, requiredAcks: Short, @@ -204,7 +203,6 @@ object AbstractCoordinatorConcurrencyTest { val producerRequestKeys = entriesPerPartition.keys.map(TopicPartitionOperationKey(_)).toSeq watchKeys ++= producerRequestKeys producePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys) - tryCompleteDelayedRequests() } override def getMagic(topicPartition: TopicPartition): Option[Byte] = { Some(RecordBatch.MAGIC_VALUE_V2) diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala index 72de4a1..1f54bd5 100644 --- a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala @@ -18,6 +18,7 @@ package kafka.coordinator.group import java.util.Properties +import java.util.concurrent.locks.{Lock, ReentrantLock} import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import kafka.common.OffsetAndMetadata @@ -60,16 +61,6 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest new LeaveGroupOperation ) - private val allOperationsWithTxn = Seq( - new JoinGroupOperation, - new SyncGroupOperation, - new OffsetFetchOperation, - new CommitTxnOffsetsOperation, - new CompleteTxnOperation, - new HeartbeatOperation, - new LeaveGroupOperation - ) - var heartbeatPurgatory: DelayedOperationPurgatory[DelayedHeartbeat] = _ var joinPurgatory: DelayedOperationPurgatory[DelayedJoin] = _ var groupCoordinator: GroupCoordinator = _ @@ -119,12 +110,33 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest @Test def testConcurrentTxnGoodPathSequence(): Unit = { - verifyConcurrentOperations(createGroupMembers, allOperationsWithTxn) + verifyConcurrentOperations(createGroupMembers, Seq( + new JoinGroupOperation, + new SyncGroupOperation, + new OffsetFetchOperation, + new CommitTxnOffsetsOperation, + new CompleteTxnOperation, + new HeartbeatOperation, + new LeaveGroupOperation + )) } @Test def testConcurrentRandomSequence(): Unit = { - verifyConcurrentRandomSequences(createGroupMembers, allOperationsWithTxn) + /** + * handleTxnCommitOffsets does not complete delayed requests now so it causes error if handleTxnCompletion is executed + * before completing delayed request. In random mode, we use this global lock to prevent such an error. + */ + val lock = new ReentrantLock() + verifyConcurrentRandomSequences(createGroupMembers, Seq( + new JoinGroupOperation, + new SyncGroupOperation, + new OffsetFetchOperation, + new CommitTxnOffsetsOperation(lock = Some(lock)), + new CompleteTxnOperation(lock = Some(lock)), + new HeartbeatOperation, + new LeaveGroupOperation + )) } @Test @@ -198,6 +210,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest groupCoordinator.handleJoinGroup(member.groupId, member.memberId, None, requireKnownMemberId = false, "clientId", "clientHost", DefaultRebalanceTimeout, DefaultSessionTimeout, protocolType, protocols, responseCallback) + replicaManager.tryCompleteActions() } override def awaitAndVerify(member: GroupMember): Unit = { val joinGroupResult = await(member, DefaultRebalanceTimeout) @@ -221,6 +234,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest groupCoordinator.handleSyncGroup(member.groupId, member.generationId, member.memberId, Some(protocolType), Some(protocolName), member.groupInstanceId, Map.empty[String, Array[Byte]], responseCallback) } + replicaManager.tryCompleteActions() } override def awaitAndVerify(member: GroupMember): Unit = { val result = await(member, DefaultSessionTimeout) @@ -238,6 +252,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest override def runWithCallback(member: GroupMember, responseCallback: HeartbeatCallback): Unit = { groupCoordinator.handleHeartbeat(member.groupId, member.memberId, member.groupInstanceId, member.generationId, responseCallback) + replicaManager.tryCompleteActions() } override def awaitAndVerify(member: GroupMember): Unit = { val error = await(member, DefaultSessionTimeout) @@ -252,6 +267,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest } override def runWithCallback(member: GroupMember, responseCallback: OffsetFetchCallback): Unit = { val (error, partitionData) = groupCoordinator.handleFetchOffsets(member.groupId, requireStable = true, None) + replicaManager.tryCompleteActions() responseCallback(error, partitionData) } override def awaitAndVerify(member: GroupMember): Unit = { @@ -271,6 +287,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest val offsets = immutable.Map(tp -> OffsetAndMetadata(1, "", Time.SYSTEM.milliseconds())) groupCoordinator.handleCommitOffsets(member.groupId, member.memberId, member.groupInstanceId, member.generationId, offsets, responseCallback) + replicaManager.tryCompleteActions() } override def awaitAndVerify(member: GroupMember): Unit = { val offsets = await(member, 500) @@ -278,7 +295,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest } } - class CommitTxnOffsetsOperation extends CommitOffsetsOperation { + class CommitTxnOffsetsOperation(lock: Option[Lock] = None) extends CommitOffsetsOperation { override def runWithCallback(member: GroupMember, responseCallback: CommitOffsetCallback): Unit = { val tp = new TopicPartition("topic", 0) val offsets = immutable.Map(tp -> OffsetAndMetadata(1, "", Time.SYSTEM.milliseconds())) @@ -293,13 +310,17 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest offsetsPartitions.map(_.partition).toSet, isCommit = random.nextBoolean) responseCallback(errors) } - groupCoordinator.handleTxnCommitOffsets(member.group.groupId, producerId, producerEpoch, - JoinGroupRequest.UNKNOWN_MEMBER_ID, Option.empty, JoinGroupRequest.UNKNOWN_GENERATION_ID, - offsets, callbackWithTxnCompletion) + lock.foreach(_.lock()) + try { + groupCoordinator.handleTxnCommitOffsets(member.group.groupId, producerId, producerEpoch, + JoinGroupRequest.UNKNOWN_MEMBER_ID, Option.empty, JoinGroupRequest.UNKNOWN_GENERATION_ID, + offsets, callbackWithTxnCompletion) + replicaManager.tryCompleteActions() + } finally lock.foreach(_.unlock()) } } - class CompleteTxnOperation extends GroupOperation[CompleteTxnCallbackParams, CompleteTxnCallback] { + class CompleteTxnOperation(lock: Option[Lock] = None) extends GroupOperation[CompleteTxnCallbackParams, CompleteTxnCallback] { override def responseCallback(responsePromise: Promise[CompleteTxnCallbackParams]): CompleteTxnCallback = { val callback: CompleteTxnCallback = error => responsePromise.success(error) callback @@ -307,9 +328,13 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest override def runWithCallback(member: GroupMember, responseCallback: CompleteTxnCallback): Unit = { val producerId = 1000L val offsetsPartitions = (0 to numPartitions).map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, _)) - groupCoordinator.groupManager.handleTxnCompletion(producerId, - offsetsPartitions.map(_.partition).toSet, isCommit = random.nextBoolean) - responseCallback(Errors.NONE) + lock.foreach(_.lock()) + try { + groupCoordinator.groupManager.handleTxnCompletion(producerId, + offsetsPartitions.map(_.partition).toSet, isCommit = random.nextBoolean) + responseCallback(Errors.NONE) + } finally lock.foreach(_.unlock()) + } override def awaitAndVerify(member: GroupMember): Unit = { val error = await(member, 500) diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala index 1be4969..3788cb1 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala @@ -509,6 +509,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren class InitProducerIdOperation(val producerIdAndEpoch: Option[ProducerIdAndEpoch] = None) extends TxnOperation[InitProducerIdResult] { override def run(txn: Transaction): Unit = { transactionCoordinator.handleInitProducerId(txn.transactionalId, 60000, producerIdAndEpoch, resultCallback) + replicaManager.tryCompleteActions() } override def awaitAndVerify(txn: Transaction): Unit = { val initPidResult = result.getOrElse(throw new IllegalStateException("InitProducerId has not completed")) @@ -525,6 +526,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren txnMetadata.producerEpoch, partitions, resultCallback) + replicaManager.tryCompleteActions() } } override def awaitAndVerify(txn: Transaction): Unit = { @@ -597,12 +599,13 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren } } txnStateManager.enableTransactionalIdExpiration() + replicaManager.tryCompleteActions() time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs + 1) } override def await(): Unit = { val (_, success) = TestUtils.computeUntilTrue({ - replicaManager.tryCompleteDelayedRequests() + replicaManager.tryCompleteActions() transactions.forall(txn => transactionMetadata(txn).isEmpty) })(identity) assertTrue("Transaction not expired", success) diff --git a/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala b/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala index c29dfec..8f481f1 100644 --- a/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala +++ b/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala @@ -33,12 +33,12 @@ import scala.jdk.CollectionConverters._ class DelayedOperationTest { - var purgatory: DelayedOperationPurgatory[MockDelayedOperation] = null + var purgatory: DelayedOperationPurgatory[DelayedOperation] = null var executorService: ExecutorService = null @Before def setUp(): Unit = { - purgatory = DelayedOperationPurgatory[MockDelayedOperation](purgatoryName = "mock") + purgatory = DelayedOperationPurgatory[DelayedOperation](purgatoryName = "mock") } @After @@ -49,6 +49,43 @@ class DelayedOperationTest { } @Test + def testLockInTryCompleteElseWatch(): Unit = { + val op = new DelayedOperation(100000L) { + override def onExpiration(): Unit = {} + override def onComplete(): Unit = {} + override def tryComplete(): Boolean = { + assertTrue(lock.asInstanceOf[ReentrantLock].isHeldByCurrentThread) + false + } + override def safeTryComplete(): Boolean = { + fail("tryCompleteElseWatch should not use safeTryComplete") + super.safeTryComplete() + } + } + purgatory.tryCompleteElseWatch(op, Seq("key")) + } + + @Test + def testSafeTryCompleteOrElse(): Unit = { + def op(shouldComplete: Boolean) = new DelayedOperation(100000L) { + override def onExpiration(): Unit = {} + override def onComplete(): Unit = {} + override def tryComplete(): Boolean = { + assertTrue(lock.asInstanceOf[ReentrantLock].isHeldByCurrentThread) + shouldComplete + } + } + var pass = false + assertFalse(op(false).safeTryCompleteOrElse { + pass = true + }) + assertTrue(pass) + assertTrue(op(true).safeTryCompleteOrElse { + fail("this method should NOT be executed") + }) + } + + @Test def testRequestSatisfaction(): Unit = { val r1 = new MockDelayedOperation(100000L) val r2 = new MockDelayedOperation(100000L) @@ -193,44 +230,6 @@ class DelayedOperationTest { } /** - * Verify that if there is lock contention between two threads attempting to complete, - * completion is performed without any blocking in either thread. - */ - @Test - def testTryCompleteLockContention(): Unit = { - executorService = Executors.newSingleThreadExecutor() - val completionAttemptsRemaining = new AtomicInteger(Int.MaxValue) - val tryCompleteSemaphore = new Semaphore(1) - val key = "key" - - val op = new MockDelayedOperation(100000L, None, None) { - override def tryComplete() = { - val shouldComplete = completionAttemptsRemaining.decrementAndGet <= 0 - tryCompleteSemaphore.acquire() - try { - if (shouldComplete) - forceComplete() - else - false - } finally { - tryCompleteSemaphore.release() - } - } - } - - purgatory.tryCompleteElseWatch(op, Seq(key)) - completionAttemptsRemaining.set(2) - tryCompleteSemaphore.acquire() - val future = runOnAnotherThread(purgatory.checkAndComplete(key), shouldComplete = false) - TestUtils.waitUntilTrue(() => tryCompleteSemaphore.hasQueuedThreads, "Not attempting to complete") - purgatory.checkAndComplete(key) // this should not block even though lock is not free - assertFalse("Operation should not have completed", op.isCompleted) - tryCompleteSemaphore.release() - future.get(10, TimeUnit.SECONDS) - assertTrue("Operation should have completed", op.isCompleted) - } - - /** * Test `tryComplete` with multiple threads to verify that there are no timing windows * when completion is not performed even if the thread that makes the operation completable * may not be able to acquire the operation lock. Since it is difficult to test all scenarios, @@ -280,23 +279,6 @@ class DelayedOperationTest { ops.foreach { op => assertTrue("Operation should have completed", op.isCompleted) } } - @Test - def testDelayedOperationLock(): Unit = { - verifyDelayedOperationLock(new MockDelayedOperation(100000L), mismatchedLocks = false) - } - - @Test - def testDelayedOperationLockOverride(): Unit = { - def newMockOperation = { - val lock = new ReentrantLock - new MockDelayedOperation(100000L, Some(lock), Some(lock)) - } - verifyDelayedOperationLock(newMockOperation, mismatchedLocks = false) - - verifyDelayedOperationLock(new MockDelayedOperation(100000L, None, Some(new ReentrantLock)), - mismatchedLocks = true) - } - def verifyDelayedOperationLock(mockDelayedOperation: => MockDelayedOperation, mismatchedLocks: Boolean): Unit = { val key = "key" executorService = Executors.newSingleThreadExecutor