This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 8ad4dddb17b [SPARK-44776][CONNECT] Add ProducedRowCount to 
SparkListenerConnectOperationFinished
8ad4dddb17b is described below

commit 8ad4dddb17b76c6c53a65730229cdb69cd1c1889
Author: Lingkai Kong <lingkai.k...@databricks.com>
AuthorDate: Tue Aug 22 10:06:50 2023 +0900

    [SPARK-44776][CONNECT] Add ProducedRowCount to 
SparkListenerConnectOperationFinished
    
    ### What changes were proposed in this pull request?
    Add ProducedRowCount field to SparkListenerConnectOperationFinished
    
    ### Why are the changes needed?
    Needed for showing number of rows getting produced
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added Unit test
    
    Closes #42454 from gjxdxh/SPARK-44776.
    
    Authored-by: Lingkai Kong <lingkai.k...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 4646991abd7f4a47a1b8712e2017a2fae98f7c5a)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../execution/SparkConnectPlanExecution.scala      |  10 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |   2 +-
 .../sql/connect/service/ExecuteEventsManager.scala |  27 +-
 .../connect/planner/SparkConnectServiceSuite.scala | 288 ++++++++++++++-------
 .../service/ExecuteEventsManagerSuite.scala        |  12 +
 5 files changed, 238 insertions(+), 101 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 131ddf76fa4..00fec4378c5 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -110,6 +110,7 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
       errorOnDuplicatedFieldNames = false)
 
     var numSent = 0
+    var totalNumRows: Long = 0
     def sendBatch(bytes: Array[Byte], count: Long): Unit = {
       val response = 
proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId)
       val batch = proto.ExecutePlanResponse.ArrowBatch
@@ -120,14 +121,15 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
       response.setArrowBatch(batch)
       responseObserver.onNext(response.build())
       numSent += 1
+      totalNumRows += count
     }
 
     dataframe.queryExecution.executedPlan match {
       case LocalTableScanExec(_, rows) =>
-        executePlan.eventsManager.postFinished()
         converter(rows.iterator).foreach { case (bytes, count) =>
           sendBatch(bytes, count)
         }
+        executePlan.eventsManager.postFinished(Some(totalNumRows))
       case _ =>
         SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
           val rows = dataframe.queryExecution.executedPlan.execute()
@@ -162,8 +164,7 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
                 resultFunc = () => ())
               // Collect errors and propagate them to the main thread.
               .andThen {
-                case Success(_) =>
-                  executePlan.eventsManager.postFinished()
+                case Success(_) => // do nothing
                 case Failure(throwable) =>
                   signal.synchronized {
                     error = Some(throwable)
@@ -200,8 +201,9 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
               currentPartitionId += 1
             }
             ThreadUtils.awaitReady(future, Duration.Inf)
+            executePlan.eventsManager.postFinished(Some(totalNumRows))
           } else {
-            executePlan.eventsManager.postFinished()
+            executePlan.eventsManager.postFinished(Some(totalNumRows))
           }
         }
     }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 5120073e2f0..e81e9bb51cb 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2513,7 +2513,7 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
               .putAllArgs(getSqlCommand.getArgsMap)
               .addAllPosArgs(getSqlCommand.getPosArgsList)))
     }
