dajac commented on code in PR #13787:
URL: https://github.com/apache/kafka/pull/13787#discussion_r1221060938


##########
core/src/main/scala/kafka/log/UnifiedLog.scala:
##########
@@ -579,6 +579,28 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     result
   }
 
+  /**
+   * Maybe create and return the verification guard object for the given 
producer ID if the transaction is not yet ongoing.
+   * Creation starts the verification process. Otherwise return null.
+   */
+  def maybeStartTransactionVerification(producerId: Long): Object = lock 
synchronized {
+    if (hasOngoingTransaction(producerId))
+      null
+    else
+      verificationGuard(producerId, true)
+  }
+
+  /**
+   * Maybe create the VerificationStateEntry for the given producer ID -- if 
an entry is present, return its verification guard, otherwise, return null.
+   */
+  def verificationGuard(producerId: Long, createIfAbsent: Boolean = false): 
Object = lock synchronized {

Review Comment:
   nit: Should we call this `getOrMaybeCreateVerificationGuard`? Could this one 
be package private? It seems that we only access it externally from tests.



##########
core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala:
##########
@@ -2095,82 +2105,143 @@ class ReplicaManagerTest {
   }
 
   @Test
-  def testVerificationForTransactionalPartitions(): Unit = {
-    val tp = new TopicPartition(topic, 0)
-    val transactionalId = "txn1"
+  def testVerificationForTransactionalPartitionsOnly(): Unit = {
+    val tp0 = new TopicPartition(topic, 0)
+    val tp1 = new TopicPartition(topic, 1)
     val producerId = 24L
     val producerEpoch = 0.toShort
     val sequence = 0
-    
-    val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new 
File(_)))
-    val metadataCache = mock(classOf[MetadataCache])
+    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = new ReplicaManager(
-      metrics = metrics,
-      config = config,
-      time = time,
-      scheduler = new MockScheduler(time),
-      logManager = mockLogMgr,
-      quotaManagers = quotaManager,
-      metadataCache = metadataCache,
-      logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
-      alterPartitionManager = alterPartitionManager,
-      addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
-
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0, tp1), node)
     try {
-      val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp.topic), 
tp, Seq(0, 1), LeaderAndIsr(1,  List(0, 1)))
-      replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => 
())
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
-      // We must set up the metadata cache to handle the append and 
verification.
-      val metadataResponseTopic = Seq(new MetadataResponseTopic()
-        .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
-        .setPartitions(Seq(
-          new MetadataResponsePartition()
-            .setPartitionIndex(0)
-            .setLeaderId(0)).asJava))
-      val node = new Node(0, "host1", 0)
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp1.topic), tp1, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
-      when(metadataCache.contains(tp)).thenReturn(true)
-      
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
-      when(metadataCache.getAliveBrokerNode(0, 
config.interBrokerListenerName)).thenReturn(Some(node))
-      when(metadataCache.getAliveBrokerNode(1, 
config.interBrokerListenerName)).thenReturn(None)
-      
-      // We will attempt to schedule to the request handler thread using a non 
request handler thread. Set this to avoid error.
-      KafkaRequestHandler.setBypassThreadCheck(true)
+      // If we supply no transactional ID and idempotent records, we do not 
verify.
+      val idempotentRecords = 
MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
+        new SimpleRecord("message".getBytes))
+      appendRecords(replicaManager, tp0, idempotentRecords)
+      verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), 
any[AddPartitionsToTxnManager.AppendCallback]())
+      assertEquals(null, getVerificationGuard(replicaManager, tp0, producerId))
+
+      // If we supply a transactional ID and some transactional and some 
idempotent records, we should only verify the topic partition with 
transactional records.
+      val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence + 1,
+        new SimpleRecord("message".getBytes))
+
+      val transactionToAdd = new AddPartitionsToTxnTransaction()
+        .setTransactionalId(transactionalId)
+        .setProducerId(producerId)
+        .setProducerEpoch(producerEpoch)
+        .setVerifyOnly(true)
+        .setTopics(new AddPartitionsToTxnTopicCollection(
+          Seq(new 
AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
+        ))
+
+      val idempotentRecords2 = 
MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
+        new SimpleRecord("message".getBytes))
+      appendRecordsToMultipleTopics(replicaManager, Map(tp0 -> 
transactionalRecords, tp1 -> idempotentRecords2), transactionalId, Some(0))
+      verify(addPartitionsToTxnManager, 
times(1)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), 
any[AddPartitionsToTxnManager.AppendCallback]())
+      assertNotEquals(null, getVerificationGuard(replicaManager, tp0, 
producerId))
+      assertEquals(null, getVerificationGuard(replicaManager, tp1, producerId))
+    } finally {
+      replicaManager.shutdown()
+    }
+
+    TestUtils.assertNoNonDaemonThreads(this.getClass.getName)
+  }
+
+  @Test
+  def testVerificationFlow(): Unit = {
+    val tp0 = new TopicPartition(topic, 0)
+    val producerId = 24L
+    val producerEpoch = 0.toShort
+    val sequence = 6
+    val node = new Node(0, "host1", 0)
+    val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
+
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0), node)
+    try {
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
       // Append some transactional records.
       val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
-        new SimpleRecord(s"message $sequence".getBytes))
-      val result = appendRecords(replicaManager, tp, transactionalRecords, 
transactionalId = transactionalId, transactionStatePartition = Some(0))
-      
+        new SimpleRecord("message".getBytes))
+
       val transactionToAdd = new AddPartitionsToTxnTransaction()
         .setTransactionalId(transactionalId)
         .setProducerId(producerId)
         .setProducerEpoch(producerEpoch)
         .setVerifyOnly(true)
         .setTopics(new AddPartitionsToTxnTopicCollection(
-          Seq(new 
AddPartitionsToTxnTopic().setName(tp.topic).setPartitions(Collections.singletonList(tp.partition))).iterator.asJava
+          Seq(new 
AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
         ))
-      
-      val appendCallback = 
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
+
       // We should add these partitions to the manager to verify.
+      val result = appendRecords(replicaManager, tp0, transactionalRecords, 
transactionalId = transactionalId, transactionStatePartition = Some(0))
+      val appendCallback = 
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
       verify(addPartitionsToTxnManager, 
times(1)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), appendCallback.capture())
+      val verificationGuard = getVerificationGuard(replicaManager, tp0, 
producerId)
+      assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
 
       // Confirm we did not write to the log and instead returned error.
       val callback: AddPartitionsToTxnManager.AppendCallback = 
