This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch branch-1.7
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/branch-1.7 by this push:
     new 9439534ab [KYUUBI #4710] [ARROW] LocalTableScanExec should not trigger 
job
9439534ab is described below

commit 9439534abf02f5619abb39970f93236029cad209
Author: Fu Chen <[email protected]>
AuthorDate: Thu Apr 20 17:58:27 2023 +0800

    [KYUUBI #4710] [ARROW] LocalTableScanExec should not trigger job
    
    ### _Why are the changes needed?_
    
    Before this PR:
    
    ![截屏2023-04-14 下午5 19 
52](https://user-images.githubusercontent.com/8537877/232003579-95c56f56-1fd7-4c8a-a13f-58d4bc16fef1.png)
    
    After this PR:
    
    ![截屏2023-04-14 下午5 18 
16](https://user-images.githubusercontent.com/8537877/232003652-77b38d08-c741-4977-bf69-6eb70f6d991a.png)
    
    ### _How was this patch tested?_
    - [ ] Add some test cases that check the changes thoroughly including 
negative and positive cases if possible
    
    - [ ] Add screenshots for manual tests if appropriate
    
    - [x] [Run 
test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests)
 locally before make a pull request
    
    Closes #4710 from cfmcgrady/arrow-local-table-scan-exec.
    
    Closes #4710
    
    e4c2891d1 [Fu Chen] fix ci
    1049200ea [Fu Chen] fix style
    4d45fe8b7 [Fu Chen] add assert
    b8bd5b5a7 [Fu Chen] LocalTableScanExec should not trigger job
    
    Authored-by: Fu Chen <[email protected]>
    Signed-off-by: Cheng Pan <[email protected]>
    (cherry picked from commit 9086b28bc392a928e7cd68aa03b69f3e60d73643)
    Signed-off-by: Cheng Pan <[email protected]>
---
 .../execution/arrow/KyuubiArrowConverters.scala    |   8 +-
 .../spark/sql/kyuubi/SparkDatasetHelper.scala      |  15 ++-
 .../operation/SparkArrowbasedOperationSuite.scala  | 139 +++++++++++++--------
 3 files changed, 103 insertions(+), 59 deletions(-)

diff --git 
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala
 
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala
index 2feadbced..8a34943cc 100644
--- 
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala
+++ 
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala
@@ -203,7 +203,7 @@ object KyuubiArrowConverters extends SQLConfHelper with 
Logging {
    * Different from 
[[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]],
    * each output arrow batch contains this batch row count.
    */
-  private def toBatchIterator(
+  def toBatchIterator(
       rowIter: Iterator[InternalRow],
       schema: StructType,
       maxRecordsPerBatch: Long,
@@ -226,6 +226,7 @@ object KyuubiArrowConverters extends SQLConfHelper with 
Logging {
    * with two key differences:
    *   1. there is no requirement to write the schema at the batch header
    *   2. iteration halts when `rowCount` equals `limit`
+   * Note that `limit < 0` means no limit, and return all rows the in the 
iterator.
    */
   private[sql] class ArrowBatchIterator(
       rowIter: Iterator[InternalRow],
@@ -255,7 +256,7 @@ object KyuubiArrowConverters extends SQLConfHelper with 
Logging {
       }
     }
 
-    override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || {
+    override def hasNext: Boolean = (rowIter.hasNext && (rowCount < limit || 
limit < 0)) || {
       root.close()
       allocator.close()
       false
@@ -283,7 +284,8 @@ object KyuubiArrowConverters extends SQLConfHelper with 
Logging {
               // If the size of rows are 0 or negative, unlimit it.
               maxRecordsPerBatch <= 0 ||
               rowCountInLastBatch < maxRecordsPerBatch ||
-              rowCount < limit)) {
+              rowCount < limit ||
+              limit < 0)) {
           val row = rowIter.next()
           arrowWriter.write(row)
           estimatedBatchSize += (row match {
diff --git 
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala
 
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala
index 1c8d32c48..10b178324 100644
--- 
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala
+++ 
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala
@@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.network.util.{ByteUnit, JavaUtils}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
-import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlan, 
SQLExecution}
+import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, 
SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
 import org.apache.spark.sql.execution.arrow.{ArrowConverters, 
KyuubiArrowConverters}
 import org.apache.spark.sql.functions._
@@ -51,6 +51,8 @@ object SparkDatasetHelper extends Logging {
       doCollectLimit(collectLimit)
     case collectLimit: CollectLimitExec if collectLimit.limit < 0 =>
       executeArrowBatchCollect(collectLimit.child)
+    case localTableScan: LocalTableScanExec =>
+      doLocalTableScan(localTableScan)
     case plan: SparkPlan =>
       toArrowBatchRdd(plan).collect()
   }
@@ -175,6 +177,17 @@ object SparkDatasetHelper extends Logging {
     result.toArray
   }
 
+  def doLocalTableScan(localTableScan: LocalTableScanExec): Array[Array[Byte]] 
= {
+    localTableScan.longMetric("numOutputRows").add(localTableScan.rows.size)
+    KyuubiArrowConverters.toBatchIterator(
+      localTableScan.rows.iterator,
+      localTableScan.schema,
+      SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch,
+      maxBatchSize,
+      -1,
+      SparkSession.active.sessionState.conf.sessionLocalTimeZone).toArray
+  }
+
   /**
    * This method provides a reflection-based implementation of
    * [[AdaptiveSparkPlanExec.finalPhysicalPlan]] that enables us to adapt to 
the Spark runtime
diff --git 
a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
 
b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
index 2ef29b398..27310992f 100644
--- 
a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
+++ 
b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
@@ -23,8 +23,8 @@ import java.util.{Set => JSet}
 import org.apache.spark.KyuubiSparkContextHelper
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
 import org.apache.spark.sql.{QueryTest, Row, SparkSession}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
-import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution, 
SparkPlan}
+import org.apache.spark.sql.catalyst.plans.logical.Project
+import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, 
QueryExecution, SparkPlan}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
 import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
 import org.apache.spark.sql.execution.exchange.Exchange
@@ -104,48 +104,29 @@ class SparkArrowbasedOperationSuite extends 
WithSparkSQLEngine with SparkDataTyp
   }
 
   test("assign a new execution id for arrow-based result") {
-    var plan: LogicalPlan = null
-
-    val listener = new QueryExecutionListener {
-      override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit = {
-        plan = qe.analyzed
+    val listener = new SQLMetricsListener
+    withJdbcStatement() { statement =>
+      withSparkListener(listener) {
+        val result = statement.executeQuery("select 1 as c1")
+        assert(result.next())
+        assert(result.getInt("c1") == 1)
       }
-      override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {}
     }
-    withJdbcStatement() { statement =>
-      // since all the new sessions have their owner listener bus, we should 
register the listener
-      // in the current session.
-      registerListener(listener)
 
-      val result = statement.executeQuery("select 1 as c1")
-      assert(result.next())
-      assert(result.getInt("c1") == 1)
-    }
-    KyuubiSparkContextHelper.waitListenerBus(spark)
-    unregisterListener(listener)
-    assert(plan.isInstanceOf[Project])
+    assert(listener.queryExecution.analyzed.isInstanceOf[Project])
   }
 
   test("arrow-based query metrics") {
-    var queryExecution: QueryExecution = null
-
-    val listener = new QueryExecutionListener {
-      override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit = {
-        queryExecution = qe
-      }
-      override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {}
-    }
+    val listener = new SQLMetricsListener
     withJdbcStatement() { statement =>
-      registerListener(listener)
-      val result = statement.executeQuery("select 1 as c1")
-      assert(result.next())
-      assert(result.getInt("c1") == 1)
+      withSparkListener(listener) {
+        val result = statement.executeQuery("select 1 as c1")
+        assert(result.next())
+        assert(result.getInt("c1") == 1)
+      }
     }
 
-    KyuubiSparkContextHelper.waitListenerBus(spark)
-    unregisterListener(listener)
-
-    val metrics = queryExecution.executedPlan.collectLeaves().head.metrics
+    val metrics = 
listener.queryExecution.executedPlan.collectLeaves().head.metrics
     assert(metrics.contains("numOutputRows"))
     assert(metrics("numOutputRows").value === 1)
   }
@@ -273,7 +254,6 @@ class SparkArrowbasedOperationSuite extends 
WithSparkSQLEngine with SparkDataTyp
         withPartitionedTable("t_3") {
           statement.executeQuery("select * from t_3 limit 10 offset 10")
         }
-        KyuubiSparkContextHelper.waitListenerBus(spark)
       }
     }
     // the extra shuffle be introduced if the `offset` > 0
@@ -292,13 +272,49 @@ class SparkArrowbasedOperationSuite extends 
WithSparkSQLEngine with SparkDataTyp
         withPartitionedTable("t_3") {
           statement.executeQuery("select * from t_3 limit 1000")
         }
-        KyuubiSparkContextHelper.waitListenerBus(spark)
       }
     }
     // Should be only one stage since there is no shuffle.
     assert(numStages == 1)
   }
 
+  test("LocalTableScanExec should not trigger job") {
+    val listener = new JobCountListener
+    withJdbcStatement("view_1") { statement =>
+      withSparkListener(listener) {
+        withAllSessions { s =>
+          import s.implicits._
+          Seq((1, "a")).toDF("c1", "c2").createOrReplaceTempView("view_1")
+          val plan = s.sql("select * from view_1").queryExecution.executedPlan
+          assert(plan.isInstanceOf[LocalTableScanExec])
+        }
+        val resultSet = statement.executeQuery("select * from view_1")
+        assert(resultSet.next())
+        assert(!resultSet.next())
+      }
+    }
+    assert(listener.numJobs == 0)
+  }
+
+  test("LocalTableScanExec metrics") {
+    val listener = new SQLMetricsListener
+    withJdbcStatement("view_1") { statement =>
+      withSparkListener(listener) {
+        withAllSessions { s =>
+          import s.implicits._
+          Seq((1, "a")).toDF("c1", "c2").createOrReplaceTempView("view_1")
+        }
+        val result = statement.executeQuery("select * from view_1")
+        assert(result.next())
+        assert(!result.next())
+      }
+    }
+
+    val metrics = 
listener.queryExecution.executedPlan.collectLeaves().head.metrics
+    assert(metrics.contains("numOutputRows"))
+    assert(metrics("numOutputRows").value === 1)
+  }
+
   private def checkResultSetFormat(statement: Statement, expectFormat: 
String): Unit = {
     val query =
       s"""
@@ -321,32 +337,30 @@ class SparkArrowbasedOperationSuite extends 
WithSparkSQLEngine with SparkDataTyp
     assert(resultSet.getString("col") === expect)
   }
 
-  private def registerListener(listener: QueryExecutionListener): Unit = {
-    // since all the new sessions have their owner listener bus, we should 
register the listener
-    // in the current session.
-    SparkSQLEngine.currentEngine.get
-      .backendService
-      .sessionManager
-      .allSessions()
-      
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))
-  }
-
-  private def unregisterListener(listener: QueryExecutionListener): Unit = {
-    SparkSQLEngine.currentEngine.get
-      .backendService
-      .sessionManager
-      .allSessions()
-      
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener))
+  // since all the new sessions have their owner listener bus, we should 
register the listener
+  // in the current session.
+  private def withSparkListener[T](listener: QueryExecutionListener)(body: => 
T): T = {
+    withAllSessions(s => s.listenerManager.register(listener))
+    try {
+      val result = body
+      KyuubiSparkContextHelper.waitListenerBus(spark)
+      result
+    } finally {
+      withAllSessions(s => s.listenerManager.unregister(listener))
+    }
   }
 
+  // since all the new sessions have their owner listener bus, we should 
register the listener
+  // in the current session.
   private def withSparkListener[T](listener: SparkListener)(body: => T): T = {
     withAllSessions(s => s.sparkContext.addSparkListener(listener))
     try {
-      body
+      val result = body
+      KyuubiSparkContextHelper.waitListenerBus(spark)
+      result
     } finally {
       withAllSessions(s => s.sparkContext.removeSparkListener(listener))
     }
-
   }
 
   private def withPartitionedTable[T](viewName: String)(body: => T): T = {
@@ -432,6 +446,21 @@ class SparkArrowbasedOperationSuite extends 
WithSparkSQLEngine with SparkDataTyp
       .get()
     staticConfKeys.contains(key)
   }
+
+  class JobCountListener extends SparkListener {
+    var numJobs = 0
+    override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+      numJobs += 1
+    }
+  }
+
+  class SQLMetricsListener extends QueryExecutionListener {
+    var queryExecution: QueryExecution = null
+    override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit = {
+      queryExecution = qe
+    }
+    override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {}
+  }
 }
 
 case class TestData(key: Int, value: String)

Reply via email to