This is an automated email from the ASF dual-hosted git repository.

chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new aeb2c67600 [GLUTEN-10933][VL] cuDF: Move lock to ShuffleReader (#10934)
aeb2c67600 is described below

commit aeb2c6760069b9dae2f871dcfe2f9bcdae041436
Author: Jin Chengcheng <[email protected]>
AuthorDate: Thu Oct 30 14:13:20 2025 +0000

    [GLUTEN-10933][VL] cuDF: Move lock to ShuffleReader (#10934)
    
    The lock in WholeStageResultIterator restrict the cpu thread to produce 
batch, move the lock to here can let threads produce first batch with 1 GB in 
advance. Maybe the threads should prepare more data and let the GPU consume, 
this depends on the GPU operator time.
    
    Need to restrict the total stage can offload to GPU, otherwise, after 
fallback, the lock cannot make effect on GPU execution.
---
 .../clickhouse/CHSparkPlanExecApi.scala            |   3 +-
 .../VeloxCelebornColumnarBatchSerializer.scala     |   3 +-
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  |   6 +-
 .../gluten/extension/CudfNodeValidationRule.scala  |  83 +++--
 .../vectorized/ColumnarBatchSerializer.scala       |  12 +-
 cpp/core/jni/JniWrapper.cc                         |   4 +-
 cpp/core/shuffle/Options.h                         |   5 +
 cpp/velox/CMakeLists.txt                           |   3 +-
 cpp/velox/compute/VeloxRuntime.cc                  |   3 +-
 cpp/velox/compute/WholeStageResultIterator.cc      |  30 +-
 cpp/velox/compute/WholeStageResultIterator.h       |  19 +-
 cpp/velox/cudf/GpuLock.cc                          |  73 ++++
 cpp/velox/cudf/GpuLock.h                           |  34 ++
 cpp/velox/shuffle/GpuShuffleReader.cc              | 393 +++++++++++++++++++++
 cpp/velox/shuffle/GpuShuffleReader.h               |  75 ++++
 cpp/velox/shuffle/VeloxShuffleReader.cc            |  26 +-
 cpp/velox/shuffle/VeloxShuffleReader.h             |   4 +-
 cpp/velox/tests/VeloxShuffleWriterTest.cc          |   3 +-
 .../gluten/vectorized/ShuffleReaderJniWrapper.java |   3 +-
 .../gluten/backendsapi/SparkPlanExecApi.scala      |   3 +-
 .../gluten/execution/WholeStageTransformer.scala   |   2 +-
 .../execution/ColumnarShuffleExchangeExec.scala    | 131 +------
 ...scala => ColumnarShuffleExchangeExecBase.scala} |  41 +--
 .../execution/GPUColumnarShuffleExchangeExec.scala |  63 ++++
 24 files changed, 779 insertions(+), 243 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index dbaccb16c8..58b8e6e3f5 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -454,7 +454,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with 
Logging {
   override def createColumnarBatchSerializer(
       schema: StructType,
       metrics: Map[String, SQLMetric],
-      shuffleWriterType: ShuffleWriterType): Serializer = {
+      shuffleWriterType: ShuffleWriterType,
+      enableCudf: Boolean): Serializer = {
     val readBatchNumRows = metrics("avgReadBatchNumRows")
     val numOutputRows = metrics("numOutputRows")
     val dataSize = metrics("dataSize")
diff --git 
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
 
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
index ce32a9b7ad..529945d9f4 100644
--- 
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
+++ 
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
@@ -103,7 +103,8 @@ private class CelebornColumnarBatchSerializerInstance(
         batchSize,
         readerBufferSize,
         deserializerBufferSize,
-        shuffleWriterType.name
+        shuffleWriterType.name,
+        false
       )
     // Close shuffle reader instance as lately as the end of task processing,
     // since the native reader could hold a reference to memory pool that
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 20f52997ab..1c65cc778f 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -633,7 +633,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
   override def createColumnarBatchSerializer(
       schema: StructType,
       metrics: Map[String, SQLMetric],
-      shuffleWriterType: ShuffleWriterType): Serializer = {
+      shuffleWriterType: ShuffleWriterType,
+      enableCudf: Boolean): Serializer = {
     val numOutputRows = metrics("numOutputRows")
     val deserializeTime = metrics("deserializeTime")
     val readBatchNumRows = metrics("avgReadBatchNumRows")
@@ -658,7 +659,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
           numOutputRows,
           deserializeTime,
           decompressTime,
-          shuffleWriterType)
+          shuffleWriterType,
+          enableCudf)
     }
   }
 
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/extension/CudfNodeValidationRule.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/extension/CudfNodeValidationRule.scala
index a092b984c8..97419ece7f 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/extension/CudfNodeValidationRule.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/extension/CudfNodeValidationRule.scala
@@ -18,10 +18,11 @@ package org.apache.gluten.extension
 
 import org.apache.gluten.config.{GlutenConfig, VeloxConfig}
 import org.apache.gluten.cudf.VeloxCudfPlanValidatorJniWrapper
-import org.apache.gluten.execution.{CudfTag, LeafTransformSupport, 
TransformSupport, WholeStageTransformer}
+import org.apache.gluten.execution.{CudfTag, LeafTransformSupport, 
TransformSupport, VeloxResizeBatchesExec, WholeStageTransformer}
+import 
org.apache.gluten.extension.CudfNodeValidationRule.setTagForWholeStageTransformer
 
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, 
GPUColumnarShuffleExchangeExec, SparkPlan}
 
 // Add the node name prefix 'Cudf' to GlutenPlan when can offload to cudf
 case class CudfNodeValidationRule(glutenConf: GlutenConfig) extends 
