This is an automated email from the ASF dual-hosted git repository.

marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new b7e44c210a [GLUTEN-10214][VL] Merge inputstream for shuffle reader 
(#10499)
b7e44c210a is described below

commit b7e44c210ab28636c29703462f26e0d4d2833770
Author: Rong Ma <[email protected]>
AuthorDate: Sun Sep 7 09:04:39 2025 +0100

    [GLUTEN-10214][VL] Merge inputstream for shuffle reader (#10499)
---
 .../clickhouse/CHSparkPlanExecApi.scala            |   7 +-
 .../apache/spark/shuffle/utils/CHShuffleUtil.scala |  27 +-
 .../VeloxCelebornColumnarBatchSerializer.scala     |  12 +-
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  |   7 +-
 .../vectorized/ColumnarBatchSerializer.scala       |  35 ++-
 .../ColumnarBatchSerializerInstance.scala          |  47 ++++
 .../spark/shuffle/ColumnarShuffleReader.scala      | 127 +++++++++
 .../apache/spark/shuffle/utils/ShuffleUtil.scala   |  26 +-
 cpp/core/jni/JniWrapper.cc                         |  51 +++-
 cpp/core/shuffle/ShuffleReader.h                   |  11 +-
 cpp/velox/benchmarks/GenericBenchmark.cc           |   4 +-
 cpp/velox/compute/VeloxRuntime.cc                  |   3 +-
 cpp/velox/shuffle/VeloxShuffleReader.cc            | 290 ++++++++++-----------
 cpp/velox/shuffle/VeloxShuffleReader.h             |  67 ++---
 cpp/velox/tests/VeloxShuffleWriterTest.cc          |   8 +-
 .../tests/utils/TestStreamReader.h}                |  19 +-
 .../gluten/vectorized/ShuffleReaderJniWrapper.java |   2 +-
 .../gluten/vectorized/ShuffleStreamReader.scala    |  52 ++++
 .../gluten/backendsapi/SparkPlanExecApi.scala      |   5 +-
 .../spark/shuffle/GlutenShuffleReaderWrapper.scala |  29 +--
 .../apache/spark/shuffle/GlutenShuffleUtils.scala  |  38 +++
 .../spark/shuffle/GlutenShuffleWriterWrapper.scala |  16 --
 .../shuffle/sort/ColumnarShuffleManager.scala      |  39 +--
 23 files changed, 627 insertions(+), 295 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index f327b3b00e..6b42bfc9ae 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.memory.SparkMemoryUtil
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{GenShuffleWriterParameters, 
GlutenShuffleWriterWrapper, HashPartitioningWrapper}
+import org.apache.spark.shuffle.{GenShuffleReaderParameters, 
GenShuffleWriterParameters, GlutenShuffleReaderWrapper, 
GlutenShuffleWriterWrapper, HashPartitioningWrapper}
 import org.apache.spark.shuffle.utils.CHShuffleUtil
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
@@ -441,6 +441,11 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with 
Logging {
     CHShuffleUtil.genColumnarShuffleWriter(parameters)
   }
 
+  override def genColumnarShuffleReader[K, C](
+      parameters: GenShuffleReaderParameters[K, C]): 
GlutenShuffleReaderWrapper[K, C] = {
+    CHShuffleUtil.genColumnarShuffleReader(parameters)
+  }
+
   /**
    * Generate ColumnarBatchSerializer for ColumnarShuffleExchangeExec.
    *
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/CHShuffleUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/CHShuffleUtil.scala
index 4c0e7f07a0..fa5ae4bb9c 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/CHShuffleUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/CHShuffleUtil.scala
@@ -16,7 +16,9 @@
  */
 package org.apache.spark.shuffle.utils
 
-import org.apache.spark.shuffle.{CHColumnarShuffleWriter, 
GenShuffleWriterParameters, GlutenShuffleWriterWrapper}
+import org.apache.spark.shuffle.{BlockStoreShuffleReader, 
CHColumnarShuffleWriter, GenShuffleReaderParameters, 
GenShuffleWriterParameters, GlutenShuffleReaderWrapper, 
GlutenShuffleWriterWrapper}
+import org.apache.spark.shuffle.sort.ColumnarShuffleHandle
+import 
org.apache.spark.shuffle.sort.ColumnarShuffleManager.bypassDecompressionSerializerManger
 
 object CHShuffleUtil {
 
@@ -29,4 +31,27 @@ object CHShuffleUtil {
         parameters.mapId,
         parameters.metrics))
   }
+
+  def genColumnarShuffleReader[K, C](
+      parameters: GenShuffleReaderParameters[K, C]): 
GlutenShuffleReaderWrapper[K, C] = {
+    val reader = if (parameters.handle.isInstanceOf[ColumnarShuffleHandle[_, 
_]]) {
+      new BlockStoreShuffleReader(
+        parameters.handle,
+        parameters.blocksByAddress,
+        parameters.context,
+        parameters.readMetrics,
+        serializerManager = bypassDecompressionSerializerManger,
+        shouldBatchFetch = parameters.shouldBatchFetch
+      )
+    } else {
+      new BlockStoreShuffleReader(
+        parameters.handle,
+        parameters.blocksByAddress,
+        parameters.context,
+        parameters.readMetrics,
+        shouldBatchFetch = parameters.shouldBatchFetch
+      )
+    }
+    GlutenShuffleReaderWrapper(reader)
+  }
 }
diff --git 
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
 
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
index 0869ad3c30..ce32a9b7ad 100644
--- 
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
+++ 
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
@@ -125,7 +125,8 @@ private class CelebornColumnarBatchSerializerInstance(
   private class TaskDeserializationStream(in: InputStream)
     extends DeserializationStream
     with TaskResource {
-    private var byteIn: JniByteInputStream = _
+    private val streamReader = ShuffleStreamReader(Iterator((null, in)))
+
     private var wrappedOut: ColumnarBatchOutIterator = _
 
     private var cb: ColumnarBatch = _
@@ -247,23 +248,22 @@ private class CelebornColumnarBatchSerializerInstance(
         readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal)
       }
       numOutputRows += numRowsTotal
-      if (byteIn != null) {
+      if (wrappedOut != null) {
         wrappedOut.close()
-        byteIn.close()
       }
+      streamReader.close()
       if (cb != null) {
         cb.close()
       }
     }
 
     private def initStream(): Unit = {
-      if (byteIn == null) {
-        byteIn = JniByteInputStreams.create(in)
+      if (wrappedOut == null) {
         wrappedOut = new ColumnarBatchOutIterator(
           runtime,
           ShuffleReaderJniWrapper
             .create(runtime)
-            .readStream(shuffleReaderHandle, byteIn))
+            .read(shuffleReaderHandle, streamReader))
       }
     }
 
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 0fe000c5c6..a822da014f 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -32,7 +32,7 @@ import 
org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, PullOutArrowEva
 import org.apache.spark.memory.SparkMemoryUtil
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{GenShuffleWriterParameters, 
GlutenShuffleWriterWrapper}
+import org.apache.spark.shuffle.{GenShuffleReaderParameters, 
GenShuffleWriterParameters, GlutenShuffleReaderWrapper, 
GlutenShuffleWriterWrapper}
 import org.apache.spark.shuffle.utils.ShuffleUtil
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
@@ -590,6 +590,11 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
     ShuffleUtil.genColumnarShuffleWriter(parameters)
   }
 
