This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new ba9b84289 Hoist some stuff out of NativeBatchDecoderIterator into
CometBlockStoreShuffleReader that can be reused. (#3627)
ba9b84289 is described below
commit ba9b8428951f415d0d76558dd2ce18905acc64f3
Author: Matt Butrovich <[email protected]>
AuthorDate: Tue Mar 3 18:51:53 2026 -0500
Hoist some stuff out of NativeBatchDecoderIterator into
CometBlockStoreShuffleReader that can be reused. (#3627)
---
.../shuffle/CometBlockStoreShuffleReader.scala | 17 ++++++++++++++---
.../shuffle/NativeBatchDecoderIterator.scala | 21 ++++++---------------
2 files changed, 20 insertions(+), 18 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
index 1283a745a..e95eb92d2 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala
@@ -35,6 +35,9 @@ import org.apache.spark.storage.BlockManagerId
import org.apache.spark.storage.ShuffleBlockFetcherIterator
import org.apache.spark.util.CompletionIterator
+import org.apache.comet.{CometConf, Native}
+import org.apache.comet.vector.NativeUtil
+
/**
* Shuffle reader that reads data from the block manager. It reads
Arrow-serialized data (IPC
* format) and returns an iterator of ColumnarBatch.
@@ -86,8 +89,11 @@ class CometBlockStoreShuffleReader[K, C](
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
var currentReadIterator: NativeBatchDecoderIterator = null
+ val nativeLib = new Native()
+ val nativeUtil = new NativeUtil()
+ val tracingEnabled = CometConf.COMET_TRACING_ENABLED.get()
- // Closes last read iterator after the task is finished.
+ // Closes last read iterator and shared resources after the task is
finished.
// We need to close read iterator during iterating input streams,
// instead of one callback per read iterator. Otherwise if there are too
many
// read iterators, it may blow up the call stack and cause OOM.
@@ -95,6 +101,7 @@ class CometBlockStoreShuffleReader[K, C](
if (currentReadIterator != null) {
currentReadIterator.close()
}
+ nativeUtil.close()
}
val recordIter: Iterator[(Int, ColumnarBatch)] = fetchIterator
@@ -102,8 +109,12 @@ class CometBlockStoreShuffleReader[K, C](
if (currentReadIterator != null) {
currentReadIterator.close()
}
- currentReadIterator =
- NativeBatchDecoderIterator(blockIdAndStream._2, context,
dep.decodeTime)
+ currentReadIterator = NativeBatchDecoderIterator(
+ blockIdAndStream._2,
+ dep.decodeTime,
+ nativeLib,
+ nativeUtil,
+ tracingEnabled)
currentReadIterator
})
.map(b => (0, b))
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala
index 126db2c63..f96c8f16d 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala
@@ -23,11 +23,10 @@ import java.io.{EOFException, InputStream}
import java.nio.{ByteBuffer, ByteOrder}
import java.nio.channels.{Channels, ReadableByteChannel}
-import org.apache.spark.TaskContext
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.vectorized.ColumnarBatch
-import org.apache.comet.{CometConf, Native}
+import org.apache.comet.Native
import org.apache.comet.vector.NativeUtil
/**
@@ -37,26 +36,19 @@ import org.apache.comet.vector.NativeUtil
*/
case class NativeBatchDecoderIterator(
in: InputStream,
- taskContext: TaskContext,
- decodeTime: SQLMetric)
+ decodeTime: SQLMetric,
+ nativeLib: Native,
+ nativeUtil: NativeUtil,
+ tracingEnabled: Boolean)
extends Iterator[ColumnarBatch] {
private var isClosed = false
private val longBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN)
- private val native = new Native()
- private val nativeUtil = new NativeUtil()
- private val tracingEnabled = CometConf.COMET_TRACING_ENABLED.get()
private var currentBatch: ColumnarBatch = null
private var batch = fetchNext()
import NativeBatchDecoderIterator._
- if (taskContext != null) {
- taskContext.addTaskCompletionListener[Unit](_ => {
- close()
- })
- }
-
private val channel: ReadableByteChannel = if (in != null) {
Channels.newChannel(in)
} else {
@@ -163,7 +155,7 @@ case class NativeBatchDecoderIterator(
val batch = nativeUtil.getNextBatch(
fieldCount,
(arrayAddrs, schemaAddrs) => {
- native.decodeShuffleBlock(
+ nativeLib.decodeShuffleBlock(
dataBuf,
bytesToRead.toInt,
arrayAddrs,
@@ -183,7 +175,6 @@ case class NativeBatchDecoderIterator(
currentBatch = null
}
in.close()
- nativeUtil.close()
resetDataBuf()
isClosed = true
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]