This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 1ed01a5c59 [GLUTEN-11133][VL] Refactor batch serialization API to
defer the buffer copy from C++ code to Java code (#11127)
1ed01a5c59 is described below
commit 1ed01a5c5950c44d3c214498cf54f3194ef896dc
Author: Hongze Zhang <[email protected]>
AuthorDate: Fri Nov 21 10:30:42 2025 +0000
[GLUTEN-11133][VL] Refactor batch serialization API to defer the buffer
copy from C++ code to Java code (#11127)
---
.../spark/sql/execution/BroadcastUtils.scala | 7 ++-
.../execution/ColumnarCachedBatchSerializer.scala | 22 +++++---
...fferArray.scala => UnsafeByteBufferArray.scala} | 50 ++++++++---------
.../unsafe/UnsafeColumnarBuildSideRelation.scala | 52 ++++++++---------
.../UnsafeColumnarBuildSideRelationTest.scala | 4 +-
cpp/core/jni/JniWrapper.cc | 33 +++++++----
.../operators/serializer/ColumnarBatchSerializer.h | 7 ++-
.../serializer/VeloxColumnarBatchSerializer.cc | 44 +++++++++------
.../serializer/VeloxColumnarBatchSerializer.h | 9 ++-
.../tests/VeloxColumnarBatchSerializerTest.cc | 5 +-
.../memory/arrow/alloc/ArrowBufferAllocators.java | 14 +++++
.../ColumnarBatchSerializerJniWrapper.java | 4 +-
.../spark/unsafe/memory/UnsafeByteBuffer.java | 65 ++++++++++++++++++++++
.../apache/spark/memory/GlobalOffHeapMemory.scala | 2 +-
14 files changed, 218 insertions(+), 100 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
index 342c9694f0..5fc9cc4464 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
@@ -175,13 +175,18 @@ object BroadcastUtils {
val handle =
ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, b)
numRows += b.numRows()
try {
- ColumnarBatchSerializerJniWrapper
+ val unsafeBuffer = ColumnarBatchSerializerJniWrapper
.create(
Runtimes
.contextInstance(
BackendsApiManager.getBackendName,
"BroadcastUtils#serializeStream"))
.serialize(handle)
+ try {
+ unsafeBuffer.toByteArray
+ } finally {
+ unsafeBuffer.close()
+ }
} finally {
ColumnarBatches.release(b)
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
index 80e5039e76..80ff907e61 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
@@ -172,15 +172,19 @@ class ColumnarCachedBatchSerializer extends
CachedBatchSerializer with Logging {
override def next(): CachedBatch = {
val batch = veloxBatches.next()
- val results =
- ColumnarBatchSerializerJniWrapper
- .create(
- Runtimes.contextInstance(
- BackendsApiManager.getBackendName,
- "ColumnarCachedBatchSerializer#serialize"))
- .serialize(
-
ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, batch))
- CachedColumnarBatch(batch.numRows(), results.length, results)
+ val unsafeBuffer = ColumnarBatchSerializerJniWrapper
+ .create(
+ Runtimes.contextInstance(
+ BackendsApiManager.getBackendName,
+ "ColumnarCachedBatchSerializer#serialize"))
+
.serialize(ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName,
batch))
+ val bytes =
+ try {
+ unsafeBuffer.toByteArray
+ } finally {
+ unsafeBuffer.close()
+ }
+ CachedColumnarBatch(batch.numRows(), bytes.length, bytes)
}
}
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeByteBufferArray.scala
similarity index 69%
rename from
backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala
rename to
backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeByteBufferArray.scala
index c0427c4407..45ff118d6d 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeByteBufferArray.scala
@@ -29,18 +29,18 @@ import org.apache.spark.unsafe.memory.MemoryAllocator
*
* @param arraySize
* underlying array[array[byte]]'s length
- * @param bytesBufferLengths
- * underlying array[array[byte]] per bytesBuffer length
+ * @param byteBufferLengths
+ * underlying array[array[byte]] per byteBuffer length
* @param totalBytes
- * all bytesBuffer's length plus together
+ * all byteBuffer's length plus together
*/
// scalastyle:off no.finalize
@Experimental
-case class UnsafeBytesBufferArray(arraySize: Int, bytesBufferLengths:
Array[Int], totalBytes: Long)
+case class UnsafeByteBufferArray(arraySize: Int, byteBufferLengths:
Array[Int], totalBytes: Long)
extends Logging {
{
assert(
- arraySize == bytesBufferLengths.length,
+ arraySize == byteBufferLengths.length,
"Unsafe buffer array size " +
"not equal to buffer lengths!")
assert(totalBytes >= 0, "Unsafe buffer array total bytes can't be
negative!")
@@ -48,29 +48,29 @@ case class UnsafeBytesBufferArray(arraySize: Int,
bytesBufferLengths: Array[Int]
private val allocatedBytes = (totalBytes + 7) / 8 * 8
/**
- * A single array to store all bytesBufferArray's value, it's inited once
when first time get
+ * A single array to store all byteBufferArray's value, it's inited once
when first time get
* accessed.
*/
private var longArray: LongArray = _
/** Index the start of each byteBuffer's offset to underlying LongArray's
initial position. */
- private val bytesBufferOffset = if (bytesBufferLengths.isEmpty) {
+ private val byteBufferOffset = if (byteBufferLengths.isEmpty) {
new Array(0)
} else {
- bytesBufferLengths.init.scanLeft(0L)(_ + _)
+ byteBufferLengths.init.scanLeft(0L)(_ + _)
}
/**
- * Put bytesBuffer at specified array index.
+ * Put byteBuffer at specified array index.
*
* @param index
* index of the array.
- * @param bytesBuffer
- * bytesBuffer to put.
+ * @param byteBuffer
+ * byteBuffer to put.
*/
- def putBytesBuffer(index: Int, bytesBuffer: Array[Byte]): Unit =
this.synchronized {
+ def putByteBuffer(index: Int, byteBuffer: Array[Byte]): Unit =
this.synchronized {
assert(index < arraySize)
- assert(bytesBuffer.length == bytesBufferLengths(index))
+ assert(byteBuffer.length == byteBufferLengths(index))
// first to allocate underlying long array
if (null == longArray && index == 0) {
GlobalOffHeapMemory.acquire(allocatedBytes)
@@ -78,45 +78,45 @@ case class UnsafeBytesBufferArray(arraySize: Int,
bytesBufferLengths: Array[Int]
}
Platform.copyMemory(
- bytesBuffer,
+ byteBuffer,
Platform.BYTE_ARRAY_OFFSET,
longArray.getBaseObject,
- longArray.getBaseOffset + bytesBufferOffset(index),
- bytesBufferLengths(index))
+ longArray.getBaseOffset + byteBufferOffset(index),
+ byteBufferLengths(index))
}
/**
- * Get bytesBuffer at specified index.
+ * Get byteBuffer at specified index.
* @param index
* @return
*/
- def getBytesBuffer(index: Int): Array[Byte] = {
+ def getByteBuffer(index: Int): Array[Byte] = {
assert(index < arraySize)
if (null == longArray) {
return new Array[Byte](0)
}
- val bytes = new Array[Byte](bytesBufferLengths(index))
+ val bytes = new Array[Byte](byteBufferLengths(index))
Platform.copyMemory(
longArray.getBaseObject,
- longArray.getBaseOffset + bytesBufferOffset(index),
+ longArray.getBaseOffset + byteBufferOffset(index),
bytes,
Platform.BYTE_ARRAY_OFFSET,
- bytesBufferLengths(index))
+ byteBufferLengths(index))
bytes
}
/**
- * Get the bytesBuffer memory address and length at specified index, usually
used when read memory
+ * Get the byteBuffer memory address and length at specified index, usually
used when read memory
* direct from offheap.
*
* @param index
* @return
*/
- def getBytesBufferOffsetAndLength(index: Int): (Long, Int) = {
+ def getByteBufferOffsetAndLength(index: Int): (Long, Int) = {
assert(index < arraySize)
assert(longArray != null, "The broadcast data in offheap should not be
null!")
- val offset = longArray.getBaseOffset + bytesBufferOffset(index)
- val length = bytesBufferLengths(index)
+ val offset = longArray.getBaseOffset + byteBufferOffset(index)
+ val length = byteBufferLengths(index)
(offset, length)
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
index 308834657a..ebbd8226fd 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
@@ -50,7 +50,7 @@ object UnsafeColumnarBuildSideRelation {
// Keep constructors with BroadcastMode for compatibility
def apply(
output: Seq[Attribute],
- batches: UnsafeBytesBufferArray,
+ batches: UnsafeByteBufferArray,
mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
val boundMode = mode match {
case HashedRelationBroadcastMode(keys, isNullAware) =>
@@ -65,7 +65,7 @@ object UnsafeColumnarBuildSideRelation {
}
def apply(
output: Seq[Attribute],
- bytesBufferArray: Array[Array[Byte]],
+ byteBufferArray: Array[Array[Byte]],
mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
val boundMode = mode match {
case HashedRelationBroadcastMode(keys, isNullAware) =>
@@ -78,7 +78,7 @@ object UnsafeColumnarBuildSideRelation {
}
new UnsafeColumnarBuildSideRelation(
output,
- bytesBufferArray,
+ byteBufferArray,
BroadcastModeUtils.toSafe(boundMode)
)
}
@@ -97,7 +97,7 @@ object UnsafeColumnarBuildSideRelation {
@Experimental
case class UnsafeColumnarBuildSideRelation(
private var output: Seq[Attribute],
- private var batches: UnsafeBytesBufferArray,
+ private var batches: UnsafeByteBufferArray,
var safeBroadcastMode: SafeBroadcastMode)
extends BuildSideRelation
with Externalizable
@@ -118,27 +118,27 @@ case class UnsafeColumnarBuildSideRelation(
/** needed for serialization. */
def this() = {
- this(null, null.asInstanceOf[UnsafeBytesBufferArray], null)
+ this(null, null.asInstanceOf[UnsafeByteBufferArray], null)
}
def this(
output: Seq[Attribute],
- bytesBufferArray: Array[Array[Byte]],
+ byteBufferArray: Array[Array[Byte]],
safeMode: SafeBroadcastMode
) = {
this(
output,
- UnsafeBytesBufferArray(
- bytesBufferArray.length,
- bytesBufferArray.map(_.length),
- bytesBufferArray.map(_.length.toLong).sum
+ UnsafeByteBufferArray(
+ byteBufferArray.length,
+ byteBufferArray.map(_.length),
+ byteBufferArray.map(_.length.toLong).sum
),
safeMode
)
- val batchesSize = bytesBufferArray.length
+ val batchesSize = byteBufferArray.length
for (i <- 0 until batchesSize) {
// copy the bytes to off-heap memory.
- batches.putBytesBuffer(i, bytesBufferArray(i))
+ batches.putByteBuffer(i, byteBufferArray(i))
}
}
@@ -146,10 +146,10 @@ case class UnsafeColumnarBuildSideRelation(
out.writeObject(output)
out.writeObject(safeBroadcastMode)
out.writeInt(batches.arraySize)
- out.writeObject(batches.bytesBufferLengths)
+ out.writeObject(batches.byteBufferLengths)
out.writeLong(batches.totalBytes)
for (i <- 0 until batches.arraySize) {
- val bytes = batches.getBytesBuffer(i)
+ val bytes = batches.getByteBuffer(i)
out.write(bytes)
}
}
@@ -158,10 +158,10 @@ case class UnsafeColumnarBuildSideRelation(
kryo.writeObject(out, output.toList)
kryo.writeClassAndObject(out, safeBroadcastMode)
out.writeInt(batches.arraySize)
- kryo.writeObject(out, batches.bytesBufferLengths)
+ kryo.writeObject(out, batches.byteBufferLengths)
out.writeLong(batches.totalBytes)
for (i <- 0 until batches.arraySize) {
- val bytes = batches.getBytesBuffer(i)
+ val bytes = batches.getByteBuffer(i)
out.write(bytes)
}
}
@@ -170,7 +170,7 @@ case class UnsafeColumnarBuildSideRelation(
output = in.readObject().asInstanceOf[Seq[Attribute]]
safeBroadcastMode = in.readObject().asInstanceOf[SafeBroadcastMode]
val totalArraySize = in.readInt()
- val bytesBufferLengths = in.readObject().asInstanceOf[Array[Int]]
+ val byteBufferLengths = in.readObject().asInstanceOf[Array[Int]]
val totalBytes = in.readLong()
// scalastyle:off
@@ -180,13 +180,13 @@ case class UnsafeColumnarBuildSideRelation(
*/
// scalastyle:on
- batches = UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths,
totalBytes)
+ batches = UnsafeByteBufferArray(totalArraySize, byteBufferLengths,
totalBytes)
for (i <- 0 until totalArraySize) {
- val length = bytesBufferLengths(i)
+ val length = byteBufferLengths(i)
val tmpBuffer = new Array[Byte](length)
in.readFully(tmpBuffer)
- batches.putBytesBuffer(i, tmpBuffer)
+ batches.putByteBuffer(i, tmpBuffer)
}
}
@@ -194,16 +194,16 @@ case class UnsafeColumnarBuildSideRelation(
output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]]
safeBroadcastMode =
kryo.readClassAndObject(in).asInstanceOf[SafeBroadcastMode]
val totalArraySize = in.readInt()
- val bytesBufferLengths = kryo.readObject(in, classOf[Array[Int]])
+ val byteBufferLengths = kryo.readObject(in, classOf[Array[Int]])
val totalBytes = in.readLong()
- batches = UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths,
totalBytes)
+ batches = UnsafeByteBufferArray(totalArraySize, byteBufferLengths,
totalBytes)
for (i <- 0 until totalArraySize) {
- val length = bytesBufferLengths(i)
+ val length = byteBufferLengths(i)
val tmpBuffer = new Array[Byte](length)
in.read(tmpBuffer)
- batches.putBytesBuffer(i, tmpBuffer)
+ batches.putByteBuffer(i, tmpBuffer)
}
}
@@ -252,7 +252,7 @@ case class UnsafeColumnarBuildSideRelation(
override def next: ColumnarBatch = {
val (offset, length) =
- batches.getBytesBufferOffsetAndLength(batchId)
+ batches.getByteBufferOffsetAndLength(batchId)
batchId += 1
val handle =
jniWrapper.deserializeDirect(serializerHandle, offset, length)
@@ -309,7 +309,7 @@ case class UnsafeColumnarBuildSideRelation(
}
override def next(): Iterator[InternalRow] = {
- val (offset, length) = batches.getBytesBufferOffsetAndLength(batchId)
+ val (offset, length) = batches.getByteBufferOffsetAndLength(batchId)
batchId += 1
val batchHandle =
serializerJniWrapper.deserializeDirect(serializerHandle, offset,
length)
diff --git
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
index 1017fd0723..903bb100b7 100644
---
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
+++
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
@@ -41,12 +41,12 @@ class UnsafeColumnarBuildSideRelationTest extends
SharedSparkSession {
val totalArraySize = 1
val perArraySize = new Array[Int](totalArraySize)
perArraySize(0) = 10
- val bytesArray = UnsafeBytesBufferArray(
+ val bytesArray = UnsafeByteBufferArray(
1,
perArraySize,
10
)
- bytesArray.putBytesBuffer(0, "1234567890".getBytes())
+ bytesArray.putByteBuffer(0, "1234567890".getBytes())
unsafeRelWithIdentityMode = UnsafeColumnarBuildSideRelation(
output,
bytesArray,
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index 5bc311f782..f3b86f8ff6 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -44,6 +44,11 @@ namespace {
jclass byteArrayClass;
+jclass unsafeByteBufferClass;
+jmethodID unsafeByteBufferAllocate;
+jmethodID unsafeByteBufferAddress;
+jmethodID unsafeByteBufferSize;
+
jclass jniByteInputStreamClass;
jmethodID jniByteInputStreamRead;
jmethodID jniByteInputStreamTell;
@@ -245,6 +250,12 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
byteArrayClass = createGlobalClassReferenceOrError(env, "[B");
+ unsafeByteBufferClass = createGlobalClassReferenceOrError(env,
"Lorg/apache/spark/unsafe/memory/UnsafeByteBuffer;");
+ unsafeByteBufferAllocate =
+ env->GetStaticMethodID(unsafeByteBufferClass, "allocate",
"(J)Lorg/apache/spark/unsafe/memory/UnsafeByteBuffer;");
+ unsafeByteBufferAddress = env->GetMethodID(unsafeByteBufferClass, "address",
"()J");
+ unsafeByteBufferSize = env->GetMethodID(unsafeByteBufferClass, "size",
"()J");
+
jniByteInputStreamClass = createGlobalClassReferenceOrError(env,
"Lorg/apache/gluten/vectorized/JniByteInputStream;");
jniByteInputStreamRead = getMethodIdOrError(env, jniByteInputStreamClass,
"read", "(JJ)J");
jniByteInputStreamTell = getMethodIdOrError(env, jniByteInputStreamClass,
"tell", "()J");
@@ -287,6 +298,7 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) {
env->DeleteGlobalRef(splitResultClass);
env->DeleteGlobalRef(nativeColumnarToRowInfoClass);
env->DeleteGlobalRef(byteArrayClass);
+ env->DeleteGlobalRef(unsafeByteBufferClass);
env->DeleteGlobalRef(shuffleReaderMetricsClass);
getJniErrorState()->close();
@@ -1139,28 +1151,25 @@ JNIEXPORT void JNICALL
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper
JNI_METHOD_END()
}
-JNIEXPORT jbyteArray JNICALL
Java_org_apache_gluten_vectorized_ColumnarBatchSerializerJniWrapper_serialize(
// NOLINT
+JNIEXPORT jobject JNICALL
Java_org_apache_gluten_vectorized_ColumnarBatchSerializerJniWrapper_serialize(
// NOLINT
JNIEnv* env,
jobject wrapper,
jlong handle) {
JNI_METHOD_START
auto ctx = getRuntime(env, wrapper);
- std::vector<std::shared_ptr<ColumnarBatch>> batches;
auto batch = ObjectStore::retrieve<ColumnarBatch>(handle);
GLUTEN_DCHECK(batch != nullptr, "Cannot find the ColumnarBatch with handle "
+ std::to_string(handle));
- batches.emplace_back(batch);
auto serializer = ctx->createColumnarBatchSerializer(nullptr);
- auto buffer = serializer->serializeColumnarBatches(batches);
- auto bufferArr = env->NewByteArray(buffer->size());
- GLUTEN_CHECK(
- bufferArr != nullptr,
- "Cannot construct a byte array of size " +
std::to_string(buffer->size()) +
- " byte(s) to serialize columnar batches");
- env->SetByteArrayRegion(bufferArr, 0, buffer->size(), reinterpret_cast<const
jbyte*>(buffer->data()));
-
- return bufferArr;
+ serializer->append(batch);
+ auto serializedSize = serializer->maxSerializedSize();
+ auto byteBuffer = env->CallStaticObjectMethod(unsafeByteBufferClass,
unsafeByteBufferAllocate, serializedSize);
+ auto byteBufferAddress = env->CallLongMethod(byteBuffer,
unsafeByteBufferAddress);
+ auto byteBufferSize = env->CallLongMethod(byteBuffer, unsafeByteBufferSize);
+ serializer->serializeTo(reinterpret_cast<uint8_t*>(byteBufferAddress),
byteBufferSize);
+
+ return byteBuffer;
JNI_METHOD_END(nullptr)
}
diff --git a/cpp/core/operators/serializer/ColumnarBatchSerializer.h
b/cpp/core/operators/serializer/ColumnarBatchSerializer.h
index 25d23aacb5..08a76f9f23 100644
--- a/cpp/core/operators/serializer/ColumnarBatchSerializer.h
+++ b/cpp/core/operators/serializer/ColumnarBatchSerializer.h
@@ -29,8 +29,11 @@ class ColumnarBatchSerializer {
virtual ~ColumnarBatchSerializer() = default;
- virtual std::shared_ptr<arrow::Buffer> serializeColumnarBatches(
- const std::vector<std::shared_ptr<ColumnarBatch>>& batches) = 0;
+ virtual void append(const std::shared_ptr<ColumnarBatch>& batch) = 0;
+
+ virtual int64_t maxSerializedSize() = 0;
+
+ virtual void serializeTo(uint8_t* address, int64_t size) = 0;
virtual std::shared_ptr<ColumnarBatch> deserialize(uint8_t* data, int32_t
size) = 0;
diff --git a/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.cc
b/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.cc
index a35316702a..c12259420a 100644
--- a/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.cc
+++ b/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.cc
@@ -51,33 +51,41 @@ VeloxColumnarBatchSerializer::VeloxColumnarBatchSerializer(
rowType_ = asRowType(importFromArrow(*cSchema));
ArrowSchemaRelease(cSchema); // otherwise the c schema leaks memory
}
+ arena_ = std::make_unique<StreamArena>(veloxPool_.get());
serde_ = std::make_unique<serializer::presto::PrestoVectorSerde>();
options_.useLosslessTimestamp = true;
}
-std::shared_ptr<arrow::Buffer>
VeloxColumnarBatchSerializer::serializeColumnarBatches(
- const std::vector<std::shared_ptr<ColumnarBatch>>& batches) {
- VELOX_DCHECK(batches.size() != 0, "Should serialize at least 1 vector");
- const std::shared_ptr<VeloxColumnarBatch>& vb =
VeloxColumnarBatch::from(veloxPool_.get(), batches[0]);
- auto firstRowVector = vb->getRowVector();
- auto numRows = firstRowVector->size();
- auto arena = std::make_unique<StreamArena>(veloxPool_.get());
- auto rowType = asRowType(firstRowVector->type());
- auto serializer = serde_->createIterativeSerializer(rowType, numRows,
arena.get(), &options_);
- for (auto& batch : batches) {
- auto rowVector = VeloxColumnarBatch::from(veloxPool_.get(),
batch)->getRowVector();
- const IndexRange allRows{0, rowVector->size()};
- serializer->append(rowVector, folly::Range(&allRows, 1));
+void VeloxColumnarBatchSerializer::append(const
std::shared_ptr<ColumnarBatch>& batch) {
+ auto rowVector = VeloxColumnarBatch::from(veloxPool_.get(),
batch)->getRowVector();
+ if (serializer_ == nullptr) {
+ // Using first batch's schema to create the Velox serializer. This logic
was introduced in
+ // https://github.com/apache/incubator-gluten/pull/1568. It's a bit
suboptimal because the schemas
+ // across different batches may vary.
+ auto numRows = rowVector->size();
+ auto rowType = asRowType(rowVector->type());
+ serializer_ = serde_->createIterativeSerializer(rowType, numRows,
arena_.get(), &options_);
}
+ const IndexRange allRows{0, rowVector->size()};
+ serializer_->append(rowVector, folly::Range(&allRows, 1));
+}
+
+int64_t VeloxColumnarBatchSerializer::maxSerializedSize() {
+ VELOX_DCHECK(serializer_ != nullptr, "Should serialize at least 1 vector");
+ return serializer_->maxSerializedSize();
+}
- std::shared_ptr<arrow::Buffer> valueBuffer;
- GLUTEN_ASSIGN_OR_THROW(valueBuffer,
arrow::AllocateResizableBuffer(serializer->maxSerializedSize(), arrowPool_));
+void VeloxColumnarBatchSerializer::serializeTo(uint8_t* address, int64_t size)
{
+ VELOX_DCHECK(serializer_ != nullptr, "Should serialize at least 1 vector");
+ auto sizeNeeded = serializer_->maxSerializedSize();
+ GLUTEN_CHECK(
+ size >= sizeNeeded,
+ "The target buffer size is insufficient: " + std::to_string(size) + "
vs." + std::to_string(sizeNeeded));
+ std::shared_ptr<arrow::MutableBuffer> valueBuffer =
std::make_shared<arrow::MutableBuffer>(address, size);
auto output =
std::make_shared<arrow::io::FixedSizeBufferWriter>(valueBuffer);
serializer::presto::PrestoOutputStreamListener listener;
ArrowFixedSizeBufferOutputStream out(output, &listener);
- serializer->flush(&out);
- GLUTEN_THROW_NOT_OK(output->Close());
- return valueBuffer;
+ serializer_->flush(&out);
}
std::shared_ptr<ColumnarBatch>
VeloxColumnarBatchSerializer::deserialize(uint8_t* data, int32_t size) {
diff --git a/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.h
b/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.h
index 18539a2450..73fec890cf 100644
--- a/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.h
+++ b/cpp/velox/operators/serializer/VeloxColumnarBatchSerializer.h
@@ -32,13 +32,18 @@ class VeloxColumnarBatchSerializer final : public
ColumnarBatchSerializer {
std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool,
struct ArrowSchema* cSchema);
- std::shared_ptr<arrow::Buffer> serializeColumnarBatches(
- const std::vector<std::shared_ptr<ColumnarBatch>>& batches) override;
+ void append(const std::shared_ptr<ColumnarBatch>& batch) override;
+
+ int64_t maxSerializedSize() override;
+
+ void serializeTo(uint8_t* address, int64_t size) override;
std::shared_ptr<ColumnarBatch> deserialize(uint8_t* data, int32_t size)
override;
private:
std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool_;
+ std::unique_ptr<facebook::velox::StreamArena> arena_;
+ std::unique_ptr<facebook::velox::IterativeVectorSerializer> serializer_;
facebook::velox::RowTypePtr rowType_;
std::unique_ptr<facebook::velox::serializer::presto::PrestoVectorSerde>
serde_;
facebook::velox::serializer::presto::PrestoVectorSerde::PrestoOptions
options_;
diff --git a/cpp/velox/tests/VeloxColumnarBatchSerializerTest.cc
b/cpp/velox/tests/VeloxColumnarBatchSerializerTest.cc
index 7833b0cfea..35c18d8ec3 100644
--- a/cpp/velox/tests/VeloxColumnarBatchSerializerTest.cc
+++ b/cpp/velox/tests/VeloxColumnarBatchSerializerTest.cc
@@ -62,7 +62,10 @@ TEST_F(VeloxColumnarBatchSerializerTest, serialize) {
auto vector = makeRowVector(children);
auto batch = std::make_shared<VeloxColumnarBatch>(vector);
auto serializer = std::make_shared<VeloxColumnarBatchSerializer>(arrowPool,
pool_, nullptr);
- auto buffer = serializer->serializeColumnarBatches({batch});
+ serializer->append(batch);
+ std::shared_ptr<arrow::Buffer> buffer;
+ GLUTEN_ASSIGN_OR_THROW(buffer,
arrow::AllocateResizableBuffer(serializer->maxSerializedSize(), arrowPool));
+
serializer->serializeTo(reinterpret_cast<uint8_t*>(buffer->mutable_address()),
buffer->size());
ArrowSchema cSchema;
exportToArrow(vector, cSchema, ArrowUtils::getBridgeOptions());
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java
b/gluten-arrow/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java
index 3d13ccc613..7d72a263ac 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java
@@ -17,12 +17,14 @@
package org.apache.gluten.memory.arrow.alloc;
import org.apache.gluten.config.GlutenConfig;
+import org.apache.gluten.memory.SimpleMemoryUsageRecorder;
import org.apache.gluten.memory.memtarget.MemoryTargets;
import org.apache.gluten.memory.memtarget.Spillers;
import org.apache.arrow.memory.AllocationListener;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
+import org.apache.spark.memory.GlobalOffHeapMemory;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.task.TaskResource;
import org.apache.spark.task.TaskResources;
@@ -34,6 +36,14 @@ import java.util.List;
import java.util.Vector;
public class ArrowBufferAllocators {
+ private static final SimpleMemoryUsageRecorder GLOBAL_USAGE = new
SimpleMemoryUsageRecorder();
+ private static final BufferAllocator GLOBAL_INSTANCE;
+
+ static {
+ final AllocationListener listener =
+ new ManagedAllocationListener(GlobalOffHeapMemory.target(),
GLOBAL_USAGE);
+ GLOBAL_INSTANCE = new RootAllocator(listener, Long.MAX_VALUE);
+ }
private ArrowBufferAllocators() {}
@@ -51,6 +61,10 @@ public class ArrowBufferAllocators {
.managed;
}
+ public static BufferAllocator globalInstance() {
+ return GLOBAL_INSTANCE;
+ }
+
public static class ArrowBufferAllocatorManager implements TaskResource {
private static Logger LOGGER =
LoggerFactory.getLogger(ArrowBufferAllocatorManager.class);
private static final List<BufferAllocator> LEAKED = new Vector<>();
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
index 97ab10082e..00fe02c629 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
@@ -19,6 +19,8 @@ package org.apache.gluten.vectorized;
import org.apache.gluten.runtime.Runtime;
import org.apache.gluten.runtime.RuntimeAware;
+import org.apache.spark.unsafe.memory.UnsafeByteBuffer;
+
public class ColumnarBatchSerializerJniWrapper implements RuntimeAware {
private final Runtime runtime;
@@ -35,7 +37,7 @@ public class ColumnarBatchSerializerJniWrapper implements
RuntimeAware {
return runtime.getHandle();
}
- public native byte[] serialize(long handle);
+ public native UnsafeByteBuffer serialize(long handle);
// Return the native ColumnarBatchSerializer handle
public native long init(long cSchema);
diff --git
a/gluten-arrow/src/main/java/org/apache/spark/unsafe/memory/UnsafeByteBuffer.java
b/gluten-arrow/src/main/java/org/apache/spark/unsafe/memory/UnsafeByteBuffer.java
new file mode 100644
index 0000000000..1cd598ab6c
--- /dev/null
+++
b/gluten-arrow/src/main/java/org/apache/spark/unsafe/memory/UnsafeByteBuffer.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.unsafe.memory;
+
+import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators;
+
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.spark.task.TaskResources;
+import org.apache.spark.unsafe.Platform;
+
+/** The API is for being called from C++ via JNI. */
+public class UnsafeByteBuffer {
+ private final ArrowBuf buffer;
+ private final long size;
+
+ private UnsafeByteBuffer(ArrowBuf buffer, long size) {
+ this.buffer = buffer;
+ this.size = size;
+ }
+
+ public static UnsafeByteBuffer allocate(long size) {
+ final BufferAllocator allocator;
+ if (TaskResources.inSparkTask()) {
+ allocator =
ArrowBufferAllocators.contextInstance(UnsafeByteBuffer.class.getName());
+ } else {
+ allocator = ArrowBufferAllocators.globalInstance();
+ }
+ final ArrowBuf arrowBuf = allocator.buffer(size);
+ return new UnsafeByteBuffer(arrowBuf, size);
+ }
+
+ public long address() {
+ return buffer.memoryAddress();
+ }
+
+ public long size() {
+ return size;
+ }
+
+ public void close() {
+ buffer.close();
+ }
+
+ public byte[] toByteArray() {
+ final byte[] values = new byte[Math.toIntExact(size)];
+ Platform.copyMemory(
+ null, buffer.memoryAddress(), values, Platform.BYTE_ARRAY_OFFSET,
values.length);
+ return values;
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
b/gluten-core/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
index c0988b8eca..b2915084dd 100644
---
a/gluten-core/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
+++
b/gluten-core/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
@@ -30,7 +30,7 @@ import org.apache.gluten.memory.memtarget.{MemoryTarget,
NoopMemoryTarget}
* BlockId to be extended by user, TestBlockId is chosen for the storage
memory reservations.
*/
object GlobalOffHeapMemory {
- private val target: MemoryTarget = if (GlutenCoreConfig.get.memoryUntracked)
{
+ val target: MemoryTarget = if (GlutenCoreConfig.get.memoryUntracked) {
new NoopMemoryTarget()
} else {
new GlobalOffHeapMemoryTarget()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]