This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 897dde7  chore: Add allocation source to StreamReader (#332)
897dde7 is described below

commit 897dde7f7d430b572659d57197af04c22526ad72
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Apr 26 14:55:05 2024 -0700

    chore: Add allocation source to StreamReader (#332)
    
    * chore: Add allocation source to StreamReader
    
    * Use simple name
---
 common/src/main/scala/org/apache/comet/vector/NativeUtil.scala     | 1 +
 common/src/main/scala/org/apache/comet/vector/StreamReader.scala   | 3 ++-
 .../spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala    | 5 +++--
 .../sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala | 2 +-
 .../org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala    | 3 ++-
 spark/src/main/scala/org/apache/spark/sql/comet/operators.scala    | 7 ++++---
 6 files changed, 13 insertions(+), 8 deletions(-)

diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala 
b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
index 763ccff..eb731f9 100644
--- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
+++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
@@ -33,6 +33,7 @@ class NativeUtil {
   import Utils._
 
   private val allocator = new RootAllocator(Long.MaxValue)
+    .newChildAllocator(this.getClass.getSimpleName, 0, Long.MaxValue)
   private val dictionaryProvider: CDataDictionaryProvider = new 
CDataDictionaryProvider
   private val importer = new ArrowImporter(allocator)
 
diff --git a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala 
b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
index 61d800b..4a08f05 100644
--- a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
+++ b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
@@ -30,8 +30,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
 /**
  * A reader that consumes Arrow data from an input channel, and produces Comet 
batches.
  */
-case class StreamReader(channel: ReadableByteChannel) extends AutoCloseable {
+case class StreamReader(channel: ReadableByteChannel, source: String) extends 
AutoCloseable {
   private var allocator = new RootAllocator(Long.MaxValue)
+    .newChildAllocator(s"${this.getClass.getSimpleName}/$source", 0, 
Long.MaxValue)
   private val channelReader = new MessageChannelReader(new 
ReadChannel(channel), allocator)
   private var arrowReader = new ArrowStreamReader(channelReader, allocator)
   private var root = arrowReader.getVectorSchemaRoot
diff --git 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
index 304c3ce..3c0fa15 100644
--- 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
+++ 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
@@ -25,9 +25,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import org.apache.comet.vector._
 
-class ArrowReaderIterator(channel: ReadableByteChannel) extends 
Iterator[ColumnarBatch] {
+class ArrowReaderIterator(channel: ReadableByteChannel, source: String)
+    extends Iterator[ColumnarBatch] {
 
-  private val reader = StreamReader(channel)
+  private val reader = StreamReader(channel, source)
   private var batch = nextBatch()
   private var currentBatch: ColumnarBatch = null
 
diff --git 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
index b461b53..90e0bb1 100644
--- 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
+++ 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
@@ -108,7 +108,7 @@ class CometBlockStoreShuffleReader[K, C](
               // Closes previous read iterator.
               currentReadIterator.close()
             }
-            currentReadIterator = new ArrowReaderIterator(channel)
+            currentReadIterator = new ArrowReaderIterator(channel, 
this.getClass.getSimpleName)
             currentReadIterator.map((0, _)) // use 0 as key since it's not used
           }
       }
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index a8322be..06c5898 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -284,7 +284,8 @@ class CometBatchRDD(
 
   override def compute(split: Partition, context: TaskContext): 
Iterator[ColumnarBatch] = {
     val partition = split.asInstanceOf[CometBatchPartition]
-    partition.value.value.toIterator.flatMap(CometExec.decodeBatches)
+    partition.value.value.toIterator
+      .flatMap(CometExec.decodeBatches(_, this.getClass.getSimpleName))
   }
 }
 
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index a857975..39ffef1 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -86,7 +86,8 @@ abstract class CometExec extends CometPlan {
     val countsAndBytes = CometExec.getByteArrayRdd(this).collect()
     val total = countsAndBytes.map(_._1).sum
     val rows = countsAndBytes.iterator
-      .flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2))
+      .flatMap(countAndBytes =>
+        CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName))
     (total, rows)
   }
 }
@@ -126,7 +127,7 @@ object CometExec {
   /**
    * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
    */
-  def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = {
+  def decodeBatches(bytes: ChunkedByteBuffer, source: String): 
Iterator[ColumnarBatch] = {
     if (bytes.size == 0) {
       return Iterator.empty
     }
@@ -135,7 +136,7 @@ object CometExec {
     val cbbis = bytes.toInputStream()
     val ins = new DataInputStream(codec.compressedInputStream(cbbis))
 
-    new ArrowReaderIterator(Channels.newChannel(ins))
+    new ArrowReaderIterator(Channels.newChannel(ins), source)
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to