This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 897dde7 chore: Add allocation source to StreamReader (#332)
897dde7 is described below
commit 897dde7f7d430b572659d57197af04c22526ad72
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Apr 26 14:55:05 2024 -0700
chore: Add allocation source to StreamReader (#332)
* chore: Add allocation source to StreamReader
* Use simple name
---
common/src/main/scala/org/apache/comet/vector/NativeUtil.scala | 1 +
common/src/main/scala/org/apache/comet/vector/StreamReader.scala | 3 ++-
.../spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala | 5 +++--
.../sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala | 2 +-
.../org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala | 3 ++-
spark/src/main/scala/org/apache/spark/sql/comet/operators.scala | 7 ++++---
6 files changed, 13 insertions(+), 8 deletions(-)
diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
index 763ccff..eb731f9 100644
--- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
+++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
@@ -33,6 +33,7 @@ class NativeUtil {
import Utils._
private val allocator = new RootAllocator(Long.MaxValue)
+ .newChildAllocator(this.getClass.getSimpleName, 0, Long.MaxValue)
private val dictionaryProvider: CDataDictionaryProvider = new
CDataDictionaryProvider
private val importer = new ArrowImporter(allocator)
diff --git a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
index 61d800b..4a08f05 100644
--- a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
+++ b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
@@ -30,8 +30,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
/**
* A reader that consumes Arrow data from an input channel, and produces Comet
batches.
*/
-case class StreamReader(channel: ReadableByteChannel) extends AutoCloseable {
+case class StreamReader(channel: ReadableByteChannel, source: String) extends
AutoCloseable {
private var allocator = new RootAllocator(Long.MaxValue)
+ .newChildAllocator(s"${this.getClass.getSimpleName}/$source", 0,
Long.MaxValue)
private val channelReader = new MessageChannelReader(new
ReadChannel(channel), allocator)
private var arrowReader = new ArrowStreamReader(channelReader, allocator)
private var root = arrowReader.getVectorSchemaRoot
diff --git
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
index 304c3ce..3c0fa15 100644
---
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
+++
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
@@ -25,9 +25,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.comet.vector._
-class ArrowReaderIterator(channel: ReadableByteChannel) extends
Iterator[ColumnarBatch] {
+class ArrowReaderIterator(channel: ReadableByteChannel, source: String)
+ extends Iterator[ColumnarBatch] {
- private val reader = StreamReader(channel)
+ private val reader = StreamReader(channel, source)
private var batch = nextBatch()
private var currentBatch: ColumnarBatch = null
diff --git
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
index b461b53..90e0bb1 100644
---
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
+++
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
@@ -108,7 +108,7 @@ class CometBlockStoreShuffleReader[K, C](
// Closes previous read iterator.
currentReadIterator.close()
}
- currentReadIterator = new ArrowReaderIterator(channel)
+ currentReadIterator = new ArrowReaderIterator(channel,
this.getClass.getSimpleName)
currentReadIterator.map((0, _)) // use 0 as key since it's not used
}
}
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index a8322be..06c5898 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -284,7 +284,8 @@ class CometBatchRDD(
override def compute(split: Partition, context: TaskContext):
Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometBatchPartition]
- partition.value.value.toIterator.flatMap(CometExec.decodeBatches)
+ partition.value.value.toIterator
+ .flatMap(CometExec.decodeBatches(_, this.getClass.getSimpleName))
}
}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index a857975..39ffef1 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -86,7 +86,8 @@ abstract class CometExec extends CometPlan {
val countsAndBytes = CometExec.getByteArrayRdd(this).collect()
val total = countsAndBytes.map(_._1).sum
val rows = countsAndBytes.iterator
- .flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2))
+ .flatMap(countAndBytes =>
+ CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName))
(total, rows)
}
}
@@ -126,7 +127,7 @@ object CometExec {
/**
* Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
*/
- def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = {
+ def decodeBatches(bytes: ChunkedByteBuffer, source: String):
Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}
@@ -135,7 +136,7 @@ object CometExec {
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))
- new ArrowReaderIterator(Channels.newChannel(ins))
+ new ArrowReaderIterator(Channels.newChannel(ins), source)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]