appendCallback.getValue()
-      callback(Map(tp -> Errors.INVALID_RECORD).toMap)
+      callback(Map(tp0 -> Errors.INVALID_RECORD).toMap)
       assertEquals(Errors.INVALID_RECORD, result.assertFired.error)
+      assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
+
+      // This time verification is successful

Review Comment:
   nit: `.`



##########
core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala:
##########
@@ -2095,82 +2105,143 @@ class ReplicaManagerTest {
   }
 
   @Test
-  def testVerificationForTransactionalPartitions(): Unit = {
-    val tp = new TopicPartition(topic, 0)
-    val transactionalId = "txn1"
+  def testVerificationForTransactionalPartitionsOnly(): Unit = {
+    val tp0 = new TopicPartition(topic, 0)
+    val tp1 = new TopicPartition(topic, 1)
     val producerId = 24L
     val producerEpoch = 0.toShort
     val sequence = 0
-    
-    val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new 
File(_)))
-    val metadataCache = mock(classOf[MetadataCache])
+    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = new ReplicaManager(
-      metrics = metrics,
-      config = config,
-      time = time,
-      scheduler = new MockScheduler(time),
-      logManager = mockLogMgr,
-      quotaManagers = quotaManager,
-      metadataCache = metadataCache,
-      logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
-      alterPartitionManager = alterPartitionManager,
-      addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
-
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0, tp1), node)
     try {
-      val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp.topic), 
tp, Seq(0, 1), LeaderAndIsr(1,  List(0, 1)))
-      replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => 
())
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
-      // We must set up the metadata cache to handle the append and 
verification.
-      val metadataResponseTopic = Seq(new MetadataResponseTopic()
-        .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
-        .setPartitions(Seq(
-          new MetadataResponsePartition()
-            .setPartitionIndex(0)
-            .setLeaderId(0)).asJava))
-      val node = new Node(0, "host1", 0)
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp1.topic), tp1, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
-      when(metadataCache.contains(tp)).thenReturn(true)
-      
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
-      when(metadataCache.getAliveBrokerNode(0, 
config.interBrokerListenerName)).thenReturn(Some(node))
-      when(metadataCache.getAliveBrokerNode(1, 
config.interBrokerListenerName)).thenReturn(None)
-      
-      // We will attempt to schedule to the request handler thread using a non 
request handler thread. Set this to avoid error.
-      KafkaRequestHandler.setBypassThreadCheck(true)
+      // If we supply no transactional ID and idempotent records, we do not 
verify.
+      val idempotentRecords = 
MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
+        new SimpleRecord("message".getBytes))
+      appendRecords(replicaManager, tp0, idempotentRecords)
+      verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), 
any[AddPartitionsToTxnManager.AppendCallback]())
+      assertEquals(null, getVerificationGuard(replicaManager, tp0, producerId))
+
+      // If we supply a transactional ID and some transactional and some 
idempotent records, we should only verify the topic partition with 
transactional records.
+      val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence + 1,
+        new SimpleRecord("message".getBytes))
+
+      val transactionToAdd = new AddPartitionsToTxnTransaction()
+        .setTransactionalId(transactionalId)
+        .setProducerId(producerId)
+        .setProducerEpoch(producerEpoch)
+        .setVerifyOnly(true)
+        .setTopics(new AddPartitionsToTxnTopicCollection(
+          Seq(new 
AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
+        ))
+
+      val idempotentRecords2 = 
MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
+        new SimpleRecord("message".getBytes))
+      appendRecordsToMultipleTopics(replicaManager, Map(tp0 -> 
transactionalRecords, tp1 -> idempotentRecords2), transactionalId, Some(0))
+      verify(addPartitionsToTxnManager, 
times(1)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), 
any[AddPartitionsToTxnManager.AppendCallback]())
+      assertNotEquals(null, getVerificationGuard(replicaManager, tp0, 
producerId))
+      assertEquals(null, getVerificationGuard(replicaManager, tp1, producerId))
+    } finally {
+      replicaManager.shutdown()
+    }
+
+    TestUtils.assertNoNonDaemonThreads(this.getClass.getName)
+  }
+
+  @Test
+  def testVerificationFlow(): Unit = {

Review Comment:
   nit: Could we come up with a better name? Perhaps 
`testTransactionVerificationFlow`?



##########
core/src/main/scala/kafka/log/UnifiedLog.scala:
##########
@@ -980,6 +1006,26 @@ class UnifiedLog(@volatile var logStartOffset: Long,
           if (duplicateBatch.isPresent) {
             return (updatedProducers, completedTxns.toList, 
Some(duplicateBatch.get()))
           }
+
+          // Verify that if the record is transactional & the append origin is 
client, that we either have an ongoing transaction or verified transaction 
state.
+          // This guarantees that transactional records are never written to 
the log outside of the transaction coordinator's knowledge of an open 
transaction on
+          // the partition. If we do not have an ongoing transaction or 
correct guard, return an error and do not append.
+          // There are two phases -- the first append to the log and 
subsequent appends.
+          //
+          // 1. First append: Verification starts with creating a verification 
guard object, sending a verification request to the transaction coordinator, and
+          // given a "verified" response, continuing the append path. (A 
non-verified response throws an error.) We create the unique verification guard 
for the transaction
+          // to ensure there is no race between the transaction coordinator 
response and an abort marker getting written to the log. We need a unique guard 
because we could
+          // have a sequence of events where we start a transaction 
verification, have the transaction coordinator send a verified response, write 
an abort marker,
+          // start a new transaction not aware of the partition, and receive 
the stale verification (ABA problem). With a unique verification guard object, 
this sequence would not
+          // result in appending to the log and would return an error. The 
guard is removed after the first append to the transaction and from then, we 
can rely on phase 2.
+          //
+          // 2. Subsequent appends: Once we write to the transaction, the 
in-memory state currentTxnFirstOffset is populated. This field remains until the
+          // transaction is completed or aborted. We can guarantee the 
transaction coordinator knows about the transaction given step 1 and that the 
transaction is still
+          // ongoing. If the transaction is expected to be ongoing, we will 
not set a verification guard. If the transaction is aborted, 
hasOngoingTransaction is false and
+          // requestVerificationGuard is null, so we will throw an error. A 
subsequent produce request (retry) should create verification state and return 
to phase 1.
+          if (batch.isTransactional && 
producerStateManager.producerStateManagerConfig().transactionVerificationEnabled())
+            if (!hasOngoingTransaction(batch.producerId) && 
(requestVerificationGuard != verificationGuard(batch.producerId) || 
requestVerificationGuard == null))

Review Comment:
   nit: Any reason why we are using two if statement here? It seems that we 
could combine them.



##########
core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala:
##########
@@ -3667,6 +3667,118 @@ class UnifiedLogTest {
     listener.verify(expectedHighWatermark = 4)
   }
 
+  @Test
+  def testTransactionIsOngoingAndVerificationGuard(): Unit = {
+    val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, 
true)
+
+    val producerId = 23L
+    val producerEpoch = 1.toShort
+    val sequence = 3
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, producerStateManagerConfig = 
producerStateManagerConfig)
+    assertFalse(log.hasOngoingTransaction(producerId))
+    assertEquals(null, log.verificationGuard(producerId))
+
+    val idempotentRecords = MemoryRecords.withIdempotentRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    val verificationGuard = log.maybeStartTransactionVerification(producerId)
+    assertTrue(verificationGuard != null)
+
+    log.appendAsLeader(idempotentRecords, leaderEpoch = 0)
+    assertFalse(log.hasOngoingTransaction(producerId))
+
+    // Since we wrote idempotent records, we keep verification guard.
+    assertEquals(verificationGuard, log.verificationGuard(producerId))
+
+    val transactionalRecords = MemoryRecords.withTransactionalRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence + 2,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    log.appendAsLeader(transactionalRecords, leaderEpoch = 0, 
verificationGuard = verificationGuard)
+    assertTrue(log.hasOngoingTransaction(producerId))
+    // Verification guard should be cleared now.
+    assertEquals(null, log.verificationGuard(producerId))
+
+    // A subsequent maybeStartTransactionVerification will be empty since we 
are already verified.
+    assertEquals(null, log.maybeStartTransactionVerification(producerId))

