This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f9581781ff6 [SPARK-54200][SS] Call close() against underlying 
InputPartition when LowLatencyReaderWrap.close() is called
2f9581781ff6 is described below

commit 2f9581781ff67be2a6c10be67e4f3936192c916d
Author: Jungtaek Lim <[email protected]>
AuthorDate: Thu Nov 6 17:54:05 2025 +0900

    [SPARK-54200][SS] Call close() against underlying InputPartition when 
LowLatencyReaderWrap.close() is called
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to fix the bug of missing close() on underlying 
InputPartition when LowLatencyReaderWrap.close() is called.
    
    ### Why are the changes needed?
    
    Not closing the underlying InputPartition could leak resource; e.g. Kafka 
consumer is not returned to the pool, which ends up with destroying the purpose 
of connection pool and creating Kafka consumer instances every batch.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    A new UT for Kafka rather than general one, since Kafka data source has an 
internal metric to provide the necessary information for validation.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #52903 from HeartSaVioR/SPARK-54200.
    
    Authored-by: Jungtaek Lim <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../consumer/InternalKafkaConsumerPool.scala       |  8 ++
 .../sql/kafka010/consumer/KafkaDataConsumer.scala  |  4 +
 .../sql/kafka010/KafkaRealTimeModeSuite.scala      | 99 +++++++++++++++++++++-
 .../datasources/v2/RealTimeStreamScanExec.scala    |  4 +-
 4 files changed, 112 insertions(+), 3 deletions(-)

diff --git 
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/InternalKafkaConsumerPool.scala
 
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/InternalKafkaConsumerPool.scala
index edd5121cfbee..06ccd7548a04 100644
--- 
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/InternalKafkaConsumerPool.scala
+++ 
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/InternalKafkaConsumerPool.scala
@@ -129,6 +129,14 @@ private[consumer] class InternalKafkaConsumerPool(
 
   def size(key: CacheKey): Int = numIdle(key) + numActive(key)
 
+  private[kafka010] def numActiveInGroupIdPrefix(groupIdPrefix: String): Int = 
{
+    import scala.jdk.CollectionConverters._
+
+    pool.getNumActivePerKey().asScala.filter { case (key, _) =>
+      key.startsWith(groupIdPrefix + "-")
+    }.values.map(_.toInt).sum
+  }
+
   // TODO: revisit the relation between CacheKey and kafkaParams - for now it 
looks a bit weird
   //   as we force all consumers having same (groupId, topicPartition) to have 
same kafkaParams
   //   which might be viable in performance perspective (kafkaParams might be 
too huge to use
diff --git 
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
 
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
index af4e5bab2947..126434625a8d 100644
--- 
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
+++ 
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
@@ -848,4 +848,8 @@ private[kafka010] object KafkaDataConsumer extends Logging {
 
     new KafkaDataConsumer(topicPartition, kafkaParams, consumerPool, 
fetchedDataPool)
   }
+
+  private[kafka010] def getActiveSizeInConsumerPool(groupIdPrefix: String): 
Int = {
+    consumerPool.numActiveInGroupIdPrefix(groupIdPrefix)
+  }
 }
diff --git 
a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
 
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
index 83aae64d84f7..468d1da7f467 100644
--- 
a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
+++ 
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.kafka010
 
+import java.util.UUID
+
 import org.scalatest.matchers.should.Matchers
 import org.scalatest.time.SpanSugar._
 
@@ -26,6 +28,7 @@ import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.sources.ContinuousMemorySink
 import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer
 import org.apache.spark.sql.streaming.{StreamingQuery, Trigger}
 import org.apache.spark.sql.streaming.OutputMode.Update
 import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock
@@ -39,9 +42,7 @@ class KafkaRealTimeModeSuite
   override protected val defaultTrigger = RealTimeTrigger.apply("3 seconds")
 
   override protected def sparkConf: SparkConf = {
-    // Should turn to use StreamingShuffleManager when it is ready.
     super.sparkConf
-      .set("spark.databricks.streaming.realTimeMode.enabled", "true")
       .set(
         SQLConf.STATE_STORE_PROVIDER_CLASS,
         classOf[RocksDBStateStoreProvider].getName)
@@ -679,3 +680,97 @@ class KafkaRealTimeModeSuite
       )
   }
 }
+
+class KafkaConsumerPoolRealTimeModeSuite
+  extends KafkaSourceTest
+  with Matchers {
+  override protected val defaultTrigger = RealTimeTrigger.apply("3 seconds")
+
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf
+      .set(
+        SQLConf.STATE_STORE_PROVIDER_CLASS,
+        classOf[RocksDBStateStoreProvider].getName)
+  }
+
+  import testImplicits._
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(
+      SQLConf.STREAMING_REAL_TIME_MODE_MIN_BATCH_DURATION,
+      defaultTrigger.batchDurationMs
+    )
+  }
+
+  test("SPARK-54200: Kafka consumers in consumer pool should be properly 
reused") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 2)
+
+    testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+    testUtils.sendMessages(topic, Array("3"), Some(1))
+
+    val groupIdPrefix = UUID.randomUUID().toString
+
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+      .option("groupIdPrefix", groupIdPrefix)
+      .load()
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+      .map(_.toInt)
+      .map(_ + 1)
+
+    // At any point of time, Kafka consumer pool should only contain at most 2 
active instances.
+    testStream(reader, Update, sink = new ContinuousMemorySink())(
+      StartStream(),
+      CheckAnswerWithTimeout(60000, 2, 3, 4),
+      WaitUntilCurrentBatchProcessed,
+      // After completion of batch 0
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          assertActiveSizeOnConsumerPool(groupIdPrefix, 2)
+
+          testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+          testUtils.sendMessages(topic, Array("6"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7),
+      WaitUntilCurrentBatchProcessed,
+      // After completion of batch 1
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          assertActiveSizeOnConsumerPool(groupIdPrefix, 2)
+
+          testUtils.sendMessages(topic, Array("7"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7, 8),
+      WaitUntilCurrentBatchProcessed,
+      // After completion of batch 2
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          assertActiveSizeOnConsumerPool(groupIdPrefix, 2)
+        }
+      },
+      StopStream
+    )
+  }
+
+  /**
+   * NOTE: This method leverages that we run test code, driver and executor in 
a same process in
+   * a normal unit test setup (say, local[<number, or *>] in spark master). 
With that setup, we
+   * can access singleton object directly.
+   */
+  private def assertActiveSizeOnConsumerPool(
+      groupIdPrefix: String,
+      maxAllowedActiveSize: Int): Unit = {
+    val activeSize = 
KafkaDataConsumer.getActiveSizeInConsumerPool(groupIdPrefix)
+    assert(activeSize <= maxAllowedActiveSize, s"Consumer pool size is 
expected to be less " +
+      s"than $maxAllowedActiveSize, but $activeSize.")
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala
index c4e072f184e6..3432f28e12cc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala
@@ -83,7 +83,9 @@ case class LowLatencyReaderWrap(
     reader.get()
   }
 
-  override def close(): Unit = {}
+  override def close(): Unit = {
+    reader.close()
+  }
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to