This is an automated email from the ASF dual-hosted git repository.
marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new b7e44c210a [GLUTEN-10214][VL] Merge inputstream for shuffle reader
(#10499)
b7e44c210a is described below
commit b7e44c210ab28636c29703462f26e0d4d2833770
Author: Rong Ma <[email protected]>
AuthorDate: Sun Sep 7 09:04:39 2025 +0100
[GLUTEN-10214][VL] Merge inputstream for shuffle reader (#10499)
---
.../clickhouse/CHSparkPlanExecApi.scala | 7 +-
.../apache/spark/shuffle/utils/CHShuffleUtil.scala | 27 +-
.../VeloxCelebornColumnarBatchSerializer.scala | 12 +-
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 7 +-
.../vectorized/ColumnarBatchSerializer.scala | 35 ++-
.../ColumnarBatchSerializerInstance.scala | 47 ++++
.../spark/shuffle/ColumnarShuffleReader.scala | 127 +++++++++
.../apache/spark/shuffle/utils/ShuffleUtil.scala | 26 +-
cpp/core/jni/JniWrapper.cc | 51 +++-
cpp/core/shuffle/ShuffleReader.h | 11 +-
cpp/velox/benchmarks/GenericBenchmark.cc | 4 +-
cpp/velox/compute/VeloxRuntime.cc | 3 +-
cpp/velox/shuffle/VeloxShuffleReader.cc | 290 ++++++++++-----------
cpp/velox/shuffle/VeloxShuffleReader.h | 67 ++---
cpp/velox/tests/VeloxShuffleWriterTest.cc | 8 +-
.../tests/utils/TestStreamReader.h} | 19 +-
.../gluten/vectorized/ShuffleReaderJniWrapper.java | 2 +-
.../gluten/vectorized/ShuffleStreamReader.scala | 52 ++++
.../gluten/backendsapi/SparkPlanExecApi.scala | 5 +-
.../spark/shuffle/GlutenShuffleReaderWrapper.scala | 29 +--
.../apache/spark/shuffle/GlutenShuffleUtils.scala | 38 +++
.../spark/shuffle/GlutenShuffleWriterWrapper.scala | 16 --
.../shuffle/sort/ColumnarShuffleManager.scala | 39 +--
23 files changed, 627 insertions(+), 295 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index f327b3b00e..6b42bfc9ae 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.memory.SparkMemoryUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{GenShuffleWriterParameters,
GlutenShuffleWriterWrapper, HashPartitioningWrapper}
+import org.apache.spark.shuffle.{GenShuffleReaderParameters,
GenShuffleWriterParameters, GlutenShuffleReaderWrapper,
GlutenShuffleWriterWrapper, HashPartitioningWrapper}
import org.apache.spark.shuffle.utils.CHShuffleUtil
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
@@ -441,6 +441,11 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
CHShuffleUtil.genColumnarShuffleWriter(parameters)
}
+ override def genColumnarShuffleReader[K, C](
+ parameters: GenShuffleReaderParameters[K, C]):
GlutenShuffleReaderWrapper[K, C] = {
+ CHShuffleUtil.genColumnarShuffleReader(parameters)
+ }
+
/**
* Generate ColumnarBatchSerializer for ColumnarShuffleExchangeExec.
*
diff --git
a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/CHShuffleUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/CHShuffleUtil.scala
index 4c0e7f07a0..fa5ae4bb9c 100644
---
a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/CHShuffleUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/CHShuffleUtil.scala
@@ -16,7 +16,9 @@
*/
package org.apache.spark.shuffle.utils
-import org.apache.spark.shuffle.{CHColumnarShuffleWriter,
GenShuffleWriterParameters, GlutenShuffleWriterWrapper}
+import org.apache.spark.shuffle.{BlockStoreShuffleReader,
CHColumnarShuffleWriter, GenShuffleReaderParameters,
GenShuffleWriterParameters, GlutenShuffleReaderWrapper,
GlutenShuffleWriterWrapper}
+import org.apache.spark.shuffle.sort.ColumnarShuffleHandle
+import
org.apache.spark.shuffle.sort.ColumnarShuffleManager.bypassDecompressionSerializerManger
object CHShuffleUtil {
@@ -29,4 +31,27 @@ object CHShuffleUtil {
parameters.mapId,
parameters.metrics))
}
+
+ def genColumnarShuffleReader[K, C](
+ parameters: GenShuffleReaderParameters[K, C]):
GlutenShuffleReaderWrapper[K, C] = {
+ val reader = if (parameters.handle.isInstanceOf[ColumnarShuffleHandle[_,
_]]) {
+ new BlockStoreShuffleReader(
+ parameters.handle,
+ parameters.blocksByAddress,
+ parameters.context,
+ parameters.readMetrics,
+ serializerManager = bypassDecompressionSerializerManger,
+ shouldBatchFetch = parameters.shouldBatchFetch
+ )
+ } else {
+ new BlockStoreShuffleReader(
+ parameters.handle,
+ parameters.blocksByAddress,
+ parameters.context,
+ parameters.readMetrics,
+ shouldBatchFetch = parameters.shouldBatchFetch
+ )
+ }
+ GlutenShuffleReaderWrapper(reader)
+ }
}
diff --git
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
index 0869ad3c30..ce32a9b7ad 100644
---
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
+++
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
@@ -125,7 +125,8 @@ private class CelebornColumnarBatchSerializerInstance(
private class TaskDeserializationStream(in: InputStream)
extends DeserializationStream
with TaskResource {
- private var byteIn: JniByteInputStream = _
+ private val streamReader = ShuffleStreamReader(Iterator((null, in)))
+
private var wrappedOut: ColumnarBatchOutIterator = _
private var cb: ColumnarBatch = _
@@ -247,23 +248,22 @@ private class CelebornColumnarBatchSerializerInstance(
readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal)
}
numOutputRows += numRowsTotal
- if (byteIn != null) {
+ if (wrappedOut != null) {
wrappedOut.close()
- byteIn.close()
}
+ streamReader.close()
if (cb != null) {
cb.close()
}
}
private def initStream(): Unit = {
- if (byteIn == null) {
- byteIn = JniByteInputStreams.create(in)
+ if (wrappedOut == null) {
wrappedOut = new ColumnarBatchOutIterator(
runtime,
ShuffleReaderJniWrapper
.create(runtime)
- .readStream(shuffleReaderHandle, byteIn))
+ .read(shuffleReaderHandle, streamReader))
}
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 0fe000c5c6..a822da014f 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -32,7 +32,7 @@ import
org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, PullOutArrowEva
import org.apache.spark.memory.SparkMemoryUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{GenShuffleWriterParameters,
GlutenShuffleWriterWrapper}
+import org.apache.spark.shuffle.{GenShuffleReaderParameters,
GenShuffleWriterParameters, GlutenShuffleReaderWrapper,
GlutenShuffleWriterWrapper}
import org.apache.spark.shuffle.utils.ShuffleUtil
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
@@ -590,6 +590,11 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
ShuffleUtil.genColumnarShuffleWriter(parameters)
}
+ override def genColumnarShuffleReader[K, C](
+ parameters: GenShuffleReaderParameters[K, C]):
GlutenShuffleReaderWrapper[K, C] = {
+ ShuffleUtil.genColumnarShuffleReader(parameters)
+ }
+
override def createColumnarWriteFilesExec(
child: WriteFilesExecTransformer,
noop: SparkPlan,
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
index 53354f432b..17eea29238 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.utils.SparkSchemaUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.storage.BlockId
import org.apache.spark.task.{TaskResource, TaskResources}
import org.apache.arrow.c.ArrowSchema
@@ -56,7 +57,7 @@ class ColumnarBatchSerializer(
/** Creates a new [[SerializerInstance]]. */
override def newInstance(): SerializerInstance = {
- new ColumnarBatchSerializerInstance(
+ new ColumnarBatchSerializerInstanceImpl(
schema,
readBatchNumRows,
numOutputRows,
@@ -68,16 +69,21 @@ class ColumnarBatchSerializer(
override def supportsRelocationOfSerializedObjects: Boolean = true
}
-private class ColumnarBatchSerializerInstance(
+private class ColumnarBatchSerializerInstanceImpl(
schema: StructType,
readBatchNumRows: SQLMetric,
numOutputRows: SQLMetric,
deserializeTime: SQLMetric,
decompressTime: SQLMetric,
shuffleWriterType: ShuffleWriterType)
- extends SerializerInstance
+ extends ColumnarBatchSerializerInstance
with Logging {
+ private val runtime =
+ Runtimes.contextInstance(BackendsApiManager.getBackendName,
"ShuffleReader")
+
+ private val jniWrapper = ShuffleReaderJniWrapper.create(runtime)
+
private val shuffleReaderHandle = {
val allocator: BufferAllocator = ArrowBufferAllocators
.contextInstance(classOf[ColumnarBatchSerializerInstance].getSimpleName)
@@ -98,8 +104,6 @@ private class ColumnarBatchSerializerInstance(
val batchSize = GlutenConfig.get.maxBatchSize
val readerBufferSize = GlutenConfig.get.columnarShuffleReaderBufferSize
val deserializerBufferSize =
GlutenConfig.get.columnarSortShuffleDeserializerBufferSize
- val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName,
"ShuffleReader")
- val jniWrapper = ShuffleReaderJniWrapper.create(runtime)
val shuffleReaderHandle = jniWrapper.make(
cSchema.memoryAddress(),
compressionCodec,
@@ -128,20 +132,23 @@ private class ColumnarBatchSerializerInstance(
}
override def deserializeStream(in: InputStream): DeserializationStream = {
- new TaskDeserializationStream(in)
+ new TaskDeserializationStream(Iterator((null, in)))
}
- private class TaskDeserializationStream(in: InputStream)
+ override def deserializeStreams(
+ streams: Iterator[(BlockId, InputStream)]): DeserializationStream = {
+ new TaskDeserializationStream(streams)
+ }
+
+ private class TaskDeserializationStream(streams: Iterator[(BlockId,
InputStream)])
extends DeserializationStream
with TaskResource {
- private val byteIn: JniByteInputStream = JniByteInputStreams.create(in)
- private val runtime =
- Runtimes.contextInstance(BackendsApiManager.getBackendName,
"ShuffleReader")
+ private val streamReader = ShuffleStreamReader(streams)
+
private val wrappedOut: ClosableIterator = new ColumnarBatchOutIterator(
runtime,
- ShuffleReaderJniWrapper
- .create(runtime)
- .readStream(shuffleReaderHandle, byteIn))
+ jniWrapper
+ .read(shuffleReaderHandle, streamReader))
private var cb: ColumnarBatch = _
@@ -232,7 +239,7 @@ private class ColumnarBatchSerializerInstance(
}
numOutputRows += numRowsTotal
wrappedOut.close()
- byteIn.close()
+ streamReader.close()
if (cb != null) {
cb.close()
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala
new file mode 100644
index 0000000000..205d38b528
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized
+
+import org.apache.spark.serializer.{DeserializationStream,
SerializationStream, SerializerInstance}
+import org.apache.spark.storage.BlockId
+
+import java.io.{InputStream, OutputStream}
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+abstract class ColumnarBatchSerializerInstance extends SerializerInstance {
+
+ /** Deserialize the streams of ColumnarBatches. */
+ def deserializeStreams(streams: Iterator[(BlockId, InputStream)]):
DeserializationStream
+
+ override def serialize[T: ClassTag](t: T): ByteBuffer = {
+ throw new UnsupportedOperationException
+ }
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
+ throw new UnsupportedOperationException
+ }
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer, loader:
ClassLoader): T = {
+ throw new UnsupportedOperationException
+ }
+
+ override def serializeStream(s: OutputStream): SerializationStream = {
+ throw new UnsupportedOperationException
+ }
+}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala
b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala
new file mode 100644
index 0000000000..1e514cf9f1
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.gluten.vectorized.ColumnarBatchSerializerInstance
+
+import org.apache.spark._
+import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId,
ShuffleBlockFetcherIterator}
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Fetches and reads the blocks from a shuffle by requesting them from other
nodes' block stores.
+ */
+class ColumnarShuffleReader[K, C](
+ handle: BaseShuffleHandle[K, _, C],
+ blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long,
Int)])],
+ context: TaskContext,
+ readMetrics: ShuffleReadMetricsReporter,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
+ shouldBatchFetch: Boolean = false)
+ extends ShuffleReader[K, C]
+ with Logging {
+
+ private val dep = handle.dependency
+
+ private def fetchContinuousBlocksInBatch: Boolean = {
+ val conf = SparkEnv.get.conf
+ val serializerRelocatable =
dep.serializer.supportsRelocationOfSerializedObjects
+ val compressed = conf.get(config.SHUFFLE_COMPRESS)
+ val codecConcatenation = if (compressed) {
+
CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf))
+ } else {
+ true
+ }
+ val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)
+ // SPARK-34790: Fetching continuous blocks in batch is incompatible with
io encryption.
+ val ioEncryption = conf.get(config.IO_ENCRYPTION_ENABLED)
+
+ val doBatchFetch = shouldBatchFetch && serializerRelocatable &&
+ (!compressed || codecConcatenation) && !useOldFetchProtocol &&
!ioEncryption
+ if (shouldBatchFetch && !doBatchFetch) {
+ logDebug(
+ "The feature tag of continuous shuffle block fetching is set to true,
but " +
+ "we can not enable the feature because other conditions are not
satisfied. " +
+ s"Shuffle compress: $compressed, serializer relocatable:
$serializerRelocatable, " +
+ s"codec concatenation: $codecConcatenation, use old shuffle fetch
protocol: " +
+ s"$useOldFetchProtocol, io encryption: $ioEncryption.")
+ }
+ doBatchFetch
+ }
+
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ val wrappedStreams = new ShuffleBlockFetcherIterator(
+ context,
+ blockManager.blockStoreClient,
+ blockManager,
+ mapOutputTracker,
+ blocksByAddress,
+ serializerManager.wrapStream,
+ // Note: we use getSizeAsMb when no suffix is provided for backwards
compatibility
+ SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
+ SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+ SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
+ SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
+ SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
+ SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
+ readMetrics,
+ fetchContinuousBlocksInBatch
+ ).toCompletionIterator
+
+ val recordIter = dep match {
+ case columnarDep: ColumnarShuffleDependency[K, _, C] =>
+ // If the dependency is a ColumnarShuffleDependency, we use the
columnar serializer.
+ columnarDep.serializer
+ .newInstance()
+ .asInstanceOf[ColumnarBatchSerializerInstance]
+ .deserializeStreams(wrappedStreams)
+ .asKeyValueIterator
+ case _ =>
+ val serializerInstance = dep.serializer.newInstance()
+ // Create a key/value iterator for each stream
+ wrappedStreams.flatMap {
+ case (blockId, wrappedStream) =>
+ // Note: the asKeyValueIterator below wraps a key/value iterator
inside of a
+ // NextIterator. The NextIterator makes sure that close() is
called on the
+ // underlying InputStream when all records have been read.
+
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+ }
+ }
+
+ // Update the context task metrics for each record read.
+ val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+ recordIter.map {
+ record =>
+ readMetrics.incRecordsRead(1)
+ record
+ },
+ context.taskMetrics().mergeShuffleReadMetrics())
+
+ // An interruptible iterator must be used here in order to support task
cancellation
+ new InterruptibleIterator[(Any, Any)](context, metricIter)
+ .asInstanceOf[Iterator[Product2[K, C]]]
+ }
+}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/shuffle/utils/ShuffleUtil.scala
b/backends-velox/src/main/scala/org/apache/spark/shuffle/utils/ShuffleUtil.scala
index d0589c90d6..6293d4e764 100644
---
a/backends-velox/src/main/scala/org/apache/spark/shuffle/utils/ShuffleUtil.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/shuffle/utils/ShuffleUtil.scala
@@ -16,7 +16,8 @@
*/
package org.apache.spark.shuffle.utils
-import org.apache.spark.shuffle.{ColumnarShuffleWriter,
GenShuffleWriterParameters, GlutenShuffleWriterWrapper}
+import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.sort.{ColumnarShuffleHandle,
ColumnarShuffleManager}
object ShuffleUtil {
@@ -29,4 +30,27 @@ object ShuffleUtil {
parameters.mapId,
parameters.metrics))
}
+
+ def genColumnarShuffleReader[K, C](
+ parameters: GenShuffleReaderParameters[K, C]):
GlutenShuffleReaderWrapper[K, C] = {
+ val reader = if (parameters.handle.isInstanceOf[ColumnarShuffleHandle[_,
_]]) {
+ new ColumnarShuffleReader[K, C](
+ parameters.handle,
+ parameters.blocksByAddress,
+ parameters.context,
+ parameters.readMetrics,
+ ColumnarShuffleManager.bypassDecompressionSerializerManger,
+ shouldBatchFetch = parameters.shouldBatchFetch
+ )
+ } else {
+ new BlockStoreShuffleReader(
+ parameters.handle,
+ parameters.blocksByAddress,
+ parameters.context,
+ parameters.readMetrics,
+ shouldBatchFetch = parameters.shouldBatchFetch
+ )
+ }
+ GlutenShuffleReaderWrapper(reader)
+ }
}
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index d124178c02..0d047b9f63 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -61,6 +61,9 @@ jclass shuffleReaderMetricsClass;
jmethodID shuffleReaderMetricsSetDecompressTime;
jmethodID shuffleReaderMetricsSetDeserializeTime;
+jclass shuffleStreamReaderClass;
+jmethodID shuffleStreamReaderNextStream;
+
class JavaInputStreamAdaptor final : public arrow::io::InputStream {
public:
JavaInputStreamAdaptor(JNIEnv* env, arrow::MemoryPool* pool, jobject jniIn)
: pool_(pool) {
@@ -190,6 +193,39 @@ void internalRuntimeReleaser(Runtime* runtime) {
delete runtime;
}
+class ShuffleStreamReader : public StreamReader {
+ public:
+ ShuffleStreamReader(JNIEnv* env, jobject reader) {
+ if (env->GetJavaVM(&vm_) != JNI_OK) {
+ throw GlutenException("Unable to get JavaVM instance");
+ }
+ ref_ = env->NewGlobalRef(reader);
+ }
+
+ ~ShuffleStreamReader() override {
+ JNIEnv* env = nullptr;
+ attachCurrentThreadAsDaemonOrThrow(vm_, &env);
+ env->DeleteGlobalRef(ref_);
+ }
+
+ std::shared_ptr<arrow::io::InputStream> readNextStream(arrow::MemoryPool*
pool) override {
+ JNIEnv* env = nullptr;
+ attachCurrentThreadAsDaemonOrThrow(vm_, &env);
+
+ jobject jniIn = env->CallObjectMethod(ref_, shuffleStreamReaderNextStream);
+ checkException(env);
+ if (jniIn == nullptr) {
+ return nullptr; // No more streams to read
+ }
+ std::shared_ptr<arrow::io::InputStream> in =
std::make_shared<JavaInputStreamAdaptor>(env, pool, jniIn);
+ return in;
+ }
+
+ private:
+ JavaVM* vm_;
+ jobject ref_;
+};
+
} // namespace
#ifdef __cplusplus
@@ -236,6 +272,11 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
shuffleReaderMetricsSetDeserializeTime =
getMethodIdOrError(env, shuffleReaderMetricsClass, "setDeserializeTime",
"(J)V");
+ shuffleStreamReaderClass =
+ createGlobalClassReferenceOrError(env,
"Lorg/apache/gluten/vectorized/ShuffleStreamReader;");
+ shuffleStreamReaderNextStream = getMethodIdOrError(
+ env, shuffleStreamReaderClass, "nextStream",
"()Lorg/apache/gluten/vectorized/JniByteInputStream;");
+
return jniVersion;
}
@@ -1061,16 +1102,18 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrappe
JNI_METHOD_END(kInvalidObjectHandle)
}
-JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper_readStream( // NOLINT
+JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper_read( // NOLINT
JNIEnv* env,
jobject wrapper,
jlong shuffleReaderHandle,
- jobject jniIn) {
+ jobject jStreamReader) {
JNI_METHOD_START
auto ctx = getRuntime(env, wrapper);
auto reader = ObjectStore::retrieve<ShuffleReader>(shuffleReaderHandle);
- std::shared_ptr<arrow::io::InputStream> in =
std::make_shared<JavaInputStreamAdaptor>(env, reader->getPool(), jniIn);
- auto outItr = reader->readStream(in);
+
+ auto streamReader = std::make_shared<ShuffleStreamReader>(env,
jStreamReader);
+
+ auto outItr = reader->read(streamReader);
return ctx->saveObject(outItr);
JNI_METHOD_END(kInvalidObjectHandle)
}
diff --git a/cpp/core/shuffle/ShuffleReader.h b/cpp/core/shuffle/ShuffleReader.h
index 6e2b079fc7..101865d253 100644
--- a/cpp/core/shuffle/ShuffleReader.h
+++ b/cpp/core/shuffle/ShuffleReader.h
@@ -21,18 +21,23 @@
namespace gluten {
+class StreamReader {
+ public:
+ virtual ~StreamReader() = default;
+
+ virtual std::shared_ptr<arrow::io::InputStream>
readNextStream(arrow::MemoryPool* pool) = 0;
+};
+
class ShuffleReader {
public:
virtual ~ShuffleReader() = default;
// FIXME iterator should be unique_ptr or un-copyable singleton
- virtual std::shared_ptr<ResultIterator>
readStream(std::shared_ptr<arrow::io::InputStream> in) = 0;
+ virtual std::shared_ptr<ResultIterator> read(const
std::shared_ptr<StreamReader>& streamReader) = 0;
virtual int64_t getDecompressTime() const = 0;
virtual int64_t getDeserializeTime() const = 0;
-
- virtual arrow::MemoryPool* getPool() const = 0;
};
} // namespace gluten
diff --git a/cpp/velox/benchmarks/GenericBenchmark.cc
b/cpp/velox/benchmarks/GenericBenchmark.cc
index d1e08c67b4..1559c5c39f 100644
--- a/cpp/velox/benchmarks/GenericBenchmark.cc
+++ b/cpp/velox/benchmarks/GenericBenchmark.cc
@@ -37,6 +37,7 @@
#include "shuffle/rss/RssPartitionWriter.h"
#include "tests/utils/LocalRssClient.h"
#include "tests/utils/TestAllocationListener.h"
+#include "tests/utils/TestStreamReader.h"
#include "utils/Exception.h"
#include "utils/StringUtil.h"
#include "utils/Timer.h"
@@ -304,8 +305,9 @@ void runShuffle(
const auto reader = createShuffleReader(runtime, schema);
GLUTEN_ASSIGN_OR_THROW(auto in, arrow::io::ReadableFile::Open(dataFile));
+ auto streamReader = std::make_shared<TestStreamReader>(std::move(in));
// Read all partitions.
- auto iter = reader->readStream(in);
+ auto iter = reader->read(streamReader);
while (iter->hasNext()) {
// Read and discard.
auto cb = iter->next();
diff --git a/cpp/velox/compute/VeloxRuntime.cc
b/cpp/velox/compute/VeloxRuntime.cc
index 65195b7ea6..e0315cf508 100644
--- a/cpp/velox/compute/VeloxRuntime.cc
+++ b/cpp/velox/compute/VeloxRuntime.cc
@@ -299,8 +299,7 @@ std::shared_ptr<ShuffleReader>
VeloxRuntime::createShuffleReader(
options.batchSize,
options.readerBufferSize,
options.deserializerBufferSize,
- memoryManager()->defaultArrowMemoryPool(),
- memoryManager()->getLeafMemoryPool(),
+ memoryManager(),
options.shuffleWriterType);
return std::make_shared<VeloxShuffleReader>(std::move(deserializerFactory));
diff --git a/cpp/velox/shuffle/VeloxShuffleReader.cc
b/cpp/velox/shuffle/VeloxShuffleReader.cc
index 73a31c76fb..a39566a9b6 100644
--- a/cpp/velox/shuffle/VeloxShuffleReader.cc
+++ b/cpp/velox/shuffle/VeloxShuffleReader.cc
@@ -453,50 +453,36 @@ class VeloxDictionaryReader {
};
VeloxHashShuffleReaderDeserializer::VeloxHashShuffleReaderDeserializer(
- std::shared_ptr<arrow::io::InputStream> in,
+ const std::shared_ptr<StreamReader>& streamReader,
const std::shared_ptr<arrow::Schema>& schema,
const std::shared_ptr<arrow::util::Codec>& codec,
const facebook::velox::RowTypePtr& rowType,
int32_t batchSize,
- int64_t bufferSize,
- arrow::MemoryPool* memoryPool,
- facebook::velox::memory::MemoryPool* veloxPool,
+ int64_t readerBufferSize,
+ VeloxMemoryManager* memoryManager,
std::vector<bool>* isValidityBuffer,
bool hasComplexType,
int64_t& deserializeTime,
int64_t& decompressTime)
- : schema_(schema),
+ : streamReader_(streamReader),
+ schema_(schema),
codec_(codec),
rowType_(rowType),
batchSize_(batchSize),
- memoryPool_(memoryPool),
- veloxPool_(veloxPool),
+ readerBufferSize_(readerBufferSize),
+ memoryManager_(memoryManager),
isValidityBuffer_(isValidityBuffer),
hasComplexType_(hasComplexType),
deserializeTime_(deserializeTime),
- decompressTime_(decompressTime) {
- GLUTEN_ASSIGN_OR_THROW(in_,
arrow::io::BufferedInputStream::Create(bufferSize, memoryPool, std::move(in)));
-}
-
-bool VeloxHashShuffleReaderDeserializer::shouldSkipMerge() const {
- // Complex type or dictionary encodings do not support merging.
- return hasComplexType_ || !dictionaryFields_.empty();
-}
-
-void VeloxHashShuffleReaderDeserializer::resolveNextBlockType() {
- if (blockTypeResolved_) {
- return;
- }
-
- blockTypeResolved_ = true;
+ decompressTime_(decompressTime) {}
+bool VeloxHashShuffleReaderDeserializer::resolveNextBlockType() {
GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(in_.get()));
switch (blockType) {
case BlockType::kEndOfStream:
- reachEos_ = true;
- break;
+ return false;
case BlockType::kDictionary: {
- VeloxDictionaryReader reader(rowType_, veloxPool_, codec_.get());
+ VeloxDictionaryReader reader(rowType_,
memoryManager_->getLeafMemoryPool().get(), codec_.get());
GLUTEN_ASSIGN_OR_THROW(dictionaryFields_, reader.readFields(in_.get()));
GLUTEN_ASSIGN_OR_THROW(dictionaries_, reader.readDictionaries(in_.get(),
dictionaryFields_));
@@ -515,114 +501,83 @@ void
VeloxHashShuffleReaderDeserializer::resolveNextBlockType() {
dictionaries_.clear();
}
} break;
+ default:
+ throw GlutenException(fmt::format("Unsupported block type: {}",
static_cast<int32_t>(blockType)));
}
+ return true;
}
-std::shared_ptr<ColumnarBatch> VeloxHashShuffleReaderDeserializer::next() {
- resolveNextBlockType();
-
- if (shouldSkipMerge()) {
- // We have leftover rows from the last mergeable read.
- if (merged_) {
- return makeColumnarBatch(rowType_, std::move(merged_), veloxPool_,
deserializeTime_);
- }
-
- if (reachEos_) {
- return nullptr;
- }
-
- uint32_t numRows = 0;
- GLUTEN_ASSIGN_OR_THROW(
- auto arrowBuffers,
- BlockPayload::deserialize(in_.get(), codec_, memoryPool_, numRows,
deserializeTime_, decompressTime_));
-
- blockTypeResolved_ = false;
-
- return makeColumnarBatch(
- rowType_, numRows, std::move(arrowBuffers), dictionaryFields_,
dictionaries_, veloxPool_, deserializeTime_);
+void VeloxHashShuffleReaderDeserializer::loadNextStream() {
+ if (reachedEos_) {
+ return;
}
- // TODO: Remove merging.
- if (reachEos_) {
- if (merged_) {
- return makeColumnarBatch(rowType_, std::move(merged_), veloxPool_,
deserializeTime_);
- }
- return nullptr;
+ auto in =
streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool());
+ if (in == nullptr) {
+ reachedEos_ = true;
+ return;
}
- std::vector<std::shared_ptr<arrow::Buffer>> arrowBuffers{};
- uint32_t numRows = 0;
- while (!merged_ || merged_->numRows() < batchSize_) {
- resolveNextBlockType();
-
- // Break the merging loop once we reach EOS or read a dictionary block.
- if (reachEos_ || !dictionaryFields_.empty()) {
- break;
- }
-
- GLUTEN_ASSIGN_OR_THROW(
- arrowBuffers,
- BlockPayload::deserialize(in_.get(), codec_, memoryPool_, numRows,
deserializeTime_, decompressTime_));
-
- blockTypeResolved_ = false;
+ GLUTEN_ASSIGN_OR_THROW(
+ in_,
+ arrow::io::BufferedInputStream::Create(
+ readerBufferSize_, memoryManager_->defaultArrowMemoryPool(),
std::move(in)));
+}
- if (!merged_) {
- merged_ = std::make_unique<InMemoryPayload>(numRows, isValidityBuffer_,
schema_, std::move(arrowBuffers));
- arrowBuffers.clear();
- continue;
- }
+std::shared_ptr<ColumnarBatch> VeloxHashShuffleReaderDeserializer::next() {
+ if (in_ == nullptr) {
+ loadNextStream();
- auto mergedRows = merged_->numRows() + numRows;
- if (mergedRows > batchSize_) {
- break;
+ if (reachedEos_) {
+ return nullptr;
}
-
- auto append = std::make_unique<InMemoryPayload>(numRows,
isValidityBuffer_, schema_, std::move(arrowBuffers));
- GLUTEN_ASSIGN_OR_THROW(merged_, InMemoryPayload::merge(std::move(merged_),
std::move(append), memoryPool_));
- arrowBuffers.clear();
- }
-
- // Reach EOS.
- if (reachEos_ && !merged_) {
- return nullptr;
}
- auto columnarBatch = makeColumnarBatch(rowType_, std::move(merged_),
veloxPool_, deserializeTime_);
+ while (!resolveNextBlockType()) {
+ loadNextStream();
- // Save remaining rows.
- if (!arrowBuffers.empty()) {
- merged_ = std::make_unique<InMemoryPayload>(numRows, isValidityBuffer_,
schema_, std::move(arrowBuffers));
+ if (reachedEos_) {
+ return nullptr;
+ }
}
- return columnarBatch;
+ uint32_t numRows = 0;
+ GLUTEN_ASSIGN_OR_THROW(
+ auto arrowBuffers,
+ BlockPayload::deserialize(
+ in_.get(), codec_, memoryManager_->defaultArrowMemoryPool(),
numRows, deserializeTime_, decompressTime_));
+
+ return makeColumnarBatch(
+ rowType_,
+ numRows,
+ std::move(arrowBuffers),
+ dictionaryFields_,
+ dictionaries_,
+ memoryManager_->getLeafMemoryPool().get(),
+ deserializeTime_);
}
VeloxSortShuffleReaderDeserializer::VeloxSortShuffleReaderDeserializer(
- std::shared_ptr<arrow::io::InputStream> in,
+ const std::shared_ptr<StreamReader>& streamReader,
const std::shared_ptr<arrow::Schema>& schema,
const std::shared_ptr<arrow::util::Codec>& codec,
const RowTypePtr& rowType,
int32_t batchSize,
int64_t readerBufferSize,
int64_t deserializerBufferSize,
- arrow::MemoryPool* memoryPool,
- facebook::velox::memory::MemoryPool* veloxPool,
+ VeloxMemoryManager* memoryManager,
int64_t& deserializeTime,
int64_t& decompressTime)
- : schema_(schema),
+ : streamReader_(streamReader),
+ schema_(schema),
codec_(codec),
rowType_(rowType),
batchSize_(batchSize),
+ readerBufferSize_(readerBufferSize),
deserializerBufferSize_(deserializerBufferSize),
deserializeTime_(deserializeTime),
decompressTime_(decompressTime),
- veloxPool_(veloxPool) {
- if (codec_ != nullptr) {
- GLUTEN_ASSIGN_OR_THROW(in_, CompressedInputStream::Make(codec_.get(),
std::move(in), memoryPool));
- } else {
- GLUTEN_ASSIGN_OR_THROW(in_,
arrow::io::BufferedInputStream::Create(readerBufferSize, memoryPool,
std::move(in)));
- }
-}
+ memoryManager_(memoryManager) {}
VeloxSortShuffleReaderDeserializer::~VeloxSortShuffleReaderDeserializer() {
if (auto in = std::dynamic_pointer_cast<CompressedInputStream>(in_)) {
@@ -631,13 +586,17 @@
VeloxSortShuffleReaderDeserializer::~VeloxSortShuffleReaderDeserializer() {
}
std::shared_ptr<ColumnarBatch> VeloxSortShuffleReaderDeserializer::next() {
+ if (in_ == nullptr) {
+ loadNextStream();
+ }
+
if (reachedEos_) {
return nullptr;
}
if (rowBuffer_ == nullptr) {
- rowBuffer_ =
- AlignedBuffer::allocate<char>(deserializerBufferSize_, veloxPool_,
std::nullopt, true /*allocateExact*/);
+ rowBuffer_ = AlignedBuffer::allocate<char>(
+ deserializerBufferSize_, memoryManager_->getLeafMemoryPool().get(),
std::nullopt, true /*allocateExact*/);
rowBufferPtr_ = rowBuffer_->asMutable<char>();
data_.reserve(batchSize_);
}
@@ -651,12 +610,17 @@ std::shared_ptr<ColumnarBatch>
VeloxSortShuffleReaderDeserializer::next() {
while (cachedRows_ < batchSize_) {
GLUTEN_ASSIGN_OR_THROW(auto bytes, in_->Read(sizeof(RowSizeType),
&lastRowSize_));
- if (bytes == 0) {
- reachedEos_ = true;
- if (bytesRead_ > 0) {
- return deserializeToBatch();
+ while (bytes == 0) {
+ // Current stream has no more data. Try to load the next stream.
+ loadNextStream();
+ if (reachedEos_) {
+ if (bytesRead_ > 0) {
+ return deserializeToBatch();
+ }
+ // If we reached EOS and have no rows, return nullptr.
+ return nullptr;
}
- return nullptr;
+ GLUTEN_ASSIGN_OR_THROW(bytes, in_->Read(sizeof(RowSizeType),
&lastRowSize_));
}
if (lastRowSize_ + bytesRead_ > rowBuffer_->size()) {
@@ -676,7 +640,8 @@ std::shared_ptr<ColumnarBatch>
VeloxSortShuffleReaderDeserializer::next() {
std::shared_ptr<ColumnarBatch>
VeloxSortShuffleReaderDeserializer::deserializeToBatch() {
ScopedTimer timer(&deserializeTime_);
- auto rowVector = facebook::velox::row::CompactRow::deserialize(data_,
rowType_, veloxPool_);
+ auto rowVector =
+ facebook::velox::row::CompactRow::deserialize(data_, rowType_,
memoryManager_->getLeafMemoryPool().get());
cachedRows_ = 0;
bytesRead_ = 0;
@@ -688,10 +653,33 @@ void
VeloxSortShuffleReaderDeserializer::reallocateRowBuffer() {
auto newSize = facebook::velox::bits::nextPowerOfTwo(lastRowSize_);
LOG(WARNING) << "Row size " << lastRowSize_ << " exceeds current buffer size
" << rowBuffer_->size()
<< ". Resizing buffer to " << newSize;
- rowBuffer_ = AlignedBuffer::allocate<char>(newSize, veloxPool_,
std::nullopt, true /*allocateExact*/);
+ rowBuffer_ = AlignedBuffer::allocate<char>(
+ newSize, memoryManager_->getLeafMemoryPool().get(), std::nullopt, true
/*allocateExact*/);
rowBufferPtr_ = rowBuffer_->asMutable<char>();
}
+void VeloxSortShuffleReaderDeserializer::loadNextStream() {
+ if (reachedEos_) {
+ return;
+ }
+
+ auto in =
streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool());
+ if (in == nullptr) {
+ reachedEos_ = true;
+ return;
+ }
+
+ if (codec_ != nullptr) {
+ GLUTEN_ASSIGN_OR_THROW(
+ in_, CompressedInputStream::Make(codec_.get(), std::move(in),
memoryManager_->defaultArrowMemoryPool()));
+ } else {
+ GLUTEN_ASSIGN_OR_THROW(
+ in_,
+ arrow::io::BufferedInputStream::Create(
+ readerBufferSize_, memoryManager_->defaultArrowMemoryPool(),
std::move(in)));
+ }
+}
+
void VeloxSortShuffleReaderDeserializer::readNextRow() {
GLUTEN_THROW_NOT_OK(in_->Read(lastRowSize_, rowBufferPtr_ + bytesRead_));
data_.push_back(std::string_view(rowBufferPtr_ + bytesRead_, lastRowSize_));
@@ -744,37 +732,37 @@ void
VeloxRssSortShuffleReaderDeserializer::VeloxInputStream::next(bool throwIfP
}
VeloxRssSortShuffleReaderDeserializer::VeloxRssSortShuffleReaderDeserializer(
- const std::shared_ptr<facebook::velox::memory::MemoryPool>& veloxPool,
+ const std::shared_ptr<StreamReader>& streamReader,
+ VeloxMemoryManager* memoryManager,
const RowTypePtr& rowType,
int32_t batchSize,
facebook::velox::common::CompressionKind veloxCompressionType,
- int64_t& deserializeTime,
- std::shared_ptr<arrow::io::InputStream> in)
- : veloxPool_(veloxPool),
+ int64_t& deserializeTime)
+ : streamReader_(streamReader),
+ memoryManager_(memoryManager),
rowType_(rowType),
batchSize_(batchSize),
veloxCompressionType_(veloxCompressionType),
serde_(getNamedVectorSerde(facebook::velox::VectorSerde::Kind::kPresto)),
- deserializeTime_(deserializeTime),
- arrowIn_(in) {
+ deserializeTime_(deserializeTime) {
serdeOptions_ = {false, veloxCompressionType_};
}
std::shared_ptr<ColumnarBatch> VeloxRssSortShuffleReaderDeserializer::next() {
- if (in_ == nullptr) {
- constexpr uint64_t kMaxReadBufferSize = (1 << 20) -
AlignedBuffer::kPaddedSize;
- auto buffer = AlignedBuffer::allocate<char>(kMaxReadBufferSize,
veloxPool_.get());
- in_ = std::make_unique<VeloxInputStream>(std::move(arrowIn_),
std::move(buffer));
- }
-
- if (!in_->hasNext()) {
- return nullptr;
+ if (in_ == nullptr || !in_->hasNext()) {
+ do {
+ loadNextStream();
+ if (reachedEos_) {
+ return nullptr;
+ }
+ } while (!in_->hasNext());
}
ScopedTimer timer(&deserializeTime_);
RowVectorPtr rowVector;
- VectorStreamGroup::read(in_.get(), veloxPool_.get(), rowType_, serde_,
&rowVector, &serdeOptions_);
+ VectorStreamGroup::read(
+ in_.get(), memoryManager_->getLeafMemoryPool().get(), rowType_, serde_,
&rowVector, &serdeOptions_);
if (rowVector->size() >= batchSize_) {
return std::make_shared<VeloxColumnarBatch>(std::move(rowVector));
@@ -782,13 +770,31 @@ std::shared_ptr<ColumnarBatch>
VeloxRssSortShuffleReaderDeserializer::next() {
while (rowVector->size() < batchSize_ && in_->hasNext()) {
RowVectorPtr rowVectorTemp;
- VectorStreamGroup::read(in_.get(), veloxPool_.get(), rowType_, serde_,
&rowVectorTemp, &serdeOptions_);
+ VectorStreamGroup::read(
+ in_.get(), memoryManager_->getLeafMemoryPool().get(), rowType_,
serde_, &rowVectorTemp, &serdeOptions_);
rowVector->append(rowVectorTemp.get());
}
return std::make_shared<VeloxColumnarBatch>(std::move(rowVector));
}
+void VeloxRssSortShuffleReaderDeserializer::loadNextStream() {
+ if (reachedEos_) {
+ return;
+ }
+
+ arrowIn_ =
streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool());
+
+ if (arrowIn_ == nullptr) {
+ reachedEos_ = true;
+ return;
+ }
+
+ constexpr uint64_t kMaxReadBufferSize = (1 << 20) -
AlignedBuffer::kPaddedSize;
+ auto buffer = AlignedBuffer::allocate<char>(kMaxReadBufferSize,
memoryManager_->getLeafMemoryPool().get());
+ in_ = std::make_unique<VeloxInputStream>(std::move(arrowIn_),
std::move(buffer));
+}
+
size_t
VeloxRssSortShuffleReaderDeserializer::VeloxInputStream::remainingSize() const {
return std::numeric_limits<unsigned long>::max();
}
@@ -801,8 +807,7 @@
VeloxShuffleReaderDeserializerFactory::VeloxShuffleReaderDeserializerFactory(
int32_t batchSize,
int64_t readerBufferSize,
int64_t deserializerBufferSize,
- arrow::MemoryPool* memoryPool,
- std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool,
+ VeloxMemoryManager* memoryManager,
ShuffleWriterType shuffleWriterType)
: schema_(schema),
codec_(codec),
@@ -811,53 +816,46 @@
VeloxShuffleReaderDeserializerFactory::VeloxShuffleReaderDeserializerFactory(
batchSize_(batchSize),
readerBufferSize_(readerBufferSize),
deserializerBufferSize_(deserializerBufferSize),
- memoryPool_(memoryPool),
- veloxPool_(veloxPool),
+ memoryManager_(memoryManager),
shuffleWriterType_(shuffleWriterType) {
initFromSchema();
}
std::unique_ptr<ColumnarBatchIterator>
VeloxShuffleReaderDeserializerFactory::createDeserializer(
- std::shared_ptr<arrow::io::InputStream> in) {
+ const std::shared_ptr<StreamReader>& streamReader) {
switch (shuffleWriterType_) {
case ShuffleWriterType::kHashShuffle:
return std::make_unique<VeloxHashShuffleReaderDeserializer>(
- std::move(in),
+ streamReader,
schema_,
codec_,
rowType_,
batchSize_,
readerBufferSize_,
- memoryPool_,
- veloxPool_.get(),
+ memoryManager_,
&isValidityBuffer_,
hasComplexType_,
deserializeTime_,
decompressTime_);
case ShuffleWriterType::kSortShuffle:
return std::make_unique<VeloxSortShuffleReaderDeserializer>(
- std::move(in),
+ streamReader,
schema_,
codec_,
rowType_,
batchSize_,
readerBufferSize_,
deserializerBufferSize_,
- memoryPool_,
- veloxPool_.get(),
+ memoryManager_,
deserializeTime_,
decompressTime_);
case ShuffleWriterType::kRssSortShuffle:
return std::make_unique<VeloxRssSortShuffleReaderDeserializer>(
- veloxPool_, rowType_, batchSize_, veloxCompressionType_,
deserializeTime_, std::move(in));
+ streamReader, memoryManager_, rowType_, batchSize_,
veloxCompressionType_, deserializeTime_);
}
GLUTEN_UNREACHABLE();
}
-arrow::MemoryPool* VeloxShuffleReaderDeserializerFactory::getPool() {
- return memoryPool_;
-}
-
int64_t VeloxShuffleReaderDeserializerFactory::getDecompressTime() {
return decompressTime_;
}
@@ -899,12 +897,8 @@ void
VeloxShuffleReaderDeserializerFactory::initFromSchema() {
VeloxShuffleReader::VeloxShuffleReader(std::unique_ptr<VeloxShuffleReaderDeserializerFactory>
factory)
: factory_(std::move(factory)) {}
-std::shared_ptr<ResultIterator>
VeloxShuffleReader::readStream(std::shared_ptr<arrow::io::InputStream> in) {
- return std::make_shared<ResultIterator>(factory_->createDeserializer(in));
-}
-
-arrow::MemoryPool* VeloxShuffleReader::getPool() const {
- return factory_->getPool();
+std::shared_ptr<ResultIterator> VeloxShuffleReader::read(const
std::shared_ptr<StreamReader>& streamReader) {
+ return
std::make_shared<ResultIterator>(factory_->createDeserializer(streamReader));
}
int64_t VeloxShuffleReader::getDecompressTime() const {
diff --git a/cpp/velox/shuffle/VeloxShuffleReader.h
b/cpp/velox/shuffle/VeloxShuffleReader.h
index 686253da25..26a1634f4d 100644
--- a/cpp/velox/shuffle/VeloxShuffleReader.h
+++ b/cpp/velox/shuffle/VeloxShuffleReader.h
@@ -30,14 +30,13 @@ namespace gluten {
class VeloxHashShuffleReaderDeserializer final : public ColumnarBatchIterator {
public:
VeloxHashShuffleReaderDeserializer(
- std::shared_ptr<arrow::io::InputStream> in,
+ const std::shared_ptr<StreamReader>& streamReader,
const std::shared_ptr<arrow::Schema>& schema,
const std::shared_ptr<arrow::util::Codec>& codec,
const facebook::velox::RowTypePtr& rowType,
int32_t batchSize,
- int64_t bufferSize,
- arrow::MemoryPool* memoryPool,
- facebook::velox::memory::MemoryPool* veloxPool,
+ int64_t readerBufferSize,
+ VeloxMemoryManager* memoryManager,
std::vector<bool>* isValidityBuffer,
bool hasComplexType,
int64_t& deserializeTime,
@@ -46,25 +45,27 @@ class VeloxHashShuffleReaderDeserializer final : public
ColumnarBatchIterator {
std::shared_ptr<ColumnarBatch> next() override;
private:
- bool shouldSkipMerge() const;
+ bool resolveNextBlockType();
- void resolveNextBlockType();
+ void loadNextStream();
- std::shared_ptr<arrow::io::InputStream> in_;
+ std::shared_ptr<StreamReader> streamReader_;
std::shared_ptr<arrow::Schema> schema_;
std::shared_ptr<arrow::util::Codec> codec_;
facebook::velox::RowTypePtr rowType_;
int32_t batchSize_;
- arrow::MemoryPool* memoryPool_;
- facebook::velox::memory::MemoryPool* veloxPool_;
+ int64_t readerBufferSize_;
+ VeloxMemoryManager* memoryManager_;
+
std::vector<bool>* isValidityBuffer_;
bool hasComplexType_;
int64_t& deserializeTime_;
int64_t& decompressTime_;
- std::unique_ptr<InMemoryPayload> merged_{nullptr};
- bool reachEos_{false};
+ std::shared_ptr<arrow::io::InputStream> in_{nullptr};
+
+ bool reachedEos_{false};
bool blockTypeResolved_{false};
std::vector<int32_t> dictionaryFields_{};
@@ -76,15 +77,14 @@ class VeloxSortShuffleReaderDeserializer final : public
ColumnarBatchIterator {
using RowSizeType = VeloxSortShuffleWriter::RowSizeType;
VeloxSortShuffleReaderDeserializer(
- std::shared_ptr<arrow::io::InputStream> in,
+ const std::shared_ptr<StreamReader>& streamReader,
const std::shared_ptr<arrow::Schema>& schema,
const std::shared_ptr<arrow::util::Codec>& codec,
const facebook::velox::RowTypePtr& rowType,
int32_t batchSize,
int64_t readerBufferSize,
int64_t deserializerBufferSize,
- arrow::MemoryPool* memoryPool,
- facebook::velox::memory::MemoryPool* veloxPool,
+ VeloxMemoryManager* memoryManager,
int64_t& deserializeTime,
int64_t& decompressTime);
@@ -99,16 +99,20 @@ class VeloxSortShuffleReaderDeserializer final : public
ColumnarBatchIterator {
void reallocateRowBuffer();
+ void loadNextStream();
+
+ std::shared_ptr<StreamReader> streamReader_;
std::shared_ptr<arrow::Schema> schema_;
std::shared_ptr<arrow::util::Codec> codec_;
facebook::velox::RowTypePtr rowType_;
uint32_t batchSize_;
+ int64_t readerBufferSize_;
int64_t deserializerBufferSize_;
int64_t& deserializeTime_;
int64_t& decompressTime_;
- facebook::velox::memory::MemoryPool* veloxPool_;
+ VeloxMemoryManager* memoryManager_;
facebook::velox::BufferPtr rowBuffer_{nullptr};
char* rowBufferPtr_{nullptr};
@@ -116,7 +120,7 @@ class VeloxSortShuffleReaderDeserializer final : public
ColumnarBatchIterator {
uint32_t lastRowSize_{0};
std::vector<std::string_view> data_;
- std::shared_ptr<arrow::io::InputStream> in_;
+ std::shared_ptr<arrow::io::InputStream> in_{nullptr};
uint32_t cachedRows_{0};
bool reachedEos_{false};
@@ -125,19 +129,22 @@ class VeloxSortShuffleReaderDeserializer final : public
ColumnarBatchIterator {
class VeloxRssSortShuffleReaderDeserializer : public ColumnarBatchIterator {
public:
VeloxRssSortShuffleReaderDeserializer(
- const std::shared_ptr<facebook::velox::memory::MemoryPool>& veloxPool,
+ const std::shared_ptr<StreamReader>& streamReader,
+ VeloxMemoryManager* memoryManager,
const facebook::velox::RowTypePtr& rowType,
int32_t batchSize,
facebook::velox::common::CompressionKind veloxCompressionType,
- int64_t& deserializeTime,
- std::shared_ptr<arrow::io::InputStream> in);
+ int64_t& deserializeTime);
std::shared_ptr<ColumnarBatch> next();
private:
class VeloxInputStream;
- std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool_;
+ void loadNextStream();
+
+ std::shared_ptr<StreamReader> streamReader_;
+ VeloxMemoryManager* memoryManager_;
facebook::velox::RowTypePtr rowType_;
std::vector<facebook::velox::RowVectorPtr> batches_;
int32_t batchSize_;
@@ -145,8 +152,10 @@ class VeloxRssSortShuffleReaderDeserializer : public
ColumnarBatchIterator {
facebook::velox::VectorSerde* const serde_;
facebook::velox::serializer::presto::PrestoVectorSerde::PrestoOptions
serdeOptions_;
int64_t& deserializeTime_;
- std::shared_ptr<VeloxInputStream> in_;
- std::shared_ptr<arrow::io::InputStream> arrowIn_;
+ std::shared_ptr<VeloxInputStream> in_{nullptr};
+ std::shared_ptr<arrow::io::InputStream> arrowIn_{nullptr};
+
+ bool reachedEos_{false};
};
class VeloxShuffleReaderDeserializerFactory {
@@ -159,13 +168,10 @@ class VeloxShuffleReaderDeserializerFactory {
int32_t batchSize,
int64_t readerBufferSize,
int64_t deserializerBufferSize,
- arrow::MemoryPool* memoryPool,
- std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool,
+ VeloxMemoryManager* memoryManager,
ShuffleWriterType shuffleWriterType);
- std::unique_ptr<ColumnarBatchIterator>
createDeserializer(std::shared_ptr<arrow::io::InputStream> in);
-
- arrow::MemoryPool* getPool();
+ std::unique_ptr<ColumnarBatchIterator> createDeserializer(const
std::shared_ptr<StreamReader>& streamReader);
int64_t getDecompressTime();
@@ -181,8 +187,7 @@ class VeloxShuffleReaderDeserializerFactory {
int32_t batchSize_;
int64_t readerBufferSize_;
int64_t deserializerBufferSize_;
- arrow::MemoryPool* memoryPool_;
- std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool_;
+ VeloxMemoryManager* memoryManager_;
std::vector<bool> isValidityBuffer_;
bool hasComplexType_{false};
@@ -197,14 +202,12 @@ class VeloxShuffleReader final : public ShuffleReader {
public:
VeloxShuffleReader(std::unique_ptr<VeloxShuffleReaderDeserializerFactory>
factory);
- std::shared_ptr<ResultIterator>
readStream(std::shared_ptr<arrow::io::InputStream> in) override;
+ std::shared_ptr<ResultIterator> read(const std::shared_ptr<StreamReader>&
streamReader) override;
int64_t getDecompressTime() const override;
int64_t getDeserializeTime() const override;
- arrow::MemoryPool* getPool() const override;
-
private:
std::unique_ptr<VeloxShuffleReaderDeserializerFactory> factory_;
};
diff --git a/cpp/velox/tests/VeloxShuffleWriterTest.cc
b/cpp/velox/tests/VeloxShuffleWriterTest.cc
index 61cf7210a4..0d62faafee 100644
--- a/cpp/velox/tests/VeloxShuffleWriterTest.cc
+++ b/cpp/velox/tests/VeloxShuffleWriterTest.cc
@@ -23,8 +23,9 @@
#include "shuffle/VeloxRssSortShuffleWriter.h"
#include "shuffle/VeloxSortShuffleWriter.h"
#include "tests/VeloxShuffleWriterTestBase.h"
+#include "tests/utils/TestAllocationListener.h"
+#include "tests/utils/TestStreamReader.h"
#include "tests/utils/TestUtils.h"
-#include "utils/TestAllocationListener.h"
#include "utils/VeloxArrowUtils.h"
#include "velox/vector/tests/utils/VectorTestBase.h"
@@ -303,13 +304,12 @@ class VeloxShuffleWriterTest : public
::testing::TestWithParam<ShuffleTestParams
kDefaultBatchSize,
kDefaultReadBufferSize,
GetParam().deserializerBufferSize,
- getDefaultMemoryManager()->defaultArrowMemoryPool(),
- pool_,
+ getDefaultMemoryManager(),
GetParam().shuffleWriterType);
const auto reader =
std::make_shared<VeloxShuffleReader>(std::move(deserializerFactory));
- const auto iter = reader->readStream(in);
+ const auto iter =
reader->read(std::make_shared<TestStreamReader>(std::move(in)));
while (iter->hasNext()) {
auto vector =
std::dynamic_pointer_cast<VeloxColumnarBatch>(iter->next())->getRowVector();
vectors.emplace_back(vector);
diff --git a/cpp/core/shuffle/ShuffleReader.h
b/cpp/velox/tests/utils/TestStreamReader.h
similarity index 67%
copy from cpp/core/shuffle/ShuffleReader.h
copy to cpp/velox/tests/utils/TestStreamReader.h
index 6e2b079fc7..fd47574c6b 100644
--- a/cpp/core/shuffle/ShuffleReader.h
+++ b/cpp/velox/tests/utils/TestStreamReader.h
@@ -17,22 +17,21 @@
#pragma once
-#include "compute/ResultIterator.h"
+#include "shuffle/ShuffleReader.h"
+#include "shuffle/ShuffleWriter.h"
namespace gluten {
-class ShuffleReader {
+class TestStreamReader : public StreamReader {
public:
- virtual ~ShuffleReader() = default;
+ explicit TestStreamReader(const std::shared_ptr<arrow::io::InputStream>&
inputStream) : inputStream_(inputStream) {}
- // FIXME iterator should be unique_ptr or un-copyable singleton
- virtual std::shared_ptr<ResultIterator>
readStream(std::shared_ptr<arrow::io::InputStream> in) = 0;
+ std::shared_ptr<arrow::io::InputStream> readNextStream(arrow::MemoryPool*)
override {
+ return std::move(inputStream_);
+ }
- virtual int64_t getDecompressTime() const = 0;
-
- virtual int64_t getDeserializeTime() const = 0;
-
- virtual arrow::MemoryPool* getPool() const = 0;
+ private:
+ std::shared_ptr<arrow::io::InputStream> inputStream_;
};
} // namespace gluten
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
index 46787d4209..6a0f2130d7 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
@@ -44,7 +44,7 @@ public class ShuffleReaderJniWrapper implements RuntimeAware {
long deserializerBufferSize,
String shuffleWriterType);
- public native long readStream(long shuffleReaderHandle, JniByteInputStream
jniIn);
+ public native long read(long shuffleReaderHandle, ShuffleStreamReader
streamReader);
public native void populateMetrics(long shuffleReaderHandle,
ShuffleReaderMetrics metrics);
diff --git
a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala
new file mode 100644
index 0000000000..59a9f9e146
--- /dev/null
+++
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized
+
+import org.apache.spark.storage.BlockId
+
+import java.io.InputStream
+
+case class ShuffleStreamReader(streams: Iterator[(BlockId, InputStream)]) {
+ private val jniStreams = streams.map {
+ case (blockId, in) =>
+ (blockId, JniByteInputStreams.create(in))
+ }
+
+ private var currentStream: JniByteInputStream = _
+
+ // Called from native side to get the next stream.
+ def nextStream(): JniByteInputStream = {
+ if (currentStream != null) {
+ currentStream.close()
+ }
+ if (!jniStreams.hasNext) {
+ currentStream = null
+ } else {
+ currentStream = jniStreams.next._2
+ }
+ currentStream
+ }
+
+ def close(): Unit = {
+ // The reader may not attempt to read all streams from `nextStream`, so we
need to close the
+ // current stream if it's not null.
+ if (currentStream != null) {
+ currentStream.close()
+ currentStream = null
+ }
+ }
+}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 58ad9d3e4b..e4de765672 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -27,7 +27,7 @@ import
org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode
import org.apache.spark.ShuffleDependency
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{GenShuffleWriterParameters,
GlutenShuffleWriterWrapper}
+import org.apache.spark.shuffle.{GenShuffleReaderParameters,
GenShuffleWriterParameters, GlutenShuffleReaderWrapper,
GlutenShuffleWriterWrapper}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
@@ -385,6 +385,9 @@ trait SparkPlanExecApi {
def genColumnarShuffleWriter[K, V](
parameters: GenShuffleWriterParameters[K, V]):
GlutenShuffleWriterWrapper[K, V]
+ def genColumnarShuffleReader[K, C](
+ parameters: GenShuffleReaderParameters[K, C]):
GlutenShuffleReaderWrapper[K, C]
+
/**
* Generate ColumnarBatchSerializer for ColumnarShuffleExchangeExec.
*
diff --git a/cpp/core/shuffle/ShuffleReader.h
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleReaderWrapper.scala
similarity index 62%
copy from cpp/core/shuffle/ShuffleReader.h
copy to
gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleReaderWrapper.scala
index 6e2b079fc7..95dd8845fd 100644
--- a/cpp/core/shuffle/ShuffleReader.h
+++
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleReaderWrapper.scala
@@ -14,25 +14,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.shuffle
-#pragma once
+import org.apache.spark.TaskContext
+import org.apache.spark.storage.{BlockId, BlockManagerId}
-#include "compute/ResultIterator.h"
+case class GlutenShuffleReaderWrapper[K, C](shuffleReader: ShuffleReader[K, C])
-namespace gluten {
-
-class ShuffleReader {
- public:
- virtual ~ShuffleReader() = default;
-
- // FIXME iterator should be unique_ptr or un-copyable singleton
- virtual std::shared_ptr<ResultIterator>
readStream(std::shared_ptr<arrow::io::InputStream> in) = 0;
-
- virtual int64_t getDecompressTime() const = 0;
-
- virtual int64_t getDeserializeTime() const = 0;
-
- virtual arrow::MemoryPool* getPool() const = 0;
-};
-
-} // namespace gluten
+case class GenShuffleReaderParameters[K, C](
+ handle: BaseShuffleHandle[K, _, C],
+ blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long,
Int)])],
+ context: TaskContext,
+ readMetrics: ShuffleReadMetricsReporter,
+ shouldBatchFetch: Boolean = false)
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala
index ec67b07207..2ceb48f7c2 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala
@@ -24,6 +24,8 @@ import org.apache.gluten.vectorized.NativePartitioning
import org.apache.spark.{SparkConf, TaskContext}
import org.apache.spark.internal.config._
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.ColumnarShuffleHandle
+import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
import org.apache.spark.storage.{BlockId, BlockManagerId}
import org.apache.spark.util.random.XORShiftRandom
@@ -135,4 +137,40 @@ object GlutenShuffleUtils {
SparkSortShuffleWriterUtil.create(other, mapId, context, metrics,
shuffleExecutorComponents)
}
}
+
+ def genColumnarShuffleWriter[K, V](
+ shuffleBlockResolver: IndexShuffleBlockResolver,
+ columnarShuffleHandle: ColumnarShuffleHandle[K, V],
+ mapId: Long,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genColumnarShuffleWriter(
+ GenShuffleWriterParameters(shuffleBlockResolver,
columnarShuffleHandle, mapId, metrics))
+ .shuffleWriter
+ }
+
+ def genColumnarShuffleReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ val (blocksByAddress, canEnableBatchFetch) = {
+ getReaderParam(handle, startMapIndex, endMapIndex, startPartition,
endPartition)
+ }
+ val shouldBatchFetch =
+ canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition,
context)
+
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genColumnarShuffleReader(
+ GenShuffleReaderParameters(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
+ blocksByAddress,
+ context,
+ metrics,
+ shouldBatchFetch))
+ .shuffleReader
+ }
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleWriterWrapper.scala
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleWriterWrapper.scala
index c5560df25a..e857b69c69 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleWriterWrapper.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleWriterWrapper.scala
@@ -16,8 +16,6 @@
*/
package org.apache.spark.shuffle
-import org.apache.gluten.backendsapi.BackendsApiManager
-
import org.apache.spark.shuffle.sort.ColumnarShuffleHandle
case class GlutenShuffleWriterWrapper[K, V](shuffleWriter: ShuffleWriter[K, V])
@@ -27,17 +25,3 @@ case class GenShuffleWriterParameters[K, V](
columnarShuffleHandle: ColumnarShuffleHandle[K, V],
mapId: Long,
metrics: ShuffleWriteMetricsReporter)
-
-object GlutenShuffleWriterWrapper {
-
- def genColumnarShuffleWriter[K, V](
- shuffleBlockResolver: IndexShuffleBlockResolver,
- columnarShuffleHandle: ColumnarShuffleHandle[K, V],
- mapId: Long,
- metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
- BackendsApiManager.getSparkPlanExecApiInstance
- .genColumnarShuffleWriter(
- GenShuffleWriterParameters(shuffleBlockResolver,
columnarShuffleHandle, mapId, metrics))
- .shuffleWriter
- }
-}
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
index 83dea82ccb..904c2dff6c 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
@@ -23,7 +23,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
-import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
import org.apache.spark.storage.BlockId
import org.apache.spark.util.collection.OpenHashSet
@@ -88,7 +87,7 @@ class ColumnarShuffleManager(conf: SparkConf)
val env = SparkEnv.get
handle match {
case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V
@unchecked] =>
- GlutenShuffleWriterWrapper.genColumnarShuffleWriter(
+ GlutenShuffleUtils.genColumnarShuffleWriter(
shuffleBlockResolver,
columnarShuffleHandle,
mapId,
@@ -133,34 +132,14 @@ class ColumnarShuffleManager(conf: SparkConf)
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
- val (blocksByAddress, canEnableBatchFetch) = {
- GlutenShuffleUtils.getReaderParam(
- handle,
- startMapIndex,
- endMapIndex,
- startPartition,
- endPartition)
- }
- val shouldBatchFetch =
- canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition,
context)
- if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) {
- new BlockStoreShuffleReader(
- handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
- blocksByAddress,
- context,
- metrics,
- serializerManager = bypassDecompressionSerializerManger,
- shouldBatchFetch = shouldBatchFetch
- )
- } else {
- new BlockStoreShuffleReader(
- handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
- blocksByAddress,
- context,
- metrics,
- shouldBatchFetch = shouldBatchFetch
- )
- }
+ GlutenShuffleUtils.genColumnarShuffleReader(
+ handle,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition,
+ context,
+ metrics)
}
/** Remove a shuffle's metadata from the ShuffleManager. */
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]