Review Comment:
   ditto.



##########
core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala:
##########
@@ -3667,6 +3667,118 @@ class UnifiedLogTest {
     listener.verify(expectedHighWatermark = 4)
   }
 
+  @Test
+  def testTransactionIsOngoingAndVerificationGuard(): Unit = {
+    val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, 
true)
+
+    val producerId = 23L
+    val producerEpoch = 1.toShort
+    val sequence = 3
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, producerStateManagerConfig = 
producerStateManagerConfig)
+    assertFalse(log.hasOngoingTransaction(producerId))
+    assertEquals(null, log.verificationGuard(producerId))
+
+    val idempotentRecords = MemoryRecords.withIdempotentRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    val verificationGuard = log.maybeStartTransactionVerification(producerId)
+    assertTrue(verificationGuard != null)
+
+    log.appendAsLeader(idempotentRecords, leaderEpoch = 0)
+    assertFalse(log.hasOngoingTransaction(producerId))
+
+    // Since we wrote idempotent records, we keep verification guard.
+    assertEquals(verificationGuard, log.verificationGuard(producerId))
+
+    val transactionalRecords = MemoryRecords.withTransactionalRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence + 2,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    log.appendAsLeader(transactionalRecords, leaderEpoch = 0, 
verificationGuard = verificationGuard)
+    assertTrue(log.hasOngoingTransaction(producerId))
+    // Verification guard should be cleared now.
+    assertEquals(null, log.verificationGuard(producerId))