-    executeHolder.eventsManager.postFinished()
+    executeHolder.eventsManager.postFinished(Some(rows.size))
     // Exactly one SQL Command Result Batch
     responseObserver.onNext(
       ExecutePlanResponse
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
index 5e831aaa98f..5b9267a9679 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
@@ -75,6 +75,8 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, 
clock: Clock) {
 
   private var canceled = Option.empty[Boolean]
 
+  private var producedRowCount = Option.empty[Long]
+
   /**
    * @return
    *   Last event posted by the Connect request
@@ -95,6 +97,13 @@ case class ExecuteEventsManager(executeHolder: 
ExecuteHolder, clock: Clock) {
    */
   private[connect] def hasError: Option[Boolean] = error
 
+  /**
+   * @return
+   *   How many rows the Connect request has produced @link
+   *   
org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished
+   */
+  private[connect] def getProducedRowCount: Option[Long] = producedRowCount
+
   /**
    * Post @link 
org.apache.spark.sql.connect.service.SparkListenerConnectOperationStarted.
    */
@@ -192,13 +201,23 @@ case class ExecuteEventsManager(executeHolder: 
ExecuteHolder, clock: Clock) {
 
   /**
    * Post @link 
org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished.
+   * @param producedRowsCountOpt
+   *   Number of rows that are returned to the user. None is expected when the 
operation does not
+   *   return any rows.
    */
-  def postFinished(): Unit = {
+  def postFinished(producedRowsCountOpt: Option[Long] = None): Unit = {
     assertStatus(
       List(ExecuteStatus.Started, ExecuteStatus.ReadyForExecution),
       ExecuteStatus.Finished)
+    producedRowCount = producedRowsCountOpt
+
     listenerBus
-      .post(SparkListenerConnectOperationFinished(jobTag, operationId, 
clock.getTimeMillis()))
+      .post(
+        SparkListenerConnectOperationFinished(
+          jobTag,
+          operationId,
+          clock.getTimeMillis(),
+          producedRowCount))
   }
 
   /**
@@ -395,6 +414,9 @@ case class SparkListenerConnectOperationFailed(
  *   36 characters UUID assigned by Connect during a request.
  * @param eventTime:
  *   The time in ms when the event was generated.
+ * @param producedRowCount:
+ *   Number of rows that are returned to the user. None is expected when the 
operation does not
+ *   return any rows.
  * @param extraTags:
  *   Additional metadata during the request.
  */
@@ -402,6 +424,7 @@ case class SparkListenerConnectOperationFinished(
     jobTag: String,
     operationId: String,
     eventTime: Long,
+    producedRowCount: Option[Long] = None,
     extraTags: Map[String, String] = Map.empty)
     extends SparkListenerEvent
 
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 285f3103b19..74649e15e9e 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -39,6 +39,7 @@ import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.CreateDataFrameViewCommand
 import org.apache.spark.internal.Logging
 import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection}
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
 import org.apache.spark.sql.connect.config.Connect
 import org.apache.spark.sql.connect.dsl.MockRemoteSession
@@ -49,13 +50,18 @@ import org.apache.spark.sql.connect.service.{ExecuteHolder, 
ExecuteStatus, Sessi
 import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog
 import org.apache.spark.sql.streaming.StreamingQuery
 import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
 /**
  * Testing Connect Service implementation.
  */
-class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar 
with Logging {
+class SparkConnectServiceSuite
+    extends SharedSparkSession
+    with MockitoSugar
+    with Logging
+    with SparkConnectPlanTest {
 
   private def sparkSessionHolder = SessionHolder.forTesting(spark)
   private def DEFAULT_UUID = 
UUID.fromString("89ea6117-1f45-4c03-ae27-f47c6aded093")
@@ -190,7 +196,7 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
             done = true
           }
         })
-      verifyEvents.onCompleted()
+      verifyEvents.onCompleted(Some(100))
       // The current implementation is expected to be blocking. This is here 
to make sure it is.
       assert(done)
 
@@ -238,6 +244,77 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
     }
   }
 