+  override def genColumnarShuffleReader[K, C](
+      parameters: GenShuffleReaderParameters[K, C]): 
GlutenShuffleReaderWrapper[K, C] = {
+    ShuffleUtil.genColumnarShuffleReader(parameters)
+  }
+
   override def createColumnarWriteFilesExec(
       child: WriteFilesExecTransformer,
       noop: SparkPlan,
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
index 53354f432b..17eea29238 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.utils.SparkSchemaUtil
 import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.storage.BlockId
 import org.apache.spark.task.{TaskResource, TaskResources}
 
 import org.apache.arrow.c.ArrowSchema
@@ -56,7 +57,7 @@ class ColumnarBatchSerializer(
 
   /** Creates a new [[SerializerInstance]]. */
   override def newInstance(): SerializerInstance = {
-    new ColumnarBatchSerializerInstance(
+    new ColumnarBatchSerializerInstanceImpl(
       schema,
       readBatchNumRows,
       numOutputRows,
@@ -68,16 +69,21 @@ class ColumnarBatchSerializer(
   override def supportsRelocationOfSerializedObjects: Boolean = true
 }
 
-private class ColumnarBatchSerializerInstance(
+private class ColumnarBatchSerializerInstanceImpl(
     schema: StructType,
     readBatchNumRows: SQLMetric,
     numOutputRows: SQLMetric,
     deserializeTime: SQLMetric,
     decompressTime: SQLMetric,
     shuffleWriterType: ShuffleWriterType)
-  extends SerializerInstance
+  extends ColumnarBatchSerializerInstance
   with Logging {
 
+  private val runtime =
+    Runtimes.contextInstance(BackendsApiManager.getBackendName, 
"ShuffleReader")
+
+  private val jniWrapper = ShuffleReaderJniWrapper.create(runtime)
+
   private val shuffleReaderHandle = {
     val allocator: BufferAllocator = ArrowBufferAllocators
       .contextInstance(classOf[ColumnarBatchSerializerInstance].getSimpleName)
@@ -98,8 +104,6 @@ private class ColumnarBatchSerializerInstance(
     val batchSize = GlutenConfig.get.maxBatchSize
     val readerBufferSize = GlutenConfig.get.columnarShuffleReaderBufferSize
     val deserializerBufferSize = 
GlutenConfig.get.columnarSortShuffleDeserializerBufferSize
-    val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName, 
"ShuffleReader")
-    val jniWrapper = ShuffleReaderJniWrapper.create(runtime)
     val shuffleReaderHandle = jniWrapper.make(
       cSchema.memoryAddress(),
       compressionCodec,
@@ -128,20 +132,23 @@ private class ColumnarBatchSerializerInstance(
   }
 
   override def deserializeStream(in: InputStream): DeserializationStream = {
-    new TaskDeserializationStream(in)
+    new TaskDeserializationStream(Iterator((null, in)))
   }
 
-  private class TaskDeserializationStream(in: InputStream)
+  override def deserializeStreams(
+      streams: Iterator[(BlockId, InputStream)]): DeserializationStream = {
+    new TaskDeserializationStream(streams)
+  }
+
+  private class TaskDeserializationStream(streams: Iterator[(BlockId, 
InputStream)])
     extends DeserializationStream
     with TaskResource {
-    private val byteIn: JniByteInputStream = JniByteInputStreams.create(in)
-    private val runtime =
-      Runtimes.contextInstance(BackendsApiManager.getBackendName, 
"ShuffleReader")
+    private val streamReader = ShuffleStreamReader(streams)
+
     private val wrappedOut: ClosableIterator = new ColumnarBatchOutIterator(
       runtime,
-      ShuffleReaderJniWrapper
-        .create(runtime)
-        .readStream(shuffleReaderHandle, byteIn))
+      jniWrapper
+        .read(shuffleReaderHandle, streamReader))
 
     private var cb: ColumnarBatch = _
 
@@ -232,7 +239,7 @@ private class ColumnarBatchSerializerInstance(
       }
       numOutputRows += numRowsTotal
       wrappedOut.close()
-      byteIn.close()
+      streamReader.close()
       if (cb != null) {
         cb.close()
       }
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala
new file mode 100644
index 0000000000..205d38b528
--- /dev/null
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized
+
+import org.apache.spark.serializer.{DeserializationStream, 
SerializationStream, SerializerInstance}
+import org.apache.spark.storage.BlockId
+
+import java.io.{InputStream, OutputStream}
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+abstract class ColumnarBatchSerializerInstance extends SerializerInstance {
+
+  /** Deserialize the streams of ColumnarBatches. */
+  def deserializeStreams(streams: Iterator[(BlockId, InputStream)]): 
DeserializationStream
+
+  override def serialize[T: ClassTag](t: T): ByteBuffer = {
+    throw new UnsupportedOperationException
+  }
+
+  override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
+    throw new UnsupportedOperationException
+  }
+
+  override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: 
ClassLoader): T = {
+    throw new UnsupportedOperationException
+  }
+
+  override def serializeStream(s: OutputStream): SerializationStream = {
+    throw new UnsupportedOperationException
+  }
+}
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala
 
b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala
new file mode 100644
index 0000000000..1e514cf9f1
--- /dev/null
+++ 
b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.gluten.vectorized.ColumnarBatchSerializerInstance
+
+import org.apache.spark._
+import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, 
ShuffleBlockFetcherIterator}
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Fetches and reads the blocks from a shuffle by requesting them from other 
nodes' block stores.
+ */
+class ColumnarShuffleReader[K, C](
+    handle: BaseShuffleHandle[K, _, C],
+    blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, 
Int)])],
+    context: TaskContext,
+    readMetrics: ShuffleReadMetricsReporter,
+    serializerManager: SerializerManager = SparkEnv.get.serializerManager,
+    blockManager: BlockManager = SparkEnv.get.blockManager,
+    mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
+    shouldBatchFetch: Boolean = false)
+  extends ShuffleReader[K, C]
+  with Logging {
+
+  private val dep = handle.dependency
+
+  private def fetchContinuousBlocksInBatch: Boolean = {
+    val conf = SparkEnv.get.conf
+    val serializerRelocatable = 
dep.serializer.supportsRelocationOfSerializedObjects
+    val compressed = conf.get(config.SHUFFLE_COMPRESS)
+    val codecConcatenation = if (compressed) {
+      
CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf))
+    } else {
+      true
+    }
+    val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)
+    // SPARK-34790: Fetching continuous blocks in batch is incompatible with 
io encryption.
+    val ioEncryption = conf.get(config.IO_ENCRYPTION_ENABLED)
+
+    val doBatchFetch = shouldBatchFetch && serializerRelocatable &&
+      (!compressed || codecConcatenation) && !useOldFetchProtocol && 
!ioEncryption
+    if (shouldBatchFetch && !doBatchFetch) {
+      logDebug(
+        "The feature tag of continuous shuffle block fetching is set to true, 
but " +
+          "we can not enable the feature because other conditions are not 
satisfied. " +
+          s"Shuffle compress: $compressed, serializer relocatable: 
$serializerRelocatable, " +
+          s"codec concatenation: $codecConcatenation, use old shuffle fetch 
protocol: " +
+          s"$useOldFetchProtocol, io encryption: $ioEncryption.")
+    }
+    doBatchFetch
+  }
+
+  /** Read the combined key-values for this reduce task */
+  override def read(): Iterator[Product2[K, C]] = {
+    val wrappedStreams = new ShuffleBlockFetcherIterator(
+      context,
+      blockManager.blockStoreClient,
+      blockManager,
+      mapOutputTracker,
+      blocksByAddress,
+      serializerManager.wrapStream,
+      // Note: we use getSizeAsMb when no suffix is provided for backwards 
compatibility
+      SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
+      SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
+      SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+      SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
+      SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
+      SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
+      SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
+      SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
+      SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
+      readMetrics,
+      fetchContinuousBlocksInBatch
+    ).toCompletionIterator
+
+    val recordIter = dep match {
+      case columnarDep: ColumnarShuffleDependency[K, _, C] =>
+        // If the dependency is a ColumnarShuffleDependency, we use the 
columnar serializer.
+        columnarDep.serializer
+          .newInstance()
+          .asInstanceOf[ColumnarBatchSerializerInstance]
+          .deserializeStreams(wrappedStreams)
+          .asKeyValueIterator
+      case _ =>
+        val serializerInstance = dep.serializer.newInstance()
+        // Create a key/value iterator for each stream
+        wrappedStreams.flatMap {
+          case (blockId, wrappedStream) =>
+            // Note: the asKeyValueIterator below wraps a key/value iterator 
inside of a
+            // NextIterator. The NextIterator makes sure that close() is 
called on the
+            // underlying InputStream when all records have been read.
+            
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+        }
+    }
+
+    // Update the context task metrics for each record read.
+    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+      recordIter.map {
+        record =>
+          readMetrics.incRecordsRead(1)
+          record
+      },
+      context.taskMetrics().mergeShuffleReadMetrics())
+
+    // An interruptible iterator must be used here in order to support task 
cancellation
+    new InterruptibleIterator[(Any, Any)](context, metricIter)
+      .asInstanceOf[Iterator[Product2[K, C]]]
+  }
+}
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/shuffle/utils/ShuffleUtil.scala
 
b/backends-velox/src/main/scala/org/apache/spark/shuffle/utils/ShuffleUtil.scala
index d0589c90d6..6293d4e764 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/shuffle/utils/ShuffleUtil.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/shuffle/utils/ShuffleUtil.scala
@@ -16,7 +16,8 @@
  */
 package org.apache.spark.shuffle.utils
 
-import org.apache.spark.shuffle.{ColumnarShuffleWriter, 
GenShuffleWriterParameters, GlutenShuffleWriterWrapper}
+import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.sort.{ColumnarShuffleHandle, 
ColumnarShuffleManager}
 
 object ShuffleUtil {
 
@@ -29,4 +30,27 @@ object ShuffleUtil {
         parameters.mapId,
         parameters.metrics))
   }
