xi-db commented on code in PR #52271:
URL: https://github.com/apache/spark/pull/52271#discussion_r2333327019


##########
sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala:
##########
@@ -374,6 +374,132 @@ class SparkConnectServiceSuite
     }
   }
 
+  test("Arrow batch chunking") {
+    withEvents { verifyEvents =>
+      val overriddenMaxChunkSize = 100
+      withSparkConf(
+        Connect.CONNECT_SESSION_RESULT_CHUNKING_MAX_CHUNK_SIZE.key ->
+          overriddenMaxChunkSize.toString) {
+        val instance = new SparkConnectService(false)
+        val connect = new MockRemoteSession()
+        val context = proto.UserContext
+          .newBuilder()
+          .setUserId("c1")
+          .build()
+        val plan = proto.Plan
+          .newBuilder()
+          .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 
1, 4)"))
+          .build()
+        val request = proto.ExecutePlanRequest
+          .newBuilder()
+          .setPlan(plan)
+          .addRequestOptions(
+            proto.ExecutePlanRequest.RequestOption
+              .newBuilder()
+              .setResultChunkingOptions(proto.ResultChunkingOptions
+                .newBuilder()
+                .setAllowArrowBatchChunking(true)
+                .build())
+              .build())
+          .setUserContext(context)
+          .setSessionId(UUID.randomUUID.toString())
+          .build()
+
+        // Execute plan.
+        @volatile var done = false
+        val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
+        instance.executePlan(
+          request,
+          new StreamObserver[proto.ExecutePlanResponse] {
+            override def onNext(v: proto.ExecutePlanResponse): Unit = {
+              responses += v
+              verifyEvents.onNext(v)
+            }
+
+            override def onError(throwable: Throwable): Unit = {
+              verifyEvents.onError(throwable)
+              throw throwable
+            }
+
+            override def onCompleted(): Unit = {
+              done = true
+            }
+          })
+        verifyEvents.onCompleted(Some(100))
+        // The current implementation is expected to be blocking. This is here 
to make sure it is.
+        assert(done)
+
+        // ArrowBatch chunks of 4 partitions + Metrics + optional progress 
messages
+        val filteredResponses = responses.filter(!_.hasExecutionProgress)
+
+        // Make sure the first response is schema only
+        val head = filteredResponses.head
+        assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
+
+        // Make sure the last response is metrics only
+        val last = filteredResponses.last
+        assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
+
+        val allocator = new RootAllocator()
+
+        // Check the 'data' batches
+        val arrowResponseQueue = 
mutable.Queue(filteredResponses.tail.dropRight(1).toSeq: _*)
+        def dequeueNextBatch(): List[proto.ExecutePlanResponse] = {
+          if (arrowResponseQueue.isEmpty) return Nil
+          val n = 
arrowResponseQueue.front.getArrowBatch.getNumChunksInBatch.toInt
+          List.fill(n)(arrowResponseQueue.dequeue())
+        }
+
+        var batchCount = 0
+        var expectedId = 0L
+        var previousEId = 0.0d
+        while (arrowResponseQueue.nonEmpty) {
+          val batch = dequeueNextBatch()
+          // In this example, the max chunk size is set to a small value, so 
each Arrow batch
+          // should be split into multiple chunks.
+          assert(batch.size > 5)
+          batchCount += 1
+
+          val rowCount = batch.head.getArrowBatch.getRowCount
+          val rowStartOffset = batch.head.getArrowBatch.getStartOffset
+          batch.zipWithIndex.foreach { case (chunk, chunkIndex) =>
+            assert(chunk.getArrowBatch.getChunkIndex == chunkIndex)
+            assert(chunk.getArrowBatch.getNumChunksInBatch == batch.size)
+            assert(chunk.getArrowBatch.getRowCount == rowCount)
+            assert(chunk.getArrowBatch.getStartOffset == rowStartOffset)
+            assert(chunk.getArrowBatch.getData != null)
+            assert(chunk.getArrowBatch.getData.size() > 0)
+            assert(chunk.getArrowBatch.getData.size() <= 
overriddenMaxChunkSize)
+          }
+
+          // Reassemble the chunks into a single Arrow batch and validate its 
content.
+          val batchData: ByteString =

Review Comment:
   Yeah, it's a good point, updated the test as well.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to