+  test("SPARK-44776: LocalTableScanExec") {
+    withEvents { verifyEvents =>
+      // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
+      assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
+      val instance = new SparkConnectService(false)
+      val connect = new MockRemoteSession()
+      val context = proto.UserContext
+        .newBuilder()
+        .setUserId("c1")
+        .build()
+
+      val rows = (0L to 5L).map { i =>
+        new GenericInternalRow(Array(i, UTF8String.fromString("" + (i - 1 + 
'a').toChar)))
+      }
+
+      val schema = StructType(Array(StructField("id", LongType), 
StructField("data", StringType)))
+      val inputRows = rows.map { row =>
+        val proj = UnsafeProjection.create(schema)
+        proj(row).copy()
+      }
+
+      val localRelation = createLocalRelationProto(schema, inputRows)
+      val plan = proto.Plan
+        .newBuilder()
+        .setRoot(localRelation)
+        .build()
+
+      val request = proto.ExecutePlanRequest
+        .newBuilder()
+        .setPlan(plan)
+        .setUserContext(context)
+        .setSessionId(UUID.randomUUID.toString())
+        .build()
+
+      // Execute plan.
+      @volatile var done = false
+      val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
+      instance.executePlan(
+        request,
+        new StreamObserver[proto.ExecutePlanResponse] {
+          override def onNext(v: proto.ExecutePlanResponse): Unit = {
+            responses += v
+            verifyEvents.onNext(v)
+          }
+
+          override def onError(throwable: Throwable): Unit = {
+            verifyEvents.onError(throwable)
+            throw throwable
+          }
+
+          override def onCompleted(): Unit = {
+            done = true
+          }
+        })
+      verifyEvents.onCompleted(Some(6))
+      // The current implementation is expected to be blocking. This is here 
to make sure it is.
+      assert(done)
+
+      // 1 Partitions + Metrics
+      assert(responses.size == 3)
+
+      // Make sure the first response is schema only
+      val head = responses.head
+      assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
+
+      // Make sure the last response is metrics only
+      val last = responses.last
+      assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
+    }
+  }
+
   test("SPARK-44657: Arrow batches respect max batch size limit") {
     // Set 10 KiB as the batch size limit
     val batchSize = 10 * 1024
@@ -301,101 +378,123 @@ class SparkConnectServiceSuite extends 
SharedSparkSession with MockitoSugar with
 
   gridTest("SPARK-43923: commands send events")(
     Seq(
-      proto.Command
-        .newBuilder()
-        .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 
1").build()),
-      proto.Command
-        .newBuilder()
-        .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show 
tables").build()),
-      proto.Command
-        .newBuilder()
-        .setWriteOperation(
-          proto.WriteOperation
-            .newBuilder()
-            .setInput(
-              
proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))
-            .setPath(Utils.createTempDir().getAbsolutePath)
-            .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)),
-      proto.Command
-        .newBuilder()
-        .setWriteOperationV2(
-          proto.WriteOperationV2
-            .newBuilder()
-            .setInput(proto.Relation.newBuilder.setRange(
-              proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L)))
-            .setTableName("testcat.testtable")
-            .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)),
-      proto.Command
-        .newBuilder()
-        .setCreateDataframeView(
-          CreateDataFrameViewCommand
-            .newBuilder()
-            .setName("testview")
-            .setInput(
-              
proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 
1")))),
-      proto.Command
-        .newBuilder()
-        .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()),
-      proto.Command
-        .newBuilder()
-        .setExtension(
-          protobuf.Any.pack(
-            proto.ExamplePluginCommand
+      (
+        proto.Command
+          .newBuilder()
+          .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 
1").build()),
+        Some(0L)),
+      (
+        proto.Command
+          .newBuilder()
+          .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show 
databases").build()),
+        Some(1L)),
+      (
+        proto.Command
+          .newBuilder()
+          .setWriteOperation(
+            proto.WriteOperation
               .newBuilder()
-              .setCustomField("SPARK-43923")
-              .build())),
-      proto.Command
-        .newBuilder()
-        .setWriteStreamOperationStart(
-          proto.WriteStreamOperationStart
-            .newBuilder()
-            .setInput(
-              proto.Relation
+              .setInput(
+                
proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))
+              .setPath(Utils.createTempDir().getAbsolutePath)
+              .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)),
+        None),
+      (
+        proto.Command
+          .newBuilder()
+          .setWriteOperationV2(
+            proto.WriteOperationV2
+              .newBuilder()
+              .setInput(proto.Relation.newBuilder.setRange(
+                proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L)))
+              .setTableName("testcat.testtable")
+              .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)),
+        None),
+      (
+        proto.Command
+          .newBuilder()
+          .setCreateDataframeView(
+            CreateDataFrameViewCommand
+              .newBuilder()
+              .setName("testview")
+              .setInput(
+                
proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 
1")))),
+        None),
+      (
+        proto.Command
+          .newBuilder()
+          .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()),
+        None),
+      (
+        proto.Command
+          .newBuilder()
+          .setExtension(
+            protobuf.Any.pack(
+              proto.ExamplePluginCommand
                 .newBuilder()
-                .setRead(proto.Read
+                .setCustomField("SPARK-43923")
+                .build())),
+        None),
+      (
+        proto.Command
+          .newBuilder()
+          .setWriteStreamOperationStart(
+            proto.WriteStreamOperationStart
+              .newBuilder()
+              .setInput(
+                proto.Relation
                   .newBuilder()
-                  .setIsStreaming(true)
-                  
.setDataSource(proto.Read.DataSource.newBuilder().setFormat("rate").build())
+                  .setRead(proto.Read
+                    .newBuilder()
+                    .setIsStreaming(true)
+                    
.setDataSource(proto.Read.DataSource.newBuilder().setFormat("rate").build())
+                    .build())
                   .build())
-                .build())
-            .setOutputMode("Append")
-            .setAvailableNow(true)
-            .setQueryName("test")
-            .setFormat("memory")
-            .putOptions("checkpointLocation", 
Utils.createTempDir().getAbsolutePath)
-            .setPath("test-path")
-            .build()),
-      proto.Command
-        .newBuilder()
-        .setStreamingQueryCommand(
-          proto.StreamingQueryCommand
-            .newBuilder()
-            .setQueryId(
-              proto.StreamingQueryInstanceId
-                .newBuilder()
-                .setId(DEFAULT_UUID.toString)
-                .setRunId(DEFAULT_UUID.toString)
-                .build())
-            .setStop(true)),
-      proto.Command
-        .newBuilder()
-        .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand
+              .setOutputMode("Append")
+              .setAvailableNow(true)
+              .setQueryName("test")
+              .setFormat("memory")
+              .putOptions("checkpointLocation", 
Utils.createTempDir().getAbsolutePath)
+              .setPath("test-path")
+              .build()),
+        None),
+      (
+        proto.Command
           .newBuilder()
-          .setListListeners(true)),
-      proto.Command
-        .newBuilder()
-        .setRegisterFunction(
-          proto.CommonInlineUserDefinedFunction
+          .setStreamingQueryCommand(
+            proto.StreamingQueryCommand
+              .newBuilder()
+              .setQueryId(
+                proto.StreamingQueryInstanceId
+                  .newBuilder()
+                  .setId(DEFAULT_UUID.toString)
+                  .setRunId(DEFAULT_UUID.toString)
+                  .build())
+              .setStop(true)),
+        None),
+      (
+        proto.Command
+          .newBuilder()
+          .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand
             .newBuilder()
-            .setFunctionName("function")
-            .setPythonUdf(
-              proto.PythonUDF
-                .newBuilder()
-                .setEvalType(100)
-                
.setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType))
-                .setCommand(ByteString.copyFrom("command".getBytes()))
-                .setPythonVer("3.10")
-                .build())))) { command =>
+            .setListListeners(true)),
+        None),
+      (
+        proto.Command
+          .newBuilder()
+          .setRegisterFunction(
+            proto.CommonInlineUserDefinedFunction
+              .newBuilder()
+              .setFunctionName("function")
+              .setPythonUdf(
+                proto.PythonUDF
+                  .newBuilder()
+                  .setEvalType(100)
+                  
.setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType))
+                  .setCommand(ByteString.copyFrom("command".getBytes()))
+                  .setPythonVer("3.10")
+                  .build())),
+        None))) { case (command, producedNumRows) =>
     val sessionId = UUID.randomUUID.toString()
     withCommandTest(sessionId) { verifyEvents =>
       val instance = new SparkConnectService(false)
@@ -435,7 +534,7 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
             done = true
           }
         })
