jdesjean commented on code in PR #41443:
URL: https://github.com/apache/spark/pull/41443#discussion_r1264160170
##########
connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala:
##########
@@ -131,126 +144,265 @@ class SparkConnectServiceSuite extends
SharedSparkSession {
}
test("SPARK-41224: collect data using arrow") {
- // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
- assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
- val instance = new SparkConnectService(false)
- val connect = new MockRemoteSession()
- val context = proto.UserContext
- .newBuilder()
- .setUserId("c1")
- .build()
- val plan = proto.Plan
- .newBuilder()
- .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1,
4)"))
- .build()
- val request = proto.ExecutePlanRequest
- .newBuilder()
- .setPlan(plan)
- .setUserContext(context)
- .build()
-
- // Execute plan.
- @volatile var done = false
- val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
- instance.executePlan(
- request,
- new StreamObserver[proto.ExecutePlanResponse] {
- override def onNext(v: proto.ExecutePlanResponse): Unit = responses +=
v
-
- override def onError(throwable: Throwable): Unit = throw throwable
-
- override def onCompleted(): Unit = done = true
- })
-
- // The current implementation is expected to be blocking. This is here to
make sure it is.
- assert(done)
-
- // 4 Partitions + Metrics
- assert(responses.size == 6)
-
- // Make sure the first response is schema only
- val head = responses.head
- assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
-
- // Make sure the last response is metrics only
- val last = responses.last
- assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
-
- val allocator = new RootAllocator()
-
- // Check the 'data' batches
- var expectedId = 0L
- var previousEId = 0.0d
- responses.tail.dropRight(1).foreach { response =>
- assert(response.hasArrowBatch)
- val batch = response.getArrowBatch
- assert(batch.getData != null)
- assert(batch.getRowCount == 25)
-
- val reader = new ArrowStreamReader(batch.getData.newInput(), allocator)
- while (reader.loadNextBatch()) {
- val root = reader.getVectorSchemaRoot
- val idVector = root.getVector(0).asInstanceOf[BigIntVector]
- val eidVector = root.getVector(1).asInstanceOf[Float8Vector]
- val numRows = root.getRowCount
- var i = 0
- while (i < numRows) {
- assert(idVector.get(i) == expectedId)
- expectedId += 1
- val eid = eidVector.get(i)
- assert(eid > previousEId)
- previousEId = eid
- i += 1
+ withEvents { verifyEvents =>
+ // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
+ assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
+ val instance = new SparkConnectService(false)
+ val connect = new MockRemoteSession()
+ val context = proto.UserContext
+ .newBuilder()
+ .setUserId("c1")
+ .build()
+ val plan = proto.Plan
+ .newBuilder()
+ .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1,
4)"))
+ .build()
+ val request = proto.ExecutePlanRequest
+ .newBuilder()
+ .setPlan(plan)
+ .setUserContext(context)
+ .build()
+
+ // Execute plan.
+ @volatile var done = false
+ val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
+ instance.executePlan(
+ request,
+ new StreamObserver[proto.ExecutePlanResponse] {
+ override def onNext(v: proto.ExecutePlanResponse): Unit = {
+ responses += v
+ verifyEvents.onNext(v)
+ }
+
+ override def onError(throwable: Throwable): Unit = {
+ verifyEvents.onError(throwable)
+ throw throwable
+ }
+
+ override def onCompleted(): Unit = {
+ done = true
+ }
+ })
+ verifyEvents.onCompleted()
+ // The current implementation is expected to be blocking. This is here
to make sure it is.
+ assert(done)
+
+ // 4 Partitions + Metrics
+ assert(responses.size == 6)
+
+ // Make sure the first response is schema only
+ val head = responses.head
+ assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
+
+ // Make sure the last response is metrics only
+ val last = responses.last
+ assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
+
+ val allocator = new RootAllocator()
+
+ // Check the 'data' batches
+ var expectedId = 0L
+ var previousEId = 0.0d
+ responses.tail.dropRight(1).foreach { response =>
+ assert(response.hasArrowBatch)
+ val batch = response.getArrowBatch
+ assert(batch.getData != null)
+ assert(batch.getRowCount == 25)
+
+ val reader = new ArrowStreamReader(batch.getData.newInput(), allocator)
+ while (reader.loadNextBatch()) {
+ val root = reader.getVectorSchemaRoot
+ val idVector = root.getVector(0).asInstanceOf[BigIntVector]
+ val eidVector = root.getVector(1).asInstanceOf[Float8Vector]
+ val numRows = root.getRowCount
+ var i = 0
+ while (i < numRows) {
+ assert(idVector.get(i) == expectedId)
+ expectedId += 1
+ val eid = eidVector.get(i)
+ assert(eid > previousEId)
+ previousEId = eid
+ i += 1
+ }
+ }
+ reader.close()
+ }
+ allocator.close()
+ }
+ }
+
+ gridTest("SPARK-43923: commands send events")(
Review Comment:
@bogao007
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]