Review Comment:
   nit: assertNull



##########
core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala:
##########
@@ -3667,6 +3667,118 @@ class UnifiedLogTest {
     listener.verify(expectedHighWatermark = 4)
   }
 
+  @Test
+  def testTransactionIsOngoingAndVerificationGuard(): Unit = {
+    val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, 
true)
+
+    val producerId = 23L
+    val producerEpoch = 1.toShort
+    val sequence = 3
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, producerStateManagerConfig = 
producerStateManagerConfig)
+    assertFalse(log.hasOngoingTransaction(producerId))
+    assertEquals(null, log.verificationGuard(producerId))
+
+    val idempotentRecords = MemoryRecords.withIdempotentRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    val verificationGuard = log.maybeStartTransactionVerification(producerId)
+    assertTrue(verificationGuard != null)
+
+    log.appendAsLeader(idempotentRecords, leaderEpoch = 0)
+    assertFalse(log.hasOngoingTransaction(producerId))
+
+    // Since we wrote idempotent records, we keep verification guard.
+    assertEquals(verificationGuard, log.verificationGuard(producerId))
+
+    val transactionalRecords = MemoryRecords.withTransactionalRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence + 2,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    log.appendAsLeader(transactionalRecords, leaderEpoch = 0, 
verificationGuard = verificationGuard)
+    assertTrue(log.hasOngoingTransaction(producerId))
+    // Verification guard should be cleared now.
+    assertEquals(null, log.verificationGuard(producerId))
+
+    // A subsequent maybeStartTransactionVerification will be empty since we 
are already verified.
+    assertEquals(null, log.maybeStartTransactionVerification(producerId))
+
+    val endTransactionMarkerRecord = MemoryRecords.withEndTransactionMarker(
+      producerId,
+      producerEpoch,
+      new EndTransactionMarker(ControlRecordType.COMMIT, 0)
+    )
+
+    log.appendAsLeader(endTransactionMarkerRecord, origin = 
AppendOrigin.COORDINATOR, leaderEpoch = 0)
+    assertFalse(log.hasOngoingTransaction(producerId))
+    assertEquals(null, log.verificationGuard(producerId))