Rule[SparkPlan] {
@@ -31,33 +32,65 @@ case class CudfNodeValidationRule(glutenConf: GlutenConfig) 
extends Rule[SparkPl
       return plan
     }
     plan.transformUp {
+      case shuffle @ ColumnarShuffleExchangeExec(
+            _,
+            v @ VeloxResizeBatchesExec(w: WholeStageTransformer, _, _),
+            _,
+            _,
+            _) =>
+        setTagForWholeStageTransformer(w)
+        if (w.isCudf) {
+          log.info("VeloxResizeBatchesExec is not supported in GPU")
+        }
+        GPUColumnarShuffleExchangeExec(
+          shuffle.outputPartitioning,
+          w,
+          shuffle.shuffleOrigin,
+          shuffle.projectOutputAttributes,
+          shuffle.advisoryPartitionSize)
+
+      case shuffle @ ColumnarShuffleExchangeExec(_, w: WholeStageTransformer, 
_, _, _) =>
+        setTagForWholeStageTransformer(w)
+        GPUColumnarShuffleExchangeExec(
+          shuffle.outputPartitioning,
+          w,
+          shuffle.shuffleOrigin,
+          shuffle.projectOutputAttributes,
+          shuffle.advisoryPartitionSize)
+
       case transformer: WholeStageTransformer =>
-        if (!VeloxConfig.get.cudfEnableTableScan) {
-          // Spark3.2 does not have exists
-          val hasLeaf = transformer.find {
-            case _: LeafTransformSupport => true
-            case _ => false
-          }.isDefined
-          if (!hasLeaf && VeloxConfig.get.cudfEnableValidation) {
-            if (
-              VeloxCudfPlanValidatorJniWrapper.validate(
-                transformer.substraitPlan.toProtobuf.toByteArray)
-            ) {
-              transformer.foreach {
-                case _: LeafTransformSupport =>
-                case t: TransformSupport =>
-                  t.setTagValue(CudfTag.CudfTag, true)
-                case _ =>
-              }
-              transformer.setTagValue(CudfTag.CudfTag, true)
-            }
-          } else {
-            transformer.setTagValue(CudfTag.CudfTag, !hasLeaf)
+        setTagForWholeStageTransformer(transformer)
+        transformer
+    }
+  }
+}
+
+object CudfNodeValidationRule {
+  def setTagForWholeStageTransformer(transformer: WholeStageTransformer): Unit 
= {
+    if (!VeloxConfig.get.cudfEnableTableScan) {
+      // Spark3.2 does not have exists
+      val hasLeaf = transformer.find {
+        case _: LeafTransformSupport => true
+        case _ => false
+      }.isDefined
+      if (!hasLeaf && VeloxConfig.get.cudfEnableValidation) {
+        if (
+          VeloxCudfPlanValidatorJniWrapper.validate(
+            transformer.substraitPlan.toProtobuf.toByteArray)
+        ) {
+          transformer.foreach {
+            case _: LeafTransformSupport =>
+            case t: TransformSupport =>
+              t.setTagValue(CudfTag.CudfTag, true)
+            case _ =>
           }
-        } else {
           transformer.setTagValue(CudfTag.CudfTag, true)
         }
-        transformer
+      } else {
+        transformer.setTagValue(CudfTag.CudfTag, !hasLeaf)
+      }
+    } else {
+      transformer.setTagValue(CudfTag.CudfTag, true)
     }
   }
 }
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
index 3b5fce63f8..2369cf3642 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
@@ -51,7 +51,8 @@ class ColumnarBatchSerializer(
     numOutputRows: SQLMetric,
     deserializeTime: SQLMetric,
     decompressTime: SQLMetric,
-    shuffleWriterType: ShuffleWriterType)
+    shuffleWriterType: ShuffleWriterType,
+    enableCudf: Boolean)
   extends Serializer
   with Serializable {
 
@@ -63,7 +64,8 @@ class ColumnarBatchSerializer(
       numOutputRows,
       deserializeTime,
       decompressTime,
-      shuffleWriterType)
+      shuffleWriterType,
+      enableCudf)
   }
 
   override def supportsRelocationOfSerializedObjects: Boolean = true
@@ -75,7 +77,8 @@ private class ColumnarBatchSerializerInstanceImpl(
     numOutputRows: SQLMetric,
     deserializeTime: SQLMetric,
     decompressTime: SQLMetric,
-    shuffleWriterType: ShuffleWriterType)
+    shuffleWriterType: ShuffleWriterType,
+    enableCudf: Boolean)
   extends ColumnarBatchSerializerInstance
   with Logging {
 
@@ -111,7 +114,8 @@ private class ColumnarBatchSerializerInstanceImpl(
       batchSize,
       readerBufferSize,
       deserializerBufferSize,
-      shuffleWriterType.name)
+      shuffleWriterType.name,
+      enableCudf)
     // Close shuffle reader instance as lately as the end of task processing,
     // since the native reader could hold a reference to memory pool that
     // was used to create all buffers read from shuffle reader. The pool
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index 0420f5a6c3..f913f90f64 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -1083,7 +1083,8 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrappe
     jint batchSize,
     jlong readerBufferSize,
     jlong deserializerBufferSize,
-    jstring shuffleWriterType) {
+    jstring shuffleWriterType,
+    jboolean enableCudf) {
   JNI_METHOD_START
   auto ctx = getRuntime(env, wrapper);
 
@@ -1095,6 +1096,7 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrappe
   options.batchSize = batchSize;
   options.readerBufferSize = readerBufferSize;
   options.deserializerBufferSize = deserializerBufferSize;
+  options.enableCudf = enableCudf;
 
   options.shuffleWriterType = 
ShuffleWriter::stringToType(jStringToCString(env, shuffleWriterType));
   std::shared_ptr<arrow::Schema> schema =
diff --git a/cpp/core/shuffle/Options.h b/cpp/core/shuffle/Options.h
index 717f75dea5..2273322f67 100644
--- a/cpp/core/shuffle/Options.h
+++ b/cpp/core/shuffle/Options.h
@@ -62,6 +62,11 @@ struct ShuffleReaderOptions {
 
   // Buffer size when deserializing rows into columnar batches. Only used for 
sort-based shuffle.
   int64_t deserializerBufferSize = kDefaultDeserializerBufferSize;
+
+  // When true, convert the buffers to cudf table.
+  // Add a lock after reader produces the Vector, the next operator should be 
CudfFromVelox.
+  // After move the shuffle read operation to gpu, move the lock to start read.
+  bool enableCudf = false;
 };
 
 struct ShuffleWriterOptions {
diff --git a/cpp/velox/CMakeLists.txt b/cpp/velox/CMakeLists.txt
index b1d5139f18..973126a554 100644
--- a/cpp/velox/CMakeLists.txt
+++ b/cpp/velox/CMakeLists.txt
@@ -199,7 +199,8 @@ if(ENABLE_S3)
 endif()
 
 if(ENABLE_GPU)
-  list(APPEND VELOX_SRCS cudf/CudfPlanValidator.cc)
+  list(APPEND VELOX_SRCS cudf/CudfPlanValidator.cc cudf/GpuLock.cc
+       shuffle/GpuShuffleReader.cc)
 endif()
 
 if(ENABLE_ENHANCED_FEATURES)
diff --git a/cpp/velox/compute/VeloxRuntime.cc 
b/cpp/velox/compute/VeloxRuntime.cc
index bc858f85c6..5ddf146404 100644
--- a/cpp/velox/compute/VeloxRuntime.cc
+++ b/cpp/velox/compute/VeloxRuntime.cc
@@ -300,7 +300,8 @@ std::shared_ptr<ShuffleReader> 
VeloxRuntime::createShuffleReader(
       options.readerBufferSize,
       options.deserializerBufferSize,
       memoryManager(),
-      options.shuffleWriterType);
+      options.shuffleWriterType,
+      options.enableCudf);
 
   return std::make_shared<VeloxShuffleReader>(std::move(deserializerFactory));
 }
diff --git a/cpp/velox/compute/WholeStageResultIterator.cc 
b/cpp/velox/compute/WholeStageResultIterator.cc
index dc0da03343..faae3aeffb 100644
--- a/cpp/velox/compute/WholeStageResultIterator.cc
+++ b/cpp/velox/compute/WholeStageResultIterator.cc
@@ -23,10 +23,10 @@
 #include "velox/exec/PlanNodeStats.h"
 #ifdef GLUTEN_ENABLE_GPU
 #include <cudf/io/types.hpp>
-#include <mutex>
 #include "velox/experimental/cudf/CudfConfig.h"
 #include "velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.h"
 #include "velox/experimental/cudf/exec/ToCudf.h"