+
+  def genColumnarShuffleReader[K, C](
+      parameters: GenShuffleReaderParameters[K, C]): 
GlutenShuffleReaderWrapper[K, C] = {
+    val reader = if (parameters.handle.isInstanceOf[ColumnarShuffleHandle[_, 
_]]) {
+      new ColumnarShuffleReader[K, C](
+        parameters.handle,
+        parameters.blocksByAddress,
+        parameters.context,
+        parameters.readMetrics,
+        ColumnarShuffleManager.bypassDecompressionSerializerManger,
+        shouldBatchFetch = parameters.shouldBatchFetch
+      )
+    } else {
+      new BlockStoreShuffleReader(
+        parameters.handle,
+        parameters.blocksByAddress,
+        parameters.context,
+        parameters.readMetrics,
+        shouldBatchFetch = parameters.shouldBatchFetch
+      )
+    }
+    GlutenShuffleReaderWrapper(reader)
+  }
 }
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index d124178c02..0d047b9f63 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -61,6 +61,9 @@ jclass shuffleReaderMetricsClass;
 jmethodID shuffleReaderMetricsSetDecompressTime;
 jmethodID shuffleReaderMetricsSetDeserializeTime;
 
+jclass shuffleStreamReaderClass;
+jmethodID shuffleStreamReaderNextStream;
+
 class JavaInputStreamAdaptor final : public arrow::io::InputStream {
  public:
   JavaInputStreamAdaptor(JNIEnv* env, arrow::MemoryPool* pool, jobject jniIn) 
: pool_(pool) {
@@ -190,6 +193,39 @@ void internalRuntimeReleaser(Runtime* runtime) {
   delete runtime;
 }
 
+class ShuffleStreamReader : public StreamReader {
+ public:
+  ShuffleStreamReader(JNIEnv* env, jobject reader) {
+    if (env->GetJavaVM(&vm_) != JNI_OK) {
+      throw GlutenException("Unable to get JavaVM instance");
+    }
+    ref_ = env->NewGlobalRef(reader);
+  }
+
+  ~ShuffleStreamReader() override {
+    JNIEnv* env = nullptr;
+    attachCurrentThreadAsDaemonOrThrow(vm_, &env);
+    env->DeleteGlobalRef(ref_);
+  }
+
+  std::shared_ptr<arrow::io::InputStream> readNextStream(arrow::MemoryPool* 
pool) override {
+    JNIEnv* env = nullptr;
+    attachCurrentThreadAsDaemonOrThrow(vm_, &env);
+
+    jobject jniIn = env->CallObjectMethod(ref_, shuffleStreamReaderNextStream);
+    checkException(env);
+    if (jniIn == nullptr) {
+      return nullptr; // No more streams to read
+    }
+    std::shared_ptr<arrow::io::InputStream> in = 
std::make_shared<JavaInputStreamAdaptor>(env, pool, jniIn);
+    return in;
+  }
+
+ private:
+  JavaVM* vm_;
+  jobject ref_;
+};
+
 } // namespace
 
 #ifdef __cplusplus
@@ -236,6 +272,11 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
   shuffleReaderMetricsSetDeserializeTime =
       getMethodIdOrError(env, shuffleReaderMetricsClass, "setDeserializeTime", 
"(J)V");
 
+  shuffleStreamReaderClass =
+      createGlobalClassReferenceOrError(env, 
"Lorg/apache/gluten/vectorized/ShuffleStreamReader;");
+  shuffleStreamReaderNextStream = getMethodIdOrError(
+      env, shuffleStreamReaderClass, "nextStream", 
"()Lorg/apache/gluten/vectorized/JniByteInputStream;");
+
   return jniVersion;
 }
 
@@ -1061,16 +1102,18 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrappe
   JNI_METHOD_END(kInvalidObjectHandle)
 }
 
-JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper_readStream( // NOLINT
+JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper_read( // NOLINT
     JNIEnv* env,
     jobject wrapper,
     jlong shuffleReaderHandle,
-    jobject jniIn) {
+    jobject jStreamReader) {
   JNI_METHOD_START
   auto ctx = getRuntime(env, wrapper);
   auto reader = ObjectStore::retrieve<ShuffleReader>(shuffleReaderHandle);
-  std::shared_ptr<arrow::io::InputStream> in = 
std::make_shared<JavaInputStreamAdaptor>(env, reader->getPool(), jniIn);
-  auto outItr = reader->readStream(in);
+
+  auto streamReader = std::make_shared<ShuffleStreamReader>(env, 
jStreamReader);
+
+  auto outItr = reader->read(streamReader);
   return ctx->saveObject(outItr);
   JNI_METHOD_END(kInvalidObjectHandle)
 }
diff --git a/cpp/core/shuffle/ShuffleReader.h b/cpp/core/shuffle/ShuffleReader.h
index 6e2b079fc7..101865d253 100644
--- a/cpp/core/shuffle/ShuffleReader.h
+++ b/cpp/core/shuffle/ShuffleReader.h
@@ -21,18 +21,23 @@
 
 namespace gluten {
 
+class StreamReader {
+ public:
+  virtual ~StreamReader() = default;
+
+  virtual std::shared_ptr<arrow::io::InputStream> 
readNextStream(arrow::MemoryPool* pool) = 0;
+};
+
 class ShuffleReader {
  public:
   virtual ~ShuffleReader() = default;
 
   // FIXME iterator should be unique_ptr or un-copyable singleton
-  virtual std::shared_ptr<ResultIterator> 
readStream(std::shared_ptr<arrow::io::InputStream> in) = 0;
+  virtual std::shared_ptr<ResultIterator> read(const 
std::shared_ptr<StreamReader>& streamReader) = 0;
 
   virtual int64_t getDecompressTime() const = 0;
 
   virtual int64_t getDeserializeTime() const = 0;
-
-  virtual arrow::MemoryPool* getPool() const = 0;
 };
 
 } // namespace gluten
diff --git a/cpp/velox/benchmarks/GenericBenchmark.cc 
b/cpp/velox/benchmarks/GenericBenchmark.cc
index d1e08c67b4..1559c5c39f 100644
--- a/cpp/velox/benchmarks/GenericBenchmark.cc
+++ b/cpp/velox/benchmarks/GenericBenchmark.cc
@@ -37,6 +37,7 @@
 #include "shuffle/rss/RssPartitionWriter.h"
 #include "tests/utils/LocalRssClient.h"
 #include "tests/utils/TestAllocationListener.h"
+#include "tests/utils/TestStreamReader.h"
 #include "utils/Exception.h"
 #include "utils/StringUtil.h"
 #include "utils/Timer.h"
@@ -304,8 +305,9 @@ void runShuffle(
     const auto reader = createShuffleReader(runtime, schema);
 
     GLUTEN_ASSIGN_OR_THROW(auto in, arrow::io::ReadableFile::Open(dataFile));
+    auto streamReader = std::make_shared<TestStreamReader>(std::move(in));
     // Read all partitions.
-    auto iter = reader->readStream(in);
+    auto iter = reader->read(streamReader);
     while (iter->hasNext()) {
       // Read and discard.
       auto cb = iter->next();
diff --git a/cpp/velox/compute/VeloxRuntime.cc 
b/cpp/velox/compute/VeloxRuntime.cc
index 65195b7ea6..e0315cf508 100644
--- a/cpp/velox/compute/VeloxRuntime.cc
+++ b/cpp/velox/compute/VeloxRuntime.cc
@@ -299,8 +299,7 @@ std::shared_ptr<ShuffleReader> 
VeloxRuntime::createShuffleReader(
       options.batchSize,
       options.readerBufferSize,
       options.deserializerBufferSize,
-      memoryManager()->defaultArrowMemoryPool(),
-      memoryManager()->getLeafMemoryPool(),
+      memoryManager(),
       options.shuffleWriterType);
 
   return std::make_shared<VeloxShuffleReader>(std::move(deserializerFactory));
diff --git a/cpp/velox/shuffle/VeloxShuffleReader.cc 
b/cpp/velox/shuffle/VeloxShuffleReader.cc
index 73a31c76fb..a39566a9b6 100644
--- a/cpp/velox/shuffle/VeloxShuffleReader.cc
+++ b/cpp/velox/shuffle/VeloxShuffleReader.cc
@@ -453,50 +453,36 @@ class VeloxDictionaryReader {
 };
 
 VeloxHashShuffleReaderDeserializer::VeloxHashShuffleReaderDeserializer(
-    std::shared_ptr<arrow::io::InputStream> in,
+    const std::shared_ptr<StreamReader>& streamReader,
     const std::shared_ptr<arrow::Schema>& schema,
     const std::shared_ptr<arrow::util::Codec>& codec,
     const facebook::velox::RowTypePtr& rowType,
     int32_t batchSize,
-    int64_t bufferSize,
-    arrow::MemoryPool* memoryPool,
-    facebook::velox::memory::MemoryPool* veloxPool,
+    int64_t readerBufferSize,
+    VeloxMemoryManager* memoryManager,
     std::vector<bool>* isValidityBuffer,
     bool hasComplexType,
     int64_t& deserializeTime,
     int64_t& decompressTime)
-    : schema_(schema),
+    : streamReader_(streamReader),
+      schema_(schema),
       codec_(codec),
       rowType_(rowType),
       batchSize_(batchSize),
-      memoryPool_(memoryPool),
-      veloxPool_(veloxPool),
+      readerBufferSize_(readerBufferSize),
+      memoryManager_(memoryManager),
       isValidityBuffer_(isValidityBuffer),
       hasComplexType_(hasComplexType),
       deserializeTime_(deserializeTime),
-      decompressTime_(decompressTime) {
-  GLUTEN_ASSIGN_OR_THROW(in_, 
arrow::io::BufferedInputStream::Create(bufferSize, memoryPool, std::move(in)));
-}
-
-bool VeloxHashShuffleReaderDeserializer::shouldSkipMerge() const {
-  // Complex type or dictionary encodings do not support merging.
-  return hasComplexType_ || !dictionaryFields_.empty();
-}
-
-void VeloxHashShuffleReaderDeserializer::resolveNextBlockType() {
-  if (blockTypeResolved_) {
-    return;
-  }
-
-  blockTypeResolved_ = true;
+      decompressTime_(decompressTime) {}
 
+bool VeloxHashShuffleReaderDeserializer::resolveNextBlockType() {
   GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(in_.get()));
   switch (blockType) {
     case BlockType::kEndOfStream:
-      reachEos_ = true;
-      break;
+      return false;
     case BlockType::kDictionary: {
-      VeloxDictionaryReader reader(rowType_, veloxPool_, codec_.get());
+      VeloxDictionaryReader reader(rowType_, 
memoryManager_->getLeafMemoryPool().get(), codec_.get());
       GLUTEN_ASSIGN_OR_THROW(dictionaryFields_, reader.readFields(in_.get()));
       GLUTEN_ASSIGN_OR_THROW(dictionaries_, reader.readDictionaries(in_.get(), 
dictionaryFields_));
 
@@ -515,114 +501,83 @@ void 
VeloxHashShuffleReaderDeserializer::resolveNextBlockType() {
         dictionaries_.clear();
       }
     } break;
+    default:
+      throw GlutenException(fmt::format("Unsupported block type: {}", 
static_cast<int32_t>(blockType)));
   }
+  return true;
 }
 
-std::shared_ptr<ColumnarBatch> VeloxHashShuffleReaderDeserializer::next() {
-  resolveNextBlockType();
-
-  if (shouldSkipMerge()) {
-    // We have leftover rows from the last mergeable read.
-    if (merged_) {
-      return makeColumnarBatch(rowType_, std::move(merged_), veloxPool_, 
deserializeTime_);
-    }
-
-    if (reachEos_) {
-      return nullptr;
-    }
-
-    uint32_t numRows = 0;
-    GLUTEN_ASSIGN_OR_THROW(
-        auto arrowBuffers,
-        BlockPayload::deserialize(in_.get(), codec_, memoryPool_, numRows, 
deserializeTime_, decompressTime_));
-
-    blockTypeResolved_ = false;
-
-    return makeColumnarBatch(
-        rowType_, numRows, std::move(arrowBuffers), dictionaryFields_, 
dictionaries_, veloxPool_, deserializeTime_);
+void VeloxHashShuffleReaderDeserializer::loadNextStream() {
+  if (reachedEos_) {
+    return;
   }
 
-  // TODO: Remove merging.
-  if (reachEos_) {
-    if (merged_) {
-      return makeColumnarBatch(rowType_, std::move(merged_), veloxPool_, 
deserializeTime_);
-    }
-    return nullptr;
+  auto in = 
streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool());
+  if (in == nullptr) {
+    reachedEos_ = true;
+    return;
   }
 
-  std::vector<std::shared_ptr<arrow::Buffer>> arrowBuffers{};
-  uint32_t numRows = 0;
-  while (!merged_ || merged_->numRows() < batchSize_) {
-    resolveNextBlockType();
-
-    // Break the merging loop once we reach EOS or read a dictionary block.
-    if (reachEos_ || !dictionaryFields_.empty()) {
-      break;
-    }
-
-    GLUTEN_ASSIGN_OR_THROW(
-        arrowBuffers,
-        BlockPayload::deserialize(in_.get(), codec_, memoryPool_, numRows, 
deserializeTime_, decompressTime_));
-
-    blockTypeResolved_ = false;
+  GLUTEN_ASSIGN_OR_THROW(
+      in_,
+      arrow::io::BufferedInputStream::Create(
+          readerBufferSize_, memoryManager_->defaultArrowMemoryPool(), 
std::move(in)));
+}
 
-    if (!merged_) {
-      merged_ = std::make_unique<InMemoryPayload>(numRows, isValidityBuffer_, 
schema_, std::move(arrowBuffers));
-      arrowBuffers.clear();
-      continue;
-    }
+std::shared_ptr<ColumnarBatch> VeloxHashShuffleReaderDeserializer::next() {
+  if (in_ == nullptr) {
+    loadNextStream();
 
-    auto mergedRows = merged_->numRows() + numRows;
-    if (mergedRows > batchSize_) {
-      break;
+    if (reachedEos_) {
+      return nullptr;
     }
-
-    auto append = std::make_unique<InMemoryPayload>(numRows, 
isValidityBuffer_, schema_, std::move(arrowBuffers));
-    GLUTEN_ASSIGN_OR_THROW(merged_, InMemoryPayload::merge(std::move(merged_), 
std::move(append), memoryPool_));
-    arrowBuffers.clear();
-  }
-
-  // Reach EOS.
-  if (reachEos_ && !merged_) {
-    return nullptr;
   }
 
-  auto columnarBatch = makeColumnarBatch(rowType_, std::move(merged_), 
veloxPool_, deserializeTime_);
+  while (!resolveNextBlockType()) {
+    loadNextStream();
 
-  // Save remaining rows.
-  if (!arrowBuffers.empty()) {
-    merged_ = std::make_unique<InMemoryPayload>(numRows, isValidityBuffer_, 
schema_, std::move(arrowBuffers));
+    if (reachedEos_) {
+      return nullptr;
+    }
   }
 
-  return columnarBatch;
+  uint32_t numRows = 0;
+  GLUTEN_ASSIGN_OR_THROW(
+      auto arrowBuffers,
+      BlockPayload::deserialize(
+          in_.get(), codec_, memoryManager_->defaultArrowMemoryPool(), 
numRows, deserializeTime_, decompressTime_));
+
+  return makeColumnarBatch(
+      rowType_,
+      numRows,
+      std::move(arrowBuffers),
+      dictionaryFields_,
+      dictionaries_,
+      memoryManager_->getLeafMemoryPool().get(),
+      deserializeTime_);
 }
 
 VeloxSortShuffleReaderDeserializer::VeloxSortShuffleReaderDeserializer(
-    std::shared_ptr<arrow::io::InputStream> in,
+    const std::shared_ptr<StreamReader>& streamReader,
     const std::shared_ptr<arrow::Schema>& schema,
     const std::shared_ptr<arrow::util::Codec>& codec,
     const RowTypePtr& rowType,
     int32_t batchSize,
     int64_t readerBufferSize,
     int64_t deserializerBufferSize,
-    arrow::MemoryPool* memoryPool,
-    facebook::velox::memory::MemoryPool* veloxPool,
+    VeloxMemoryManager* memoryManager,
     int64_t& deserializeTime,
     int64_t& decompressTime)
-    : schema_(schema),
+    : streamReader_(streamReader),
+      schema_(schema),
       codec_(codec),
       rowType_(rowType),
       batchSize_(batchSize),
+      readerBufferSize_(readerBufferSize),
       deserializerBufferSize_(deserializerBufferSize),
       deserializeTime_(deserializeTime),
       decompressTime_(decompressTime),
-      veloxPool_(veloxPool) {
-  if (codec_ != nullptr) {
-    GLUTEN_ASSIGN_OR_THROW(in_, CompressedInputStream::Make(codec_.get(), 
std::move(in), memoryPool));
-  } else {
-    GLUTEN_ASSIGN_OR_THROW(in_, 
arrow::io::BufferedInputStream::Create(readerBufferSize, memoryPool, 
std::move(in)));
-  }
-}
+      memoryManager_(memoryManager) {}
 
 VeloxSortShuffleReaderDeserializer::~VeloxSortShuffleReaderDeserializer() {
   if (auto in = std::dynamic_pointer_cast<CompressedInputStream>(in_)) {
@@ -631,13 +586,17 @@ 
VeloxSortShuffleReaderDeserializer::~VeloxSortShuffleReaderDeserializer() {
 }
 
 std::shared_ptr<ColumnarBatch> VeloxSortShuffleReaderDeserializer::next() {
+  if (in_ == nullptr) {
+    loadNextStream();
+  }
+
   if (reachedEos_) {
     return nullptr;
   }
 
   if (rowBuffer_ == nullptr) {
-    rowBuffer_ =
-        AlignedBuffer::allocate<char>(deserializerBufferSize_, veloxPool_, 
std::nullopt, true /*allocateExact*/);
+    rowBuffer_ = AlignedBuffer::allocate<char>(
+        deserializerBufferSize_, memoryManager_->getLeafMemoryPool().get(), 
std::nullopt, true /*allocateExact*/);
     rowBufferPtr_ = rowBuffer_->asMutable<char>();
     data_.reserve(batchSize_);
   }
@@ -651,12 +610,17 @@ std::shared_ptr<ColumnarBatch> 
VeloxSortShuffleReaderDeserializer::next() {
 
   while (cachedRows_ < batchSize_) {
     GLUTEN_ASSIGN_OR_THROW(auto bytes, in_->Read(sizeof(RowSizeType), 
&lastRowSize_));
-    if (bytes == 0) {
-      reachedEos_ = true;
-      if (bytesRead_ > 0) {
-        return deserializeToBatch();
+    while (bytes == 0) {
+      // Current stream has no more data. Try to load the next stream.
+      loadNextStream();
+      if (reachedEos_) {
+        if (bytesRead_ > 0) {
+          return deserializeToBatch();
+        }
+        // If we reached EOS and have no rows, return nullptr.
+        return nullptr;
       }
-      return nullptr;
+      GLUTEN_ASSIGN_OR_THROW(bytes, in_->Read(sizeof(RowSizeType), 
&lastRowSize_));
     }
 
     if (lastRowSize_ + bytesRead_ > rowBuffer_->size()) {
@@ -676,7 +640,8 @@ std::shared_ptr<ColumnarBatch> 
VeloxSortShuffleReaderDeserializer::next() {
 std::shared_ptr<ColumnarBatch> 
VeloxSortShuffleReaderDeserializer::deserializeToBatch() {
   ScopedTimer timer(&deserializeTime_);
 
-  auto rowVector = facebook::velox::row::CompactRow::deserialize(data_, 
rowType_, veloxPool_);
+  auto rowVector =
+      facebook::velox::row::CompactRow::deserialize(data_, rowType_, 
memoryManager_->getLeafMemoryPool().get());
 
   cachedRows_ = 0;
   bytesRead_ = 0;
@@ -688,10 +653,33 @@ void 
VeloxSortShuffleReaderDeserializer::reallocateRowBuffer() {
   auto newSize = facebook::velox::bits::nextPowerOfTwo(lastRowSize_);
   LOG(WARNING) << "Row size " << lastRowSize_ << " exceeds current buffer size 
" << rowBuffer_->size()
                << ". Resizing buffer to " << newSize;
-  rowBuffer_ = AlignedBuffer::allocate<char>(newSize, veloxPool_, 
std::nullopt, true /*allocateExact*/);
+  rowBuffer_ = AlignedBuffer::allocate<char>(
+      newSize, memoryManager_->getLeafMemoryPool().get(), std::nullopt, true 
/*allocateExact*/);
   rowBufferPtr_ = rowBuffer_->asMutable<char>();
 }
 
+void VeloxSortShuffleReaderDeserializer::loadNextStream() {
+  if (reachedEos_) {
+    return;
+  }
+
+  auto in = 
streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool());
+  if (in == nullptr) {
+    reachedEos_ = true;
+    return;
+  }
+
+  if (codec_ != nullptr) {
+    GLUTEN_ASSIGN_OR_THROW(
+        in_, CompressedInputStream::Make(codec_.get(), std::move(in), 
memoryManager_->defaultArrowMemoryPool()));
+  } else {
+    GLUTEN_ASSIGN_OR_THROW(
+        in_,
+        arrow::io::BufferedInputStream::Create(
+            readerBufferSize_, memoryManager_->defaultArrowMemoryPool(), 
std::move(in)));
+  }
+}
+
 void VeloxSortShuffleReaderDeserializer::readNextRow() {
   GLUTEN_THROW_NOT_OK(in_->Read(lastRowSize_, rowBufferPtr_ + bytesRead_));
   data_.push_back(std::string_view(rowBufferPtr_ + bytesRead_, lastRowSize_));
@@ -744,37 +732,37 @@ void 
VeloxRssSortShuffleReaderDeserializer::VeloxInputStream::next(bool throwIfP
 }
 
 VeloxRssSortShuffleReaderDeserializer::VeloxRssSortShuffleReaderDeserializer(
-    const std::shared_ptr<facebook::velox::memory::MemoryPool>& veloxPool,
+    const std::shared_ptr<StreamReader>& streamReader,
+    VeloxMemoryManager* memoryManager,
     const RowTypePtr& rowType,
     int32_t batchSize,
     facebook::velox::common::CompressionKind veloxCompressionType,
-    int64_t& deserializeTime,
-    std::shared_ptr<arrow::io::InputStream> in)
-    : veloxPool_(veloxPool),
+    int64_t& deserializeTime)
+    : streamReader_(streamReader),
+      memoryManager_(memoryManager),
       rowType_(rowType),
       batchSize_(batchSize),
       veloxCompressionType_(veloxCompressionType),
       serde_(getNamedVectorSerde(facebook::velox::VectorSerde::Kind::kPresto)),
-      deserializeTime_(deserializeTime),
-      arrowIn_(in) {
+      deserializeTime_(deserializeTime) {
   serdeOptions_ = {false, veloxCompressionType_};
 }
 
 std::shared_ptr<ColumnarBatch> VeloxRssSortShuffleReaderDeserializer::next() {
-  if (in_ == nullptr) {
-    constexpr uint64_t kMaxReadBufferSize = (1 << 20) - 
AlignedBuffer::kPaddedSize;
-    auto buffer = AlignedBuffer::allocate<char>(kMaxReadBufferSize, 
veloxPool_.get());
-    in_ = std::make_unique<VeloxInputStream>(std::move(arrowIn_), 
std::move(buffer));
-  }
-
-  if (!in_->hasNext()) {
-    return nullptr;
+  if (in_ == nullptr || !in_->hasNext()) {
+    do {
+      loadNextStream();
+      if (reachedEos_) {
+        return nullptr;
+      }
+    } while (!in_->hasNext());
   }
 
   ScopedTimer timer(&deserializeTime_);
 
   RowVectorPtr rowVector;
-  VectorStreamGroup::read(in_.get(), veloxPool_.get(), rowType_, serde_, 
&rowVector, &serdeOptions_);
+  VectorStreamGroup::read(
+      in_.get(), memoryManager_->getLeafMemoryPool().get(), rowType_, serde_, 
&rowVector, &serdeOptions_);
 
   if (rowVector->size() >= batchSize_) {
     return std::make_shared<VeloxColumnarBatch>(std::move(rowVector));
@@ -782,13 +770,31 @@ std::shared_ptr<ColumnarBatch> 
VeloxRssSortShuffleReaderDeserializer::next() {
 
   while (rowVector->size() < batchSize_ && in_->hasNext()) {
     RowVectorPtr rowVectorTemp;
-    VectorStreamGroup::read(in_.get(), veloxPool_.get(), rowType_, serde_, 
&rowVectorTemp, &serdeOptions_);
+    VectorStreamGroup::read(
+        in_.get(), memoryManager_->getLeafMemoryPool().get(), rowType_, 
serde_, &rowVectorTemp, &serdeOptions_);
     rowVector->append(rowVectorTemp.get());
   }
 
   return std::make_shared<VeloxColumnarBatch>(std::move(rowVector));
 }
 
+void VeloxRssSortShuffleReaderDeserializer::loadNextStream() {
+  if (reachedEos_) {
+    return;
+  }
+
+  arrowIn_ = 
streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool());
+
+  if (arrowIn_ == nullptr) {
+    reachedEos_ = true;
+    return;
+  }
+
+  constexpr uint64_t kMaxReadBufferSize = (1 << 20) - 
AlignedBuffer::kPaddedSize;
+  auto buffer = AlignedBuffer::allocate<char>(kMaxReadBufferSize, 
memoryManager_->getLeafMemoryPool().get());
+  in_ = std::make_unique<VeloxInputStream>(std::move(arrowIn_), 
std::move(buffer));
+}
+
 size_t 
VeloxRssSortShuffleReaderDeserializer::VeloxInputStream::remainingSize() const {
   return std::numeric_limits<unsigned long>::max();
 }
@@ -801,8 +807,7 @@ 
VeloxShuffleReaderDeserializerFactory::VeloxShuffleReaderDeserializerFactory(
     int32_t batchSize,
     int64_t readerBufferSize,
     int64_t deserializerBufferSize,
-    arrow::MemoryPool* memoryPool,
-    std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool,
+    VeloxMemoryManager* memoryManager,
     ShuffleWriterType shuffleWriterType)
     : schema_(schema),
       codec_(codec),
@@ -811,53 +816,46 @@ 
VeloxShuffleReaderDeserializerFactory::VeloxShuffleReaderDeserializerFactory(
       batchSize_(batchSize),
       readerBufferSize_(readerBufferSize),
       deserializerBufferSize_(deserializerBufferSize),
-      memoryPool_(memoryPool),
-      veloxPool_(veloxPool),
+      memoryManager_(memoryManager),
       shuffleWriterType_(shuffleWriterType) {
   initFromSchema();
 }
 
 std::unique_ptr<ColumnarBatchIterator> 
VeloxShuffleReaderDeserializerFactory::createDeserializer(
-    std::shared_ptr<arrow::io::InputStream> in) {
+    const std::shared_ptr<StreamReader>& streamReader) {
   switch (shuffleWriterType_) {
     case ShuffleWriterType::kHashShuffle:
       return std::make_unique<VeloxHashShuffleReaderDeserializer>(
-          std::move(in),
+          streamReader,
           schema_,
           codec_,
           rowType_,
           batchSize_,
           readerBufferSize_,
-          memoryPool_,
-          veloxPool_.get(),
+          memoryManager_,
           &isValidityBuffer_,
           hasComplexType_,
           deserializeTime_,
           decompressTime_);
     case ShuffleWriterType::kSortShuffle:
       return std::make_unique<VeloxSortShuffleReaderDeserializer>(
-          std::move(in),
+          streamReader,
           schema_,
           codec_,
           rowType_,
           batchSize_,
           readerBufferSize_,
           deserializerBufferSize_,
-          memoryPool_,
-          veloxPool_.get(),
+          memoryManager_,
           deserializeTime_,
           decompressTime_);
     case ShuffleWriterType::kRssSortShuffle:
       return std::make_unique<VeloxRssSortShuffleReaderDeserializer>(
-          veloxPool_, rowType_, batchSize_, veloxCompressionType_, 
deserializeTime_, std::move(in));
+          streamReader, memoryManager_, rowType_, batchSize_, 
veloxCompressionType_, deserializeTime_);
   }
   GLUTEN_UNREACHABLE();
 }
 
-arrow::MemoryPool* VeloxShuffleReaderDeserializerFactory::getPool() {
-  return memoryPool_;
-}
-
 int64_t VeloxShuffleReaderDeserializerFactory::getDecompressTime() {
   return decompressTime_;
 }
@@ -899,12 +897,8 @@ void 
VeloxShuffleReaderDeserializerFactory::initFromSchema() {
 
VeloxShuffleReader::VeloxShuffleReader(std::unique_ptr<VeloxShuffleReaderDeserializerFactory>
 factory)
     : factory_(std::move(factory)) {}
 
-std::shared_ptr<ResultIterator> 
VeloxShuffleReader::readStream(std::shared_ptr<arrow::io::InputStream> in) {
-  return std::make_shared<ResultIterator>(factory_->createDeserializer(in));
-}
-
-arrow::MemoryPool* VeloxShuffleReader::getPool() const {
-  return factory_->getPool();
+std::shared_ptr<ResultIterator> VeloxShuffleReader::read(const 
std::shared_ptr<StreamReader>& streamReader) {
+  return 
std::make_shared<ResultIterator>(factory_->createDeserializer(streamReader));
 }
 
 int64_t VeloxShuffleReader::getDecompressTime() const {
diff --git a/cpp/velox/shuffle/VeloxShuffleReader.h 
b/cpp/velox/shuffle/VeloxShuffleReader.h
index 686253da25..26a1634f4d 100644
--- a/cpp/velox/shuffle/VeloxShuffleReader.h
+++ b/cpp/velox/shuffle/VeloxShuffleReader.h
@@ -30,14 +30,13 @@ namespace gluten {
 class VeloxHashShuffleReaderDeserializer final : public ColumnarBatchIterator {
  public:
   VeloxHashShuffleReaderDeserializer(
-      std::shared_ptr<arrow::io::InputStream> in,
+      const std::shared_ptr<StreamReader>& streamReader,
       const std::shared_ptr<arrow::Schema>& schema,
       const std::shared_ptr<arrow::util::Codec>& codec,
       const facebook::velox::RowTypePtr& rowType,
       int32_t batchSize,
-      int64_t bufferSize,
-      arrow::MemoryPool* memoryPool,
-      facebook::velox::memory::MemoryPool* veloxPool,
+      int64_t readerBufferSize,
+      VeloxMemoryManager* memoryManager,
       std::vector<bool>* isValidityBuffer,
       bool hasComplexType,
       int64_t& deserializeTime,
@@ -46,25 +45,27 @@ class VeloxHashShuffleReaderDeserializer final : public 
ColumnarBatchIterator {
   std::shared_ptr<ColumnarBatch> next() override;
 
  private:
-  bool shouldSkipMerge() const;
+  bool resolveNextBlockType();
 
-  void resolveNextBlockType();
+  void loadNextStream();
 
-  std::shared_ptr<arrow::io::InputStream> in_;
+  std::shared_ptr<StreamReader> streamReader_;
   std::shared_ptr<arrow::Schema> schema_;
   std::shared_ptr<arrow::util::Codec> codec_;
   facebook::velox::RowTypePtr rowType_;
   int32_t batchSize_;
-  arrow::MemoryPool* memoryPool_;
-  facebook::velox::memory::MemoryPool* veloxPool_;
+  int64_t readerBufferSize_;
+  VeloxMemoryManager* memoryManager_;
+
   std::vector<bool>* isValidityBuffer_;
   bool hasComplexType_;
 
   int64_t& deserializeTime_;
   int64_t& decompressTime_;
 
-  std::unique_ptr<InMemoryPayload> merged_{nullptr};
-  bool reachEos_{false};
+  std::shared_ptr<arrow::io::InputStream> in_{nullptr};
+
+  bool reachedEos_{false};
   bool blockTypeResolved_{false};
 
   std::vector<int32_t> dictionaryFields_{};
@@ -76,15 +77,14 @@ class VeloxSortShuffleReaderDeserializer final : public 
ColumnarBatchIterator {
   using RowSizeType = VeloxSortShuffleWriter::RowSizeType;
 
   VeloxSortShuffleReaderDeserializer(
-      std::shared_ptr<arrow::io::InputStream> in,
+      const std::shared_ptr<StreamReader>& streamReader,
       const std::shared_ptr<arrow::Schema>& schema,
       const std::shared_ptr<arrow::util::Codec>& codec,
       const facebook::velox::RowTypePtr& rowType,
       int32_t batchSize,
       int64_t readerBufferSize,
       int64_t deserializerBufferSize,
-      arrow::MemoryPool* memoryPool,
-      facebook::velox::memory::MemoryPool* veloxPool,
+      VeloxMemoryManager* memoryManager,
       int64_t& deserializeTime,
       int64_t& decompressTime);
 
@@ -99,16 +99,20 @@ class VeloxSortShuffleReaderDeserializer final : public 
ColumnarBatchIterator {
 
   void reallocateRowBuffer();
 
+  void loadNextStream();
+
+  std::shared_ptr<StreamReader> streamReader_;
   std::shared_ptr<arrow::Schema> schema_;
   std::shared_ptr<arrow::util::Codec> codec_;
   facebook::velox::RowTypePtr rowType_;
 
   uint32_t batchSize_;
+  int64_t readerBufferSize_;
   int64_t deserializerBufferSize_;
   int64_t& deserializeTime_;
   int64_t& decompressTime_;
 
-  facebook::velox::memory::MemoryPool* veloxPool_;
+  VeloxMemoryManager* memoryManager_;
 
   facebook::velox::BufferPtr rowBuffer_{nullptr};
   char* rowBufferPtr_{nullptr};
@@ -116,7 +120,7 @@ class VeloxSortShuffleReaderDeserializer final : public 
ColumnarBatchIterator {
   uint32_t lastRowSize_{0};
   std::vector<std::string_view> data_;
 
-  std::shared_ptr<arrow::io::InputStream> in_;
+  std::shared_ptr<arrow::io::InputStream> in_{nullptr};
 
   uint32_t cachedRows_{0};
   bool reachedEos_{false};
@@ -125,19 +129,22 @@ class VeloxSortShuffleReaderDeserializer final : public 
ColumnarBatchIterator {
 class VeloxRssSortShuffleReaderDeserializer : public ColumnarBatchIterator {
  public:
   VeloxRssSortShuffleReaderDeserializer(
-      const std::shared_ptr<facebook::velox::memory::MemoryPool>& veloxPool,
+      const std::shared_ptr<StreamReader>& streamReader,
+      VeloxMemoryManager* memoryManager,
       const facebook::velox::RowTypePtr& rowType,
       int32_t batchSize,
       facebook::velox::common::CompressionKind veloxCompressionType,
-      int64_t& deserializeTime,
-      std::shared_ptr<arrow::io::InputStream> in);
+      int64_t& deserializeTime);
 
   std::shared_ptr<ColumnarBatch> next();
 
  private:
   class VeloxInputStream;
 
-  std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool_;
+  void loadNextStream();
+
+  std::shared_ptr<StreamReader> streamReader_;
+  VeloxMemoryManager* memoryManager_;
   facebook::velox::RowTypePtr rowType_;
   std::vector<facebook::velox::RowVectorPtr> batches_;
   int32_t batchSize_;
@@ -145,8 +152,10 @@ class VeloxRssSortShuffleReaderDeserializer : public 
ColumnarBatchIterator {
   facebook::velox::VectorSerde* const serde_;
   facebook::velox::serializer::presto::PrestoVectorSerde::PrestoOptions 
serdeOptions_;
   int64_t& deserializeTime_;
-  std::shared_ptr<VeloxInputStream> in_;
-  std::shared_ptr<arrow::io::InputStream> arrowIn_;
+  std::shared_ptr<VeloxInputStream> in_{nullptr};
+  std::shared_ptr<arrow::io::InputStream> arrowIn_{nullptr};
+
+  bool reachedEos_{false};
 };
 
 class VeloxShuffleReaderDeserializerFactory {
@@ -159,13 +168,10 @@ class VeloxShuffleReaderDeserializerFactory {
       int32_t batchSize,
       int64_t readerBufferSize,
       int64_t deserializerBufferSize,
-      arrow::MemoryPool* memoryPool,
-      std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool,
+      VeloxMemoryManager* memoryManager,
       ShuffleWriterType shuffleWriterType);
 
-  std::unique_ptr<ColumnarBatchIterator> 
createDeserializer(std::shared_ptr<arrow::io::InputStream> in);
-
-  arrow::MemoryPool* getPool();
+  std::unique_ptr<ColumnarBatchIterator> createDeserializer(const 
std::shared_ptr<StreamReader>& streamReader);
 
   int64_t getDecompressTime();
 
@@ -181,8 +187,7 @@ class VeloxShuffleReaderDeserializerFactory {
   int32_t batchSize_;
   int64_t readerBufferSize_;
   int64_t deserializerBufferSize_;
-  arrow::MemoryPool* memoryPool_;
-  std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool_;
+  VeloxMemoryManager* memoryManager_;
 
   std::vector<bool> isValidityBuffer_;
   bool hasComplexType_{false};
@@ -197,14 +202,12 @@ class VeloxShuffleReader final : public ShuffleReader {
  public:
   VeloxShuffleReader(std::unique_ptr<VeloxShuffleReaderDeserializerFactory> 
factory);
 
-  std::shared_ptr<ResultIterator> 
readStream(std::shared_ptr<arrow::io::InputStream> in) override;
+  std::shared_ptr<ResultIterator> read(const std::shared_ptr<StreamReader>& 
streamReader) override;
 
   int64_t getDecompressTime() const override;
 
   int64_t getDeserializeTime() const override;
 
-  arrow::MemoryPool* getPool() const override;
-
  private:
   std::unique_ptr<VeloxShuffleReaderDeserializerFactory> factory_;
 };
diff --git a/cpp/velox/tests/VeloxShuffleWriterTest.cc 
b/cpp/velox/tests/VeloxShuffleWriterTest.cc
index 61cf7210a4..0d62faafee 100644
--- a/cpp/velox/tests/VeloxShuffleWriterTest.cc
+++ b/cpp/velox/tests/VeloxShuffleWriterTest.cc
@@ -23,8 +23,9 @@
 #include "shuffle/VeloxRssSortShuffleWriter.h"
 #include "shuffle/VeloxSortShuffleWriter.h"
 #include "tests/VeloxShuffleWriterTestBase.h"
+#include "tests/utils/TestAllocationListener.h"
+#include "tests/utils/TestStreamReader.h"
 #include "tests/utils/TestUtils.h"
-#include "utils/TestAllocationListener.h"
 #include "utils/VeloxArrowUtils.h"
 
 #include "velox/vector/tests/utils/VectorTestBase.h"
@@ -303,13 +304,12 @@ class VeloxShuffleWriterTest : public 
::testing::TestWithParam<ShuffleTestParams
         kDefaultBatchSize,
         kDefaultReadBufferSize,
         GetParam().deserializerBufferSize,
-        getDefaultMemoryManager()->defaultArrowMemoryPool(),
-        pool_,
+        getDefaultMemoryManager(),
         GetParam().shuffleWriterType);
 
     const auto reader = 
std::make_shared<VeloxShuffleReader>(std::move(deserializerFactory));
 
-    const auto iter = reader->readStream(in);
+    const auto iter = 
reader->read(std::make_shared<TestStreamReader>(std::move(in)));
     while (iter->hasNext()) {
       auto vector = 
std::dynamic_pointer_cast<VeloxColumnarBatch>(iter->next())->getRowVector();
       vectors.emplace_back(vector);
diff --git a/cpp/core/shuffle/ShuffleReader.h 
b/cpp/velox/tests/utils/TestStreamReader.h
similarity index 67%
copy from cpp/core/shuffle/ShuffleReader.h
copy to cpp/velox/tests/utils/TestStreamReader.h
index 6e2b079fc7..fd47574c6b 100644
--- a/cpp/core/shuffle/ShuffleReader.h
+++ b/cpp/velox/tests/utils/TestStreamReader.h
@@ -17,22 +17,21 @@
 
 #pragma once
 
-#include "compute/ResultIterator.h"
+#include "shuffle/ShuffleReader.h"
+#include "shuffle/ShuffleWriter.h"
 
 namespace gluten {
 
-class ShuffleReader {
+class TestStreamReader : public StreamReader {
  public:
-  virtual ~ShuffleReader() = default;
+  explicit TestStreamReader(const std::shared_ptr<arrow::io::InputStream>& 
inputStream) : inputStream_(inputStream) {}
 
-  // FIXME iterator should be unique_ptr or un-copyable singleton
-  virtual std::shared_ptr<ResultIterator> 
readStream(std::shared_ptr<arrow::io::InputStream> in) = 0;
+  std::shared_ptr<arrow::io::InputStream> readNextStream(arrow::MemoryPool*) 
override {
+    return std::move(inputStream_);
+  }
 
-  virtual int64_t getDecompressTime() const = 0;
-
-  virtual int64_t getDeserializeTime() const = 0;
-
-  virtual arrow::MemoryPool* getPool() const = 0;
+ private:
+  std::shared_ptr<arrow::io::InputStream> inputStream_;
 };
 
 } // namespace gluten
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
index 46787d4209..6a0f2130d7 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
+++ 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
@@ -44,7 +44,7 @@ public class ShuffleReaderJniWrapper implements RuntimeAware {
       long deserializerBufferSize,
       String shuffleWriterType);
 
-  public native long readStream(long shuffleReaderHandle, JniByteInputStream 
jniIn);
+  public native long read(long shuffleReaderHandle, ShuffleStreamReader 
streamReader);
 
   public native void populateMetrics(long shuffleReaderHandle, 
ShuffleReaderMetrics metrics);
 
diff --git 
a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala
 
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala
new file mode 100644
index 0000000000..59a9f9e146
--- /dev/null
+++ 
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized
+
+import org.apache.spark.storage.BlockId
+
+import java.io.InputStream
+
+case class ShuffleStreamReader(streams: Iterator[(BlockId, InputStream)]) {
+  private val jniStreams = streams.map {
+    case (blockId, in) =>
+      (blockId, JniByteInputStreams.create(in))
+  }
+
+  private var currentStream: JniByteInputStream = _
+
+  // Called from native side to get the next stream.
+  def nextStream(): JniByteInputStream = {
+    if (currentStream != null) {
+      currentStream.close()
+    }
+    if (!jniStreams.hasNext) {
+      currentStream = null
+    } else {
+      currentStream = jniStreams.next._2
+    }
+    currentStream
+  }
+
+  def close(): Unit = {
+    // The reader may not attempt to read all streams from `nextStream`, so we 
need to close the
+    // current stream if it's not null.
+    if (currentStream != null) {
+      currentStream.close()
+      currentStream = null
+    }
+  }
+}
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 58ad9d3e4b..e4de765672 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -27,7 +27,7 @@ import 
org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode
 import org.apache.spark.ShuffleDependency
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{GenShuffleWriterParameters, 
GlutenShuffleWriterWrapper}
+import org.apache.spark.shuffle.{GenShuffleReaderParameters, 
GenShuffleWriterParameters, GlutenShuffleReaderWrapper, 
GlutenShuffleWriterWrapper}
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
@@ -385,6 +385,9 @@ trait SparkPlanExecApi {
   def genColumnarShuffleWriter[K, V](
       parameters: GenShuffleWriterParameters[K, V]): 
GlutenShuffleWriterWrapper[K, V]
 
+  def genColumnarShuffleReader[K, C](
+      parameters: GenShuffleReaderParameters[K, C]): 
GlutenShuffleReaderWrapper[K, C]
+
   /**
    * Generate ColumnarBatchSerializer for ColumnarShuffleExchangeExec.
    *
diff --git a/cpp/core/shuffle/ShuffleReader.h 
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleReaderWrapper.scala
similarity index 62%
copy from cpp/core/shuffle/ShuffleReader.h
copy to 
gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleReaderWrapper.scala
index 6e2b079fc7..95dd8845fd 100644
--- a/cpp/core/shuffle/ShuffleReader.h
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleReaderWrapper.scala
@@ -14,25 +14,16 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+package org.apache.spark.shuffle
 
-#pragma once
+import org.apache.spark.TaskContext
+import org.apache.spark.storage.{BlockId, BlockManagerId}
 
-#include "compute/ResultIterator.h"
+case class GlutenShuffleReaderWrapper[K, C](shuffleReader: ShuffleReader[K, C])
 
-namespace gluten {
-
-class ShuffleReader {
- public:
-  virtual ~ShuffleReader() = default;
-
-  // FIXME iterator should be unique_ptr or un-copyable singleton
-  virtual std::shared_ptr<ResultIterator> 
readStream(std::shared_ptr<arrow::io::InputStream> in) = 0;
-
-  virtual int64_t getDecompressTime() const = 0;
-
-  virtual int64_t getDeserializeTime() const = 0;
-
-  virtual arrow::MemoryPool* getPool() const = 0;
-};
-
-} // namespace gluten
+case class GenShuffleReaderParameters[K, C](
+    handle: BaseShuffleHandle[K, _, C],
+    blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, 
Int)])],
+    context: TaskContext,
+    readMetrics: ShuffleReadMetricsReporter,
+    shouldBatchFetch: Boolean = false)
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala
index ec67b07207..2ceb48f7c2 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala
@@ -24,6 +24,8 @@ import org.apache.gluten.vectorized.NativePartitioning
 import org.apache.spark.{SparkConf, TaskContext}
 import org.apache.spark.internal.config._
 import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.ColumnarShuffleHandle
+import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
 import org.apache.spark.storage.{BlockId, BlockManagerId}
 import org.apache.spark.util.random.XORShiftRandom
 
@@ -135,4 +137,40 @@ object GlutenShuffleUtils {
         SparkSortShuffleWriterUtil.create(other, mapId, context, metrics, 
shuffleExecutorComponents)
     }
   }
+
+  def genColumnarShuffleWriter[K, V](
+      shuffleBlockResolver: IndexShuffleBlockResolver,
+      columnarShuffleHandle: ColumnarShuffleHandle[K, V],
+      mapId: Long,
+      metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+    BackendsApiManager.getSparkPlanExecApiInstance
+      .genColumnarShuffleWriter(
+        GenShuffleWriterParameters(shuffleBlockResolver, 
columnarShuffleHandle, mapId, metrics))
+      .shuffleWriter
+  }
+
+  def genColumnarShuffleReader[K, C](
+      handle: ShuffleHandle,
+      startMapIndex: Int,
+      endMapIndex: Int,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext,
+      metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+    val (blocksByAddress, canEnableBatchFetch) = {
+      getReaderParam(handle, startMapIndex, endMapIndex, startPartition, 
endPartition)
+    }
+    val shouldBatchFetch =
+      canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, 
context)
+
+    BackendsApiManager.getSparkPlanExecApiInstance
+      .genColumnarShuffleReader(
+        GenShuffleReaderParameters(
+          handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
+          blocksByAddress,
+          context,
+          metrics,
+          shouldBatchFetch))
+      .shuffleReader
+  }
 }
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleWriterWrapper.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleWriterWrapper.scala
index c5560df25a..e857b69c69 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleWriterWrapper.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleWriterWrapper.scala
@@ -16,8 +16,6 @@
  */
 package org.apache.spark.shuffle
 
-import org.apache.gluten.backendsapi.BackendsApiManager
-
 import org.apache.spark.shuffle.sort.ColumnarShuffleHandle
 
 case class GlutenShuffleWriterWrapper[K, V](shuffleWriter: ShuffleWriter[K, V])
@@ -27,17 +25,3 @@ case class GenShuffleWriterParameters[K, V](
     columnarShuffleHandle: ColumnarShuffleHandle[K, V],
     mapId: Long,
     metrics: ShuffleWriteMetricsReporter)
-
-object GlutenShuffleWriterWrapper {
-
-  def genColumnarShuffleWriter[K, V](
-      shuffleBlockResolver: IndexShuffleBlockResolver,
-      columnarShuffleHandle: ColumnarShuffleHandle[K, V],
-      mapId: Long,
-      metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
-    BackendsApiManager.getSparkPlanExecApiInstance
-      .genColumnarShuffleWriter(
-        GenShuffleWriterParameters(shuffleBlockResolver, 
columnarShuffleHandle, mapId, metrics))
-      .shuffleWriter
-  }
-}
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
index 83dea82ccb..904c2dff6c 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
@@ -23,7 +23,6 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.shuffle._
 import org.apache.spark.shuffle.api.ShuffleExecutorComponents
-import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
 import org.apache.spark.storage.BlockId
 import org.apache.spark.util.collection.OpenHashSet
 
@@ -88,7 +87,7 @@ class ColumnarShuffleManager(conf: SparkConf)
     val env = SparkEnv.get
     handle match {
       case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V 
@unchecked] =>
-        GlutenShuffleWriterWrapper.genColumnarShuffleWriter(
+        GlutenShuffleUtils.genColumnarShuffleWriter(
           shuffleBlockResolver,
           columnarShuffleHandle,
           mapId,
@@ -133,34 +132,14 @@ class ColumnarShuffleManager(conf: SparkConf)
       endPartition: Int,
       context: TaskContext,
       metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
-    val (blocksByAddress, canEnableBatchFetch) = {
-      GlutenShuffleUtils.getReaderParam(
-        handle,
-        startMapIndex,
-        endMapIndex,
-        startPartition,
-        endPartition)
-    }
-    val shouldBatchFetch =
-      canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, 
context)
-    if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) {
-      new BlockStoreShuffleReader(
-        handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
-        blocksByAddress,
-        context,
-        metrics,
-        serializerManager = bypassDecompressionSerializerManger,
-        shouldBatchFetch = shouldBatchFetch
-      )
-    } else {
-      new BlockStoreShuffleReader(
-        handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
-        blocksByAddress,
-        context,
-        metrics,
-        shouldBatchFetch = shouldBatchFetch
-      )
-    }
+    GlutenShuffleUtils.genColumnarShuffleReader(
+      handle,
+      startMapIndex,
+      endMapIndex,
+      startPartition,
+      endPartition,
+      context,
+      metrics)
   }
 
   /** Remove a shuffle's metadata from the ShuffleManager. */


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to