Review Comment:
   ditto.



##########
core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala:
##########
@@ -3667,6 +3667,118 @@ class UnifiedLogTest {
     listener.verify(expectedHighWatermark = 4)
   }
 
+  @Test
+  def testTransactionIsOngoingAndVerificationGuard(): Unit = {
+    val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, 
true)
+
+    val producerId = 23L
+    val producerEpoch = 1.toShort
+    val sequence = 3
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, producerStateManagerConfig = 
producerStateManagerConfig)
+    assertFalse(log.hasOngoingTransaction(producerId))
+    assertEquals(null, log.verificationGuard(producerId))
+
+    val idempotentRecords = MemoryRecords.withIdempotentRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    val verificationGuard = log.maybeStartTransactionVerification(producerId)
+    assertTrue(verificationGuard != null)
+
+    log.appendAsLeader(idempotentRecords, leaderEpoch = 0)
+    assertFalse(log.hasOngoingTransaction(producerId))
+
+    // Since we wrote idempotent records, we keep verification guard.
+    assertEquals(verificationGuard, log.verificationGuard(producerId))
+
+    val transactionalRecords = MemoryRecords.withTransactionalRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence + 2,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    log.appendAsLeader(transactionalRecords, leaderEpoch = 0, 
verificationGuard = verificationGuard)
+    assertTrue(log.hasOngoingTransaction(producerId))
+    // Verification guard should be cleared now.
+    assertEquals(null, log.verificationGuard(producerId))
+
+    // A subsequent maybeStartTransactionVerification will be empty since we 
are already verified.
+    assertEquals(null, log.maybeStartTransactionVerification(producerId))
+
+    val endTransactionMarkerRecord = MemoryRecords.withEndTransactionMarker(
+      producerId,
+      producerEpoch,
+      new EndTransactionMarker(ControlRecordType.COMMIT, 0)
+    )
+
+    log.appendAsLeader(endTransactionMarkerRecord, origin = 
AppendOrigin.COORDINATOR, leaderEpoch = 0)
+    assertFalse(log.hasOngoingTransaction(producerId))
+    assertEquals(null, log.verificationGuard(producerId))
+
+    // A new maybeStartTransactionVerification will not be empty, as we need 
to verify the next transaction.
+    val newVerificationGuard = 
log.maybeStartTransactionVerification(producerId)
+    assertTrue(newVerificationGuard != null)

Review Comment:
   nit: assertNotNull. There are other cases in this file.



##########
core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala:
##########
@@ -2095,82 +2105,143 @@ class ReplicaManagerTest {
   }
 
   @Test
