This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 7ddc963956 [GLUTEN-11133][VL] Reduce on-heap memory allocation and
memory copy when off-heap BHJ is enabled (#11148)
7ddc963956 is described below
commit 7ddc963956067b034c5783dde484fdee7a761aa4
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed Nov 26 16:38:18 2025 +0000
[GLUTEN-11133][VL] Reduce on-heap memory allocation and memory copy when
off-heap BHJ is enabled (#11148)
---
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 20 ++-
.../spark/sql/execution/BroadcastUtils.scala | 83 +++++------
.../execution/ColumnarCachedBatchSerializer.scala | 7 +-
.../execution/unsafe/UnsafeByteBufferArray.scala | 138 -----------------
.../unsafe/UnsafeColumnarBuildSideRelation.scala | 123 ++++-----------
.../UnsafeColumnarBuildSideRelationTest.scala | 53 +++++--
cpp/core/jni/JniWrapper.cc | 27 ++--
.../vectorized/ColumnarBatchSerializeResult.java | 55 +++++--
.../ColumnarBatchSerializerJniWrapper.java | 4 +-
.../sql/execution/unsafe/JniUnsafeByteBuffer.java | 89 +++++++++++
.../sql/execution/unsafe/UnsafeByteArray.java | 166 +++++++++++++++++++++
.../spark/unsafe/memory/UnsafeByteBuffer.java | 65 --------
12 files changed, 434 insertions(+), 396 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index fcd9b682a4..bb2c9c6c38 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -63,6 +63,8 @@ import javax.ws.rs.core.UriBuilder
import java.util.Locale
+import scala.collection.JavaConverters._
+
class VeloxSparkPlanExecApi extends SparkPlanExecApi {
/** Transform GetArrayItem to Substrait. */
@@ -670,26 +672,32 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
dataSize: SQLMetric): BuildSideRelation = {
val useOffheapBroadcastBuildRelation =
VeloxConfig.get.enableBroadcastBuildRelationInOffheap
- val serialized: Array[ColumnarBatchSerializeResult] = child
+ val serialized: Seq[ColumnarBatchSerializeResult] = child
.executeColumnar()
.mapPartitions(itr => Iterator(BroadcastUtils.serializeStream(itr)))
- .filter(_.getNumRows != 0)
+ .filter(_.numRows != 0)
.collect
- val rawSize = serialized.flatMap(_.getSerialized.map(_.length.toLong)).sum
+ val rawSize = serialized.map(_.sizeInBytes()).sum
if (rawSize >= GlutenConfig.get.maxBroadcastTableSize) {
throw new SparkException(
"Cannot broadcast the table that is larger than " +
s"${SparkMemoryUtil.bytesToString(GlutenConfig.get.maxBroadcastTableSize)}: " +
s"${SparkMemoryUtil.bytesToString(rawSize)}")
}
- numOutputRows += serialized.map(_.getNumRows).sum
+ numOutputRows += serialized.map(_.numRows).sum
dataSize += rawSize
if (useOffheapBroadcastBuildRelation) {
TaskResources.runUnsafe {
- UnsafeColumnarBuildSideRelation(child.output,
serialized.flatMap(_.getSerialized), mode)
+ UnsafeColumnarBuildSideRelation(
+ child.output,
+ serialized.flatMap(_.offHeapData().asScala),
+ mode)
}
} else {
- ColumnarBuildSideRelation(child.output,
serialized.flatMap(_.getSerialized), mode)
+ ColumnarBuildSideRelation(
+ child.output,
+ serialized.flatMap(_.onHeapData().asScala).toArray,
+ mode)
}
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
index 5fc9cc4464..ad066d47f9 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.task.TaskResources
+import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
// Utility methods to convert Vanilla broadcast relations from/to Velox
broadcast relations.
@@ -93,58 +94,44 @@ object BroadcastUtils {
schema: StructType,
from: Broadcast[F],
fn: Iterator[InternalRow] => Iterator[ColumnarBatch]): Broadcast[T] = {
- val useOffheapBuildRelation =
VeloxConfig.get.enableBroadcastBuildRelationInOffheap
+
+ def batchIterationToRelation(batchItr: () => Iterator[ColumnarBatch]):
BuildSideRelation = {
+ TaskResources.runUnsafe {
+ serializeStream(batchItr()) match {
+ case ColumnarBatchSerializeResult.EMPTY =>
+ ColumnarBuildSideRelation(
+ SparkShimLoader.getSparkShims.attributesFromStruct(schema),
+ Array[Array[Byte]](),
+ mode)
+ case result: ColumnarBatchSerializeResult =>
+ if (result.isOffHeap) {
+ UnsafeColumnarBuildSideRelation(
+ SparkShimLoader.getSparkShims.attributesFromStruct(schema),
+ result.offHeapData().asScala.toSeq,
+ mode)
+ } else {
+ ColumnarBuildSideRelation(
+ SparkShimLoader.getSparkShims.attributesFromStruct(schema),
+ result.onHeapData().asScala.toArray,
+ mode)
+ }
+ }
+ }
+ }
+
mode match {
case HashedRelationBroadcastMode(_, _) =>
// HashedRelation to ColumnarBuildSideRelation.
val fromBroadcast = from.asInstanceOf[Broadcast[HashedRelation]]
val fromRelation = fromBroadcast.value.asReadOnlyCopy()
- val toRelation = TaskResources.runUnsafe {
- val batchItr: Iterator[ColumnarBatch] =
fn(reconstructRows(fromRelation))
- val serialized: Array[Array[Byte]] = serializeStream(batchItr) match
{
- case ColumnarBatchSerializeResult.EMPTY =>
- Array()
- case result: ColumnarBatchSerializeResult =>
- result.getSerialized
- }
- if (useOffheapBuildRelation) {
- UnsafeColumnarBuildSideRelation(
- SparkShimLoader.getSparkShims.attributesFromStruct(schema),
- serialized,
- mode)
- } else {
- ColumnarBuildSideRelation(
- SparkShimLoader.getSparkShims.attributesFromStruct(schema),
- serialized,
- mode)
- }
- }
+ val toRelation = batchIterationToRelation(() =>
fn(reconstructRows(fromRelation)))
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
case IdentityBroadcastMode =>
// Array[InternalRow] to ColumnarBuildSideRelation.
val fromBroadcast = from.asInstanceOf[Broadcast[Array[InternalRow]]]
val fromRelation = fromBroadcast.value
- val toRelation = TaskResources.runUnsafe {
- val batchItr: Iterator[ColumnarBatch] = fn(fromRelation.iterator)
- val serialized: Array[Array[Byte]] = serializeStream(batchItr) match
{
- case ColumnarBatchSerializeResult.EMPTY =>
- Array()
- case result: ColumnarBatchSerializeResult =>
- result.getSerialized
- }
- if (useOffheapBuildRelation) {
- UnsafeColumnarBuildSideRelation(
- SparkShimLoader.getSparkShims.attributesFromStruct(schema),
- serialized,
- mode)
- } else {
- ColumnarBuildSideRelation(
- SparkShimLoader.getSparkShims.attributesFromStruct(schema),
- serialized,
- mode)
- }
- }
+ val toRelation = batchIterationToRelation(() =>
fn(fromRelation.iterator))
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
case _ => throw new IllegalStateException("Unexpected broadcast mode: "
+ mode)
@@ -175,25 +162,25 @@ object BroadcastUtils {
val handle =
ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, b)
numRows += b.numRows()
try {
- val unsafeBuffer = ColumnarBatchSerializerJniWrapper
+ ColumnarBatchSerializerJniWrapper
.create(
Runtimes
.contextInstance(
BackendsApiManager.getBackendName,
"BroadcastUtils#serializeStream"))
.serialize(handle)
- try {
- unsafeBuffer.toByteArray
- } finally {
- unsafeBuffer.close()
- }
} finally {
ColumnarBatches.release(b)
}
})
.toArray
if (values.nonEmpty) {
- new ColumnarBatchSerializeResult(numRows, values)
+ val useOffheapBroadcastBuildRelation =
+ VeloxConfig.get.enableBroadcastBuildRelationInOffheap
+ new ColumnarBatchSerializeResult(
+ useOffheapBroadcastBuildRelation,
+ numRows,
+ values.toSeq.asJava)
} else {
ColumnarBatchSerializeResult.EMPTY
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
index 80ff907e61..a04e7d68fb 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
@@ -178,12 +178,7 @@ class ColumnarCachedBatchSerializer extends
CachedBatchSerializer with Logging {
BackendsApiManager.getBackendName,
"ColumnarCachedBatchSerializer#serialize"))
.serialize(ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName,
batch))
- val bytes =
- try {
- unsafeBuffer.toByteArray
- } finally {
- unsafeBuffer.close()
- }
+ val bytes = unsafeBuffer.toByteArray
CachedColumnarBatch(batch.numRows(), bytes.length, bytes)
}
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeByteBufferArray.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeByteBufferArray.scala
deleted file mode 100644
index 45ff118d6d..0000000000
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeByteBufferArray.scala
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.execution.unsafe
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.internal.Logging
-import org.apache.spark.memory.GlobalOffHeapMemory
-import org.apache.spark.unsafe.Platform
-import org.apache.spark.unsafe.array.LongArray
-import org.apache.spark.unsafe.memory.MemoryAllocator
-
-/**
- * Used to store broadcast variable off-heap memory for broadcast variable.
The underlying data
- * structure is a LongArray allocated in off-heap memory.
- *
- * @param arraySize
- * underlying array[array[byte]]'s length
- * @param byteBufferLengths
- * underlying array[array[byte]] per byteBuffer length
- * @param totalBytes
- * all byteBuffer's length plus together
- */
-// scalastyle:off no.finalize
-@Experimental
-case class UnsafeByteBufferArray(arraySize: Int, byteBufferLengths:
Array[Int], totalBytes: Long)
- extends Logging {
- {
- assert(
- arraySize == byteBufferLengths.length,
- "Unsafe buffer array size " +
- "not equal to buffer lengths!")
- assert(totalBytes >= 0, "Unsafe buffer array total bytes can't be
negative!")
- }
- private val allocatedBytes = (totalBytes + 7) / 8 * 8
-
- /**
- * A single array to store all byteBufferArray's value, it's inited once
when first time get
- * accessed.
- */
- private var longArray: LongArray = _
-
- /** Index the start of each byteBuffer's offset to underlying LongArray's
initial position. */
- private val byteBufferOffset = if (byteBufferLengths.isEmpty) {
- new Array(0)
- } else {
- byteBufferLengths.init.scanLeft(0L)(_ + _)
- }
-
- /**
- * Put byteBuffer at specified array index.
- *
- * @param index
- * index of the array.
- * @param byteBuffer
- * byteBuffer to put.
- */
- def putByteBuffer(index: Int, byteBuffer: Array[Byte]): Unit =
this.synchronized {
- assert(index < arraySize)
- assert(byteBuffer.length == byteBufferLengths(index))
- // first to allocate underlying long array
- if (null == longArray && index == 0) {
- GlobalOffHeapMemory.acquire(allocatedBytes)
- longArray = new
LongArray(MemoryAllocator.UNSAFE.allocate(allocatedBytes))
- }
-
- Platform.copyMemory(
- byteBuffer,
- Platform.BYTE_ARRAY_OFFSET,
- longArray.getBaseObject,
- longArray.getBaseOffset + byteBufferOffset(index),
- byteBufferLengths(index))
- }
-
- /**
- * Get byteBuffer at specified index.
- * @param index
- * @return
- */
- def getByteBuffer(index: Int): Array[Byte] = {
- assert(index < arraySize)
- if (null == longArray) {
- return new Array[Byte](0)
- }
- val bytes = new Array[Byte](byteBufferLengths(index))
- Platform.copyMemory(
- longArray.getBaseObject,
- longArray.getBaseOffset + byteBufferOffset(index),
- bytes,
- Platform.BYTE_ARRAY_OFFSET,
- byteBufferLengths(index))
- bytes
- }
-
- /**
- * Get the byteBuffer memory address and length at specified index, usually
used when read memory
- * direct from offheap.
- *
- * @param index
- * @return
- */
- def getByteBufferOffsetAndLength(index: Int): (Long, Int) = {
- assert(index < arraySize)
- assert(longArray != null, "The broadcast data in offheap should not be
null!")
- val offset = longArray.getBaseOffset + byteBufferOffset(index)
- val length = byteBufferLengths(index)
- (offset, length)
- }
-
- /**
- * It's needed once the broadcast variable is garbage collected. Since now,
we don't have an
- * elegant way to free the underlying memory in offheap.
- */
- override def finalize(): Unit = {
- try {
- if (longArray != null) {
- longArray = null
- GlobalOffHeapMemory.release(allocatedBytes)
- }
- } finally {
- super.finalize()
- }
- }
-}
-// scalastyle:on no.finalize
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
index ebbd8226fd..ba307415c5 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
@@ -47,10 +47,9 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import scala.collection.JavaConverters.asScalaIteratorConverter
object UnsafeColumnarBuildSideRelation {
- // Keep constructors with BroadcastMode for compatibility
def apply(
output: Seq[Attribute],
- batches: UnsafeByteBufferArray,
+ batches: Seq[UnsafeByteArray],
mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
val boundMode = mode match {
case HashedRelationBroadcastMode(keys, isNullAware) =>
@@ -63,25 +62,6 @@ object UnsafeColumnarBuildSideRelation {
}
new UnsafeColumnarBuildSideRelation(output, batches,
BroadcastModeUtils.toSafe(boundMode))
}
- def apply(
- output: Seq[Attribute],
- byteBufferArray: Array[Array[Byte]],
- mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
- val boundMode = mode match {
- case HashedRelationBroadcastMode(keys, isNullAware) =>
- // Bind each key to the build-side output so simple cols become
BoundReference
- val boundKeys: Seq[Expression] =
- keys.map(k => BindReferences.bindReference(k, AttributeSeq(output)))
- HashedRelationBroadcastMode(boundKeys, isNullAware)
- case m =>
- m // IdentityBroadcastMode, etc.
- }
- new UnsafeColumnarBuildSideRelation(
- output,
- byteBufferArray,
- BroadcastModeUtils.toSafe(boundMode)
- )
- }
}
/**
@@ -95,10 +75,10 @@ object UnsafeColumnarBuildSideRelation {
* the broadcast mode.
*/
@Experimental
-case class UnsafeColumnarBuildSideRelation(
+class UnsafeColumnarBuildSideRelation(
private var output: Seq[Attribute],
- private var batches: UnsafeByteBufferArray,
- var safeBroadcastMode: SafeBroadcastMode)
+ private var batches: Seq[UnsafeByteArray],
+ private var safeBroadcastMode: SafeBroadcastMode)
extends BuildSideRelation
with Externalizable
with Logging
@@ -118,93 +98,35 @@ case class UnsafeColumnarBuildSideRelation(
/** needed for serialization. */
def this() = {
- this(null, null.asInstanceOf[UnsafeByteBufferArray], null)
+ this(null, null, null)
}
- def this(
- output: Seq[Attribute],
- byteBufferArray: Array[Array[Byte]],
- safeMode: SafeBroadcastMode
- ) = {
- this(
- output,
- UnsafeByteBufferArray(
- byteBufferArray.length,
- byteBufferArray.map(_.length),
- byteBufferArray.map(_.length.toLong).sum
- ),
- safeMode
- )
- val batchesSize = byteBufferArray.length
- for (i <- 0 until batchesSize) {
- // copy the bytes to off-heap memory.
- batches.putByteBuffer(i, byteBufferArray(i))
- }
+ private[unsafe] def getBatches(): Seq[UnsafeByteArray] = {
+ batches
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException
{
out.writeObject(output)
out.writeObject(safeBroadcastMode)
- out.writeInt(batches.arraySize)
- out.writeObject(batches.byteBufferLengths)
- out.writeLong(batches.totalBytes)
- for (i <- 0 until batches.arraySize) {
- val bytes = batches.getByteBuffer(i)
- out.write(bytes)
- }
+ out.writeObject(batches.toArray)
}
override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
kryo.writeObject(out, output.toList)
kryo.writeClassAndObject(out, safeBroadcastMode)
- out.writeInt(batches.arraySize)
- kryo.writeObject(out, batches.byteBufferLengths)
- out.writeLong(batches.totalBytes)
- for (i <- 0 until batches.arraySize) {
- val bytes = batches.getByteBuffer(i)
- out.write(bytes)
- }
+ kryo.writeClassAndObject(out, batches.toArray)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
output = in.readObject().asInstanceOf[Seq[Attribute]]
safeBroadcastMode = in.readObject().asInstanceOf[SafeBroadcastMode]
- val totalArraySize = in.readInt()
- val byteBufferLengths = in.readObject().asInstanceOf[Array[Int]]
- val totalBytes = in.readLong()
-
- // scalastyle:off
- /**
- * We use off-heap memory to reduce on-heap pressure Similar to
- *
https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala#L389-L410
- */
- // scalastyle:on
-
- batches = UnsafeByteBufferArray(totalArraySize, byteBufferLengths,
totalBytes)
-
- for (i <- 0 until totalArraySize) {
- val length = byteBufferLengths(i)
- val tmpBuffer = new Array[Byte](length)
- in.readFully(tmpBuffer)
- batches.putByteBuffer(i, tmpBuffer)
- }
+ batches = in.readObject().asInstanceOf[Array[UnsafeByteArray]].toSeq
}
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]]
safeBroadcastMode =
kryo.readClassAndObject(in).asInstanceOf[SafeBroadcastMode]
- val totalArraySize = in.readInt()
- val byteBufferLengths = kryo.readObject(in, classOf[Array[Int]])
- val totalBytes = in.readLong()
-
- batches = UnsafeByteBufferArray(totalArraySize, byteBufferLengths,
totalBytes)
-
- for (i <- 0 until totalArraySize) {
- val length = byteBufferLengths(i)
- val tmpBuffer = new Array[Byte](length)
- in.read(tmpBuffer)
- batches.putByteBuffer(i, tmpBuffer)
- }
+ batches =
kryo.readClassAndObject(in).asInstanceOf[Array[UnsafeByteArray]].toSeq
}
private def transformProjection: UnsafeProjection = safeBroadcastMode match {
@@ -247,15 +169,17 @@ case class UnsafeColumnarBuildSideRelation(
var batchId = 0
override def hasNext: Boolean = {
- batchId < batches.arraySize
+ batchId < batches.size
}
override def next: ColumnarBatch = {
- val (offset, length) =
- batches.getByteBufferOffsetAndLength(batchId)
+ val unsafeByteArray = batches(batchId)
batchId += 1
val handle =
- jniWrapper.deserializeDirect(serializerHandle, offset, length)
+ jniWrapper.deserializeDirect(
+ serializerHandle,
+ unsafeByteArray.address(),
+ Math.toIntExact(unsafeByteArray.size()))
ColumnarBatches.create(handle)
}
})
@@ -296,10 +220,10 @@ case class UnsafeColumnarBuildSideRelation(
val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime)
val c2rId = jniWrapper.nativeColumnarToRowInit()
var batchId = 0
- val iterator = if (batches.arraySize > 0) {
+ val iterator = if (batches.nonEmpty) {
val res: Iterator[Iterator[InternalRow]] = new
Iterator[Iterator[InternalRow]] {
override def hasNext: Boolean = {
- val itHasNext = batchId < batches.arraySize
+ val itHasNext = batchId < batches.size
if (!itHasNext && !closed) {
jniWrapper.nativeClose(c2rId)
serializerJniWrapper.close(serializerHandle)
@@ -309,10 +233,13 @@ case class UnsafeColumnarBuildSideRelation(
}
override def next(): Iterator[InternalRow] = {
- val (offset, length) = batches.getByteBufferOffsetAndLength(batchId)
+ val unsafeByteArray = batches(batchId)
batchId += 1
val batchHandle =
- serializerJniWrapper.deserializeDirect(serializerHandle, offset,
length)
+ serializerJniWrapper.deserializeDirect(
+ serializerHandle,
+ unsafeByteArray.address(),
+ Math.toIntExact(unsafeByteArray.size()))
val batch = ColumnarBatches.create(batchHandle)
if (batch.numRows == 0) {
batch.close()
@@ -370,7 +297,7 @@ case class UnsafeColumnarBuildSideRelation(
override def estimatedSize: Long = {
if (batches != null) {
- batches.totalBytes
+ batches.map(_.size()).sum
} else {
0L
}
diff --git
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
index 903bb100b7..6d0448fd84 100644
---
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
+++
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
@@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.execution.unsafe
+import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
+
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
@@ -23,6 +25,9 @@ import
org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode
import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType
+import org.apache.spark.unsafe.Platform
+
+import java.util
class UnsafeColumnarBuildSideRelationTest extends SharedSparkSession {
override protected def sparkConf: SparkConf = {
@@ -33,28 +38,38 @@ class UnsafeColumnarBuildSideRelationTest extends
SharedSparkSession {
var unsafeRelWithIdentityMode: UnsafeColumnarBuildSideRelation = _
var unsafeRelWithHashMode: UnsafeColumnarBuildSideRelation = _
+ var sampleBytes: Array[Array[Byte]] = _
+
+ private def toUnsafeByteArray(bytes: Array[Byte]): UnsafeByteArray = {
+ val buf = ArrowBufferAllocators.globalInstance().buffer(bytes.length)
+ buf.setBytes(0, bytes, 0, bytes.length);
+ new UnsafeByteArray(buf, bytes.length.toLong)
+ }
+
+ private def toByteArray(unsafeByteArray: UnsafeByteArray): Array[Byte] = {
+ val byteArray = new Array[Byte](Math.toIntExact(unsafeByteArray.size()))
+ Platform.copyMemory(
+ null,
+ unsafeByteArray.address(),
+ byteArray,
+ Platform.BYTE_ARRAY_OFFSET,
+ byteArray.length)
+ byteArray
+ }
override def beforeAll(): Unit = {
super.beforeAll()
val a = AttributeReference("a", StringType, nullable = false, null)()
val output = Seq(a)
- val totalArraySize = 1
- val perArraySize = new Array[Int](totalArraySize)
- perArraySize(0) = 10
- val bytesArray = UnsafeByteBufferArray(
- 1,
- perArraySize,
- 10
- )
- bytesArray.putByteBuffer(0, "1234567890".getBytes())
+ sampleBytes = Array("12345".getBytes(), "7890".getBytes)
unsafeRelWithIdentityMode = UnsafeColumnarBuildSideRelation(
output,
- bytesArray,
+ sampleBytes.map(a => toUnsafeByteArray(a)),
IdentityBroadcastMode
)
unsafeRelWithHashMode = UnsafeColumnarBuildSideRelation(
output,
- bytesArray,
+ sampleBytes.map(a => toUnsafeByteArray(a)),
HashedRelationBroadcastMode(output, isNullAware = false)
)
}
@@ -68,12 +83,20 @@ class UnsafeColumnarBuildSideRelationTest extends
SharedSparkSession {
val obj =
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer)
assert(obj != null)
assert(obj.mode == IdentityBroadcastMode)
+ assert(
+ util.Arrays.deepEquals(
+ obj.getBatches().map(toByteArray).toArray[AnyRef],
+ sampleBytes.asInstanceOf[Array[AnyRef]]))
// test unsafeRelWithHashMode
val buffer2 = serializerInstance.serialize(unsafeRelWithHashMode)
val obj2 =
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer2)
assert(obj2 != null)
assert(obj2.mode.isInstanceOf[HashedRelationBroadcastMode])
+ assert(
+ util.Arrays.deepEquals(
+ obj2.getBatches().map(toByteArray).toArray[AnyRef],
+ sampleBytes.asInstanceOf[Array[AnyRef]]))
}
test("Kryo serialization") {
@@ -85,12 +108,20 @@ class UnsafeColumnarBuildSideRelationTest extends
SharedSparkSession {
val obj =
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer)
assert(obj != null)
assert(obj.mode == IdentityBroadcastMode)
+ assert(
+ util.Arrays.deepEquals(
+ obj.getBatches().map(toByteArray).toArray[AnyRef],
+ sampleBytes.asInstanceOf[Array[AnyRef]]))
// test unsafeRelWithHashMode
val buffer2 = serializerInstance.serialize(unsafeRelWithHashMode)
val obj2 =
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer2)
assert(obj2 != null)
assert(obj2.mode.isInstanceOf[HashedRelationBroadcastMode])
+ assert(
+ util.Arrays.deepEquals(
+ obj2.getBatches().map(toByteArray).toArray[AnyRef],
+ sampleBytes.asInstanceOf[Array[AnyRef]]))
}
}
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index 4129e53ac7..307fa3c129 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -44,10 +44,10 @@ namespace {
jclass byteArrayClass;
-jclass unsafeByteBufferClass;
-jmethodID unsafeByteBufferAllocate;
-jmethodID unsafeByteBufferAddress;
-jmethodID unsafeByteBufferSize;
+jclass jniUnsafeByteBufferClass;
+jmethodID jniUnsafeByteBufferAllocate;
+jmethodID jniUnsafeByteBufferAddress;
+jmethodID jniUnsafeByteBufferSize;
jclass jniByteInputStreamClass;
jmethodID jniByteInputStreamRead;
@@ -250,11 +250,12 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
byteArrayClass = createGlobalClassReferenceOrError(env, "[B");
- unsafeByteBufferClass = createGlobalClassReferenceOrError(env,
"Lorg/apache/spark/unsafe/memory/UnsafeByteBuffer;");
- unsafeByteBufferAllocate =
- env->GetStaticMethodID(unsafeByteBufferClass, "allocate",
"(J)Lorg/apache/spark/unsafe/memory/UnsafeByteBuffer;");
- unsafeByteBufferAddress = env->GetMethodID(unsafeByteBufferClass, "address",
"()J");
- unsafeByteBufferSize = env->GetMethodID(unsafeByteBufferClass, "size",
"()J");
+ jniUnsafeByteBufferClass =
+ createGlobalClassReferenceOrError(env,
"Lorg/apache/spark/sql/execution/unsafe/JniUnsafeByteBuffer;");
+ jniUnsafeByteBufferAllocate = env->GetStaticMethodID(
+ jniUnsafeByteBufferClass, "allocate",
"(J)Lorg/apache/spark/sql/execution/unsafe/JniUnsafeByteBuffer;");
+ jniUnsafeByteBufferAddress = env->GetMethodID(jniUnsafeByteBufferClass,
"address", "()J");
+ jniUnsafeByteBufferSize = env->GetMethodID(jniUnsafeByteBufferClass, "size",
"()J");
jniByteInputStreamClass = createGlobalClassReferenceOrError(env,
"Lorg/apache/gluten/vectorized/JniByteInputStream;");
jniByteInputStreamRead = getMethodIdOrError(env, jniByteInputStreamClass,
"read", "(JJ)J");
@@ -298,7 +299,7 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) {
env->DeleteGlobalRef(splitResultClass);
env->DeleteGlobalRef(nativeColumnarToRowInfoClass);
env->DeleteGlobalRef(byteArrayClass);
- env->DeleteGlobalRef(unsafeByteBufferClass);
+ env->DeleteGlobalRef(jniUnsafeByteBufferClass);
env->DeleteGlobalRef(shuffleReaderMetricsClass);
getJniErrorState()->close();
@@ -1192,9 +1193,9 @@ JNIEXPORT jobject JNICALL
Java_org_apache_gluten_vectorized_ColumnarBatchSeriali
auto serializer = ctx->createColumnarBatchSerializer(nullptr);
serializer->append(batch);
auto serializedSize = serializer->maxSerializedSize();
- auto byteBuffer = env->CallStaticObjectMethod(unsafeByteBufferClass,
unsafeByteBufferAllocate, serializedSize);
- auto byteBufferAddress = env->CallLongMethod(byteBuffer,
unsafeByteBufferAddress);
- auto byteBufferSize = env->CallLongMethod(byteBuffer, unsafeByteBufferSize);
+ auto byteBuffer = env->CallStaticObjectMethod(jniUnsafeByteBufferClass,
jniUnsafeByteBufferAllocate, serializedSize);
+ auto byteBufferAddress = env->CallLongMethod(byteBuffer,
jniUnsafeByteBufferAddress);
+ auto byteBufferSize = env->CallLongMethod(byteBuffer,
jniUnsafeByteBufferSize);
serializer->serializeTo(reinterpret_cast<uint8_t*>(byteBufferAddress),
byteBufferSize);
return byteBuffer;
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializeResult.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializeResult.java
index 881048b052..3aa26afd43 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializeResult.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializeResult.java
@@ -16,26 +16,63 @@
*/
package org.apache.gluten.vectorized;
+import com.google.common.base.Preconditions;
+import org.apache.spark.sql.execution.unsafe.JniUnsafeByteBuffer;
+import org.apache.spark.sql.execution.unsafe.UnsafeByteArray;
+
import java.io.Serializable;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
public class ColumnarBatchSerializeResult implements Serializable {
public static final ColumnarBatchSerializeResult EMPTY =
- new ColumnarBatchSerializeResult(0, new byte[0][0]);
-
- private long numRows;
+ new ColumnarBatchSerializeResult(true, 0, Collections.emptyList());
- private byte[][] serialized;
+ private final boolean isOffHeap;
+ private final long numRows;
+ private final long sizeInBytes;
+ private final List<byte[]> onHeapData;
+ private final List<UnsafeByteArray> offHeapData;
- public ColumnarBatchSerializeResult(long numRows, byte[][] serialized) {
+ public ColumnarBatchSerializeResult(
+ boolean isOffHeap, long numRows, List<JniUnsafeByteBuffer> serialized) {
this.numRows = numRows;
- this.serialized = serialized;
+ this.isOffHeap = isOffHeap;
+ if (isOffHeap) {
+ onHeapData = null;
+ offHeapData =
+ serialized.stream()
+ .map(JniUnsafeByteBuffer::toUnsafeByteArray)
+ .collect(Collectors.toList());
+ sizeInBytes = offHeapData.stream().mapToInt(unsafe ->
Math.toIntExact(unsafe.size())).sum();
+ } else {
+ onHeapData =
+
serialized.stream().map(JniUnsafeByteBuffer::toByteArray).collect(Collectors.toList());
+ offHeapData = null;
+ sizeInBytes = onHeapData.stream().mapToInt(bytes -> bytes.length).sum();
+ }
+ }
+
+ public boolean isOffHeap() {
+ return isOffHeap;
}
- public long getNumRows() {
+ public long numRows() {
return numRows;
}
- public byte[][] getSerialized() {
- return serialized;
+ public long sizeInBytes() {
+ return sizeInBytes;
+ }
+
+ public List<byte[]> onHeapData() {
+ Preconditions.checkState(!isOffHeap);
+ return onHeapData;
+ }
+
+ public List<UnsafeByteArray> offHeapData() {
+ Preconditions.checkState(isOffHeap);
+ return offHeapData;
}
}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
index 00fe02c629..909b5b411d 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java
@@ -19,7 +19,7 @@ package org.apache.gluten.vectorized;
import org.apache.gluten.runtime.Runtime;
import org.apache.gluten.runtime.RuntimeAware;
-import org.apache.spark.unsafe.memory.UnsafeByteBuffer;
+import org.apache.spark.sql.execution.unsafe.JniUnsafeByteBuffer;
public class ColumnarBatchSerializerJniWrapper implements RuntimeAware {
private final Runtime runtime;
@@ -37,7 +37,7 @@ public class ColumnarBatchSerializerJniWrapper implements
RuntimeAware {
return runtime.getHandle();
}
- public native UnsafeByteBuffer serialize(long handle);
+ public native JniUnsafeByteBuffer serialize(long handle);
// Return the native ColumnarBatchSerializer handle
public native long init(long cSchema);
diff --git
a/gluten-arrow/src/main/java/org/apache/spark/sql/execution/unsafe/JniUnsafeByteBuffer.java
b/gluten-arrow/src/main/java/org/apache/spark/sql/execution/unsafe/JniUnsafeByteBuffer.java
new file mode 100644
index 0000000000..86cbb5f7af
--- /dev/null
+++
b/gluten-arrow/src/main/java/org/apache/spark/sql/execution/unsafe/JniUnsafeByteBuffer.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.unsafe;
+
+import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators;
+
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.spark.unsafe.Platform;
+
+/**
+ * A temperate unsafe byte buffer implementation that is created and operated
from C++ via JNI. The
+ * buffer has to be converted either to a Java on-heap byte array or to a Java
off-heap unsafe byte
+ * array after Java code receives this object.
+ */
+public class JniUnsafeByteBuffer {
+ private ArrowBuf buffer;
+ private long size;
+ private boolean released = false;
+
+ private JniUnsafeByteBuffer(ArrowBuf buffer, long size) {
+ this.buffer = buffer;
+ this.size = size;
+ }
+
+ // Invoked by C++ code via JNI.
+ public static JniUnsafeByteBuffer allocate(long size) {
+ final ArrowBuf arrowBuf =
ArrowBufferAllocators.globalInstance().buffer(size);
+ return new JniUnsafeByteBuffer(arrowBuf, size);
+ }
+
+ // Invoked by C++ code via JNI.
+ public long address() {
+ ensureOpen();
+ return buffer.memoryAddress();
+ }
+
+ // Invoked by C++ code via JNI.
+ public long size() {
+ ensureOpen();
+ return size;
+ }
+
+ private synchronized void ensureOpen() {
+ if (released) {
+ throw new IllegalStateException("Already released");
+ }
+ }
+
+ private synchronized void release() {
+ ensureOpen();
+ buffer.close();
+ released = true;
+ buffer = null;
+ size = 0;
+ }
+
+ public synchronized byte[] toByteArray() {
+ ensureOpen();
+ final byte[] values = new byte[Math.toIntExact(size)];
+ Platform.copyMemory(
+ null, buffer.memoryAddress(), values, Platform.BYTE_ARRAY_OFFSET,
values.length);
+ release();
+ return values;
+ }
+
+ public synchronized UnsafeByteArray toUnsafeByteArray() {
+ final UnsafeByteArray out;
+ ensureOpen();
+ // We can safely release the buffer after UnsafeByteArray is constructed
because it keeps
+ // its own reference to the buffer.
+ out = new UnsafeByteArray(buffer, size);
+ release();
+ return out;
+ }
+}
diff --git
a/gluten-arrow/src/main/java/org/apache/spark/sql/execution/unsafe/UnsafeByteArray.java
b/gluten-arrow/src/main/java/org/apache/spark/sql/execution/unsafe/UnsafeByteArray.java
new file mode 100644
index 0000000000..8a25beec21
--- /dev/null
+++
b/gluten-arrow/src/main/java/org/apache/spark/sql/execution/unsafe/UnsafeByteArray.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.unsafe;
+
+import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.KryoSerializable;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.apache.arrow.memory.ArrowBuf;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+
+/** A serializable unsafe byte array. */
+public class UnsafeByteArray implements Externalizable, KryoSerializable {
+ private ArrowBuf buffer;
+ private long size;
+
+ UnsafeByteArray(ArrowBuf buffer, long size) {
+ this.buffer = buffer;
+ this.buffer.getReferenceManager().retain();
+ this.size = size;
+ }
+
+ public UnsafeByteArray() {}
+
+ public long address() {
+ return buffer.memoryAddress();
+ }
+
+ public long size() {
+ return size;
+ }
+
+ public void release() {
+ if (buffer != null) {
+ buffer.close();
+ buffer = null;
+ size = 0;
+ }
+ }
+
+ // ------------ KryoSerializable ------------
+
+ @Override
+ public void write(Kryo kryo, Output output) {
+ // write length first
+ output.writeLong(size);
+
+ // stream bytes out of ArrowBuf
+ final int chunkSize = 8 * 1024;
+ byte[] tmp = new byte[chunkSize];
+
+ long remaining = size;
+ int index = 0;
+ while (remaining > 0) {
+ int chunk = (int) Math.min(chunkSize, remaining);
+ buffer.getBytes(index, tmp, 0, chunk);
+ output.write(tmp, 0, chunk);
+ index += chunk;
+ remaining -= chunk;
+ }
+ }
+
+ @Override
+ public void read(Kryo kryo, Input input) {
+ // read length
+ this.size = input.readLong();
+
+ if (size > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("UnsafeByteArray size too large: " +
size);
+ }
+
+ // allocate ArrowBuf
+ this.buffer = ArrowBufferAllocators.globalInstance().buffer((int) size);
+
+ // stream bytes into ArrowBuf
+ final int chunkSize = 8 * 1024;
+ byte[] tmp = new byte[chunkSize];
+
+ long remaining = size;
+ int index = 0;
+ while (remaining > 0) {
+ int chunk = (int) Math.min(chunkSize, remaining);
+ input.readBytes(tmp, 0, chunk);
+ buffer.setBytes(index, tmp, 0, chunk);
+ index += chunk;
+ remaining -= chunk;
+ }
+ }
+
+ // ------------ Externalizable ------------
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ // write length first
+ out.writeLong(size);
+
+ final int chunkSize = 8 * 1024;
+ byte[] tmp = new byte[chunkSize];
+
+ long remaining = size;
+ int index = 0;
+ while (remaining > 0) {
+ int chunk = (int) Math.min(chunkSize, remaining);
+ buffer.getBytes(index, tmp, 0, chunk);
+ out.write(tmp, 0, chunk);
+ index += chunk;
+ remaining -= chunk;
+ }
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ this.size = in.readLong();
+
+ if (size > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("UnsafeByteArray size too large: " +
size);
+ }
+
+ this.buffer = ArrowBufferAllocators.globalInstance().buffer((int) size);
+
+ final int chunkSize = 8 * 1024;
+ byte[] tmp = new byte[chunkSize];
+
+ long remaining = size;
+ int index = 0;
+ while (remaining > 0) {
+ int chunk = (int) Math.min(chunkSize, remaining);
+ // ObjectInput extends DataInput, so we can use readFully
+ in.readFully(tmp, 0, chunk);
+ buffer.setBytes(index, tmp, 0, chunk);
+ index += chunk;
+ remaining -= chunk;
+ }
+ }
+
+ /**
+ * It's needed once the broadcast variable is garbage collected. Since now,
we don't have an
+ * elegant way to free the underlying memory in off-heap.
+ *
+ * <p>Since: https://github.com/apache/incubator-gluten/pull/8127.
+ */
+ public void finalize() throws Throwable {
+ release();
+ super.finalize();
+ }
+}
diff --git
a/gluten-arrow/src/main/java/org/apache/spark/unsafe/memory/UnsafeByteBuffer.java
b/gluten-arrow/src/main/java/org/apache/spark/unsafe/memory/UnsafeByteBuffer.java
deleted file mode 100644
index 1cd598ab6c..0000000000
---
a/gluten-arrow/src/main/java/org/apache/spark/unsafe/memory/UnsafeByteBuffer.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.unsafe.memory;
-
-import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators;
-
-import org.apache.arrow.memory.ArrowBuf;
-import org.apache.arrow.memory.BufferAllocator;
-import org.apache.spark.task.TaskResources;
-import org.apache.spark.unsafe.Platform;
-
-/** The API is for being called from C++ via JNI. */
-public class UnsafeByteBuffer {
- private final ArrowBuf buffer;
- private final long size;
-
- private UnsafeByteBuffer(ArrowBuf buffer, long size) {
- this.buffer = buffer;
- this.size = size;
- }
-
- public static UnsafeByteBuffer allocate(long size) {
- final BufferAllocator allocator;
- if (TaskResources.inSparkTask()) {
- allocator =
ArrowBufferAllocators.contextInstance(UnsafeByteBuffer.class.getName());
- } else {
- allocator = ArrowBufferAllocators.globalInstance();
- }
- final ArrowBuf arrowBuf = allocator.buffer(size);
- return new UnsafeByteBuffer(arrowBuf, size);
- }
-
- public long address() {
- return buffer.memoryAddress();
- }
-
- public long size() {
- return size;
- }
-
- public void close() {
- buffer.close();
- }
-
- public byte[] toByteArray() {
- final byte[] values = new byte[Math.toIntExact(size)];
- Platform.copyMemory(
- null, buffer.memoryAddress(), values, Platform.BYTE_ARRAY_OFFSET,
values.length);
- return values;
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]