This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 2.5
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.5 by this push:
     new e9013ef  KAFKA-13099; Transactional expiration should account for max 
batch size (#11098)
e9013ef is described below

commit e9013efb987efb0be9cb1b0e75e828cbc69e1194
Author: Jason Gustafson <[email protected]>
AuthorDate: Tue Jul 27 18:23:00 2021 -0700

    KAFKA-13099; Transactional expiration should account for max batch size 
(#11098)
    
    When expiring transactionalIds, we group the tombstones together into 
batches. Currently there is no limit on the size of these batches, which can 
lead to `MESSAGE_TOO_LARGE` errors when a bunch of transactionalIds need to be 
expired at the same time. This patch fixes the problem by ensuring that the 
batch size respects the configured limit. Any transactionalIds which are 
eligible for expiration and cannot be fit into the batch are postponed until 
the next periodic check.
    
    Reviewers: David Jacot <[email protected]>, Guozhang Wang 
<[email protected]>
---
 .../apache/kafka/common/record/MemoryRecords.java  |  14 ++
 .../transaction/TransactionMetadata.scala          |  20 +-
 .../transaction/TransactionStateManager.scala      | 207 +++++++++++-----
 core/src/main/scala/kafka/utils/Pool.scala         |   2 +-
 .../AbstractCoordinatorConcurrencyTest.scala       |   7 +-
 .../TransactionCoordinatorConcurrencyTest.scala    |   6 +-
 .../transaction/TransactionStateManagerTest.scala  | 266 ++++++++++++++++++---
 7 files changed, 421 insertions(+), 101 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java 
b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
index 8f73565..b5efe12 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
@@ -405,6 +405,20 @@ public class MemoryRecords extends AbstractRecords {
         return builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, 
compressionType, timestampType, baseOffset);
     }
 
+    public static MemoryRecordsBuilder builder(ByteBuffer buffer,
+                                               CompressionType compressionType,
+                                               TimestampType timestampType,
+                                               long baseOffset,
+                                               int maxSize) {
+        long logAppendTime = RecordBatch.NO_TIMESTAMP;
+        if (timestampType == TimestampType.LOG_APPEND_TIME)
+            logAppendTime = System.currentTimeMillis();
+
+        return new MemoryRecordsBuilder(buffer, 
RecordBatch.CURRENT_MAGIC_VALUE, compressionType, timestampType, baseOffset,
+            logAppendTime, RecordBatch.NO_PRODUCER_ID, 
RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE,
+            false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, maxSize);
+    }
+
     public static MemoryRecordsBuilder idempotentBuilder(ByteBuffer buffer,
                                                          CompressionType 
compressionType,
                                                          long baseOffset,
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
index 24b418a..ffd58e6 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
@@ -25,7 +25,10 @@ import org.apache.kafka.common.record.RecordBatch
 
 import scala.collection.{immutable, mutable}
 
-private[transaction] sealed trait TransactionState { def byte: Byte }
+private[transaction] sealed trait TransactionState {
+  def byte: Byte
+  def isExpirationAllowed: Boolean = false
+}
 
 /**
  * Transaction has not existed yet
@@ -33,7 +36,10 @@ private[transaction] sealed trait TransactionState { def 
byte: Byte }
  * transition: received AddPartitionsToTxnRequest => Ongoing
  *             received AddOffsetsToTxnRequest => Ongoing
  */
-private[transaction] case object Empty extends TransactionState { val byte: 
Byte = 0 }
+private[transaction] case object Empty extends TransactionState {
+  val byte: Byte = 0
+  override def isExpirationAllowed: Boolean = true
+}
 
 /**
  * Transaction has started and ongoing
@@ -64,14 +70,20 @@ private[transaction] case object PrepareAbort extends 
TransactionState { val byt
  *
  * Will soon be removed from the ongoing transaction cache
  */
-private[transaction] case object CompleteCommit extends TransactionState { val 
byte: Byte = 4 }
+private[transaction] case object CompleteCommit extends TransactionState {
+  val byte: Byte = 4
+  override def isExpirationAllowed: Boolean = true
+}
 
 /**
  * Group has completed abort
  *
  * Will soon be removed from the ongoing transaction cache
  */
-private[transaction] case object CompleteAbort extends TransactionState { val 
byte: Byte = 5 }
+private[transaction] case object CompleteAbort extends TransactionState {
+  val byte: Byte = 5
+  override def isExpirationAllowed: Boolean = true
+}
 
 /**
   * TransactionalId has expired and is about to be removed from the 
transaction cache
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index 174e3a5..7b0cc3c 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -33,7 +33,7 @@ import org.apache.kafka.common.internals.Topic
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.metrics.stats.{Avg, Max}
 import org.apache.kafka.common.protocol.Errors
-import org.apache.kafka.common.record.{FileRecords, MemoryRecords, 
SimpleRecord}
+import org.apache.kafka.common.record.{FileRecords, MemoryRecords, 
MemoryRecordsBuilder, Record, SimpleRecord, TimestampType}
 import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
 import org.apache.kafka.common.requests.TransactionResult
 import org.apache.kafka.common.utils.{Time, Utils}
@@ -140,80 +140,163 @@ class TransactionStateManager(brokerId: Int,
     }
   }
 
-  def enableTransactionalIdExpiration(): Unit = {
-    scheduler.schedule("transactionalId-expiration", () => {
-      val now = time.milliseconds()
-      inReadLock(stateLock) {
-        val transactionalIdByPartition: Map[Int, 
mutable.Iterable[TransactionalIdCoordinatorEpochAndMetadata]] =
-          transactionMetadataCache.flatMap { case (_, entry) =>
-            entry.metadataPerTransactionalId.filter { case (_, txnMetadata) => 
txnMetadata.state match {
-              case Empty | CompleteCommit | CompleteAbort => true
-              case _ => false
-            }
-            }.filter { case (_, txnMetadata) =>
-              txnMetadata.txnLastUpdateTimestamp <= now - 
config.transactionalIdExpirationMs
-            }.map { case (transactionalId, txnMetadata) =>
-              val txnMetadataTransition = txnMetadata.inLock {
-                txnMetadata.prepareDead()
+  private def removeExpiredTransactionalIds(
+    transactionPartition: TopicPartition,
+    txnMetadataCacheEntry: TxnMetadataCacheEntry
+  ): Unit = {
+    inReadLock(stateLock) {
+      replicaManager.getLogConfig(transactionPartition) match {
+        case Some(logConfig) =>
+          val currentTimeMs = time.milliseconds()
+          val maxBatchSize = logConfig.maxMessageSize
+          val expired = 
mutable.ListBuffer.empty[TransactionalIdCoordinatorEpochAndMetadata]
+          var recordsBuilder: MemoryRecordsBuilder = null
+          val stateEntries = 
txnMetadataCacheEntry.metadataPerTransactionalId.values.iterator.buffered
+
+          def flushRecordsBuilder(): Unit = {
+            writeTombstonesForExpiredTransactionalIds(
+              transactionPartition,
+              expired.toSeq,
+              recordsBuilder.build()
+            )
+            expired.clear()
+            recordsBuilder = null
+          }
+
+          while (stateEntries.hasNext) {
+            val txnMetadata = stateEntries.head
+            val transactionalId = txnMetadata.transactionalId
+            var fullBatch = false
+
+            txnMetadata.inLock {
+              if (txnMetadata.pendingState.isEmpty && 
shouldExpire(txnMetadata, currentTimeMs)) {
+                if (recordsBuilder == null) {
+                  recordsBuilder = MemoryRecords.builder(
+                    ByteBuffer.allocate(math.min(16384, maxBatchSize)),
+                    TransactionLog.EnforcedCompressionType,
+                    TimestampType.CREATE_TIME,
+                    0L,
+                    maxBatchSize
+                  )
+                }
+
+                if (maybeAppendExpiration(txnMetadata, recordsBuilder, 
currentTimeMs)) {
+                  val transitMetadata = txnMetadata.prepareDead()
+                  expired += TransactionalIdCoordinatorEpochAndMetadata(
+                    transactionalId,
+                    txnMetadataCacheEntry.coordinatorEpoch,
+                    transitMetadata
+                  )
+                } else {
+                  fullBatch = true
+                }
               }
-              TransactionalIdCoordinatorEpochAndMetadata(transactionalId, 
entry.coordinatorEpoch, txnMetadataTransition)
             }
-          }.groupBy { transactionalIdCoordinatorEpochAndMetadata =>
-            
partitionFor(transactionalIdCoordinatorEpochAndMetadata.transactionalId)
+
+            if (fullBatch) {
+              flushRecordsBuilder()
+            } else {
+              // Advance the iterator if we do not need to retry the append
+              stateEntries.next()
+            }
           }
 
-        val recordsPerPartition = transactionalIdByPartition
-          .map { case (partition, transactionalIdCoordinatorEpochAndMetadatas) 
=>
-            val deletes: Array[SimpleRecord] = 
transactionalIdCoordinatorEpochAndMetadatas.map { entry =>
-              new SimpleRecord(now, 
TransactionLog.keyToBytes(entry.transactionalId), null)
-            }.toArray
-            val records = 
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, deletes: _*)
-            val topicPartition = new 
TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partition)
-            (topicPartition, records)
+          if (expired.nonEmpty) {
+            flushRecordsBuilder()
           }
 
-        def removeFromCacheCallback(responses: collection.Map[TopicPartition, 
PartitionResponse]): Unit = {
-          responses.foreach { case (topicPartition, response) =>
-            inReadLock(stateLock) {
-              val toRemove = 
transactionalIdByPartition(topicPartition.partition)
-              transactionMetadataCache.get(topicPartition.partition).foreach { 
txnMetadataCacheEntry =>
-                toRemove.foreach { idCoordinatorEpochAndMetadata =>
-                  val transactionalId = 
idCoordinatorEpochAndMetadata.transactionalId
-                  val txnMetadata = 
txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId)
-                  txnMetadata.inLock {
-                    if (txnMetadataCacheEntry.coordinatorEpoch == 
idCoordinatorEpochAndMetadata.coordinatorEpoch
-                      && txnMetadata.pendingState.contains(Dead)
-                      && txnMetadata.producerEpoch == 
idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch
-                      && response.error == Errors.NONE) {
-                      
txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId)
-                    } else {
-                      warn(s"Failed to remove expired transactionalId: 
$transactionalId" +
-                        s" from cache. Tombstone append error code: 
${response.error}," +
-                        s" pendingState: ${txnMetadata.pendingState}, 
producerEpoch: ${txnMetadata.producerEpoch}," +
-                        s" expected producerEpoch: 
${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}," +
-                        s" coordinatorEpoch: 
${txnMetadataCacheEntry.coordinatorEpoch}, expected coordinatorEpoch: " +
-                        s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}")
-                      txnMetadata.pendingState = None
-                    }
-                  }
+        case None =>
+          warn(s"Transaction expiration for partition $transactionPartition 
failed because the log " +
+            "config was not available, which likely means the partition is not 
online or is no longer local.")
+      }
+    }
+  }
+
+  private def shouldExpire(
+    txnMetadata: TransactionMetadata,
+    currentTimeMs: Long
+  ): Boolean = {
+    txnMetadata.state.isExpirationAllowed &&
+      txnMetadata.txnLastUpdateTimestamp <= currentTimeMs - 
config.transactionalIdExpirationMs
+  }
+
+  private def maybeAppendExpiration(
+    txnMetadata: TransactionMetadata,
+    recordsBuilder: MemoryRecordsBuilder,
+    currentTimeMs: Long
+  ): Boolean = {
+    val keyBytes = TransactionLog.keyToBytes(txnMetadata.transactionalId)
+    if (recordsBuilder.hasRoomFor(currentTimeMs, keyBytes, null, 
Record.EMPTY_HEADERS)) {
+      recordsBuilder.append(currentTimeMs, keyBytes, null, 
Record.EMPTY_HEADERS)
+      true
+    } else {
+      false
+    }
+  }
+
+  private[transaction] def removeExpiredTransactionalIds(): Unit = {
+    inReadLock(stateLock) {
+      transactionMetadataCache.foreach { case (partitionId, 
partitionCacheEntry) =>
+        val transactionPartition = new 
TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partitionId)
+        removeExpiredTransactionalIds(transactionPartition, 
partitionCacheEntry)
+      }
+    }
+  }
+
+  private def writeTombstonesForExpiredTransactionalIds(
+    transactionPartition: TopicPartition,
+    expiredForPartition: Iterable[TransactionalIdCoordinatorEpochAndMetadata],
+    tombstoneRecords: MemoryRecords
+  ): Unit = {
+    def removeFromCacheCallback(responses: collection.Map[TopicPartition, 
PartitionResponse]): Unit = {
+      responses.foreach { case (topicPartition, response) =>
+        inReadLock(stateLock) {
+          transactionMetadataCache.get(topicPartition.partition).foreach { 
txnMetadataCacheEntry =>
+            expiredForPartition.foreach { idCoordinatorEpochAndMetadata =>
+              val transactionalId = 
idCoordinatorEpochAndMetadata.transactionalId
+              val txnMetadata = 
txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId)
+              txnMetadata.inLock {
+                if (txnMetadataCacheEntry.coordinatorEpoch == 
idCoordinatorEpochAndMetadata.coordinatorEpoch
+                  && txnMetadata.pendingState.contains(Dead)
+                  && txnMetadata.producerEpoch == 
idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch
+                  && response.error == Errors.NONE) {
+                  
txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId)
+                } else {
+                  warn(s"Failed to remove expired transactionalId: 
$transactionalId" +
+                    s" from cache. Tombstone append error code: 
${response.error}," +
+                    s" pendingState: ${txnMetadata.pendingState}, 
producerEpoch: ${txnMetadata.producerEpoch}," +
+                    s" expected producerEpoch: 
${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}," +
+                    s" coordinatorEpoch: 
${txnMetadataCacheEntry.coordinatorEpoch}, expected coordinatorEpoch: " +
+                    s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}")
+                  txnMetadata.pendingState = None
                 }
               }
             }
           }
         }
-
-        replicaManager.appendRecords(
-          config.requestTimeoutMs,
-          TransactionLog.EnforcedRequiredAcks,
-          internalTopicsAllowed = true,
-          origin = AppendOrigin.Coordinator,
-          recordsPerPartition,
-          removeFromCacheCallback,
-          Some(stateLock.readLock)
-        )
       }
+    }
 
-    }, delay = config.removeExpiredTransactionalIdsIntervalMs, period = 
config.removeExpiredTransactionalIdsIntervalMs)
+    inReadLock(stateLock) {
+      replicaManager.appendRecords(
+        config.requestTimeoutMs,
+        TransactionLog.EnforcedRequiredAcks,
+        internalTopicsAllowed = true,
+        origin = AppendOrigin.Coordinator,
+        entriesPerPartition = Map(transactionPartition -> tombstoneRecords),
+        removeFromCacheCallback,
+        Some(stateLock.readLock)
+      )
+    }
+  }
+
+  def enableTransactionalIdExpiration(): Unit = {
+    scheduler.schedule(
+      name = "transactionalId-expiration",
+      fun = removeExpiredTransactionalIds,
+      delay = config.removeExpiredTransactionalIdsIntervalMs,
+      period = config.removeExpiredTransactionalIdsIntervalMs
+    )
   }
 
   def getTransactionState(transactionalId: String): Either[Errors, 
Option[CoordinatorEpochAndTxnMetadata]] = {
diff --git a/core/src/main/scala/kafka/utils/Pool.scala 
b/core/src/main/scala/kafka/utils/Pool.scala
index 2f24aff..c46a98e 100644
--- a/core/src/main/scala/kafka/utils/Pool.scala
+++ b/core/src/main/scala/kafka/utils/Pool.scala
@@ -74,7 +74,7 @@ class Pool[K,V](valueFactory: Option[K => V] = None) extends 
Iterable[(K, V)] {
   def values: Iterable[V] = pool.values.asScala
 
   def clear(): Unit = { pool.clear() }
-  
+
   override def size: Int = pool.size
   
   override def iterator: Iterator[(K, V)] = new Iterator[(K,V)]() {
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
index c39a23f..9927ae4 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
@@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger
 import java.util.concurrent.locks.Lock
 
 import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
-import kafka.log.{AppendOrigin, Log}
+import kafka.log.{AppendOrigin, Log, LogConfig}
 import kafka.server._
 import kafka.utils._
 import kafka.utils.timer.MockTimer
@@ -218,6 +218,11 @@ object AbstractCoordinatorConcurrencyTest {
     def updateLog(topicPartition: TopicPartition, log: Log, endOffset: Long): 
Unit = {
       getOrCreateLogs().put(topicPartition, (log, endOffset))
     }
+
+    override def getLogConfig(topicPartition: TopicPartition): 
Option[LogConfig] = {
+      getOrCreateLogs().get(topicPartition).map(_._1.config)
+    }
+
     override def getLog(topicPartition: TopicPartition): Option[Log] =
       getOrCreateLogs().get(topicPartition).map(l => l._1)
     override def getLogEndOffset(topicPartition: TopicPartition): Option[Long] 
=
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
index 6b05ef3..fc5964b 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -17,14 +17,15 @@
 package kafka.coordinator.transaction
 
 import java.nio.ByteBuffer
+import java.util.Collections
 
 import kafka.api.KAFKA_2_4_IV1
 import kafka.coordinator.AbstractCoordinatorConcurrencyTest
 import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
 import kafka.coordinator.transaction.TransactionCoordinatorConcurrencyTest._
-import kafka.log.Log
 import kafka.server.{DelayedOperationPurgatory, FetchDataInfo, FetchLogEnd, 
KafkaConfig, LogOffsetMetadata, MetadataCache}
 import kafka.utils.timer.MockTimer
+import kafka.log.{Log, LogConfig}
 import kafka.utils.{Pool, TestUtils}
 import org.apache.kafka.clients.{ClientResponse, NetworkClient}
 import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME
@@ -442,8 +443,9 @@ class TransactionCoordinatorConcurrencyTest extends 
AbstractCoordinatorConcurren
   }
 
   private def prepareTxnLog(partitionId: Int): Unit = {
-
     val logMock: Log =  EasyMock.mock(classOf[Log])
+    EasyMock.expect(logMock.config).andStubReturn(new 
LogConfig(Collections.emptyMap()))
+
     val fileRecordsMock: FileRecords = EasyMock.mock(classOf[FileRecords])
 
     val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, 
partitionId)
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index dab9181..3e15d26 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -20,10 +20,11 @@ import java.lang.management.ManagementFactory
 import java.nio.ByteBuffer
 import java.util.concurrent.CountDownLatch
 import java.util.concurrent.locks.ReentrantLock
+import java.util.function.Consumer
 
 import javax.management.ObjectName
 import kafka.api.KAFKA_2_4_IV1
-import kafka.log.{AppendOrigin, Log}
+import kafka.log.{AppendOrigin, Defaults, Log, LogConfig}
 import kafka.server.{FetchDataInfo, FetchLogEnd, LogOffsetMetadata, 
ReplicaManager}
 import kafka.utils.{MockScheduler, Pool, TestUtils}
 import kafka.zk.KafkaZkClient
@@ -36,12 +37,12 @@ import 
org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
 import org.apache.kafka.common.requests.TransactionResult
 import org.apache.kafka.common.utils.MockTime
 import org.easymock.{Capture, EasyMock, IAnswer}
-import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
+import org.junit.Assert.{assertEquals, assertFalse, assertNull, assertTrue}
 import org.junit.{After, Before, Test}
 import org.scalatest.Assertions.fail
 
-import scala.collection.JavaConverters._
 import scala.collection.{Map, mutable}
+import scala.jdk.CollectionConverters._
 
 class TransactionStateManagerTest {
 
@@ -513,7 +514,7 @@ class TransactionStateManagerTest {
   }
 
   @Test
-  def shouldRemoveCompleteCommmitExpiredTransactionalIds(): Unit = {
+  def shouldRemoveCompleteCommitExpiredTransactionalIds(): Unit = {
     setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteCommit)
     verifyMetadataDoesntExist(transactionalId1)
     verifyMetadataDoesExistAndIsUsable(transactionalId2)
@@ -562,6 +563,159 @@ class TransactionStateManagerTest {
   }
 
   @Test
+  def testTransactionalExpirationWithTooSmallBatchSize(): Unit = {
+    // The batch size is too small, but we nevertheless expect the
+    // coordinator to attempt the append. This test mainly ensures
+    // that the expiration task does not get stuck.
+
+    val partitionIds = 0 until numPartitions
+    val maxBatchSize = 16
+
+    loadTransactionsForPartitions(partitionIds)
+    val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds 
= 20)
+
+    EasyMock.reset(replicaManager)
+    expectLogConfig(partitionIds, maxBatchSize)
+
+    val attemptedAppends = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(Errors.MESSAGE_TOO_LARGE, attemptedAppends)
+    EasyMock.replay(replicaManager)
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+    transactionManager.removeExpiredTransactionalIds()
+    EasyMock.verify(replicaManager)
+
+    for (batches <- attemptedAppends.values; batch <- batches) {
+      assertTrue(batch.sizeInBytes() > maxBatchSize)
+    }
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+  }
+
+  @Test
+  def testTransactionalExpirationWithOfflineLogDir(): Unit = {
+    val onlinePartitionId = 0
+    val offlinePartitionId = 1
+
+    val partitionIds = Seq(onlinePartitionId, offlinePartitionId)
+    val maxBatchSize = 512
+
+    loadTransactionsForPartitions(partitionIds)
+    val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds 
= 20)
+
+    EasyMock.reset(replicaManager)
+
+    // Partition 0 returns log config as normal
+    expectLogConfig(Seq(onlinePartitionId), maxBatchSize)
+    // No log config returned for partition 0 since it is offline
+    EasyMock.expect(replicaManager.getLogConfig(new 
TopicPartition(TRANSACTION_STATE_TOPIC_NAME, offlinePartitionId)))
+      .andStubReturn(None)
+
+    val appendedRecords = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(Errors.NONE, appendedRecords)
+    EasyMock.replay(replicaManager)
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+    transactionManager.removeExpiredTransactionalIds()
+    EasyMock.verify(replicaManager)
+
+    assertEquals(Set(onlinePartitionId), 
appendedRecords.keySet.map(_.partition))
+
+    val (transactionalIdsForOnlinePartition, 
transactionalIdsForOfflinePartition) =
+      allTransactionalIds.partition { transactionalId =>
+        transactionManager.partitionFor(transactionalId) == onlinePartitionId
+      }
+
+    val expiredTransactionalIds = 
collectTransactionalIdsFromTombstones(appendedRecords)
+    assertEquals(transactionalIdsForOnlinePartition, expiredTransactionalIds)
+    assertEquals(transactionalIdsForOfflinePartition, 
listExpirableTransactionalIds())
+  }
+
+  @Test
+  def testTransactionExpirationShouldRespectBatchSize(): Unit = {
+    val partitionIds = 0 until numPartitions
+    val maxBatchSize = 512
+
+    loadTransactionsForPartitions(partitionIds)
+    val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds 
= 1000)
+
+    EasyMock.reset(replicaManager)
+    expectLogConfig(partitionIds, maxBatchSize)
+
+    val appendedRecords = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(Errors.NONE, appendedRecords)
+    EasyMock.replay(replicaManager)
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+    transactionManager.removeExpiredTransactionalIds()
+    EasyMock.verify(replicaManager)
+
+    assertEquals(Set.empty, listExpirableTransactionalIds())
+    assertEquals(partitionIds.toSet, appendedRecords.keys.map(_.partition))
+
+    appendedRecords.values.foreach { batches =>
+      assertTrue(batches.size > 1) // Ensure a non-trivial test case
+      assertTrue(batches.forall(_.sizeInBytes() < maxBatchSize))
+    }
+
+    val expiredTransactionalIds = 
collectTransactionalIdsFromTombstones(appendedRecords)
+    assertEquals(allTransactionalIds, expiredTransactionalIds)
+  }
+
+  private def collectTransactionalIdsFromTombstones(
+    appendedRecords: mutable.Map[TopicPartition, mutable.Buffer[MemoryRecords]]
+  ): Set[String] = {
+    val expiredTransactionalIds = mutable.Set.empty[String]
+    appendedRecords.values.foreach { batches =>
+      batches.foreach { records =>
+        records.records.forEach(new Consumer[Record] {
+          override def accept(record: Record): Unit = {
+            val transactionalId = 
TransactionLog.readTxnRecordKey(record.key).transactionalId
+            assertNull(record.value)
+            expiredTransactionalIds += transactionalId
+            assertEquals(Right(None), 
transactionManager.getTransactionState(transactionalId))
+
+          }
+        })
+      }
+    }
+    expiredTransactionalIds.toSet
+  }
+
+  private def loadExpiredTransactionalIds(
+    numTransactionalIds: Int
+  ): Set[String] = {
+    val allTransactionalIds = mutable.Set.empty[String]
+    for (i <- 0 to numTransactionalIds) {
+      val txnlId = s"id_$i"
+      val producerId = i
+      val txnMetadata = transactionMetadata(txnlId, producerId)
+      txnMetadata.txnLastUpdateTimestamp = time.milliseconds() - 
txnConfig.transactionalIdExpirationMs
+      transactionManager.putTransactionStateIfNotExists(txnMetadata)
+      allTransactionalIds += txnlId
+    }
+    allTransactionalIds.toSet
+  }
+
+  private def listExpirableTransactionalIds(): Set[String] = {
+    val activeTransactionalIds = transactionManager.transactionMetadataCache
+        .values
+        .flatMap(_.metadataPerTransactionalId.values.map(_.transactionalId))
+
+    activeTransactionalIds.filter { transactionalId =>
+      transactionManager.getTransactionState(transactionalId) match {
+        case Right(Some(epochAndMetadata)) =>
+          val txnMetadata = epochAndMetadata.transactionMetadata
+          val timeSinceLastUpdate = time.milliseconds() - 
txnMetadata.txnLastUpdateTimestamp
+          timeSinceLastUpdate >= txnConfig.transactionalIdExpirationMs &&
+            txnMetadata.state.isExpirationAllowed &&
+            txnMetadata.pendingState.isEmpty
+        case _ => false
+      }
+    }.toSet
+  }
+
+  @Test
   def testSuccessfulReimmigration(): Unit = {
     txnMetadata1.state = PrepareCommit
     txnMetadata1.addPartitions(Set[TopicPartition](new 
TopicPartition("topic1", 0),
@@ -633,35 +787,69 @@ class TransactionStateManagerTest {
     }
   }
 
-  private def setupAndRunTransactionalIdExpiration(error: Errors, txnState: 
TransactionState): Unit = {
-    for (partitionId <- 0 until numPartitions) {
+  private def expectTransactionalIdExpiration(
+    appendError: Errors,
+    capturedAppends: mutable.Map[TopicPartition, mutable.Buffer[MemoryRecords]]
+  ): Unit = {
+    val recordsCapture: Capture[Map[TopicPartition, MemoryRecords]] = 
EasyMock.newCapture()
+    val callbackCapture: Capture[Map[TopicPartition, PartitionResponse] => 
Unit] = EasyMock.newCapture()
+
+    EasyMock.expect(replicaManager.appendRecords(
+      EasyMock.anyLong(),
+      EasyMock.eq((-1).toShort),
+      EasyMock.eq(true),
+      EasyMock.eq(AppendOrigin.Coordinator),
+      EasyMock.capture(recordsCapture),
+      EasyMock.capture(callbackCapture),
+      EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]],
+      EasyMock.anyObject()
+    )).andAnswer(new IAnswer[Unit] {
+      override def answer(): Unit = {
+        callbackCapture.getValue.apply(
+          recordsCapture.getValue.map { case (topicPartition, records) =>
+            val batches = capturedAppends.getOrElse(topicPartition, {
+              val batches = mutable.Buffer.empty[MemoryRecords]
+              capturedAppends += topicPartition -> batches
+              batches
+            })
+
+            batches += records
+
+            topicPartition -> new PartitionResponse(appendError, 0L, 
RecordBatch.NO_TIMESTAMP, 0L)
+          }.toMap
+        )
+      }
+    }).anyTimes()
+  }
+
+  private def loadTransactionsForPartitions(
+    partitionIds: Seq[Int]
+  ): Unit = {
+    for (partitionId <- partitionIds) {
       transactionManager.addLoadedTransactionsToCache(partitionId, 0, new 
Pool[String, TransactionMetadata]())
     }
+  }
 
-    val capturedArgument: Capture[Map[TopicPartition, PartitionResponse] => 
Unit] = EasyMock.newCapture()
+  private def expectLogConfig(
+    partitionIds: Seq[Int],
+    maxBatchSize: Int
+  ): Unit = {
+    val logConfig: LogConfig = EasyMock.mock(classOf[LogConfig])
+    EasyMock.expect(logConfig.maxMessageSize).andStubReturn(maxBatchSize)
 
-    val partition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, 
transactionManager.partitionFor(transactionalId1))
-    val recordsByPartition = Map(partition -> 
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType,
-      new SimpleRecord(time.milliseconds() + 
txnConfig.removeExpiredTransactionalIdsIntervalMs, 
TransactionLog.keyToBytes(transactionalId1), null)))
-
-    txnState match {
-      case Empty | CompleteCommit | CompleteAbort =>
-
-        EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(),
-          EasyMock.eq((-1).toShort),
-          EasyMock.eq(true),
-          EasyMock.eq(AppendOrigin.Coordinator),
-          EasyMock.eq(recordsByPartition),
-          EasyMock.capture(capturedArgument),
-          EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]],
-          EasyMock.anyObject()
-        )).andAnswer(() => capturedArgument.getValue.apply(
-          Map(partition -> new PartitionResponse(error, 0L, 
RecordBatch.NO_TIMESTAMP, 0L)))
-        )
-      case _ => // shouldn't append
+    for (partitionId <- partitionIds) {
+      EasyMock.expect(replicaManager.getLogConfig(new 
TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId)))
+        .andStubReturn(Some(logConfig))
     }
 
-    EasyMock.replay(replicaManager)
+    EasyMock.replay(logConfig)
+  }
+
+  private def setupAndRunTransactionalIdExpiration(error: Errors, txnState: 
TransactionState): Unit = {
+    val partitionIds = 0 until numPartitions
+
+    loadTransactionsForPartitions(partitionIds)
+    expectLogConfig(partitionIds, Defaults.MaxMessageSize)
 
     txnMetadata1.txnLastUpdateTimestamp = time.milliseconds() - 
txnConfig.transactionalIdExpirationMs
     txnMetadata1.state = txnState
@@ -670,12 +858,28 @@ class TransactionStateManagerTest {
     txnMetadata2.txnLastUpdateTimestamp = time.milliseconds()
     transactionManager.putTransactionStateIfNotExists(txnMetadata2)
 
-    transactionManager.enableTransactionalIdExpiration()
-    time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs)
-
-    scheduler.tick()
+    val appendedRecords = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(error, appendedRecords)
 
+    EasyMock.replay(replicaManager)
+    transactionManager.removeExpiredTransactionalIds()
     EasyMock.verify(replicaManager)
+
+    val stateAllowsExpiration = txnState match {
+      case Empty | CompleteCommit | CompleteAbort => true
+      case _ => false
+    }
+
+    if (stateAllowsExpiration) {
+      val partitionId = transactionManager.partitionFor(transactionalId1)
+      val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, 
partitionId)
+      val expectedTombstone = new SimpleRecord(time.milliseconds(), 
TransactionLog.keyToBytes(transactionalId1), null)
+      val expectedRecords = 
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, 
expectedTombstone)
+      assertEquals(Set(topicPartition), appendedRecords.keySet)
+      assertEquals(Seq(expectedRecords), appendedRecords(topicPartition).toSeq)
+    } else {
+      assertEquals(Map.empty, appendedRecords)
+    }
   }
 
   private def verifyWritesTxnMarkersInPrepareState(state: TransactionState): 
Unit = {

Reply via email to