-  def testVerificationForTransactionalPartitions(): Unit = {
-    val tp = new TopicPartition(topic, 0)
-    val transactionalId = "txn1"
+  def testVerificationForTransactionalPartitionsOnly(): Unit = {
+    val tp0 = new TopicPartition(topic, 0)
+    val tp1 = new TopicPartition(topic, 1)
     val producerId = 24L
     val producerEpoch = 0.toShort
     val sequence = 0
-    
-    val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new 
File(_)))
-    val metadataCache = mock(classOf[MetadataCache])
+    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = new ReplicaManager(
-      metrics = metrics,
-      config = config,
-      time = time,
-      scheduler = new MockScheduler(time),
-      logManager = mockLogMgr,
-      quotaManagers = quotaManager,
-      metadataCache = metadataCache,
-      logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
-      alterPartitionManager = alterPartitionManager,
-      addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
-
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0, tp1), node)
     try {
-      val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp.topic), 
tp, Seq(0, 1), LeaderAndIsr(1,  List(0, 1)))
-      replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => 
())
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
-      // We must set up the metadata cache to handle the append and 
verification.
-      val metadataResponseTopic = Seq(new MetadataResponseTopic()
-        .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
-        .setPartitions(Seq(
-          new MetadataResponsePartition()
-            .setPartitionIndex(0)
-            .setLeaderId(0)).asJava))
-      val node = new Node(0, "host1", 0)
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp1.topic), tp1, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
-      when(metadataCache.contains(tp)).thenReturn(true)
-      
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
-      when(metadataCache.getAliveBrokerNode(0, 
config.interBrokerListenerName)).thenReturn(Some(node))
-      when(metadataCache.getAliveBrokerNode(1, 
config.interBrokerListenerName)).thenReturn(None)
-      
-      // We will attempt to schedule to the request handler thread using a non 
request handler thread. Set this to avoid error.
-      KafkaRequestHandler.setBypassThreadCheck(true)
+      // If we supply no transactional ID and idempotent records, we do not 
verify.
+      val idempotentRecords = 
MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
+        new SimpleRecord("message".getBytes))
+      appendRecords(replicaManager, tp0, idempotentRecords)
+      verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), 
any[AddPartitionsToTxnManager.AppendCallback]())
+      assertEquals(null, getVerificationGuard(replicaManager, tp0, producerId))

Review Comment:
   There are other cases here as well.



##########
core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala:
##########
@@ -2095,82 +2105,143 @@ class ReplicaManagerTest {
   }
 
   @Test
