juliuszsompolski commented on code in PR #41443:
URL: https://github.com/apache/spark/pull/41443#discussion_r1265555819


##########
connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala:
##########
@@ -131,126 +151,363 @@ class SparkConnectServiceSuite extends 
SharedSparkSession {
   }
 
   test("SPARK-41224: collect data using arrow") {
-    // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
-    assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
-    val instance = new SparkConnectService(false)
-    val connect = new MockRemoteSession()
-    val context = proto.UserContext
-      .newBuilder()
-      .setUserId("c1")
-      .build()
-    val plan = proto.Plan
-      .newBuilder()
-      .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 
4)"))
-      .build()
-    val request = proto.ExecutePlanRequest
-      .newBuilder()
-      .setPlan(plan)
-      .setUserContext(context)
-      .build()
-
-    // Execute plan.
-    @volatile var done = false
-    val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
-    instance.executePlan(
-      request,
-      new StreamObserver[proto.ExecutePlanResponse] {
-        override def onNext(v: proto.ExecutePlanResponse): Unit = responses += 
v
-
-        override def onError(throwable: Throwable): Unit = throw throwable
-
-        override def onCompleted(): Unit = done = true
-      })
-
-    // The current implementation is expected to be blocking. This is here to 
make sure it is.
-    assert(done)
-
-    // 4 Partitions + Metrics
-    assert(responses.size == 6)
-
-    // Make sure the first response is schema only
-    val head = responses.head
-    assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
-
-    // Make sure the last response is metrics only
-    val last = responses.last
-    assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
-
-    val allocator = new RootAllocator()
-
-    // Check the 'data' batches
-    var expectedId = 0L
-    var previousEId = 0.0d
-    responses.tail.dropRight(1).foreach { response =>
-      assert(response.hasArrowBatch)
-      val batch = response.getArrowBatch
-      assert(batch.getData != null)
-      assert(batch.getRowCount == 25)
-
-      val reader = new ArrowStreamReader(batch.getData.newInput(), allocator)
-      while (reader.loadNextBatch()) {
-        val root = reader.getVectorSchemaRoot
-        val idVector = root.getVector(0).asInstanceOf[BigIntVector]
-        val eidVector = root.getVector(1).asInstanceOf[Float8Vector]
-        val numRows = root.getRowCount
-        var i = 0
-        while (i < numRows) {
-          assert(idVector.get(i) == expectedId)
-          expectedId += 1
-          val eid = eidVector.get(i)
-          assert(eid > previousEId)
-          previousEId = eid
-          i += 1
+    withEvents { verifyEvents =>
+      // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
+      assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
+      val instance = new SparkConnectService(false)
+      val connect = new MockRemoteSession()
+      val context = proto.UserContext
+        .newBuilder()
+        .setUserId("c1")
+        .build()
+      val plan = proto.Plan
+        .newBuilder()
+        .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 
4)"))
+        .build()
+      val request = proto.ExecutePlanRequest
+        .newBuilder()
+        .setPlan(plan)
+        .setUserContext(context)
+        .build()
+
+      // Execute plan.
+      @volatile var done = false
+      val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
+      instance.executePlan(
+        request,
+        new StreamObserver[proto.ExecutePlanResponse] {
+          override def onNext(v: proto.ExecutePlanResponse): Unit = {
+            responses += v
+            verifyEvents.onNext(v)
+          }
+
+          override def onError(throwable: Throwable): Unit = {
+            verifyEvents.onError(throwable)
+            throw throwable
+          }
+
+          override def onCompleted(): Unit = {
+            done = true
+          }
+        })
+      verifyEvents.onCompleted()
+      // The current implementation is expected to be blocking. This is here 
to make sure it is.
+      assert(done)
+
+      // 4 Partitions + Metrics
+      assert(responses.size == 6)
+
+      // Make sure the first response is schema only
+      val head = responses.head
+      assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
+
+      // Make sure the last response is metrics only
+      val last = responses.last
+      assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
+
+      val allocator = new RootAllocator()
+
+      // Check the 'data' batches
+      var expectedId = 0L
+      var previousEId = 0.0d
+      responses.tail.dropRight(1).foreach { response =>
+        assert(response.hasArrowBatch)
+        val batch = response.getArrowBatch
+        assert(batch.getData != null)
+        assert(batch.getRowCount == 25)
+
+        val reader = new ArrowStreamReader(batch.getData.newInput(), allocator)
+        while (reader.loadNextBatch()) {
+          val root = reader.getVectorSchemaRoot
+          val idVector = root.getVector(0).asInstanceOf[BigIntVector]
+          val eidVector = root.getVector(1).asInstanceOf[Float8Vector]
+          val numRows = root.getRowCount
+          var i = 0
+          while (i < numRows) {
+            assert(idVector.get(i) == expectedId)
+            expectedId += 1
+            val eid = eidVector.get(i)
+            assert(eid > previousEId)
+            previousEId = eid
+            i += 1
+          }
         }
+        reader.close()
       }
-      reader.close()
+      allocator.close()
     }
-    allocator.close()
   }
 
-  test("SPARK-41165: failures in the arrow collect path should not cause 
hangs") {
-    val instance = new SparkConnectService(false)
+  gridTest("SPARK-43923: commands send events")(
+    Seq(
+      proto.Command
+        .newBuilder()
+        .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 
1").build()),
+      proto.Command
+        .newBuilder()
+        .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show 
tables").build()),
+      proto.Command
+        .newBuilder()
+        .setWriteOperation(
+          proto.WriteOperation
+            .newBuilder()
+            .setInput(
+              
proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))
+            .setPath("my/test/path")
+            .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)),
+      proto.Command
+        .newBuilder()
+        .setWriteOperationV2(
+          proto.WriteOperationV2
+            .newBuilder()
+            .setInput(proto.Relation.newBuilder.setRange(
+              proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L)))
+            .setTableName("testcat.testtable")
+            .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)),
+      proto.Command
+        .newBuilder()
+        .setCreateDataframeView(
+          CreateDataFrameViewCommand
+            .newBuilder()
+            .setName("testview")
+            .setInput(
+              
proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 
1")))),
+      proto.Command
+        .newBuilder()
+        .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()),
+      proto.Command
+        .newBuilder()
+        .setExtension(
+          protobuf.Any.pack(
+            proto.ExamplePluginCommand
+              .newBuilder()
+              .setCustomField("SPARK-43923")
+              .build())),
+      proto.Command
+        .newBuilder()
+        .setWriteStreamOperationStart(
+          proto.WriteStreamOperationStart
+            .newBuilder()
+            .setInput(
+              proto.Relation
+                .newBuilder()
+                .setRead(proto.Read
+                  .newBuilder()
+                  .setIsStreaming(true)
+                  
.setDataSource(proto.Read.DataSource.newBuilder().setFormat("rate").build())
+                  .build())
+                .build())
+            .setOutputMode("Append")
+            .setAvailableNow(true)
+            .setQueryName("test")
+            .setFormat("memory")
+            .putOptions("checkpointLocation", s"${UUID.randomUUID}")
+            .setPath("test-path")
+            .build()),
+      proto.Command
+        .newBuilder()
+        .setStreamingQueryCommand(
+          proto.StreamingQueryCommand
+            .newBuilder()
+            .setQueryId(
+              proto.StreamingQueryInstanceId
+                .newBuilder()
+                .setId(DEFAULT_UUID.toString)
+                .setRunId(DEFAULT_UUID.toString)
+                .build())
+            .setStop(true)),
+      proto.Command
+        .newBuilder()
+        .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand
+          .newBuilder()
+          .setListListeners(true)),
+      proto.Command
+        .newBuilder()
+        .setRegisterFunction(
+          proto.CommonInlineUserDefinedFunction
+            .newBuilder()
+            .setFunctionName("function")
+            .setPythonUdf(
+              proto.PythonUDF
+                .newBuilder()
+                .setEvalType(100)
+                
.setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType))
+                .setCommand(ByteString.copyFrom("command".getBytes()))
+                .setPythonVer("3.10")
+                .build())))) { command =>
+    withCommandTest { verifyEvents =>
+      val instance = new SparkConnectService(false)
+      val context = proto.UserContext
+        .newBuilder()
+        .setUserId("c1")
+        .build()
+      val plan = proto.Plan
+        .newBuilder()
+        .setCommand(command)
+        .build()
+
+      val request = proto.ExecutePlanRequest
+        .newBuilder()
+        .setPlan(plan)
+        .setSessionId("s1")
+        .setUserContext(context)
+        .build()
+
+      // Execute plan.
+      @volatile var done = false
+      val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
+      instance.executePlan(
+        request,
+        new StreamObserver[proto.ExecutePlanResponse] {
+          override def onNext(v: proto.ExecutePlanResponse): Unit = {
+            responses += v
+            verifyEvents.onNext(v)
+          }
+
+          override def onError(throwable: Throwable): Unit = {
+            verifyEvents.onError(throwable)
+            throw throwable
+          }
+
+          override def onCompleted(): Unit = {
+            done = true
+          }
+        })
+      verifyEvents.onCompleted()
+      // The current implementation is expected to be blocking.
+      // This is here to make sure it is.
+      assert(done)
+
+      // Result + Metrics
+      if (responses.size > 1) {
+        assert(responses.size == 2)
 
-    // Add an always crashing UDF
-    val session = SparkConnectService.getOrCreateIsolatedSession("c1", 
"session").session
-    val instaKill: Long => Long = { _ =>
-      throw new Exception("Kaboom")
+        // Make sure the first response result only
+        val head = responses.head
+        assert(head.hasSqlCommandResult && !head.hasMetrics)
+
+        // Make sure the last response is metrics only
+        val last = responses.last
+        assert(last.hasMetrics && !last.hasSqlCommandResult)
+      }
     }
-    session.udf.register("insta_kill", instaKill)
-
-    val connect = new MockRemoteSession()
-    val context = proto.UserContext
-      .newBuilder()
-      .setUserId("c1")
-      .build()
-    val plan = proto.Plan
-      .newBuilder()
-      .setRoot(connect.sql("select insta_kill(id) from range(10)"))
-      .build()
-    val request = proto.ExecutePlanRequest
-      .newBuilder()
-      .setPlan(plan)
-      .setUserContext(context)
-      .setSessionId("session")
-      .build()
-
-    // The observer is executed inside this thread. So
-    // we can perform the checks inside the observer.
-    instance.executePlan(
-      request,
-      new StreamObserver[proto.ExecutePlanResponse] {
-        override def onNext(v: proto.ExecutePlanResponse): Unit = {
-          fail("this should not receive responses")
-        }
+  }
 
-        override def onError(throwable: Throwable): Unit = {
-          assert(throwable.isInstanceOf[StatusRuntimeException])
-        }
+  test("SPARK-43923: canceled request send events") {
+    withEvents { verifyEvents =>
+      val instance = new SparkConnectService(false)
+
+      // Add an always crashing UDF
+      val session = SparkConnectService.getOrCreateIsolatedSession("c1", 
"session").session
+      val sleep: Long => Long = { time =>
+        Thread.sleep(time)
+        time
+      }
+      session.udf.register("sleep", sleep)
+
+      val connect = new MockRemoteSession()
+      val context = proto.UserContext
+        .newBuilder()
+        .setUserId("c1")
+        .build()
+      val plan = proto.Plan
+        .newBuilder()
+        .setRoot(connect.sql("select sleep(10000)"))
+        .build()
+      val request = proto.ExecutePlanRequest
+        .newBuilder()
+        .setPlan(plan)
+        .setUserContext(context)
+        .setSessionId("session")
+        .build()
 
-        override def onCompleted(): Unit = {
-          fail("this should not complete")
+      new Thread {

Review Comment:
   should `join` this thread at the end of the test?
   or `import scala.concurrent.ExecutionContext.Implicits.global` and make this 
a Future; see how various other suites uses Futures...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to