This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 1ed01a5c59 [GLUTEN-11133][VL] Refactor batch serialization API to 
defer the buffer copy from C++ code to Java code (#11127)
1ed01a5c59 is described below

commit 1ed01a5c5950c44d3c214498cf54f3194ef896dc
Author: Hongze Zhang <[email protected]>
AuthorDate: Fri Nov 21 10:30:42 2025 +0000

    [GLUTEN-11133][VL] Refactor batch serialization API to defer the buffer 
copy from C++ code to Java code (#11127)
---
 .../spark/sql/execution/BroadcastUtils.scala       |  7 ++-
 .../execution/ColumnarCachedBatchSerializer.scala  | 22 +++++---
 ...fferArray.scala => UnsafeByteBufferArray.scala} | 50 ++++++++---------
 .../unsafe/UnsafeColumnarBuildSideRelation.scala   | 52 ++++++++---------
 .../UnsafeColumnarBuildSideRelationTest.scala      |  4 +-
 cpp/core/jni/JniWrapper.cc                         | 33 +++++++----
 .../operators/serializer/ColumnarBatchSerializer.h |  7 ++-
 .../serializer/VeloxColumnarBatchSerializer.cc     | 44 +++++++++------
 .../serializer/VeloxColumnarBatchSerializer.h      |  9 ++-
 .../tests/VeloxColumnarBatchSerializerTest.cc      |  5 +-
 .../memory/arrow/alloc/ArrowBufferAllocators.java  | 14 +++++
 .../ColumnarBatchSerializerJniWrapper.java         |  4 +-
 .../spark/unsafe/memory/UnsafeByteBuffer.java      | 65 ++++++++++++++++++++++
 .../apache/spark/memory/GlobalOffHeapMemory.scala  |  2 +-
 14 files changed, 218 insertions(+), 100 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
index 342c9694f0..5fc9cc4464 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
@@ -175,13 +175,18 @@ object BroadcastUtils {
           val handle = 
ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, b)
           numRows += b.numRows()
           try {
-            ColumnarBatchSerializerJniWrapper
+            val unsafeBuffer = ColumnarBatchSerializerJniWrapper
               .create(
                 Runtimes
                   .contextInstance(
                     BackendsApiManager.getBackendName,
                     "BroadcastUtils#serializeStream"))
               .serialize(handle)
+            try {
+              unsafeBuffer.toByteArray
+            } finally {
+              unsafeBuffer.close()
+            }
           } finally {
             ColumnarBatches.release(b)
           }
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
index 80e5039e76..80ff907e61 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
@@ -172,15 +172,19 @@ class ColumnarCachedBatchSerializer extends 
CachedBatchSerializer with Logging {
 
           override def next(): CachedBatch = {
             val batch = veloxBatches.next()
-            val results =
-              ColumnarBatchSerializerJniWrapper
-                .create(
-                  Runtimes.contextInstance(
-                    BackendsApiManager.getBackendName,
-                    "ColumnarCachedBatchSerializer#serialize"))
-                .serialize(
-                  
ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, batch))
-            CachedColumnarBatch(batch.numRows(), results.length, results)
+            val unsafeBuffer = ColumnarBatchSerializerJniWrapper
+              .create(
+                Runtimes.contextInstance(
+                  BackendsApiManager.getBackendName,
+                  "ColumnarCachedBatchSerializer#serialize"))
+              
.serialize(ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, 
batch))
+            val bytes =
+              try {
+                unsafeBuffer.toByteArray
+              } finally {
+                unsafeBuffer.close()
+              }
+            CachedColumnarBatch(batch.numRows(), bytes.length, bytes)
           }
         }
     }
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeByteBufferArray.scala
similarity index 69%
rename from 
backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala
rename to 
backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeByteBufferArray.scala
index c0427c4407..45ff118d6d 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeByteBufferArray.scala
@@ -29,18 +29,18 @@ import org.apache.spark.unsafe.memory.MemoryAllocator
  *
  * @param arraySize
  *   underlying array[array[byte]]'s length
- * @param bytesBufferLengths
- *   underlying array[array[byte]] per bytesBuffer length
+ * @param byteBufferLengths
+ *   underlying array[array[byte]] per byteBuffer length
  * @param totalBytes