-  def testVerificationForTransactionalPartitions(): Unit = {
-    val tp = new TopicPartition(topic, 0)
-    val transactionalId = "txn1"
+  def testVerificationForTransactionalPartitionsOnly(): Unit = {
+    val tp0 = new TopicPartition(topic, 0)
+    val tp1 = new TopicPartition(topic, 1)
     val producerId = 24L
     val producerEpoch = 0.toShort
     val sequence = 0
-    
-    val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new 
File(_)))
-    val metadataCache = mock(classOf[MetadataCache])
+    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = new ReplicaManager(
-      metrics = metrics,
-      config = config,
-      time = time,
-      scheduler = new MockScheduler(time),
-      logManager = mockLogMgr,
-      quotaManagers = quotaManager,
-      metadataCache = metadataCache,
-      logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
-      alterPartitionManager = alterPartitionManager,
-      addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
-
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0, tp1), node)
     try {
-      val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp.topic), 
tp, Seq(0, 1), LeaderAndIsr(1,  List(0, 1)))
-      replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => 
())
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
-      // We must set up the metadata cache to handle the append and 
verification.
-      val metadataResponseTopic = Seq(new MetadataResponseTopic()
-        .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
-        .setPartitions(Seq(
-          new MetadataResponsePartition()
-            .setPartitionIndex(0)
-            .setLeaderId(0)).asJava))
-      val node = new Node(0, "host1", 0)
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp1.topic), tp1, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
-      when(metadataCache.contains(tp)).thenReturn(true)
-      
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), 
config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
-      when(metadataCache.getAliveBrokerNode(0, 
config.interBrokerListenerName)).thenReturn(Some(node))
-      when(metadataCache.getAliveBrokerNode(1, 
config.interBrokerListenerName)).thenReturn(None)
-      
-      // We will attempt to schedule to the request handler thread using a non 
request handler thread. Set this to avoid error.
-      KafkaRequestHandler.setBypassThreadCheck(true)
+      // If we supply no transactional ID and idempotent records, we do not 
verify.
+      val idempotentRecords = 
MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
+        new SimpleRecord("message".getBytes))
+      appendRecords(replicaManager, tp0, idempotentRecords)
+      verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), 
any[AddPartitionsToTxnManager.AppendCallback]())
+      assertEquals(null, getVerificationGuard(replicaManager, tp0, producerId))
+
+      // If we supply a transactional ID and some transactional and some 
idempotent records, we should only verify the topic partition with 
transactional records.
+      val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence + 1,
+        new SimpleRecord("message".getBytes))
+
+      val transactionToAdd = new AddPartitionsToTxnTransaction()
+        .setTransactionalId(transactionalId)
+        .setProducerId(producerId)
+        .setProducerEpoch(producerEpoch)
+        .setVerifyOnly(true)
+        .setTopics(new AddPartitionsToTxnTopicCollection(
+          Seq(new 
AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
+        ))
+
+      val idempotentRecords2 = 
MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
+        new SimpleRecord("message".getBytes))
+      appendRecordsToMultipleTopics(replicaManager, Map(tp0 -> 
transactionalRecords, tp1 -> idempotentRecords2), transactionalId, Some(0))
+      verify(addPartitionsToTxnManager, 
times(1)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), 
any[AddPartitionsToTxnManager.AppendCallback]())
+      assertNotEquals(null, getVerificationGuard(replicaManager, tp0, 
producerId))
+      assertEquals(null, getVerificationGuard(replicaManager, tp1, producerId))
+    } finally {
+      replicaManager.shutdown()
+    }
+
+    TestUtils.assertNoNonDaemonThreads(this.getClass.getName)
+  }
+
+  @Test
+  def testVerificationFlow(): Unit = {
+    val tp0 = new TopicPartition(topic, 0)
+    val producerId = 24L
+    val producerEpoch = 0.toShort
+    val sequence = 6
+    val node = new Node(0, "host1", 0)
+    val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
+
+    val replicaManager = 
setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager,
 List(tp0), node)
+    try {
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), 
LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
       // Append some transactional records.
       val transactionalRecords = 
MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence,
-        new SimpleRecord(s"message $sequence".getBytes))
-      val result = appendRecords(replicaManager, tp, transactionalRecords, 
transactionalId = transactionalId, transactionStatePartition = Some(0))
-      
+        new SimpleRecord("message".getBytes))
+
       val transactionToAdd = new AddPartitionsToTxnTransaction()
         .setTransactionalId(transactionalId)
         .setProducerId(producerId)
         .setProducerEpoch(producerEpoch)
         .setVerifyOnly(true)
         .setTopics(new AddPartitionsToTxnTopicCollection(
-          Seq(new 
AddPartitionsToTxnTopic().setName(tp.topic).setPartitions(Collections.singletonList(tp.partition))).iterator.asJava
+          Seq(new 
AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
         ))
-      
-      val appendCallback = 
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
+
       // We should add these partitions to the manager to verify.
+      val result = appendRecords(replicaManager, tp0, transactionalRecords, 
transactionalId = transactionalId, transactionStatePartition = Some(0))
+      val appendCallback = 
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
       verify(addPartitionsToTxnManager, 
times(1)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), appendCallback.capture())
+      val verificationGuard = getVerificationGuard(replicaManager, tp0, 
producerId)
+      assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
 
       // Confirm we did not write to the log and instead returned error.
       val callback: AddPartitionsToTxnManager.AppendCallback = 
appendCallback.getValue()
-      callback(Map(tp -> Errors.INVALID_RECORD).toMap)
+      callback(Map(tp0 -> Errors.INVALID_RECORD).toMap)
       assertEquals(Errors.INVALID_RECORD, result.assertFired.error)
