This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 8ad4dddb17b [SPARK-44776][CONNECT] Add ProducedRowCount to SparkListenerConnectOperationFinished 8ad4dddb17b is described below commit 8ad4dddb17b76c6c53a65730229cdb69cd1c1889 Author: Lingkai Kong <lingkai.k...@databricks.com> AuthorDate: Tue Aug 22 10:06:50 2023 +0900 [SPARK-44776][CONNECT] Add ProducedRowCount to SparkListenerConnectOperationFinished ### What changes were proposed in this pull request? Add ProducedRowCount field to SparkListenerConnectOperationFinished ### Why are the changes needed? Needed for showing number of rows getting produced ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added Unit test Closes #42454 from gjxdxh/SPARK-44776. Authored-by: Lingkai Kong <lingkai.k...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 4646991abd7f4a47a1b8712e2017a2fae98f7c5a) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../execution/SparkConnectPlanExecution.scala | 10 +- .../sql/connect/planner/SparkConnectPlanner.scala | 2 +- .../sql/connect/service/ExecuteEventsManager.scala | 27 +- .../connect/planner/SparkConnectServiceSuite.scala | 288 ++++++++++++++------- .../service/ExecuteEventsManagerSuite.scala | 12 + 5 files changed, 238 insertions(+), 101 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 131ddf76fa4..00fec4378c5 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -110,6 +110,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) errorOnDuplicatedFieldNames = false) var numSent = 0 + var totalNumRows: Long = 0 def sendBatch(bytes: Array[Byte], count: Long): Unit = { val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId) val batch = proto.ExecutePlanResponse.ArrowBatch @@ -120,14 +121,15 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) response.setArrowBatch(batch) responseObserver.onNext(response.build()) numSent += 1 + totalNumRows += count } dataframe.queryExecution.executedPlan match { case LocalTableScanExec(_, rows) => - executePlan.eventsManager.postFinished() converter(rows.iterator).foreach { case (bytes, count) => sendBatch(bytes, count) } + executePlan.eventsManager.postFinished(Some(totalNumRows)) case _ => SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { val rows = dataframe.queryExecution.executedPlan.execute() @@ -162,8 +164,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) resultFunc = () => ()) // Collect errors and propagate them to the main thread. .andThen { - case Success(_) => - executePlan.eventsManager.postFinished() + case Success(_) => // do nothing case Failure(throwable) => signal.synchronized { error = Some(throwable) @@ -200,8 +201,9 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) currentPartitionId += 1 } ThreadUtils.awaitReady(future, Duration.Inf) + executePlan.eventsManager.postFinished(Some(totalNumRows)) } else { - executePlan.eventsManager.postFinished() + executePlan.eventsManager.postFinished(Some(totalNumRows)) } } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5120073e2f0..e81e9bb51cb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2513,7 +2513,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .putAllArgs(getSqlCommand.getArgsMap) .addAllPosArgs(getSqlCommand.getPosArgsList))) } - executeHolder.eventsManager.postFinished() + executeHolder.eventsManager.postFinished(Some(rows.size)) // Exactly one SQL Command Result Batch responseObserver.onNext( ExecutePlanResponse diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala index 5e831aaa98f..5b9267a9679 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala @@ -75,6 +75,8 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { private var canceled = Option.empty[Boolean] + private var producedRowCount = Option.empty[Long] + /** * @return * Last event posted by the Connect request @@ -95,6 +97,13 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { */ private[connect] def hasError: Option[Boolean] = error + /** + * @return + * How many rows the Connect request has produced @link + * org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished + */ + private[connect] def getProducedRowCount: Option[Long] = producedRowCount + /** * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationStarted. */ @@ -192,13 +201,23 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { /** * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished. + * @param producedRowsCountOpt + * Number of rows that are returned to the user. None is expected when the operation does not + * return any rows. */ - def postFinished(): Unit = { + def postFinished(producedRowsCountOpt: Option[Long] = None): Unit = { assertStatus( List(ExecuteStatus.Started, ExecuteStatus.ReadyForExecution), ExecuteStatus.Finished) + producedRowCount = producedRowsCountOpt + listenerBus - .post(SparkListenerConnectOperationFinished(jobTag, operationId, clock.getTimeMillis())) + .post( + SparkListenerConnectOperationFinished( + jobTag, + operationId, + clock.getTimeMillis(), + producedRowCount)) } /** @@ -395,6 +414,9 @@ case class SparkListenerConnectOperationFailed( * 36 characters UUID assigned by Connect during a request. * @param eventTime: * The time in ms when the event was generated. + * @param producedRowCount: + * Number of rows that are returned to the user. None is expected when the operation does not + * return any rows. * @param extraTags: * Additional metadata during the request. */ @@ -402,6 +424,7 @@ case class SparkListenerConnectOperationFinished( jobTag: String, operationId: String, eventTime: Long, + producedRowCount: Option[Long] = None, extraTags: Map[String, String] = Map.empty) extends SparkListenerEvent diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 285f3103b19..74649e15e9e 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.CreateDataFrameViewCommand import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.dsl.MockRemoteSession @@ -49,13 +50,18 @@ import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus, Sessi import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** * Testing Connect Service implementation. */ -class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with Logging { +class SparkConnectServiceSuite + extends SharedSparkSession + with MockitoSugar + with Logging + with SparkConnectPlanTest { private def sparkSessionHolder = SessionHolder.forTesting(spark) private def DEFAULT_UUID = UUID.fromString("89ea6117-1f45-4c03-ae27-f47c6aded093") @@ -190,7 +196,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with done = true } }) - verifyEvents.onCompleted() + verifyEvents.onCompleted(Some(100)) // The current implementation is expected to be blocking. This is here to make sure it is. assert(done) @@ -238,6 +244,77 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } } + test("SPARK-44776: LocalTableScanExec") { + withEvents { verifyEvents => + // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 + assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) + val instance = new SparkConnectService(false) + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + + val rows = (0L to 5L).map { i => + new GenericInternalRow(Array(i, UTF8String.fromString("" + (i - 1 + 'a').toChar))) + } + + val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType))) + val inputRows = rows.map { row => + val proj = UnsafeProjection.create(schema) + proj(row).copy() + } + + val localRelation = createLocalRelationProto(schema, inputRows) + val plan = proto.Plan + .newBuilder() + .setRoot(localRelation) + .build() + + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId(UUID.randomUUID.toString()) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + responses += v + verifyEvents.onNext(v) + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onError(throwable) + throw throwable + } + + override def onCompleted(): Unit = { + done = true + } + }) + verifyEvents.onCompleted(Some(6)) + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + + // 1 Partitions + Metrics + assert(responses.size == 3) + + // Make sure the first response is schema only + val head = responses.head + assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) + + // Make sure the last response is metrics only + val last = responses.last + assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) + } + } + test("SPARK-44657: Arrow batches respect max batch size limit") { // Set 10 KiB as the batch size limit val batchSize = 10 * 1024 @@ -301,101 +378,123 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with gridTest("SPARK-43923: commands send events")( Seq( - proto.Command - .newBuilder() - .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 1").build()), - proto.Command - .newBuilder() - .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show tables").build()), - proto.Command - .newBuilder() - .setWriteOperation( - proto.WriteOperation - .newBuilder() - .setInput( - proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1"))) - .setPath(Utils.createTempDir().getAbsolutePath) - .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)), - proto.Command - .newBuilder() - .setWriteOperationV2( - proto.WriteOperationV2 - .newBuilder() - .setInput(proto.Relation.newBuilder.setRange( - proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L))) - .setTableName("testcat.testtable") - .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)), - proto.Command - .newBuilder() - .setCreateDataframeView( - CreateDataFrameViewCommand - .newBuilder() - .setName("testview") - .setInput( - proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))), - proto.Command - .newBuilder() - .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()), - proto.Command - .newBuilder() - .setExtension( - protobuf.Any.pack( - proto.ExamplePluginCommand + ( + proto.Command + .newBuilder() + .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 1").build()), + Some(0L)), + ( + proto.Command + .newBuilder() + .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show databases").build()), + Some(1L)), + ( + proto.Command + .newBuilder() + .setWriteOperation( + proto.WriteOperation .newBuilder() - .setCustomField("SPARK-43923") - .build())), - proto.Command - .newBuilder() - .setWriteStreamOperationStart( - proto.WriteStreamOperationStart - .newBuilder() - .setInput( - proto.Relation + .setInput( + proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1"))) + .setPath(Utils.createTempDir().getAbsolutePath) + .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)), + None), + ( + proto.Command + .newBuilder() + .setWriteOperationV2( + proto.WriteOperationV2 + .newBuilder() + .setInput(proto.Relation.newBuilder.setRange( + proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L))) + .setTableName("testcat.testtable") + .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)), + None), + ( + proto.Command + .newBuilder() + .setCreateDataframeView( + CreateDataFrameViewCommand + .newBuilder() + .setName("testview") + .setInput( + proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))), + None), + ( + proto.Command + .newBuilder() + .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()), + None), + ( + proto.Command + .newBuilder() + .setExtension( + protobuf.Any.pack( + proto.ExamplePluginCommand .newBuilder() - .setRead(proto.Read + .setCustomField("SPARK-43923") + .build())), + None), + ( + proto.Command + .newBuilder() + .setWriteStreamOperationStart( + proto.WriteStreamOperationStart + .newBuilder() + .setInput( + proto.Relation .newBuilder() - .setIsStreaming(true) - .setDataSource(proto.Read.DataSource.newBuilder().setFormat("rate").build()) + .setRead(proto.Read + .newBuilder() + .setIsStreaming(true) + .setDataSource(proto.Read.DataSource.newBuilder().setFormat("rate").build()) + .build()) .build()) - .build()) - .setOutputMode("Append") - .setAvailableNow(true) - .setQueryName("test") - .setFormat("memory") - .putOptions("checkpointLocation", Utils.createTempDir().getAbsolutePath) - .setPath("test-path") - .build()), - proto.Command - .newBuilder() - .setStreamingQueryCommand( - proto.StreamingQueryCommand - .newBuilder() - .setQueryId( - proto.StreamingQueryInstanceId - .newBuilder() - .setId(DEFAULT_UUID.toString) - .setRunId(DEFAULT_UUID.toString) - .build()) - .setStop(true)), - proto.Command - .newBuilder() - .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand + .setOutputMode("Append") + .setAvailableNow(true) + .setQueryName("test") + .setFormat("memory") + .putOptions("checkpointLocation", Utils.createTempDir().getAbsolutePath) + .setPath("test-path") + .build()), + None), + ( + proto.Command .newBuilder() - .setListListeners(true)), - proto.Command - .newBuilder() - .setRegisterFunction( - proto.CommonInlineUserDefinedFunction + .setStreamingQueryCommand( + proto.StreamingQueryCommand + .newBuilder() + .setQueryId( + proto.StreamingQueryInstanceId + .newBuilder() + .setId(DEFAULT_UUID.toString) + .setRunId(DEFAULT_UUID.toString) + .build()) + .setStop(true)), + None), + ( + proto.Command + .newBuilder() + .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand .newBuilder() - .setFunctionName("function") - .setPythonUdf( - proto.PythonUDF - .newBuilder() - .setEvalType(100) - .setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType)) - .setCommand(ByteString.copyFrom("command".getBytes())) - .setPythonVer("3.10") - .build())))) { command => + .setListListeners(true)), + None), + ( + proto.Command + .newBuilder() + .setRegisterFunction( + proto.CommonInlineUserDefinedFunction + .newBuilder() + .setFunctionName("function") + .setPythonUdf( + proto.PythonUDF + .newBuilder() + .setEvalType(100) + .setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType)) + .setCommand(ByteString.copyFrom("command".getBytes())) + .setPythonVer("3.10") + .build())), + None))) { case (command, producedNumRows) => val sessionId = UUID.randomUUID.toString() withCommandTest(sessionId) { verifyEvents => val instance = new SparkConnectService(false) @@ -435,7 +534,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with done = true } }) - verifyEvents.onCompleted() + verifyEvents.onCompleted(producedNumRows) // The current implementation is expected to be blocking. // This is here to make sure it is. assert(done) @@ -788,8 +887,9 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with assert(executeHolder.eventsManager.hasCanceled.isEmpty) assert(executeHolder.eventsManager.hasError.isDefined) } - def onCompleted(): Unit = { + def onCompleted(producedRowCount: Option[Long] = None): Unit = { assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) } def onCanceled(): Unit = { assert(executeHolder.eventsManager.hasCanceled.contains(true)) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala index e7cc8007142..7950f9c5474 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -141,6 +141,18 @@ class ExecuteEventsManagerSuite DEFAULT_CLOCK.getTimeMillis())) } + test("SPARK-44776: post finished with row number") { + val events = setupEvents(ExecuteStatus.Started) + events.postFinished(Some(100)) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationFinished( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis(), + Some(100))) + } + test("SPARK-43923: post closed") { val events = setupEvents(ExecuteStatus.Finished) events.postClosed() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org