- *   all bytesBuffer's length plus together
+ *   all byteBuffer's length plus together
  */
 // scalastyle:off no.finalize
 @Experimental
-case class UnsafeBytesBufferArray(arraySize: Int, bytesBufferLengths: 
Array[Int], totalBytes: Long)
+case class UnsafeByteBufferArray(arraySize: Int, byteBufferLengths: 
Array[Int], totalBytes: Long)
   extends Logging {
   {
     assert(
-      arraySize == bytesBufferLengths.length,
+      arraySize == byteBufferLengths.length,
       "Unsafe buffer array size " +
         "not equal to buffer lengths!")
     assert(totalBytes >= 0, "Unsafe buffer array total bytes can't be 
negative!")
@@ -48,29 +48,29 @@ case class UnsafeBytesBufferArray(arraySize: Int, 
bytesBufferLengths: Array[Int]
   private val allocatedBytes = (totalBytes + 7) / 8 * 8
 
   /**
-   * A single array to store all bytesBufferArray's value, it's inited once 
when first time get
+   * A single array to store all byteBufferArray's value, it's inited once 
when first time get
    * accessed.
    */
   private var longArray: LongArray = _
 
   /** Index the start of each byteBuffer's offset to underlying LongArray's 
initial position. */
-  private val bytesBufferOffset = if (bytesBufferLengths.isEmpty) {
+  private val byteBufferOffset = if (byteBufferLengths.isEmpty) {
     new Array(0)
   } else {
-    bytesBufferLengths.init.scanLeft(0L)(_ + _)
+    byteBufferLengths.init.scanLeft(0L)(_ + _)
   }
 
   /**
-   * Put bytesBuffer at specified array index.
+   * Put byteBuffer at specified array index.
    *
    * @param index
    *   index of the array.
-   * @param bytesBuffer
-   *   bytesBuffer to put.
+   * @param byteBuffer
+   *   byteBuffer to put.
    */
-  def putBytesBuffer(index: Int, bytesBuffer: Array[Byte]): Unit = 
this.synchronized {
+  def putByteBuffer(index: Int, byteBuffer: Array[Byte]): Unit = 
this.synchronized {
     assert(index < arraySize)
-    assert(bytesBuffer.length == bytesBufferLengths(index))
+    assert(byteBuffer.length == byteBufferLengths(index))
     // first to allocate underlying long array
     if (null == longArray && index == 0) {
       GlobalOffHeapMemory.acquire(allocatedBytes)
@@ -78,45 +78,45 @@ case class UnsafeBytesBufferArray(arraySize: Int, 
bytesBufferLengths: Array[Int]
     }
 
     Platform.copyMemory(
-      bytesBuffer,
+      byteBuffer,
       Platform.BYTE_ARRAY_OFFSET,
       longArray.getBaseObject,
-      longArray.getBaseOffset + bytesBufferOffset(index),
-      bytesBufferLengths(index))
+      longArray.getBaseOffset + byteBufferOffset(index),
+      byteBufferLengths(index))
   }
 
   /**
-   * Get bytesBuffer at specified index.
+   * Get byteBuffer at specified index.
    * @param index
    * @return
    */
-  def getBytesBuffer(index: Int): Array[Byte] = {
+  def getByteBuffer(index: Int): Array[Byte] = {
     assert(index < arraySize)
     if (null == longArray) {
       return new Array[Byte](0)
     }
-    val bytes = new Array[Byte](bytesBufferLengths(index))
+    val bytes = new Array[Byte](byteBufferLengths(index))
     Platform.copyMemory(
       longArray.getBaseObject,
-      longArray.getBaseOffset + bytesBufferOffset(index),
+      longArray.getBaseOffset + byteBufferOffset(index),
       bytes,
       Platform.BYTE_ARRAY_OFFSET,
-      bytesBufferLengths(index))
+      byteBufferLengths(index))
     bytes
   }
 
   /**
-   * Get the bytesBuffer memory address and length at specified index, usually 
used when read memory
+   * Get the byteBuffer memory address and length at specified index, usually 
used when read memory
    * direct from offheap.
    *
    * @param index
    * @return
    */
-  def getBytesBufferOffsetAndLength(index: Int): (Long, Int) = {
+  def getByteBufferOffsetAndLength(index: Int): (Long, Int) = {
     assert(index < arraySize)
     assert(longArray != null, "The broadcast data in offheap should not be 
null!")
-    val offset = longArray.getBaseOffset + bytesBufferOffset(index)
-    val length = bytesBufferLengths(index)
+    val offset = longArray.getBaseOffset + byteBufferOffset(index)
+    val length = byteBufferLengths(index)
     (offset, length)
   }
 
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
index 308834657a..ebbd8226fd 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
@@ -50,7 +50,7 @@ object UnsafeColumnarBuildSideRelation {
   // Keep constructors with BroadcastMode for compatibility
   def apply(
       output: Seq[Attribute],
-      batches: UnsafeBytesBufferArray,
+      batches: UnsafeByteBufferArray,
       mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
     val boundMode = mode match {
       case HashedRelationBroadcastMode(keys, isNullAware) =>
@@ -65,7 +65,7 @@ object UnsafeColumnarBuildSideRelation {
   }
   def apply(
       output: Seq[Attribute],
-      bytesBufferArray: Array[Array[Byte]],
+      byteBufferArray: Array[Array[Byte]],
       mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
     val boundMode = mode match {
       case HashedRelationBroadcastMode(keys, isNullAware) =>
@@ -78,7 +78,7 @@ object UnsafeColumnarBuildSideRelation {
     }
     new UnsafeColumnarBuildSideRelation(
       output,
-      bytesBufferArray,
+      byteBufferArray,
       BroadcastModeUtils.toSafe(boundMode)
     )
   }
@@ -97,7 +97,7 @@ object UnsafeColumnarBuildSideRelation {
 @Experimental
 case class UnsafeColumnarBuildSideRelation(
     private var output: Seq[Attribute],
-    private var batches: UnsafeBytesBufferArray,
+    private var batches: UnsafeByteBufferArray,
     var safeBroadcastMode: SafeBroadcastMode)
   extends BuildSideRelation
   with Externalizable
@@ -118,27 +118,27 @@ case class UnsafeColumnarBuildSideRelation(
 
   /** needed for serialization. */
   def this() = {
-    this(null, null.asInstanceOf[UnsafeBytesBufferArray], null)
+    this(null, null.asInstanceOf[UnsafeByteBufferArray], null)
   }
 
   def this(
       output: Seq[Attribute],
-      bytesBufferArray: Array[Array[Byte]],
+      byteBufferArray: Array[Array[Byte]],
       safeMode: SafeBroadcastMode
   ) = {
     this(
       output,
-      UnsafeBytesBufferArray(
-        bytesBufferArray.length,
-        bytesBufferArray.map(_.length),
-        bytesBufferArray.map(_.length.toLong).sum
+      UnsafeByteBufferArray(
+        byteBufferArray.length,
+        byteBufferArray.map(_.length),
+        byteBufferArray.map(_.length.toLong).sum
       ),
       safeMode
     )
-    val batchesSize = bytesBufferArray.length
+    val batchesSize = byteBufferArray.length
     for (i <- 0 until batchesSize) {
       // copy the bytes to off-heap memory.
-      batches.putBytesBuffer(i, bytesBufferArray(i))
+      batches.putByteBuffer(i, byteBufferArray(i))
     }
   }
 
@@ -146,10 +146,10 @@ case class UnsafeColumnarBuildSideRelation(
     out.writeObject(output)
     out.writeObject(safeBroadcastMode)
     out.writeInt(batches.arraySize)
-    out.writeObject(batches.bytesBufferLengths)
+    out.writeObject(batches.byteBufferLengths)
     out.writeLong(batches.totalBytes)
     for (i <- 0 until batches.arraySize) {
-      val bytes = batches.getBytesBuffer(i)
+      val bytes = batches.getByteBuffer(i)
       out.write(bytes)
     }
   }
@@ -158,10 +158,10 @@ case class UnsafeColumnarBuildSideRelation(
     kryo.writeObject(out, output.toList)
     kryo.writeClassAndObject(out, safeBroadcastMode)
     out.writeInt(batches.arraySize)
-    kryo.writeObject(out, batches.bytesBufferLengths)
+    kryo.writeObject(out, batches.byteBufferLengths)
     out.writeLong(batches.totalBytes)
     for (i <- 0 until batches.arraySize) {
-      val bytes = batches.getBytesBuffer(i)
+      val bytes = batches.getByteBuffer(i)
       out.write(bytes)
     }
   }
@@ -170,7 +170,7 @@ case class UnsafeColumnarBuildSideRelation(
     output = in.readObject().asInstanceOf[Seq[Attribute]]
     safeBroadcastMode = in.readObject().asInstanceOf[SafeBroadcastMode]
     val totalArraySize = in.readInt()
-    val bytesBufferLengths = in.readObject().asInstanceOf[Array[Int]]
+    val byteBufferLengths = in.readObject().asInstanceOf[Array[Int]]
     val totalBytes = in.readLong()
 
     // scalastyle:off
@@ -180,13 +180,13 @@ case class UnsafeColumnarBuildSideRelation(
      */
     // scalastyle:on
 
-    batches = UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths, 
totalBytes)
+    batches = UnsafeByteBufferArray(totalArraySize, byteBufferLengths, 
totalBytes)
 
     for (i <- 0 until totalArraySize) {
-      val length = bytesBufferLengths(i)
+      val length = byteBufferLengths(i)
       val tmpBuffer = new Array[Byte](length)
       in.readFully(tmpBuffer)
-      batches.putBytesBuffer(i, tmpBuffer)
+      batches.putByteBuffer(i, tmpBuffer)
     }
   }
 
@@ -194,16 +194,16 @@ case class UnsafeColumnarBuildSideRelation(
     output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]]
     safeBroadcastMode = 
kryo.readClassAndObject(in).asInstanceOf[SafeBroadcastMode]
     val totalArraySize = in.readInt()
-    val bytesBufferLengths = kryo.readObject(in, classOf[Array[Int]])
+    val byteBufferLengths = kryo.readObject(in, classOf[Array[Int]])
     val totalBytes = in.readLong()
 
-    batches = UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths, 
totalBytes)
+    batches = UnsafeByteBufferArray(totalArraySize, byteBufferLengths, 
totalBytes)
 
     for (i <- 0 until totalArraySize) {
-      val length = bytesBufferLengths(i)
+      val length = byteBufferLengths(i)
       val tmpBuffer = new Array[Byte](length)
       in.read(tmpBuffer)
-      batches.putBytesBuffer(i, tmpBuffer)
+      batches.putByteBuffer(i, tmpBuffer)
     }
   }
 
@@ -252,7 +252,7 @@ case class UnsafeColumnarBuildSideRelation(
 
         override def next: ColumnarBatch = {
           val (offset, length) =
-            batches.getBytesBufferOffsetAndLength(batchId)
+            batches.getByteBufferOffsetAndLength(batchId)
           batchId += 1
           val handle =
             jniWrapper.deserializeDirect(serializerHandle, offset, length)
@@ -309,7 +309,7 @@ case class UnsafeColumnarBuildSideRelation(
         }
 
         override def next(): Iterator[InternalRow] = {
-          val (offset, length) = batches.getBytesBufferOffsetAndLength(batchId)
+          val (offset, length) = batches.getByteBufferOffsetAndLength(batchId)
           batchId += 1
           val batchHandle =
             serializerJniWrapper.deserializeDirect(serializerHandle, offset, 
length)
diff --git 
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
 
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
index 1017fd0723..903bb100b7 100644
--- 
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
+++ 
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
@@ -41,12 +41,12 @@ class UnsafeColumnarBuildSideRelationTest extends 
SharedSparkSession {
     val totalArraySize = 1
     val perArraySize = new Array[Int](totalArraySize)
     perArraySize(0) = 10
-    val bytesArray = UnsafeBytesBufferArray(
+    val bytesArray = UnsafeByteBufferArray(
       1,
       perArraySize,
       10
     )
-    bytesArray.putBytesBuffer(0, "1234567890".getBytes())
+    bytesArray.putByteBuffer(0, "1234567890".getBytes())
     unsafeRelWithIdentityMode = UnsafeColumnarBuildSideRelation(
       output,
       bytesArray,
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index 5bc311f782..f3b86f8ff6 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -44,6 +44,11 @@ namespace {
 
 jclass byteArrayClass;
 
+jclass unsafeByteBufferClass;
+jmethodID unsafeByteBufferAllocate;
+jmethodID unsafeByteBufferAddress;
+jmethodID unsafeByteBufferSize;
+
 jclass jniByteInputStreamClass;
 jmethodID jniByteInputStreamRead;
 jmethodID jniByteInputStreamTell;
@@ -245,6 +250,12 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
 
   byteArrayClass = createGlobalClassReferenceOrError(env, "[B");
 
+  unsafeByteBufferClass = createGlobalClassReferenceOrError(env, 
"Lorg/apache/spark/unsafe/memory/UnsafeByteBuffer;");
+  unsafeByteBufferAllocate =
+      env->GetStaticMethodID(unsafeByteBufferClass, "allocate", 
"(J)Lorg/apache/spark/unsafe/memory/UnsafeByteBuffer;");
+  unsafeByteBufferAddress = env->GetMethodID(unsafeByteBufferClass, "address", 
"()J");
+  unsafeByteBufferSize = env->GetMethodID(unsafeByteBufferClass, "size", 
"()J");
+
   jniByteInputStreamClass = createGlobalClassReferenceOrError(env, 
"Lorg/apache/gluten/vectorized/JniByteInputStream;");
   jniByteInputStreamRead = getMethodIdOrError(env, jniByteInputStreamClass, 
"read", "(JJ)J");
   jniByteInputStreamTell = getMethodIdOrError(env, jniByteInputStreamClass, 
"tell", "()J");
@@ -287,6 +298,7 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) {
   env->DeleteGlobalRef(splitResultClass);
   env->DeleteGlobalRef(nativeColumnarToRowInfoClass);
   env->DeleteGlobalRef(byteArrayClass);
+  env->DeleteGlobalRef(unsafeByteBufferClass);
   env->DeleteGlobalRef(shuffleReaderMetricsClass);
 
   getJniErrorState()->close();
@@ -1139,28 +1151,25 @@ JNIEXPORT void JNICALL 
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper
   JNI_METHOD_END()
 }
 
-JNIEXPORT jbyteArray JNICALL 
Java_org_apache_gluten_vectorized_ColumnarBatchSerializerJniWrapper_serialize( 
// NOLINT
+JNIEXPORT jobject JNICALL 
Java_org_apache_gluten_vectorized_ColumnarBatchSerializerJniWrapper_serialize( 
// NOLINT
     JNIEnv* env,
     jobject wrapper,
     jlong handle) {
   JNI_METHOD_START
   auto ctx = getRuntime(env, wrapper);
 
-  std::vector<std::shared_ptr<ColumnarBatch>> batches;
   auto batch = ObjectStore::retrieve<ColumnarBatch>(handle);
   GLUTEN_DCHECK(batch != nullptr, "Cannot find the ColumnarBatch with handle " 
+ std::to_string(handle));
-  batches.emplace_back(batch);
 
   auto serializer = ctx->createColumnarBatchSerializer(nullptr);
-  auto buffer = serializer->serializeColumnarBatches(batches);
-  auto bufferArr = env->NewByteArray(buffer->size());
-  GLUTEN_CHECK(
-      bufferArr != nullptr,
-      "Cannot construct a byte array of size " + 
std::to_string(buffer->size()) +
-          " byte(s) to serialize columnar batches");
-  env->SetByteArrayRegion(bufferArr, 0, buffer->size(), reinterpret_cast<const 
jbyte*>(buffer->data()));
-
-  return bufferArr;
+  serializer->append(batch);
+  auto serializedSize = serializer->maxSerializedSize();
+  auto byteBuffer = env->CallStaticObjectMethod(unsafeByteBufferClass, 
unsafeByteBufferAllocate, serializedSize);
+  auto byteBufferAddress = env->CallLongMethod(byteBuffer, 
unsafeByteBufferAddress);
+  auto byteBufferSize = env->CallLongMethod(byteBuffer, unsafeByteBufferSize);
+  serializer->serializeTo(reinterpret_cast<uint8_t*>(byteBufferAddress), 
byteBufferSize);
+
+  return byteBuffer;
   JNI_METHOD_END(nullptr)
 }
 
diff --git a/cpp/core/operators/serializer/ColumnarBatchSerializer.h 
b/cpp/core/operators/serializer/ColumnarBatchSerializer.h
index 25d23aacb5..08a76f9f23 100644
--- a/cpp/core/operators/serializer/ColumnarBatchSerializer.h
+++ b/cpp/core/operators/serializer/ColumnarBatchSerializer.h
@@ -29,8 +29,11 @@ class ColumnarBatchSerializer {
 
   virtual ~ColumnarBatchSerializer() = default;
 
-  virtual std::shared_ptr<arrow::Buffer> serializeColumnarBatches(
-      const std::vector<std::shared_ptr<ColumnarBatch>>& batches) = 0;
+  virtual void append(const std::shared_ptr<ColumnarBatch>& batch) = 0;
+
+  virtual int64_t maxSerializedSize() = 0;
+
+  virtual void serializeTo(uint8_t* address, int64_t size) = 0;
 
   virtual std::shared_ptr<ColumnarBatch> deserialize(uint8_t* data, int32_t 
size) = 0;
 
diff --git a/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.cc 
b/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.cc
index a35316702a..c12259420a 100644
--- a/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.cc
+++ b/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.cc
@@ -51,33 +51,41 @@ VeloxColumnarBatchSerializer::VeloxColumnarBatchSerializer(
     rowType_ = asRowType(importFromArrow(*cSchema));
     ArrowSchemaRelease(cSchema); // otherwise the c schema leaks memory
   }
+  arena_ = std::make_unique<StreamArena>(veloxPool_.get());
   serde_ = std::make_unique<serializer::presto::PrestoVectorSerde>();
   options_.useLosslessTimestamp = true;
 }
 
-std::shared_ptr<arrow::Buffer> 
VeloxColumnarBatchSerializer::serializeColumnarBatches(
-    const std::vector<std::shared_ptr<ColumnarBatch>>& batches) {
-  VELOX_DCHECK(batches.size() != 0, "Should serialize at least 1 vector");
-  const std::shared_ptr<VeloxColumnarBatch>& vb = 
VeloxColumnarBatch::from(veloxPool_.get(), batches[0]);
-  auto firstRowVector = vb->getRowVector();
-  auto numRows = firstRowVector->size();
-  auto arena = std::make_unique<StreamArena>(veloxPool_.get());
-  auto rowType = asRowType(firstRowVector->type());
-  auto serializer = serde_->createIterativeSerializer(rowType, numRows, 
arena.get(), &options_);
-  for (auto& batch : batches) {
-    auto rowVector = VeloxColumnarBatch::from(veloxPool_.get(), 
batch)->getRowVector();
-    const IndexRange allRows{0, rowVector->size()};
-    serializer->append(rowVector, folly::Range(&allRows, 1));
+void VeloxColumnarBatchSerializer::append(const 
std::shared_ptr<ColumnarBatch>& batch) {
+  auto rowVector = VeloxColumnarBatch::from(veloxPool_.get(), 
batch)->getRowVector();
+  if (serializer_ == nullptr) {
+    // Using first batch's schema to create the Velox serializer. This logic 
was introduced in
+    // https://github.com/apache/incubator-gluten/pull/1568. It's a bit 
suboptimal because the schemas
+    // across different batches may vary.
+    auto numRows = rowVector->size();
+    auto rowType = asRowType(rowVector->type());
+    serializer_ = serde_->createIterativeSerializer(rowType, numRows, 
arena_.get(), &options_);
   }
+  const IndexRange allRows{0, rowVector->size()};
+  serializer_->append(rowVector, folly::Range(&allRows, 1));
+}
+
+int64_t VeloxColumnarBatchSerializer::maxSerializedSize() {
+  VELOX_DCHECK(serializer_ != nullptr, "Should serialize at least 1 vector");
+  return serializer_->maxSerializedSize();
+}
 
-  std::shared_ptr<arrow::Buffer> valueBuffer;
-  GLUTEN_ASSIGN_OR_THROW(valueBuffer, 
arrow::AllocateResizableBuffer(serializer->maxSerializedSize(), arrowPool_));
+void VeloxColumnarBatchSerializer::serializeTo(uint8_t* address, int64_t size) 
{
+  VELOX_DCHECK(serializer_ != nullptr, "Should serialize at least 1 vector");
+  auto sizeNeeded = serializer_->maxSerializedSize();
+  GLUTEN_CHECK(
+      size >= sizeNeeded,
+      "The target buffer size is insufficient: " + std::to_string(size) + " 
vs." + std::to_string(sizeNeeded));
+  std::shared_ptr<arrow::MutableBuffer> valueBuffer = 
std::make_shared<arrow::MutableBuffer>(address, size);
   auto output = 
std::make_shared<arrow::io::FixedSizeBufferWriter>(valueBuffer);
   serializer::presto::PrestoOutputStreamListener listener;
   ArrowFixedSizeBufferOutputStream out(output, &listener);
-  serializer->flush(&out);
-  GLUTEN_THROW_NOT_OK(output->Close());
-  return valueBuffer;
+  serializer_->flush(&out);
 }
 
 std::shared_ptr<ColumnarBatch> 
VeloxColumnarBatchSerializer::deserialize(uint8_t* data, int32_t size) {
diff --git a/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.h 
b/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.h
index 18539a2450..73fec890cf 100644
--- a/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.h
+++ b/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.h
@@ -32,13 +32,18 @@ class VeloxColumnarBatchSerializer final : public 
ColumnarBatchSerializer {
       std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool,
       struct ArrowSchema* cSchema);
 
-  std::shared_ptr<arrow::Buffer> serializeColumnarBatches(
-      const std::vector<std::shared_ptr<ColumnarBatch>>& batches) override;
+  void append(const std::shared_ptr<ColumnarBatch>& batch) override;
+
+  int64_t maxSerializedSize() override;
+
+  void serializeTo(uint8_t* address, int64_t size) override;
 
   std::shared_ptr<ColumnarBatch> deserialize(uint8_t* data, int32_t size) 
override;
 
  private:
   std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool_;
+  std::unique_ptr<facebook::velox::StreamArena> arena_;
+  std::unique_ptr<facebook::velox::IterativeVectorSerializer> serializer_;
   facebook::velox::RowTypePtr rowType_;
   std::unique_ptr<facebook::velox::serializer::presto::PrestoVectorSerde> 
serde_;
   facebook::velox::serializer::presto::PrestoVectorSerde::PrestoOptions 
options_;
diff --git a/cpp/velox/tests/VeloxColumnarBatchSerializerTest.cc 
b/cpp/velox/tests/VeloxColumnarBatchSerializerTest.cc
index 7833b0cfea..35c18d8ec3 100644
--- a/cpp/velox/tests/VeloxColumnarBatchSerializerTest.cc
+++ b/cpp/velox/tests/VeloxColumnarBatchSerializerTest.cc
@@ -62,7 +62,10 @@ TEST_F(VeloxColumnarBatchSerializerTest, serialize) {
   auto vector = makeRowVector(children);
   auto batch = std::make_shared<VeloxColumnarBatch>(vector);
   auto serializer = std::make_shared<VeloxColumnarBatchSerializer>(arrowPool, 
pool_, nullptr);
-  auto buffer = serializer->serializeColumnarBatches({batch});
+  serializer->append(batch);
+  std::shared_ptr<arrow::Buffer> buffer;
+  GLUTEN_ASSIGN_OR_THROW(buffer, 
arrow::AllocateResizableBuffer(serializer->maxSerializedSize(), arrowPool));
+  
serializer->serializeTo(reinterpret_cast<uint8_t*>(buffer->mutable_address()), 
buffer->size());
 
   ArrowSchema cSchema;
   exportToArrow(vector, cSchema, ArrowUtils::getBridgeOptions());
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java
 
b/gluten-arrow/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java
index 3d13ccc613..7d72a263ac 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java
+++ 
b/gluten-arrow/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java
@@ -17,12 +17,14 @@
 package org.apache.gluten.memory.arrow.alloc;
 
 import org.apache.gluten.config.GlutenConfig;
+import org.apache.gluten.memory.SimpleMemoryUsageRecorder;
 import org.apache.gluten.memory.memtarget.MemoryTargets;
 import org.apache.gluten.memory.memtarget.Spillers;
 
 import org.apache.arrow.memory.AllocationListener;
 import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.memory.RootAllocator;
+import org.apache.spark.memory.GlobalOffHeapMemory;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.task.TaskResource;
 import org.apache.spark.task.TaskResources;
@@ -34,6 +36,14 @@ import java.util.List;
 import java.util.Vector;
 
 public class ArrowBufferAllocators {
+  private static final SimpleMemoryUsageRecorder GLOBAL_USAGE = new 
SimpleMemoryUsageRecorder();
+  private static final BufferAllocator GLOBAL_INSTANCE;
+
+  static {
+    final AllocationListener listener =
+        new ManagedAllocationListener(GlobalOffHeapMemory.target(), 
GLOBAL_USAGE);
+    GLOBAL_INSTANCE = new RootAllocator(listener, Long.MAX_VALUE);
+  }
 
   private ArrowBufferAllocators() {}
 
@@ -51,6 +61,10 @@ public class ArrowBufferAllocators {
         .managed;
   }
 
+  public static BufferAllocator globalInstance() {
+    return GLOBAL_INSTANCE;
+  }
+
   public static class ArrowBufferAllocatorManager implements TaskResource {
     private static Logger LOGGER = 
LoggerFactory.getLogger(ArrowBufferAllocatorManager.class);
     private static final List<BufferAllocator> LEAKED = new Vector<>();
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
index 97ab10082e..00fe02c629 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
+++ 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
@@ -19,6 +19,8 @@ package org.apache.gluten.vectorized;
 import org.apache.gluten.runtime.Runtime;
 import org.apache.gluten.runtime.RuntimeAware;
 
+import org.apache.spark.unsafe.memory.UnsafeByteBuffer;
+
 public class ColumnarBatchSerializerJniWrapper implements RuntimeAware {
   private final Runtime runtime;
 
@@ -35,7 +37,7 @@ public class ColumnarBatchSerializerJniWrapper implements 
RuntimeAware {
     return runtime.getHandle();
   }
 
-  public native byte[] serialize(long handle);
+  public native UnsafeByteBuffer serialize(long handle);
 
   // Return the native ColumnarBatchSerializer handle
   public native long init(long cSchema);
diff --git 
a/gluten-arrow/src/main/java/org/apache/spark/unsafe/memory/UnsafeByteBuffer.java
 
b/gluten-arrow/src/main/java/org/apache/spark/unsafe/memory/UnsafeByteBuffer.java
new file mode 100644
index 0000000000..1cd598ab6c
--- /dev/null
+++ 
b/gluten-arrow/src/main/java/org/apache/spark/unsafe/memory/UnsafeByteBuffer.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.unsafe.memory;
+
+import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators;
+
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.spark.task.TaskResources;
+import org.apache.spark.unsafe.Platform;
+
+/** The API is for being called from C++ via JNI. */
+public class UnsafeByteBuffer {
+  private final ArrowBuf buffer;
+  private final long size;
+
+  private UnsafeByteBuffer(ArrowBuf buffer, long size) {
+    this.buffer = buffer;
+    this.size = size;
+  }
+
+  public static UnsafeByteBuffer allocate(long size) {
+    final BufferAllocator allocator;
+    if (TaskResources.inSparkTask()) {
+      allocator = 
ArrowBufferAllocators.contextInstance(UnsafeByteBuffer.class.getName());
+    } else {
+      allocator = ArrowBufferAllocators.globalInstance();
+    }
+    final ArrowBuf arrowBuf = allocator.buffer(size);
+    return new UnsafeByteBuffer(arrowBuf, size);
+  }
+
+  public long address() {
+    return buffer.memoryAddress();
+  }
+
+  public long size() {
+    return size;
+  }
+
+  public void close() {
+    buffer.close();
+  }
+
+  public byte[] toByteArray() {
+    final byte[] values = new byte[Math.toIntExact(size)];
+    Platform.copyMemory(
+        null, buffer.memoryAddress(), values, Platform.BYTE_ARRAY_OFFSET, 
values.length);
+    return values;
+  }
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala 
b/gluten-core/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
index c0988b8eca..b2915084dd 100644
--- 
a/gluten-core/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
+++ 
b/gluten-core/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
@@ -30,7 +30,7 @@ import org.apache.gluten.memory.memtarget.{MemoryTarget, 
NoopMemoryTarget}
  * BlockId to be extended by user, TestBlockId is chosen for the storage 
memory reservations.
  */
 object GlobalOffHeapMemory {
-  private val target: MemoryTarget = if (GlutenCoreConfig.get.memoryUntracked) 
{
+  val target: MemoryTarget = if (GlutenCoreConfig.get.memoryUntracked) {
     new NoopMemoryTarget()
   } else {
     new GlobalOffHeapMemoryTarget()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to