+      assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
+
+      // This time verification is successful
+      appendRecords(replicaManager, tp0, transactionalRecords)
+      val appendCallback2 = 
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
+      verify(addPartitionsToTxnManager, 
times(1)).addTxnData(ArgumentMatchers.eq(node), 
ArgumentMatchers.eq(transactionToAdd), appendCallback2.capture())
+      assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
+
+      val callback2: AddPartitionsToTxnManager.AppendCallback = 
appendCallback.getValue()
+      callback2(Map.empty[TopicPartition, Errors].toMap)
+      assertEquals(verificationGuard, getVerificationGuard(replicaManager, 
tp0, producerId))
+      
assertTrue(replicaManager.localLog(tp0).get.hasOngoingTransaction(producerId))
+    } finally {
+      replicaManager.shutdown()
+    }
 
-      // If we supply no transactional ID and idempotent records, we do not 
verify, so counter stays the same.
-      val idempotentRecords2 = 
MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, 
producerEpoch, sequence + 1,
+    TestUtils.assertNoNonDaemonThreads(this.getClass.getName)
+  }
+
+  @Test
+  def testVerificationGuardOnMultiplePartitions(): Unit = {

Review Comment:
   nit: `TransactionVerification` as well?



##########
core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala:
##########
@@ -3667,6 +3667,118 @@ class UnifiedLogTest {
     listener.verify(expectedHighWatermark = 4)
   }
 
+  @Test
+  def testTransactionIsOngoingAndVerificationGuard(): Unit = {
+    val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, 
true)
+
+    val producerId = 23L
+    val producerEpoch = 1.toShort
+    val sequence = 3
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, producerStateManagerConfig = 
producerStateManagerConfig)
+    assertFalse(log.hasOngoingTransaction(producerId))
+    assertEquals(null, log.verificationGuard(producerId))
+
+    val idempotentRecords = MemoryRecords.withIdempotentRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    val verificationGuard = log.maybeStartTransactionVerification(producerId)
+    assertTrue(verificationGuard != null)

Review Comment:
   nit: assertNotNull



##########
core/src/test/scala/unit/kafka/cluster/PartitionTest.scala:
##########
@@ -3273,17 +3272,35 @@ class PartitionTest extends AbstractPartitionTest {
       baseOffset = 0L,
       producerId = producerId)
     partition.appendRecordsToLeader(idempotentRecords, origin = 
AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching)
-    assertFalse(partition.hasOngoingTransaction(producerId))
 
-    val transactionRecords = createTransactionalRecords(List(
+    def transactionRecords() = createTransactionalRecords(List(
       new SimpleRecord("k1".getBytes, "v1".getBytes),
       new SimpleRecord("k2".getBytes, "v2".getBytes),
       new SimpleRecord("k3".getBytes, "v3".getBytes)),
       baseOffset = 0L,
       baseSequence = 3,
       producerId = producerId)
-    partition.appendRecordsToLeader(transactionRecords, origin = 
AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching)
-    assertTrue(partition.hasOngoingTransaction(producerId))
+
+    // When verification guard is not there, we should not be able to append.
+    assertThrows(classOf[InvalidRecordException], () => 
partition.appendRecordsToLeader(transactionRecords(), origin = 
AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching))
+
+    // Before appendRecordsToLeader is called, ReplicaManager will call 
maybeStartTransactionVerification. We should get a non-null verification object.
+    val verificationGuard = 
partition.maybeStartTransactionVerification(producerId)
+    assertTrue(verificationGuard != null)
+
+    // With the wrong verification guard, append should fail.
+    assertThrows(classOf[InvalidRecordException], () => 
partition.appendRecordsToLeader(transactionRecords(),
+      origin = AppendOrigin.CLIENT, requiredAcks = 1, 
RequestLocal.withThreadConfinedCaching, Optional.of(new Object)))
+
+    // We should return the same verification object when we still need to 
verify. Append should proceed.
+    val verificationGuard2 = 
partition.maybeStartTransactionVerification(producerId)
+    assertEquals(verificationGuard, verificationGuard2)
+    partition.appendRecordsToLeader(transactionRecords(), origin = 
AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching, 
verificationGuard)
+
+    // We should no longer need a verification object. Future appends without 
verification guard will also succeed.
+    val verificationGuard3 = 
partition.maybeStartTransactionVerification(producerId)
+    assertEquals(null, verificationGuard3)

Review Comment:
   nit: You could use `assertNull`. There are a few other cases.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to