This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 4fec40e  refactor: Skipping slicing on shuffle arrays in shuffle 
reader (#189)
4fec40e is described below

commit 4fec40e5b81a6ef04e35be1ae8332bcfdf8597fe
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Mar 11 14:19:51 2024 -0700

    refactor: Skipping slicing on shuffle arrays in shuffle reader (#189)
    
    * refactor: Skipping slicing on shuffle arrays
    
    * Add note for columnar shuffle batch size.
---
 .../main/scala/org/apache/comet/CometConf.scala    |  4 ++-
 .../execution/shuffle/ArrowReaderIterator.scala    | 37 +++-------------------
 2 files changed, 8 insertions(+), 33 deletions(-)

diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala 
b/common/src/main/scala/org/apache/comet/CometConf.scala
index 1153b55..de49fdf 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -227,7 +227,9 @@ object CometConf {
   val COMET_COLUMNAR_SHUFFLE_BATCH_SIZE: ConfigEntry[Int] =
     conf("spark.comet.columnar.shuffle.batch.size")
       .internal()
-      .doc("Batch size when writing out sorted spill files on the native 
side.")
+      .doc("Batch size when writing out sorted spill files on the native side. 
Note that " +
+        "this should not be larger than batch size (i.e., 
`spark.comet.batchSize`). Otherwise " +
+        "it will produce larger batches than expected in the native operator 
after shuffle.")
       .intConf
       .createWithDefault(8192)
 
diff --git 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
index c17c5bc..e8dba93 100644
--- 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
+++ 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
@@ -21,22 +21,14 @@ package org.apache.spark.sql.comet.execution.shuffle
 
 import java.nio.channels.ReadableByteChannel
 
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
-import org.apache.comet.CometConf
-import org.apache.comet.vector.{NativeUtil, StreamReader}
+import org.apache.comet.vector.StreamReader
 
 class ArrowReaderIterator(channel: ReadableByteChannel) extends 
Iterator[ColumnarBatch] {
 
-  private val nativeUtil = new NativeUtil
-
-  private val maxBatchSize = CometConf.COMET_BATCH_SIZE.get(SQLConf.get)
-
   private val reader = StreamReader(channel)
-  private var currentIdx = -1
   private var batch = nextBatch()
-  private var previousBatch: ColumnarBatch = null
   private var currentBatch: ColumnarBatch = null
 
   override def hasNext: Boolean = {
@@ -57,40 +49,20 @@ class ArrowReaderIterator(channel: ReadableByteChannel) 
extends Iterator[Columna
     }
 
     val nextBatch = batch.get
-    val batchRows = nextBatch.numRows()
-    val numRows = Math.min(batchRows - currentIdx, maxBatchSize)
 
-    // Release the previous sliced batch.
+    // Release the previous batch.
     // If it is not released, when closing the reader, arrow library will 
complain about
     // memory leak.
     if (currentBatch != null) {
-      // Close plain arrays in the previous sliced batch.
-      // The dictionary arrays will be closed when closing the entire batch.
       currentBatch.close()
     }
 
-    currentBatch = nativeUtil.takeRows(nextBatch, currentIdx, numRows)
-    currentIdx += numRows
-
-    if (currentIdx == batchRows) {
-      // We cannot close the batch here, because if there is dictionary array 
in the batch,
-      // the dictionary array will be closed immediately, and the returned 
sliced batch will
-      // be invalid.
-      previousBatch = batch.get
-
-      batch = None
-      currentIdx = -1
-    }
-
+    currentBatch = nextBatch
+    batch = None
     currentBatch
   }
 
   private def nextBatch(): Option[ColumnarBatch] = {
-    if (previousBatch != null) {
-      previousBatch.close()
-      previousBatch = null
-    }
-    currentIdx = 0
     reader.nextBatch()
   }
 
@@ -98,6 +70,7 @@ class ArrowReaderIterator(channel: ReadableByteChannel) 
extends Iterator[Columna
     synchronized {
       if (currentBatch != null) {
         currentBatch.close()
+        currentBatch = null
       }
       reader.close()
     }

Reply via email to