+#include "cudf/GpuLock.h"
 #endif
 
 using namespace facebook;
@@ -75,11 +75,11 @@ WholeStageResultIterator::WholeStageResultIterator(
     : memoryManager_(memoryManager),
       veloxCfg_(
           
std::make_shared<facebook::velox::config::ConfigBase>(std::unordered_map<std::string,
 std::string>(confMap))),
-      taskInfo_(taskInfo),
-      veloxPlan_(planNode),
 #ifdef GLUTEN_ENABLE_GPU
-      lock_(mutex_, std::defer_lock),
+      enableCudf_(veloxCfg_->get<bool>(kCudfEnabled, kCudfEnabledDefault)),
 #endif
+      taskInfo_(taskInfo),
+      veloxPlan_(planNode),
       scanNodeIds_(scanNodeIds),
       scanInfos_(scanInfos),
       streamIds_(streamIds) {
@@ -90,13 +90,6 @@ WholeStageResultIterator::WholeStageResultIterator(
   }
   getOrderedNodeIds(veloxPlan_, orderedNodeIds_);
 
-#ifdef GLUTEN_ENABLE_GPU
-  enableCudf_ = veloxCfg_->get<bool>(kCudfEnabled, kCudfEnabledDefault);
-  if (enableCudf_) {
-    lock_.lock();
-  }
-#endif
-
   auto fileSystem = velox::filesystems::getFileSystem(spillDir, nullptr);
   GLUTEN_CHECK(fileSystem != nullptr, "File System for spilling is null!");
   fileSystem->mkdir(spillDir);
@@ -213,10 +206,6 @@ WholeStageResultIterator::WholeStageResultIterator(
   }
 }
 
-#ifdef GLUTEN_ENABLE_GPU
-std::mutex WholeStageResultIterator::mutex_;
-#endif
-
 std::shared_ptr<velox::core::QueryCtx> 
WholeStageResultIterator::createNewVeloxQueryCtx() {
   std::unordered_map<std::string, std::shared_ptr<velox::config::ConfigBase>> 
connectorConfigs;
   connectorConfigs[kHiveConnectorId] = createConnectorConfig();
@@ -236,17 +225,6 @@ std::shared_ptr<velox::core::QueryCtx> 
WholeStageResultIterator::createNewVeloxQ
 }
 
 std::shared_ptr<ColumnarBatch> WholeStageResultIterator::next() {
-  auto result = nextInternal();
-#ifdef GLUTEN_ENABLE_GPU
-  if (result == nullptr && enableCudf_) {
-    lock_.unlock();
-  }
-#endif
-
-  return result;
-}
-
-std::shared_ptr<ColumnarBatch> WholeStageResultIterator::nextInternal() {
   tryAddSplitsToTask();
   if (task_->isFinished()) {
     return nullptr;
diff --git a/cpp/velox/compute/WholeStageResultIterator.h 
b/cpp/velox/compute/WholeStageResultIterator.h
index 671016596b..8bd8484a56 100644
--- a/cpp/velox/compute/WholeStageResultIterator.h
+++ b/cpp/velox/compute/WholeStageResultIterator.h
@@ -27,6 +27,9 @@
 #include "velox/connectors/hive/iceberg/IcebergSplit.h"
 #include "velox/core/PlanNode.h"
 #include "velox/exec/Task.h"
+#ifdef GLUTEN_ENABLE_GPU
+#include "cudf/GpuLock.h"
+#endif
 
 namespace gluten {
 
@@ -48,8 +51,8 @@ class WholeStageResultIterator : public ColumnarBatchIterator 
{
       task_->requestCancel().wait();
     }
 #ifdef GLUTEN_ENABLE_GPU
-    if (enableCudf_ && lock_.owns_lock()) {
-      lock_.unlock();
+    if (enableCudf_) {
+      unlockGpu();
     }
 #endif
   }
@@ -75,8 +78,6 @@ class WholeStageResultIterator : public ColumnarBatchIterator 
{
   }
 
  private:
-  std::shared_ptr<ColumnarBatch> nextInternal();
-
   /// Get the Spark confs to Velox query context.
   std::unordered_map<std::string, std::string> getQueryContextConf();
 
@@ -113,6 +114,9 @@ class WholeStageResultIterator : public 
ColumnarBatchIterator {
 
   /// Config, task and plan.
   std::shared_ptr<config::ConfigBase> veloxCfg_;
+#ifdef GLUTEN_ENABLE_GPU
+  const bool enableCudf_;
+#endif
   const SparkTaskInfo taskInfo_;
   std::shared_ptr<facebook::velox::exec::Task> task_;
   std::shared_ptr<const facebook::velox::core::PlanNode> veloxPlan_;
@@ -124,13 +128,6 @@ class WholeStageResultIterator : public 
ColumnarBatchIterator {
   /// Metrics
   std::unique_ptr<Metrics> metrics_{};
 
-#ifdef GLUTEN_ENABLE_GPU
-  // Mutex for thread safety.
-  static std::mutex mutex_;
-  std::unique_lock<std::mutex> lock_;
-  bool enableCudf_;
-#endif
-
   /// All the children plan node ids with postorder traversal.
   std::vector<facebook::velox::core::PlanNodeId> orderedNodeIds_;
 
diff --git a/cpp/velox/cudf/GpuLock.cc b/cpp/velox/cudf/GpuLock.cc
new file mode 100644
index 0000000000..47d383d940
--- /dev/null
+++ b/cpp/velox/cudf/GpuLock.cc
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GpuLock.h"
+#include <mutex>
+#include <condition_variable>
+#include <optional>
+#include <stdexcept>
+
+namespace gluten {
+
+namespace {
+struct GpuLockState {
+  std::mutex gGpuMutex;
+  std::condition_variable gGpuCv;
+  std::optional<std::thread::id> gGpuOwner;
+};
+
+GpuLockState& getGpuLockState() {
+  static GpuLockState gGpuLockState;
+  return gGpuLockState;
+}
+}
+
+void lockGpu() {
+    std::thread::id tid = std::this_thread::get_id();
+    std::unique_lock<std::mutex> lock(getGpuLockState().gGpuMutex);
+    if (getGpuLockState().gGpuOwner == tid) {
+        // Reentrant call from the same thread — do nothing
+        return;
+    }
+
+
+    // Wait until the GPU lock becomes available
+    getGpuLockState().gGpuCv.wait(lock, [] {
+        return !getGpuLockState().gGpuOwner.has_value();
+    });
+
+    // Acquire ownership
+    getGpuLockState().gGpuOwner = tid;
+}
+
+void unlockGpu() {
+    std::thread::id tid = std::this_thread::get_id();
+    std::unique_lock<std::mutex> lock(getGpuLockState().gGpuMutex);
+    if (!getGpuLockState().gGpuOwner.has_value() || 
getGpuLockState().gGpuOwner != tid) {
+        throw std::runtime_error("unlockGpu() called by non-owner thread!");
+    }
+
+    // Release ownership
+    getGpuLockState().gGpuOwner = std::nullopt;
+
+    // Notify one waiting thread
+    lock.unlock();
+    getGpuLockState().gGpuCv.notify_one();
+}
+
+
+} // namespace gluten
diff --git a/cpp/velox/cudf/GpuLock.h b/cpp/velox/cudf/GpuLock.h
new file mode 100644
index 0000000000..1cc8640df7
--- /dev/null
+++ b/cpp/velox/cudf/GpuLock.h
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <thread>
+
+namespace gluten {
+
+/**
+ * @brief Acquire the GPU lock (reentrant within the same thread)
+ */
+void lockGpu();
+
+/**
+ * @brief Release the GPU lock (must be called by the owning thread)
+ */
+void unlockGpu();
+
+} // namespace gluten
diff --git a/cpp/velox/shuffle/GpuShuffleReader.cc 
b/cpp/velox/shuffle/GpuShuffleReader.cc
new file mode 100644
index 0000000000..fefa473bf1
--- /dev/null
+++ b/cpp/velox/shuffle/GpuShuffleReader.cc
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "shuffle/GpuShuffleReader.h"
+
+#include <arrow/array/array_binary.h>
+#include <arrow/io/buffered.h>
+
+#include "memory/VeloxColumnarBatch.h"
+#include "shuffle/Payload.h"
+#include "shuffle/Utils.h"
+#include "utils/Common.h"
+#include "utils/Macros.h"
+#include "utils/Timer.h"
+#include "utils/VeloxArrowUtils.h"
+#include "velox/common/caching/AsyncDataCache.h"
+#include "velox/serializers/PrestoSerializer.h"
+#include "velox/vector/ComplexVector.h"
+#include "velox/vector/FlatVector.h"
+#include "velox/vector/arrow/Bridge.h"
+
+#include <algorithm>
+
+#include "cudf/GpuLock.h"
+
+using namespace facebook::velox;
+
+namespace gluten {
+namespace {
+
+arrow::Result<BlockType> readBlockType(arrow::io::InputStream* inputStream) {
+  BlockType type;
+  ARROW_ASSIGN_OR_RAISE(auto bytes, inputStream->Read(sizeof(BlockType), 
&type));
+  if (bytes == 0) {
+    // Reach EOS.
+    return BlockType::kEndOfStream;
+  }
+  return type;
+}
+
+struct BufferViewReleaser {
+  BufferViewReleaser() : BufferViewReleaser(nullptr) {}
+
+  BufferViewReleaser(std::shared_ptr<arrow::Buffer> arrowBuffer) : 
bufferReleaser_(std::move(arrowBuffer)) {}
+
+  void addRef() const {}
+
+  void release() const {}
+
+ private:
+  const std::shared_ptr<arrow::Buffer> bufferReleaser_;
+};
+
+BufferPtr wrapInBufferViewAsOwner(const void* buffer, size_t length, 
std::shared_ptr<arrow::Buffer> bufferReleaser) {
+  return BufferView<BufferViewReleaser>::create(
+      static_cast<const uint8_t*>(buffer), length, 
{std::move(bufferReleaser)});
+}
+
+BufferPtr convertToVeloxBuffer(std::shared_ptr<arrow::Buffer> buffer) {
+  if (buffer == nullptr) {
+    return nullptr;
+  }
+  return wrapInBufferViewAsOwner(buffer->data(), buffer->size(), buffer);
+}
+
+template <TypeKind Kind, typename T = typename TypeTraits<Kind>::NativeType>
+VectorPtr readFlatVector(
+    std::vector<BufferPtr>& buffers,
+    int32_t& bufferIdx,
+    uint32_t length,
+    std::shared_ptr<const Type> type,
+    const VectorPtr& dictionary,
+    memory::MemoryPool* pool) {
+  auto nulls = buffers[bufferIdx++];
+  auto valuesOrIndices = buffers[bufferIdx++];
+
+  nulls = nulls == nullptr || nulls->size() == 0 ? BufferPtr(nullptr) : nulls;
+
+  if (dictionary != nullptr) {
+    return BaseVector::wrapInDictionary(nulls, valuesOrIndices, length, 
dictionary);
+  }
+
+  return std::make_shared<FlatVector<T>>(
+      pool, type, nulls, length, std::move(valuesOrIndices), 
std::vector<BufferPtr>{});
+}
+
+template <>
+VectorPtr readFlatVector<TypeKind::UNKNOWN>(
+    std::vector<BufferPtr>& buffers,
+    int32_t& bufferIdx,
+    uint32_t length,
+    std::shared_ptr<const Type> type,
+    const VectorPtr& dictionary,
+    memory::MemoryPool* pool) {
+  return BaseVector::createNullConstant(type, length, pool);
+}
+
+template <>
+VectorPtr readFlatVector<TypeKind::HUGEINT>(
+    std::vector<BufferPtr>& buffers,
+    int32_t& bufferIdx,
+    uint32_t length,
+    std::shared_ptr<const Type> type,
+    const VectorPtr& dictionary,
+    memory::MemoryPool* pool) {
+  auto nulls = buffers[bufferIdx++];
+  auto valuesOrIndices = buffers[bufferIdx++];
+
+  // Because if buffer does not compress, it will get from netty, the address 
maynot aligned 16B, which will cause
+  // int128_t = xxx coredump by instruction movdqa
+  const auto* addr = valuesOrIndices->as<facebook::velox::int128_t>();
+  if ((reinterpret_cast<uintptr_t>(addr) & 0xf) != 0) {
+    auto alignedBuffer = 
AlignedBuffer::allocate<char>(valuesOrIndices->size(), pool);
+    fastCopy(alignedBuffer->asMutable<char>(), valuesOrIndices->as<char>(), 
valuesOrIndices->size());
+    valuesOrIndices = alignedBuffer;
+  }
+
+  nulls = nulls == nullptr || nulls->size() == 0 ? BufferPtr(nullptr) : nulls;
+
+  if (dictionary != nullptr) {
+    return BaseVector::wrapInDictionary(nulls, valuesOrIndices, length, 
dictionary);
+  }
+
+  return std::make_shared<FlatVector<int128_t>>(
+      pool, type, nulls, length, std::move(valuesOrIndices), 
std::vector<BufferPtr>{});
+}
+
+VectorPtr readFlatVectorStringView(
+    std::vector<BufferPtr>& buffers,
+    int32_t& bufferIdx,
+    uint32_t length,
+    std::shared_ptr<const Type> type,
+    const VectorPtr& dictionary,
+    memory::MemoryPool* pool) {
+  auto nulls = buffers[bufferIdx++];
+  auto lengthOrIndices = buffers[bufferIdx++];
+
+  nulls = nulls == nullptr || nulls->size() == 0 ? BufferPtr(nullptr) : nulls;
+
+  if (dictionary != nullptr) {
+    return BaseVector::wrapInDictionary(nulls, lengthOrIndices, length, 
dictionary);
+  }
+
+  auto valueBuffer = buffers[bufferIdx++];
+
+  const auto* rawLength = lengthOrIndices->as<StringLengthType>();
+  const auto* valueBufferPtr = valueBuffer->as<char>();
+
+  auto values = AlignedBuffer::allocate<char>(sizeof(StringView) * length, 
pool);
+  auto* rawValues = values->asMutable<StringView>();
+
+  uint64_t offset = 0;
+  for (int32_t i = 0; i < length; ++i) {
+    rawValues[i] = StringView(valueBufferPtr + offset, rawLength[i]);
+    offset += rawLength[i];
+  }
+
+  std::vector<BufferPtr> stringBuffers;
+  stringBuffers.emplace_back(valueBuffer);
+
+  return std::make_shared<FlatVector<StringView>>(
+      pool, type, nulls, length, std::move(values), std::move(stringBuffers));
+}
+
+template <>
+VectorPtr readFlatVector<TypeKind::VARCHAR>(
+    std::vector<BufferPtr>& buffers,
+    int32_t& bufferIdx,
+    uint32_t length,
+    std::shared_ptr<const Type> type,
+    const VectorPtr& dictionary,
+    memory::MemoryPool* pool) {
+  return readFlatVectorStringView(buffers, bufferIdx, length, type, 
dictionary, pool);
+}
+
+template <>
+VectorPtr readFlatVector<TypeKind::VARBINARY>(
+    std::vector<BufferPtr>& buffers,
+    int32_t& bufferIdx,
+    uint32_t length,
+    std::shared_ptr<const Type> type,
+    const VectorPtr& dictionary,
+    memory::MemoryPool* pool) {
+  return readFlatVectorStringView(buffers, bufferIdx, length, type, 
dictionary, pool);
+}
+
+std::unique_ptr<ByteInputStream> toByteStream(uint8_t* data, int32_t size) {
+  std::vector<ByteRange> byteRanges;
+  byteRanges.push_back(ByteRange{data, size, 0});
+  auto byteStream = std::make_unique<BufferInputStream>(byteRanges);
+  return byteStream;
+}
+
+RowVectorPtr readComplexType(BufferPtr buffer, RowTypePtr& rowType, 
memory::MemoryPool* pool) {
+  RowVectorPtr result;
+  auto byteStream = toByteStream(const_cast<uint8_t*>(buffer->as<uint8_t>()), 
buffer->size());
+  auto serde = std::make_unique<serializer::presto::PrestoVectorSerde>();
+  serializer::presto::PrestoVectorSerde::PrestoOptions options;
+  options.useLosslessTimestamp = true;
+  serde->deserialize(byteStream.get(), pool, rowType, &result, &options);
+  return result;
+}
+
+RowTypePtr getComplexWriteType(const std::vector<TypePtr>& types) {
+  std::vector<std::string> complexTypeColNames;
+  std::vector<TypePtr> complexTypeChildrens;
+  for (int32_t i = 0; i < types.size(); ++i) {
+    auto kind = types[i]->kind();
+    switch (kind) {
+      case TypeKind::ROW:
+      case TypeKind::MAP:
+      case TypeKind::ARRAY: {
+        complexTypeColNames.emplace_back(types[i]->name());
+        complexTypeChildrens.emplace_back(types[i]);
+      } break;
+      default:
+        break;
+    }
+  }
+  return std::make_shared<const RowType>(std::move(complexTypeColNames), 
std::move(complexTypeChildrens));
+}
+
+RowVectorPtr deserialize(
+    RowTypePtr type,
+    uint32_t numRows,
+    std::vector<BufferPtr>& buffers,
+    const std::vector<int32_t>& dictionaryFields,
+    const std::vector<VectorPtr>& dictionaries,
+    memory::MemoryPool* pool) {
+  std::vector<VectorPtr> children;
+  auto types = type->as<TypeKind::ROW>().children();
+
+  std::vector<VectorPtr> complexChildren;
+  auto complexRowType = getComplexWriteType(types);
+  if (complexRowType->children().size() > 0) {
+    complexChildren = readComplexType(buffers[buffers.size() - 1], 
complexRowType, pool)->children();
+  }
+
+  int32_t bufferIdx = 0;
+  int32_t complexIdx = 0;
+  int32_t dictionaryIdx = 0;
+  for (size_t i = 0; i < types.size(); ++i) {
+    const auto kind = types[i]->kind();
+    switch (kind) {
+      case TypeKind::ROW:
+      case TypeKind::MAP:
+      case TypeKind::ARRAY: {
+        children.emplace_back(std::move(complexChildren[complexIdx]));
+        complexIdx++;
+      } break;
+      default: {
+        VectorPtr dictionary{nullptr};
+        if (!dictionaryFields.empty() && dictionaryIdx < 
dictionaryFields.size() &&
+            dictionaryFields[dictionaryIdx] == i) {
+          dictionary = dictionaries[dictionaryIdx++];
+        }
+        auto res = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL(
+            readFlatVector, kind, buffers, bufferIdx, numRows, types[i], 
dictionary, pool);
+        children.emplace_back(std::move(res));
+      } break;
+    }
+  }
+
+  return std::make_shared<RowVector>(pool, type, BufferPtr(nullptr), numRows, 
children);
+}
+
+std::shared_ptr<VeloxColumnarBatch> makeColumnarBatch(
+    RowTypePtr type,
+    uint32_t numRows,
+    std::vector<std::shared_ptr<arrow::Buffer>> arrowBuffers,
+    const std::vector<int32_t>& dictionaryFields,
+    const std::vector<VectorPtr>& dictionaries,
+    memory::MemoryPool* pool,
+    int64_t& deserializeTime) {
+  ScopedTimer timer(&deserializeTime);
+  std::vector<BufferPtr> veloxBuffers;
+  veloxBuffers.reserve(arrowBuffers.size());
+  for (auto& buffer : arrowBuffers) {
+    veloxBuffers.push_back(convertToVeloxBuffer(std::move(buffer)));
+  }
+  auto rowVector = deserialize(type, numRows, veloxBuffers, dictionaryFields, 
dictionaries, pool);
+  return std::make_shared<VeloxColumnarBatch>(std::move(rowVector));
+}
+
+} // namespace
+
+
+GpuHashShuffleReaderDeserializer::GpuHashShuffleReaderDeserializer(
+    const std::shared_ptr<StreamReader>& streamReader,
+    const std::shared_ptr<arrow::Schema>& schema,
+    const std::shared_ptr<arrow::util::Codec>& codec,
+    const facebook::velox::RowTypePtr& rowType,
+    int32_t batchSize,
+    int64_t readerBufferSize,
+    VeloxMemoryManager* memoryManager,
+    std::vector<bool>* isValidityBuffer,
+    bool hasComplexType,
+    int64_t& deserializeTime,
+    int64_t& decompressTime)
+    : streamReader_(streamReader),
+      schema_(schema),
+      codec_(codec),
+      rowType_(rowType),
+      batchSize_(batchSize),
+      readerBufferSize_(readerBufferSize),
+      memoryManager_(memoryManager),
+      isValidityBuffer_(isValidityBuffer),
+      hasComplexType_(hasComplexType),
+      deserializeTime_(deserializeTime),
+      decompressTime_(decompressTime) {}
+
+bool GpuHashShuffleReaderDeserializer::resolveNextBlockType() {
+  GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(in_.get()));
+  switch (blockType) {
+    case BlockType::kEndOfStream:
+      return false;
+    case BlockType::kPlainPayload:
+      return true;
+    default:
+      throw GlutenException(fmt::format("Unsupported block type: {}", 
static_cast<int32_t>(blockType)));
+  }
+  return true;
+}
+
+void GpuHashShuffleReaderDeserializer::loadNextStream() {
+  if (reachedEos_) {
+    return;
+  }
+
+  auto in = 
streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool());
+  if (in == nullptr) {
+    reachedEos_ = true;
+    return;
+  }
+
+  GLUTEN_ASSIGN_OR_THROW(
+      in_,
+      arrow::io::BufferedInputStream::Create(
+          readerBufferSize_, memoryManager_->defaultArrowMemoryPool(), 
std::move(in)));
+}
+
+std::shared_ptr<ColumnarBatch> GpuHashShuffleReaderDeserializer::next() {
+  if (in_ == nullptr) {
+    loadNextStream();
+
+    if (reachedEos_) {
+      return nullptr;
+    }
+  }
+
+  while (!resolveNextBlockType()) {
+    loadNextStream();
+
+    if (reachedEos_) {
+      return nullptr;
+    }
+  }
+
+  uint32_t numRows = 0;
+  GLUTEN_ASSIGN_OR_THROW(
+      auto arrowBuffers,
+      BlockPayload::deserialize(
+          in_.get(), codec_, memoryManager_->defaultArrowMemoryPool(), 
numRows, deserializeTime_, decompressTime_));
+
+  auto batch  = makeColumnarBatch(
+      rowType_,
+      numRows,
+      std::move(arrowBuffers),
+      dictionaryFields_,
+      dictionaries_,
+      memoryManager_->getLeafMemoryPool().get(),
+      deserializeTime_);
+
+  lockGpu();
+
+  return batch;
+}
+
+}
diff --git a/cpp/velox/shuffle/GpuShuffleReader.h 
b/cpp/velox/shuffle/GpuShuffleReader.h
new file mode 100644
index 0000000000..3de5a8228d
--- /dev/null
+++ b/cpp/velox/shuffle/GpuShuffleReader.h
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "memory/VeloxMemoryManager.h"
+#include "shuffle/Payload.h"
+#include "shuffle/ShuffleReader.h"
+
+#include "velox/serializers/PrestoSerializer.h"
+#include "velox/type/Type.h"
+#include "velox/vector/ComplexVector.h"
+
+namespace gluten {
+
+class GpuHashShuffleReaderDeserializer final : public ColumnarBatchIterator {
+ public:
+  GpuHashShuffleReaderDeserializer(
+      const std::shared_ptr<StreamReader>& streamReader,
+      const std::shared_ptr<arrow::Schema>& schema,
+      const std::shared_ptr<arrow::util::Codec>& codec,
+      const facebook::velox::RowTypePtr& rowType,
+      int32_t batchSize,
+      int64_t readerBufferSize,
+      VeloxMemoryManager* memoryManager,
+      std::vector<bool>* isValidityBuffer,
+      bool hasComplexType,
+      int64_t& deserializeTime,
+      int64_t& decompressTime);
+
+  std::shared_ptr<ColumnarBatch> next() override;
+
+ private:
+  bool resolveNextBlockType();
+
+  void loadNextStream();
+
+  std::shared_ptr<StreamReader> streamReader_;
+  std::shared_ptr<arrow::Schema> schema_;
+  std::shared_ptr<arrow::util::Codec> codec_;
+  facebook::velox::RowTypePtr rowType_;
+  int32_t batchSize_;
+  int64_t readerBufferSize_;
+  VeloxMemoryManager* memoryManager_;
+
+  std::vector<bool>* isValidityBuffer_;
+  bool hasComplexType_;
+
+  int64_t& deserializeTime_;
+  int64_t& decompressTime_;
+
+  std::shared_ptr<arrow::io::InputStream> in_{nullptr};
+
+  bool reachedEos_{false};
+  bool blockTypeResolved_{false};
+
+  // Not used.
+  std::vector<int32_t> dictionaryFields_{};
+  std::vector<facebook::velox::VectorPtr> dictionaries_{};
+};
+} // namespace gluten
diff --git a/cpp/velox/shuffle/VeloxShuffleReader.cc 
b/cpp/velox/shuffle/VeloxShuffleReader.cc
index 5e7b78f00d..b6c1bf4b42 100644
--- a/cpp/velox/shuffle/VeloxShuffleReader.cc
+++ b/cpp/velox/shuffle/VeloxShuffleReader.cc
@@ -37,6 +37,10 @@
 
 #include <algorithm>
 
+#ifdef GLUTEN_ENABLE_GPU
+#include "GpuShuffleReader.h"
+#endif
+
 using namespace facebook::velox;
 
 namespace gluten {
@@ -799,8 +803,10 @@ 
VeloxShuffleReaderDeserializerFactory::VeloxShuffleReaderDeserializerFactory(
     int64_t readerBufferSize,
     int64_t deserializerBufferSize,
     VeloxMemoryManager* memoryManager,
-    ShuffleWriterType shuffleWriterType)
-    : schema_(schema),
+    ShuffleWriterType shuffleWriterType,
+    bool enableCudf)
+    : enableCudf_(enableCudf),
+      schema_(schema),
       codec_(codec),
       veloxCompressionType_(veloxCompressionType),
       rowType_(rowType),
@@ -816,6 +822,22 @@ std::unique_ptr<ColumnarBatchIterator> 
VeloxShuffleReaderDeserializerFactory::cr
     const std::shared_ptr<StreamReader>& streamReader) {
   switch (shuffleWriterType_) {
     case ShuffleWriterType::kHashShuffle:
+ #ifdef GLUTEN_ENABLE_GPU       
+      if (enableCudf_) {
+        return std::make_unique<GpuHashShuffleReaderDeserializer>(
+          streamReader,
+          schema_,
+          codec_,
+          rowType_,
+          batchSize_,
+          readerBufferSize_,
+          memoryManager_,
+          &isValidityBuffer_,
+          hasComplexType_,
+          deserializeTime_,
+          decompressTime_);
+      }
+#endif
       return std::make_unique<VeloxHashShuffleReaderDeserializer>(
           streamReader,
           schema_,
diff --git a/cpp/velox/shuffle/VeloxShuffleReader.h 
b/cpp/velox/shuffle/VeloxShuffleReader.h
index 26a1634f4d..bdc67d2643 100644
--- a/cpp/velox/shuffle/VeloxShuffleReader.h
+++ b/cpp/velox/shuffle/VeloxShuffleReader.h
@@ -169,7 +169,8 @@ class VeloxShuffleReaderDeserializerFactory {
       int64_t readerBufferSize,
       int64_t deserializerBufferSize,
       VeloxMemoryManager* memoryManager,
-      ShuffleWriterType shuffleWriterType);
+      ShuffleWriterType shuffleWriterType,
+      bool enableCudf);
 
   std::unique_ptr<ColumnarBatchIterator> createDeserializer(const 
std::shared_ptr<StreamReader>& streamReader);
 
@@ -180,6 +181,7 @@ class VeloxShuffleReaderDeserializerFactory {
  private:
   void initFromSchema();
 
+  const bool enableCudf_;
   std::shared_ptr<arrow::Schema> schema_;
   std::shared_ptr<arrow::util::Codec> codec_;
   facebook::velox::common::CompressionKind veloxCompressionType_;
diff --git a/cpp/velox/tests/VeloxShuffleWriterTest.cc 
b/cpp/velox/tests/VeloxShuffleWriterTest.cc
index 0d62faafee..c2906c67d8 100644
--- a/cpp/velox/tests/VeloxShuffleWriterTest.cc
+++ b/cpp/velox/tests/VeloxShuffleWriterTest.cc
@@ -305,7 +305,8 @@ class VeloxShuffleWriterTest : public 
::testing::TestWithParam<ShuffleTestParams
         kDefaultReadBufferSize,
         GetParam().deserializerBufferSize,
         getDefaultMemoryManager(),
-        GetParam().shuffleWriterType);
+        GetParam().shuffleWriterType,
+        false);
 
     const auto reader = 
std::make_shared<VeloxShuffleReader>(std::move(deserializerFactory));
 
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
index 6a0f2130d7..a1b497c8bb 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
+++ 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java
@@ -42,7 +42,8 @@ public class ShuffleReaderJniWrapper implements RuntimeAware {
       int batchSize,
       long readerBufferSize,
       long deserializerBufferSize,
-      String shuffleWriterType);
+      String shuffleWriterType,
+      boolean enableCudf);
 
   public native long read(long shuffleReaderHandle, ShuffleStreamReader 
streamReader);
 
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 2ee41f0d8d..0f2cc5b30c 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -398,7 +398,8 @@ trait SparkPlanExecApi {
   def createColumnarBatchSerializer(
       schema: StructType,
       metrics: Map[String, SQLMetric],
-      shuffleWriterType: ShuffleWriterType): Serializer
+      shuffleWriterType: ShuffleWriterType,
+      enableCudf: Boolean = false): Serializer
 
   /** Create broadcast relation for BroadcastExchangeExec */
   def createBroadcastRelation(
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala
index 264195f93f..dc2d3897c6 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala
@@ -69,7 +69,7 @@ trait TransformSupport extends ValidatablePlan {
       s"${this.getClass.getSimpleName} doesn't support doExecute")
   }
 
-  protected def isCudf: Boolean = 
getTagValue[Boolean](CudfTag.CudfTag).getOrElse(false)
+  def isCudf: Boolean = getTagValue[Boolean](CudfTag.CudfTag).getOrElse(false)
 
   // Use super.nodeName will cause exception scala 213 Super calls can only 
target methods
   // for FileSourceScan.
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
index 055b943392..bdb9f92de4 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
@@ -17,27 +17,13 @@
 package org.apache.spark.sql.execution
 
 import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.config.ShuffleWriterType
-import org.apache.gluten.execution.ValidatablePlan
-import org.apache.gluten.execution.ValidationResult
-import org.apache.gluten.extension.columnar.transition.Convention
 import org.apache.gluten.sql.shims.SparkShimLoader
 
-import org.apache.spark._
 import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.execution.exchange._
-import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter
-import org.apache.spark.sql.metric.SQLColumnarShuffleReadMetricsReporter
-import org.apache.spark.sql.vectorized.ColumnarBatch
-
-import scala.concurrent.Future
 
 case class ColumnarShuffleExchangeExec(
     override val outputPartitioning: Partitioning,
@@ -45,129 +31,18 @@ case class ColumnarShuffleExchangeExec(
     shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS,
     projectOutputAttributes: Seq[Attribute],
     advisoryPartitionSize: Option[Long] = None)
-  extends ShuffleExchangeLike
-  with ValidatablePlan {
-  private[sql] lazy val writeMetrics =
-    SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
-
-  private[sql] lazy val readMetrics =
-    
SQLColumnarShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
-
-  val shuffleWriterType: ShuffleWriterType =
-    
BackendsApiManager.getSparkPlanExecApiInstance.getShuffleWriterType(outputPartitioning,
 output)
-
-  // Note: "metrics" is made transient to avoid sending driver-side metrics to 
tasks.
-  @transient override lazy val metrics =
-    BackendsApiManager.getMetricsApiInstance
-      .genColumnarShuffleExchangeMetrics(
-        sparkContext,
-        shuffleWriterType) ++ readMetrics ++ writeMetrics
-
-  @transient lazy val inputColumnarRDD: RDD[ColumnarBatch] = 
child.executeColumnar()
-
-  // 'mapOutputStatisticsFuture' is only needed when enable AQE.
-  @transient override lazy val mapOutputStatisticsFuture: 
Future[MapOutputStatistics] = {
-    if (inputColumnarRDD.getNumPartitions == 0) {
-      Future.successful(null)
-    } else {
-      sparkContext.submitMapStage(columnarShuffleDependency)
-    }
-  }
-
-  /**
-   * A [[ShuffleDependency]] that will partition rows of its child based on 
the partitioning scheme
-   * defined in `newPartitioning`. Those partitions of the returned 
ShuffleDependency will be the
-   * input of shuffle.
-   */
-  @transient
-  lazy val columnarShuffleDependency: ShuffleDependency[Int, ColumnarBatch, 
ColumnarBatch] = {
-    BackendsApiManager.getSparkPlanExecApiInstance.genShuffleDependency(
-      inputColumnarRDD,
-      child.output,
-      projectOutputAttributes,
-      outputPartitioning,
-      serializer,
-      writeMetrics,
-      metrics,
-      shuffleWriterType)
-  }
+  extends ColumnarShuffleExchangeExecBase(outputPartitioning, child, 
projectOutputAttributes) {
 
   // super.stringArgs ++ Iterator(output.map(o => 
s"${o}#${o.dataType.simpleString}"))
   val serializer: Serializer = BackendsApiManager.getSparkPlanExecApiInstance
     .createColumnarBatchSerializer(schema, metrics, shuffleWriterType)
 
-  var cachedShuffleRDD: ShuffledColumnarBatchRDD = _
-
-  override protected def doValidateInternal(): ValidationResult = {
-    BackendsApiManager.getValidatorApiInstance
-      .doColumnarShuffleExchangeExecValidate(output, outputPartitioning, child)
-      .map {
-        reason =>
-          ValidationResult.failed(
-            s"Found schema check failure for schema ${child.schema} due to: 
$reason")
-      }
-      .getOrElse(ValidationResult.succeeded)
-  }
-
   override def nodeName: String = "ColumnarExchange"
 
-  override def numMappers: Int = inputColumnarRDD.getNumPartitions
-
-  override def numPartitions: Int = 
columnarShuffleDependency.partitioner.numPartitions
-
-  override def runtimeStatistics: Statistics = {
-    val dataSize = metrics("dataSize").value
-    val rowCount = 
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value
-    Statistics(dataSize, Some(rowCount))
-  }
-
-  // Required for Spark 4.0 to implement a trait method.
-  // The "override" keyword is omitted to maintain compatibility with earlier 
Spark versions.
-  def shuffleId: Int = columnarShuffleDependency.shuffleId
-
-  override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): 
RDD[ColumnarBatch] = {
-    new ShuffledColumnarBatchRDD(columnarShuffleDependency, readMetrics, 
partitionSpecs)
-  }
-
-  override def stringArgs: Iterator[Any] = {
-    super.stringArgs ++ 
Iterator(s"[shuffle_writer_type=${shuffleWriterType.name}]")
-  }
-
-  override def batchType(): Convention.BatchType = 
BackendsApiManager.getSettings.primaryBatchType
-
-  override def rowType0(): Convention.RowType = Convention.RowType.None
-
-  override def doExecute(): RDD[InternalRow] = {
-    throw new UnsupportedOperationException()
-  }
-
-  override def doExecuteColumnar(): RDD[ColumnarBatch] = {
-    if (cachedShuffleRDD == null) {
-      cachedShuffleRDD = new 
ShuffledColumnarBatchRDD(columnarShuffleDependency, readMetrics)
-    }
-    cachedShuffleRDD
-  }
-
-  override def verboseString(maxFields: Int): String =
-    toString(super.verboseString(maxFields), maxFields)
-
-  private def toString(original: String, maxFields: Int): String = {
-    original + ", [output=" + truncatedString(
-      output.map(_.verboseString(maxFields)),
-      "[",
-      ", ",
-      "]",
-      maxFields) + "]"
-  }
-
-  override def output: Seq[Attribute] = if (projectOutputAttributes != null) {
-    projectOutputAttributes
-  } else {
-    child.output
-  }
-
   protected def withNewChildInternal(newChild: SparkPlan): 
ColumnarShuffleExchangeExec =
     copy(child = newChild)
+
+  override def getSerializer: Serializer = serializer
 }
 
 object ColumnarShuffleExchangeExec extends Logging {
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala
similarity index 82%
copy from 
gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
copy to 
gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala
index 055b943392..af8ba75393 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala
@@ -18,13 +18,10 @@ package org.apache.spark.sql.execution
 
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.config.ShuffleWriterType
-import org.apache.gluten.execution.ValidatablePlan
-import org.apache.gluten.execution.ValidationResult
+import org.apache.gluten.execution.{ValidatablePlan, ValidationResult}
 import org.apache.gluten.extension.columnar.transition.Convention
-import org.apache.gluten.sql.shims.SparkShimLoader
 
 import org.apache.spark._
-import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.sql.catalyst.InternalRow
@@ -39,12 +36,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import scala.concurrent.Future
 
-case class ColumnarShuffleExchangeExec(
+abstract class ColumnarShuffleExchangeExecBase(
     override val outputPartitioning: Partitioning,
     child: SparkPlan,
-    shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS,
-    projectOutputAttributes: Seq[Attribute],
-    advisoryPartitionSize: Option[Long] = None)
+    projectOutputAttributes: Seq[Attribute])
   extends ShuffleExchangeLike
   with ValidatablePlan {
   private[sql] lazy val writeMetrics =
@@ -86,16 +81,12 @@ case class ColumnarShuffleExchangeExec(
       child.output,
       projectOutputAttributes,
       outputPartitioning,
-      serializer,
+      getSerializer,
       writeMetrics,
       metrics,
       shuffleWriterType)
   }
 
-  // super.stringArgs ++ Iterator(output.map(o => 
s"${o}#${o.dataType.simpleString}"))
-  val serializer: Serializer = BackendsApiManager.getSparkPlanExecApiInstance
-    .createColumnarBatchSerializer(schema, metrics, shuffleWriterType)
-
   var cachedShuffleRDD: ShuffledColumnarBatchRDD = _
 
   override protected def doValidateInternal(): ValidationResult = {
@@ -109,8 +100,6 @@ case class ColumnarShuffleExchangeExec(
       .getOrElse(ValidationResult.succeeded)
   }
 
-  override def nodeName: String = "ColumnarExchange"
-
   override def numMappers: Int = inputColumnarRDD.getNumPartitions
 
   override def numPartitions: Int = 
columnarShuffleDependency.partitioner.numPartitions
@@ -121,6 +110,8 @@ case class ColumnarShuffleExchangeExec(
     Statistics(dataSize, Some(rowCount))
   }
 
+  def getSerializer: Serializer
+
   // Required for Spark 4.0 to implement a trait method.
   // The "override" keyword is omitted to maintain compatibility with earlier 
Spark versions.
   def shuffleId: Int = columnarShuffleDependency.shuffleId
@@ -165,24 +156,4 @@ case class ColumnarShuffleExchangeExec(
   } else {
     child.output
   }
-
-  protected def withNewChildInternal(newChild: SparkPlan): 
ColumnarShuffleExchangeExec =
-    copy(child = newChild)
-}
-
-object ColumnarShuffleExchangeExec extends Logging {
-
-  def apply(
-      plan: ShuffleExchangeExec,
-      child: SparkPlan,
-      shuffleOutputAttributes: Seq[Attribute]): ColumnarShuffleExchangeExec = {
-    ColumnarShuffleExchangeExec(
-      plan.outputPartitioning,
-      child,
-      plan.shuffleOrigin,
-      shuffleOutputAttributes,
-      advisoryPartitionSize = 
SparkShimLoader.getSparkShims.getShuffleAdvisoryPartitionSize(plan)
-    )
-  }
-
 }
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GPUColumnarShuffleExchangeExec.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GPUColumnarShuffleExchangeExec.scala
new file mode 100644
index 0000000000..0c1f9cfe31
--- /dev/null
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GPUColumnarShuffleExchangeExec.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.sql.shims.SparkShimLoader
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.exchange._
+
+// The write is Velox RowVector, but the reader transforms it to cudf table
+case class GPUColumnarShuffleExchangeExec(
+    override val outputPartitioning: Partitioning,
+    child: SparkPlan,
+    shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS,
+    projectOutputAttributes: Seq[Attribute],
+    advisoryPartitionSize: Option[Long] = None)
+  extends ColumnarShuffleExchangeExecBase(outputPartitioning, child, 
projectOutputAttributes) {
+
+  // super.stringArgs ++ Iterator(output.map(o => 
s"${o}#${o.dataType.simpleString}"))
+  val serializer: Serializer = BackendsApiManager.getSparkPlanExecApiInstance
+    .createColumnarBatchSerializer(schema, metrics, shuffleWriterType, true)
+
+  override def nodeName: String = "CudfColumnarExchange"
+
+  protected def withNewChildInternal(newChild: SparkPlan): 
GPUColumnarShuffleExchangeExec =
+    copy(child = newChild)
+
+  override def getSerializer: Serializer = serializer
+}
+
+object GPUColumnarShuffleExchangeExec extends Logging {
+
+  def apply(
+      plan: ShuffleExchangeExec,
+      child: SparkPlan,
+      shuffleOutputAttributes: Seq[Attribute]): GPUColumnarShuffleExchangeExec 
= {
+    GPUColumnarShuffleExchangeExec(
+      plan.outputPartitioning,
+      child,
+      plan.shuffleOrigin,
+      shuffleOutputAttributes,
+      advisoryPartitionSize = 
SparkShimLoader.getSparkShims.getShuffleAdvisoryPartitionSize(plan)
+    )
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to