juliuszsompolski commented on code in PR #41443:
URL: https://github.com/apache/spark/pull/41443#discussion_r1265320272
##########
connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala:
##########
@@ -131,126 +150,363 @@ class SparkConnectServiceSuite extends
SharedSparkSession {
}
test("SPARK-41224: collect data using arrow") {
- // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
- assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
- val instance = new SparkConnectService(false)
- val connect = new MockRemoteSession()
- val context = proto.UserContext
- .newBuilder()
- .setUserId("c1")
- .build()
- val plan = proto.Plan
- .newBuilder()
- .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1,
4)"))
- .build()
- val request = proto.ExecutePlanRequest
- .newBuilder()
- .setPlan(plan)
- .setUserContext(context)
- .build()
-
- // Execute plan.
- @volatile var done = false
- val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
- instance.executePlan(
- request,
- new StreamObserver[proto.ExecutePlanResponse] {
- override def onNext(v: proto.ExecutePlanResponse): Unit = responses +=
v
-
- override def onError(throwable: Throwable): Unit = throw throwable
-
- override def onCompleted(): Unit = done = true
- })
-
- // The current implementation is expected to be blocking. This is here to
make sure it is.
- assert(done)
-
- // 4 Partitions + Metrics
- assert(responses.size == 6)
-
- // Make sure the first response is schema only
- val head = responses.head
- assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
-
- // Make sure the last response is metrics only
- val last = responses.last
- assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
-
- val allocator = new RootAllocator()
-
- // Check the 'data' batches
- var expectedId = 0L
- var previousEId = 0.0d
- responses.tail.dropRight(1).foreach { response =>
- assert(response.hasArrowBatch)
- val batch = response.getArrowBatch
- assert(batch.getData != null)
- assert(batch.getRowCount == 25)
-
- val reader = new ArrowStreamReader(batch.getData.newInput(), allocator)
- while (reader.loadNextBatch()) {
- val root = reader.getVectorSchemaRoot
- val idVector = root.getVector(0).asInstanceOf[BigIntVector]
- val eidVector = root.getVector(1).asInstanceOf[Float8Vector]
- val numRows = root.getRowCount
- var i = 0
- while (i < numRows) {
- assert(idVector.get(i) == expectedId)
- expectedId += 1
- val eid = eidVector.get(i)
- assert(eid > previousEId)
- previousEId = eid
- i += 1
+ withEvents { verifyEvents =>
+ // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
+ assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
+ val instance = new SparkConnectService(false)
+ val connect = new MockRemoteSession()
+ val context = proto.UserContext
+ .newBuilder()
+ .setUserId("c1")
+ .build()
+ val plan = proto.Plan
+ .newBuilder()
+ .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1,
4)"))
+ .build()
+ val request = proto.ExecutePlanRequest
+ .newBuilder()
+ .setPlan(plan)
+ .setUserContext(context)
+ .build()
+
+ // Execute plan.
+ @volatile var done = false
+ val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
+ instance.executePlan(
+ request,
+ new StreamObserver[proto.ExecutePlanResponse] {
+ override def onNext(v: proto.ExecutePlanResponse): Unit = {
+ responses += v
+ verifyEvents.onNext(v)
+ }
+
+ override def onError(throwable: Throwable): Unit = {
+ verifyEvents.onError(throwable)
+ throw throwable
+ }
+
+ override def onCompleted(): Unit = {
+ done = true
+ }
+ })
+ verifyEvents.onCompleted()
+ // The current implementation is expected to be blocking. This is here
to make sure it is.
+ assert(done)
+
+ // 4 Partitions + Metrics
+ assert(responses.size == 6)
+
+ // Make sure the first response is schema only
+ val head = responses.head
+ assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
+
+ // Make sure the last response is metrics only
+ val last = responses.last
+ assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
+
+ val allocator = new RootAllocator()
+
+ // Check the 'data' batches
+ var expectedId = 0L
+ var previousEId = 0.0d
+ responses.tail.dropRight(1).foreach { response =>
+ assert(response.hasArrowBatch)
+ val batch = response.getArrowBatch
+ assert(batch.getData != null)
+ assert(batch.getRowCount == 25)
+
+ val reader = new ArrowStreamReader(batch.getData.newInput(), allocator)
+ while (reader.loadNextBatch()) {
+ val root = reader.getVectorSchemaRoot
+ val idVector = root.getVector(0).asInstanceOf[BigIntVector]
+ val eidVector = root.getVector(1).asInstanceOf[Float8Vector]
+ val numRows = root.getRowCount
+ var i = 0
+ while (i < numRows) {
+ assert(idVector.get(i) == expectedId)
+ expectedId += 1
+ val eid = eidVector.get(i)
+ assert(eid > previousEId)
+ previousEId = eid
+ i += 1
+ }
}
+ reader.close()
}
- reader.close()
+ allocator.close()
}
- allocator.close()
}
- test("SPARK-41165: failures in the arrow collect path should not cause
hangs") {
- val instance = new SparkConnectService(false)
+ gridTest("SPARK-43923: commands send events")(
+ Seq(
+ proto.Command
+ .newBuilder()
+ .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select
1").build()),
+ proto.Command
+ .newBuilder()
+ .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show
tables").build()),
+ proto.Command
+ .newBuilder()
+ .setWriteOperation(
+ proto.WriteOperation
+ .newBuilder()
+ .setInput(
+
proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))
+ .setPath("my/test/path")
+ .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)),
+ proto.Command
+ .newBuilder()
+ .setWriteOperationV2(
+ proto.WriteOperationV2
+ .newBuilder()
+ .setInput(proto.Relation.newBuilder.setRange(
+ proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L)))
+ .setTableName("testcat.testtable")
+ .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)),
+ proto.Command
+ .newBuilder()
+ .setCreateDataframeView(
+ CreateDataFrameViewCommand
+ .newBuilder()
+ .setName("testview")
+ .setInput(
+
proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select
1")))),
+ proto.Command
+ .newBuilder()
+ .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()),
+ proto.Command
+ .newBuilder()
+ .setExtension(
+ protobuf.Any.pack(
+ proto.ExamplePluginCommand
+ .newBuilder()
+ .setCustomField("SPARK-43923")
+ .build())),
+ proto.Command
+ .newBuilder()
+ .setWriteStreamOperationStart(
+ proto.WriteStreamOperationStart
+ .newBuilder()
+ .setInput(
+ proto.Relation
+ .newBuilder()
+ .setRead(proto.Read
+ .newBuilder()
+ .setIsStreaming(true)
+
.setDataSource(proto.Read.DataSource.newBuilder().setFormat("rate").build())
+ .build())
+ .build())
+ .setOutputMode("Append")
+ .setAvailableNow(true)
+ .setQueryName("test")
+ .setFormat("memory")
+ .putOptions("checkpointLocation", s"${UUID.randomUUID}")
+ .setPath("test-path")
+ .build()),
+ proto.Command
+ .newBuilder()
+ .setStreamingQueryCommand(
+ proto.StreamingQueryCommand
+ .newBuilder()
+ .setQueryId(
+ proto.StreamingQueryInstanceId
+ .newBuilder()
+ .setId(DEFAULT_UUID.toString)
+ .setRunId(DEFAULT_UUID.toString)
+ .build())
+ .setStop(true)),
+ proto.Command
+ .newBuilder()
+ .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand
+ .newBuilder()
+ .setListListeners(true)),
+ proto.Command
+ .newBuilder()
+ .setRegisterFunction(
+ proto.CommonInlineUserDefinedFunction
+ .newBuilder()
+ .setFunctionName("function")
+ .setPythonUdf(
+ proto.PythonUDF
+ .newBuilder()
+ .setEvalType(100)
+
.setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType))
+ .setCommand(ByteString.copyFrom("command".getBytes()))
+ .setPythonVer("3.10")
+ .build())))) { command =>
+ withCommandTest { verifyEvents =>
+ val instance = new SparkConnectService(false)
+ val context = proto.UserContext
+ .newBuilder()
+ .setUserId("c1")
+ .build()
+ val plan = proto.Plan
+ .newBuilder()
+ .setCommand(command)
+ .build()
+
+ val request = proto.ExecutePlanRequest
+ .newBuilder()
+ .setPlan(plan)
+ .setSessionId("s1")
+ .setUserContext(context)
+ .build()
+
+ // Execute plan.
+ @volatile var done = false
+ val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
+ instance.executePlan(
+ request,
+ new StreamObserver[proto.ExecutePlanResponse] {
+ override def onNext(v: proto.ExecutePlanResponse): Unit = {
+ responses += v
+ verifyEvents.onNext(v)
+ }
+
+ override def onError(throwable: Throwable): Unit = {
+ verifyEvents.onError(throwable)
+ throw throwable
+ }
+
+ override def onCompleted(): Unit = {
+ done = true
+ }
+ })
+ verifyEvents.onCompleted()
+ // The current implementation is expected to be blocking.
+ // This is here to make sure it is.
+ assert(done)
+
+ // Result + Metrics
+ if (responses.size > 1) {
+ assert(responses.size == 2)
- // Add an always crashing UDF
- val session = SparkConnectService.getOrCreateIsolatedSession("c1",
"session").session
- val instaKill: Long => Long = { _ =>
- throw new Exception("Kaboom")
+ // Make sure the first response result only
+ val head = responses.head
+ assert(head.hasSqlCommandResult && !head.hasMetrics)
+
+ // Make sure the last response is metrics only
+ val last = responses.last
+ assert(last.hasMetrics && !last.hasSqlCommandResult)
+ }
}
- session.udf.register("insta_kill", instaKill)
-
- val connect = new MockRemoteSession()
- val context = proto.UserContext
- .newBuilder()
- .setUserId("c1")
- .build()
- val plan = proto.Plan
- .newBuilder()
- .setRoot(connect.sql("select insta_kill(id) from range(10)"))
- .build()
- val request = proto.ExecutePlanRequest
- .newBuilder()
- .setPlan(plan)
- .setUserContext(context)
- .setSessionId("session")
- .build()
-
- // The observer is executed inside this thread. So
- // we can perform the checks inside the observer.
- instance.executePlan(
- request,
- new StreamObserver[proto.ExecutePlanResponse] {
- override def onNext(v: proto.ExecutePlanResponse): Unit = {
- fail("this should not receive responses")
- }
+ }
- override def onError(throwable: Throwable): Unit = {
- assert(throwable.isInstanceOf[StatusRuntimeException])
- }
+ test("SPARK-43923: canceled request send events") {
+ withEvents { verifyEvents =>
+ val instance = new SparkConnectService(false)
+
+ // Add an always crashing UDF
+ val session = SparkConnectService.getOrCreateIsolatedSession("c1",
"session").session
+ val sleep: Long => Long = { time =>
+ Thread.sleep(time)
+ time
+ }
+ session.udf.register("sleep", sleep)
+
+ val connect = new MockRemoteSession()
+ val context = proto.UserContext
+ .newBuilder()
+ .setUserId("c1")
+ .build()
+ val plan = proto.Plan
+ .newBuilder()
+ .setRoot(connect.sql("select sleep(10000)"))
+ .build()
+ val request = proto.ExecutePlanRequest
+ .newBuilder()
+ .setPlan(plan)
+ .setUserContext(context)
+ .setSessionId("session")
+ .build()
- override def onCompleted(): Unit = {
- fail("this should not complete")
+ val thread = new Thread {
+ override def run: Unit = {
+ Thread.sleep(1000)
Review Comment:
This sleep is going to result in flakiness.
See sth like
https://github.com/apache/spark/blob/85d8d62216d3b830cc5af3dec05422a9cda4cea0/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala#L133
- could you also have a listener raise some semaphore when the timing is
right? (e.g. after a JobStarted event)
##########
connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala:
##########
@@ -378,4 +634,210 @@ class SparkConnectServiceSuite extends SharedSparkSession
{
assert(valuesList.last.hasLong && valuesList.last.getLong == 99)
}
}
+
+ protected def withCommandTest(f: VerifyEvents => Unit): Unit = {
+ withView("testview") {
+ withTable("testcat.testtable") {
+ withSparkConf(
+ "spark.sql.catalog.testcat" ->
classOf[InMemoryPartitionTableCatalog].getName,
+ Connect.CONNECT_EXTENSIONS_COMMAND_CLASSES.key ->
+ "org.apache.spark.sql.connect.plugin.ExampleCommandPlugin") {
+ withEvents { verifyEvents =>
+ val restartedQuery = mock[StreamingQuery]
+ when(restartedQuery.id).thenReturn(DEFAULT_UUID)
+ when(restartedQuery.runId).thenReturn(DEFAULT_UUID)
+
SparkConnectService.streamingSessionManager.registerNewStreamingQuery(
+ SparkConnectService.getOrCreateIsolatedSession("c1", "s1"),
+ restartedQuery)
+ f(verifyEvents)
+ }
+ }
+ }
+ }
+ }
+
+ protected def withSparkConf(pairs: (String, String)*)(f: => Unit): Unit = {
+ val conf = SparkEnv.get.conf
+ pairs.foreach { kv => conf.set(kv._1, kv._2) }
+ try f
+ finally {
+ pairs.foreach { kv => conf.remove(kv._1) }
+ }
+ }
+
+ protected def withEvents(f: VerifyEvents => Unit): Unit = {
+ val verifyEvents = new VerifyEvents(spark.sparkContext)
+ spark.sparkContext.addSparkListener(verifyEvents.listener)
+ Utils.tryWithSafeFinally({
+ f(verifyEvents)
+ SparkConnectService.invalidateAllSessions()
+ verifyEvents.onSessionClosed()
+ }) {
+ verifyEvents.waitUntilEmpty()
+ spark.sparkContext.removeSparkListener(verifyEvents.listener)
+ SparkConnectService.invalidateAllSessions()
+ SparkConnectPluginRegistry.reset()
+ }
+ }
+
+ protected def gridTest[A](testNamePrefix: String, testTags: Tag*)(params:
Seq[A])(
+ testFun: A => Unit): Unit = {
+ for (param <- params) {
+ test(testNamePrefix + s" ($param)", testTags: _*)(testFun(param))
+ }
+ }
+
+ sealed abstract class Status(value: Int)
+
+ object Status {
+ case object Pending extends Status(0)
+ case object SessionStarted extends Status(1)
+ case object Started extends Status(2)
+ case object Analyzed extends Status(3)
+ case object ReadyForExecution extends Status(4)
+ case object Finished extends Status(5)
+ case object Failed extends Status(6)
+ case object Canceled extends Status(7)
+ case object Closed extends Status(8)
+ case object SessionClosed extends Status(9)
+ }
Review Comment:
Should this kind of status be maintained by the `ExecuteEventsManager`, and
should the manager itself assert valid transitions?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]