This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f9d417fc17a [SPARK-44657][CONNECT] Fix incorrect limit handling in 
ArrowBatchWithSchemaIterator and config parsing of 
CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
f9d417fc17a is described below

commit f9d417fc17a82ddf02d6bbab82abc8e1aa264356
Author: vicennial <venkata.gud...@databricks.com>
AuthorDate: Tue Aug 8 17:30:16 2023 +0900

    [SPARK-44657][CONNECT] Fix incorrect limit handling in 
ArrowBatchWithSchemaIterator and config parsing of 
CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
    
    ### What changes were proposed in this pull request?
    
    Fixes the limit checking of `maxEstimatedBatchSize` and 
`maxRecordsPerBatch` to respect the more restrictive limit and fixes the config 
parsing of `CONNECT_GRPC_ARROW_MAX_BATCH_SIZE` by converting the value to bytes.
    
    ### Why are the changes needed?
    
    Bugfix.
    In the arrow writer 
[code](https://github.com/apache/spark/blob/6161bf44f40f8146ea4c115c788fd4eaeb128769/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala#L154-L163)
 , the conditions don’t seem to hold what the documentation says regd 
"maxBatchSize and maxRecordsPerBatch, respect whatever smaller" since it seems 
to actually respect the conf which is "larger" (i.e less restrictive) due to || 
operator.
    
    Further, when the `CONNECT_GRPC_ARROW_MAX_BATCH_SIZE` conf is read, the 
value is not converted to bytes from MiB 
([example](https://github.com/apache/spark/blob/3e5203c64c06cc8a8560dfa0fb6f52e74589b583/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala#L103)).
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #42321 from vicennial/SPARK-44657.
    
    Authored-by: vicennial <venkata.gud...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../apache/spark/sql/connect/config/Connect.scala  | 10 ++--
 .../connect/planner/SparkConnectServiceSuite.scala | 61 ++++++++++++++++++++++
 .../sql/execution/arrow/ArrowConverters.scala      | 21 +++++---
 3 files changed, 79 insertions(+), 13 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 64c2d6f1cb6..e25cb5cbab2 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -49,12 +49,12 @@ object Connect {
   val CONNECT_GRPC_ARROW_MAX_BATCH_SIZE =
     ConfigBuilder("spark.connect.grpc.arrow.maxBatchSize")
       .doc(
-        "When using Apache Arrow, limit the maximum size of one arrow batch 
that " +
-          "can be sent from server side to client side. Currently, we 
conservatively use 70% " +
-          "of it because the size is not accurate but estimated.")
+        "When using Apache Arrow, limit the maximum size of one arrow batch, 
in bytes unless " +
+          "otherwise specified, that can be sent from server side to client 
side. Currently, we " +
+          "conservatively use 70% of it because the size is not accurate but 
estimated.")
       .version("3.4.0")
-      .bytesConf(ByteUnit.MiB)
-      .createWithDefaultString("4m")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefault(4 * 1024 * 1024)
 
   val CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE =
     ConfigBuilder("spark.connect.grpc.maxInboundMessageSize")
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index e833d12c4f5..285f3103b19 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -238,6 +238,67 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
     }
   }
 
+  test("SPARK-44657: Arrow batches respect max batch size limit") {
+    // Set 10 KiB as the batch size limit
+    val batchSize = 10 * 1024
+    withSparkConf("spark.connect.grpc.arrow.maxBatchSize" -> 
batchSize.toString) {
+      // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
+      assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
+      val instance = new SparkConnectService(false)
+      val connect = new MockRemoteSession()
+      val context = proto.UserContext
+        .newBuilder()
+        .setUserId("c1")
+        .build()
+      val plan = proto.Plan
+        .newBuilder()
+        .setRoot(connect.sql("select * from range(0, 15000, 1, 1)"))
+        .build()
+      val request = proto.ExecutePlanRequest
+        .newBuilder()
+        .setPlan(plan)
+        .setUserContext(context)
+        .setSessionId(UUID.randomUUID.toString())
+        .build()
+
+      // Execute plan.
+      @volatile var done = false
+      val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
+      instance.executePlan(
+        request,
+        new StreamObserver[proto.ExecutePlanResponse] {
+          override def onNext(v: proto.ExecutePlanResponse): Unit = {
+            responses += v
+          }
+
+          override def onError(throwable: Throwable): Unit = {
+            throw throwable
+          }
+
+          override def onCompleted(): Unit = {
+            done = true
+          }
+        })
+      // The current implementation is expected to be blocking. This is here 
to make sure it is.
+      assert(done)
+
+      // 1 schema + 1 metric + at least 2 data batches
+      assert(responses.size > 3)
+
+      val allocator = new RootAllocator()
+
+      // Check the 'data' batches
+      responses.tail.dropRight(1).foreach { response =>
+        assert(response.hasArrowBatch)
+        val batch = response.getArrowBatch
+        assert(batch.getData != null)
+        // Batch size must be <= 70% since we intentionally use this 
multiplier for the size
+        // estimator.
+        assert(batch.getData.size() <= batchSize * 0.7)
+      }
+    }
+  }
+
   gridTest("SPARK-43923: commands send events")(
     Seq(
       proto.Command
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 59d931bbe48..86dd7984b58 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -150,17 +150,22 @@ private[sql] object ArrowConverters extends Logging {
         // Always write the schema.
         MessageSerializer.serialize(writeChannel, arrowSchema)
 
+        def isBatchSizeLimitExceeded: Boolean = {
+          // If `maxEstimatedBatchSize` is zero or negative, it implies 
unlimited.
+          maxEstimatedBatchSize > 0 && estimatedBatchSize >= 
maxEstimatedBatchSize
+        }
+        def isRecordLimitExceeded: Boolean = {
+          // If `maxRecordsPerBatch` is zero or negative, it implies unlimited.
+          maxRecordsPerBatch > 0 && rowCountInLastBatch >= maxRecordsPerBatch
+        }
         // Always write the first row.
         while (rowIter.hasNext && (
-          // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller.
           // If the size in bytes is positive (set properly), always write the 
first row.
-          rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 ||
-            // If the size in bytes of rows are 0 or negative, unlimit it.
-            estimatedBatchSize <= 0 ||
-            estimatedBatchSize < maxEstimatedBatchSize ||
-            // If the size of rows are 0 or negative, unlimit it.
-            maxRecordsPerBatch <= 0 ||
-            rowCountInLastBatch < maxRecordsPerBatch)) {
+          (rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0) ||
+            // If either limit is hit, create a batch. This implies that the 
limit that is hit first
+            // triggers the creation of a batch even if the other limit is not 
yet hit, hence
+            // preferring the more restrictive limit.
+            (!isBatchSizeLimitExceeded && !isRecordLimitExceeded))) {
           val row = rowIter.next()
           arrowWriter.write(row)
           estimatedBatchSize += (row match {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to