-      verifyEvents.onCompleted()
+      verifyEvents.onCompleted(producedNumRows)
       // The current implementation is expected to be blocking.
       // This is here to make sure it is.
       assert(done)
@@ -788,8 +887,9 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
       assert(executeHolder.eventsManager.hasCanceled.isEmpty)
       assert(executeHolder.eventsManager.hasError.isDefined)
     }
-    def onCompleted(): Unit = {
+    def onCompleted(producedRowCount: Option[Long] = None): Unit = {
       assert(executeHolder.eventsManager.status == ExecuteStatus.Closed)
+      assert(executeHolder.eventsManager.getProducedRowCount == 
producedRowCount)
     }
     def onCanceled(): Unit = {
       assert(executeHolder.eventsManager.hasCanceled.contains(true))
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
index e7cc8007142..7950f9c5474 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
@@ -141,6 +141,18 @@ class ExecuteEventsManagerSuite
           DEFAULT_CLOCK.getTimeMillis()))
   }
 
+  test("SPARK-44776: post finished with row number") {
+    val events = setupEvents(ExecuteStatus.Started)
+    events.postFinished(Some(100))
+    
verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, 
times(1))
+      .post(
+        SparkListenerConnectOperationFinished(
+          events.executeHolder.jobTag,
+          DEFAULT_QUERY_ID,
+          DEFAULT_CLOCK.getTimeMillis(),
+          Some(100)))
+  }
+
   test("SPARK-43923: post closed") {
     val events = setupEvents(ExecuteStatus.Finished)
     events.postClosed()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to