This is an automated email from the ASF dual-hosted git repository.

kerwinzhang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 804ab4d344 [GLUTEN-10236][VL] Support both sort and rss_sort shuffle 
writer for Celeborn (#10244)
804ab4d344 is described below

commit 804ab4d3447802043e0c86196fd3bbfb55d89269
Author: Rong Ma <[email protected]>
AuthorDate: Fri Jul 25 03:22:00 2025 +0100

    [GLUTEN-10236][VL] Support both sort and rss_sort shuffle writer for 
Celeborn (#10244)
---
 .github/workflows/velox_backend_x86.yml            |  5 ++
 .../shuffle/CHCelebornColumnarShuffleWriter.scala  |  2 +-
 .../backendsapi/clickhouse/CHMetricsApi.scala      |  3 +-
 .../clickhouse/CHSparkPlanExecApi.scala            | 10 +--
 .../VeloxCelebornColumnarBatchSerializer.scala     | 19 +++---
 .../VeloxCelebornColumnarShuffleWriter.scala       | 56 +++++++++------
 .../gluten/uniffle/UniffleShuffleManager.java      |  3 +-
 .../writer/VeloxUniffleColumnarShuffleWriter.java  |  8 +--
 .../gluten/backendsapi/velox/VeloxMetricsApi.scala | 34 ++++++----
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  | 55 +++++++++++----
 ...AppendBatchResizeForShuffleInputAndOutput.scala |  3 +-
 .../vectorized/ColumnarBatchSerializer.scala       | 12 ++--
 .../spark/shuffle/ColumnarShuffleWriter.scala      | 14 +++-
 .../spark/sql/execution/utils/ExecUtil.scala       |  6 +-
 cpp/core/shuffle/rss/RssPartitionWriter.cc         | 22 +++---
 cpp/core/shuffle/rss/RssPartitionWriter.h          | 79 +---------------------
 docs/Configuration.md                              |  1 +
 docs/get-started/Velox.md                          |  8 +++
 .../shuffle/CelebornColumnarShuffleWriter.scala    |  7 --
 .../sql/perf/DeltaOptimizedWriterTransformer.scala |  7 +-
 .../org/apache/gluten/backendsapi/MetricsApi.scala |  3 +-
 .../gluten/backendsapi/SparkPlanExecApi.scala      | 11 ++-
 .../org/apache/gluten/config/GlutenConfig.scala    | 28 ++++++++
 .../execution/ColumnarCollectLimitBaseExec.scala   | 11 +--
 .../execution/ColumnarCollectTailBaseExec.scala    | 11 +--
 .../spark/shuffle/ColumnarShuffleDependency.scala  |  3 +-
 .../execution/ColumnarShuffleExchangeExec.scala    | 18 ++---
 27 files changed, 230 insertions(+), 209 deletions(-)

diff --git a/.github/workflows/velox_backend_x86.yml 
b/.github/workflows/velox_backend_x86.yml
index ce603b3a8c..9f695a634b 100644
--- a/.github/workflows/velox_backend_x86.yml
+++ b/.github/workflows/velox_backend_x86.yml
@@ -613,6 +613,11 @@ jobs:
             --local --preset=velox-with-celeborn 
--extra-conf=spark.celeborn.client.spark.shuffle.writer=sort \
             --extra-conf=spark.celeborn.push.sortMemory.threshold=8m 
--benchmark-type=ds --error-on-memleak \
             --off-heap-size=10g -s=1.0 --threads=8 --iterations=1
+          GLUTEN_IT_JVM_ARGS=-Xmx10G sbin/gluten-it.sh queries-compare \
+            --local --preset=velox-with-celeborn 
--extra-conf=spark.celeborn.client.spark.shuffle.writer=sort \
+            
--extra-conf=spark.gluten.sql.columnar.shuffle.celeborn.useRssSort=false \
+            --extra-conf=spark.celeborn.push.sortMemory.threshold=8m 
--benchmark-type=ds --error-on-memleak \
+            --off-heap-size=10g -s=1.0 --threads=8 --iterations=1
 
   spark-test-spark32:
     needs: build-native-lib-centos-7
diff --git 
a/backends-clickhouse/src-celeborn/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
 
b/backends-clickhouse/src-celeborn/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
index 7b9567fc62..0312c396c3 100644
--- 
a/backends-clickhouse/src-celeborn/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
+++ 
b/backends-clickhouse/src-celeborn/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
@@ -82,7 +82,7 @@ class CHCelebornColumnarShuffleWriter[K, V](
       CHBackendSettings.shuffleHashAlgorithm,
       celebornPartitionPusher,
       CHConfig.get.chColumnarForceMemorySortShuffle
-        || ShuffleMode.SORT.name.equalsIgnoreCase(shuffleWriterType)
+        || ShuffleMode.SORT.name.equalsIgnoreCase(dep.shuffleWriterType.name)
     )
 
     splitResult = jniWrapper.stop(nativeShuffleWriter)
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
index 6ff629aafb..7d370a2a01 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
@@ -17,6 +17,7 @@
 package org.apache.gluten.backendsapi.clickhouse
 
 import org.apache.gluten.backendsapi.MetricsApi
+import org.apache.gluten.config.ShuffleWriterType
 import org.apache.gluten.logging.LogLevelUtil
 import org.apache.gluten.metrics._
 import org.apache.gluten.substrait.{AggregationParams, JoinParams}
@@ -251,7 +252,7 @@ class CHMetricsApi extends MetricsApi with Logging with 
LogLevelUtil {
 
   override def genColumnarShuffleExchangeMetrics(
       sparkContext: SparkContext,
-      isSort: Boolean): Map[String, SQLMetric] =
+      shuffleWriterType: ShuffleWriterType): Map[String, SQLMetric] =
     Map(
       "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
       "bytesSpilled" -> SQLMetrics.createSizeMetric(sparkContext, "shuffle 
bytes spilled"),
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 4fa68ff654..12e775086c 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -17,7 +17,7 @@
 package org.apache.gluten.backendsapi.clickhouse
 
 import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi}
-import org.apache.gluten.config.GlutenConfig
+import org.apache.gluten.config.{GlutenConfig, ShuffleWriterType}
 import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
 import org.apache.gluten.execution._
 import org.apache.gluten.expression._
@@ -415,7 +415,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with 
Logging {
       serializer: Serializer,
       writeMetrics: Map[String, SQLMetric],
       metrics: Map[String, SQLMetric],
-      isSort: Boolean
+      shuffleWriterType: ShuffleWriterType
   ): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
     CHExecUtil.genShuffleDependency(
       rdd,
@@ -429,10 +429,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with 
Logging {
   }
   // scalastyle:on argcount
 
-  /** Determine whether to use sort-based shuffle based on shuffle 
partitioning and output. */
-  override def useSortBasedShuffle(partitioning: Partitioning, output: 
Seq[Attribute]): Boolean =
-    false
-
   /**
    * Generate ColumnarShuffleWriter for ColumnarShuffleManager.
    *
@@ -451,7 +447,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with 
Logging {
   override def createColumnarBatchSerializer(
       schema: StructType,
       metrics: Map[String, SQLMetric],
-      isSort: Boolean): Serializer = {
+      shuffleWriterType: ShuffleWriterType): Serializer = {
     val readBatchNumRows = metrics("avgReadBatchNumRows")
     val numOutputRows = metrics("numOutputRows")
     val dataSize = metrics("dataSize")
diff --git 
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
 
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
index 2db65383df..0869ad3c30 100644
--- 
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
+++ 
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala
@@ -17,8 +17,7 @@
 package org.apache.spark.shuffle
 
 import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.config.GlutenConfig
-import org.apache.gluten.config.ReservedKeys.{GLUTEN_RSS_SORT_SHUFFLE_WRITER, 
GLUTEN_SORT_SHUFFLE_WRITER}
+import org.apache.gluten.config.{GlutenConfig, ShuffleWriterType}
 import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
 import org.apache.gluten.runtime.Runtimes
 import org.apache.gluten.utils.ArrowAbiUtil
@@ -49,20 +48,26 @@ import scala.reflect.ClassTag
 class CelebornColumnarBatchSerializer(
     schema: StructType,
     readBatchNumRows: SQLMetric,
-    numOutputRows: SQLMetric)
+    numOutputRows: SQLMetric,
+    shuffleWriterType: ShuffleWriterType)
   extends Serializer
   with Serializable {
 
   /** Creates a new [[SerializerInstance]]. */
   override def newInstance(): SerializerInstance = {
-    new CelebornColumnarBatchSerializerInstance(schema, readBatchNumRows, 
numOutputRows)
+    new CelebornColumnarBatchSerializerInstance(
+      schema,
+      readBatchNumRows,
+      numOutputRows,
+      shuffleWriterType)
   }
 }
 
 private class CelebornColumnarBatchSerializerInstance(
     schema: StructType,
     readBatchNumRows: SQLMetric,
-    numOutputRows: SQLMetric)
+    numOutputRows: SQLMetric,
+    shuffleWriterType: ShuffleWriterType)
   extends SerializerInstance
   with Logging {
 
@@ -86,8 +91,6 @@ private class CelebornColumnarBatchSerializerInstance(
       }
     val compressionCodecBackend =
       GlutenConfig.get.columnarShuffleCodecBackend.orNull
-    val shuffleWriterType = GlutenConfig.get.celebornShuffleWriterType
-      .replace(GLUTEN_SORT_SHUFFLE_WRITER, GLUTEN_RSS_SORT_SHUFFLE_WRITER)
     val jniWrapper = ShuffleReaderJniWrapper.create(runtime)
     val batchSize = GlutenConfig.get.maxBatchSize
     val readerBufferSize = GlutenConfig.get.columnarShuffleReaderBufferSize
@@ -100,7 +103,7 @@ private class CelebornColumnarBatchSerializerInstance(
         batchSize,
         readerBufferSize,
         deserializerBufferSize,
-        shuffleWriterType
+        shuffleWriterType.name
       )
     // Close shuffle reader instance as lately as the end of task processing,
     // since the native reader could hold a reference to memory pool that
diff --git 
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
 
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
index 3dd9920bda..881d9cf660 100644
--- 
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
+++ 
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
@@ -18,16 +18,16 @@ package org.apache.spark.shuffle
 
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.columnarbatch.ColumnarBatches
-import org.apache.gluten.config.{GlutenConfig, ReservedKeys}
+import org.apache.gluten.config.{GlutenConfig, HashShuffleWriterType, 
RssSortShuffleWriterType, SortShuffleWriterType}
 import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers}
 import org.apache.gluten.runtime.Runtimes
 import org.apache.gluten.vectorized._
 
 import org.apache.spark._
+import org.apache.spark.internal.config.{SHUFFLE_DISK_WRITE_BUFFER_SIZE, 
SHUFFLE_SORT_INIT_BUFFER_SIZE, SHUFFLE_SORT_USE_RADIXSORT}
 import org.apache.spark.memory.SparkMemoryUtil
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle
-import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.util.SparkResourceUtil
 
@@ -62,14 +62,6 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
     SparkMemoryUtil.getCurrentAvailableOffHeapMemory / 
SparkResourceUtil.getTaskSlots(conf)
   }
 
-  private val nativeMetrics: SQLMetric = {
-    if (dep.isSort) {
-      dep.metrics("sortTime")
-    } else {
-      dep.metrics("splitTime")
-    }
-  }
-
   @throws[IOException]
   override def internalWrite(records: Iterator[Product2[K, V]]): Unit = {
     if (!records.hasNext) {
@@ -108,11 +100,25 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
     splitResult = shuffleWriterJniWrapper.stop(nativeShuffleWriter)
 
     dep.metrics("shuffleWallTime").add(System.nanoTime() - startTime)
-    nativeMetrics
-      .add(
-        dep.metrics("shuffleWallTime").value - splitResult.getTotalPushTime -
-          splitResult.getTotalWriteTime -
-          splitResult.getTotalCompressTime)
+    dep.shuffleWriterType match {
+      case HashShuffleWriterType =>
+        dep
+          .metrics("splitTime")
+          .add(
+            dep.metrics("shuffleWallTime").value - 
splitResult.getTotalPushTime -
+              splitResult.getTotalWriteTime -
+              splitResult.getTotalCompressTime)
+      case RssSortShuffleWriterType =>
+        dep
+          .metrics("sortTime")
+          .add(
+            dep.metrics("shuffleWallTime").value - 
splitResult.getTotalPushTime -
+              splitResult.getTotalWriteTime -
+              splitResult.getTotalCompressTime)
+      case SortShuffleWriterType =>
+        dep.metrics("sortTime").add(splitResult.getSortTime)
+        dep.metrics("c2rTime").add(splitResult.getC2RTime)
+    }
     dep.metrics("dataSize").add(splitResult.getRawPartitionLengths.sum)
     writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten)
     writeMetrics.incWriteTime(splitResult.getTotalWriteTime + 
splitResult.getTotalPushTime)
@@ -135,8 +141,8 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
       celebornPartitionPusher
     )
 
-    nativeShuffleWriter = shuffleWriterType match {
-      case ReservedKeys.GLUTEN_HASH_SHUFFLE_WRITER =>
+    nativeShuffleWriter = dep.shuffleWriterType match {
+      case HashShuffleWriterType =>
         shuffleWriterJniWrapper.createHashShuffleWriter(
           numPartitions,
           dep.nativePartitioning.getShortName,
@@ -145,7 +151,17 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
           GlutenConfig.get.columnarShuffleReallocThreshold,
           partitionWriterHandle
         )
-      case ReservedKeys.GLUTEN_RSS_SORT_SHUFFLE_WRITER =>
+      case SortShuffleWriterType =>
+        shuffleWriterJniWrapper.createSortShuffleWriter(
+          numPartitions,
+          dep.nativePartitioning.getShortName,
+          GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, 
context.partitionId),
+          conf.get(SHUFFLE_DISK_WRITE_BUFFER_SIZE).toInt,
+          conf.get(SHUFFLE_SORT_INIT_BUFFER_SIZE).toInt,
+          conf.get(SHUFFLE_SORT_USE_RADIXSORT),
+          partitionWriterHandle
+        )
+      case RssSortShuffleWriterType =>
         shuffleWriterJniWrapper.createRssSortShuffleWriter(
           numPartitions,
           dep.nativePartitioning.getShortName,
@@ -155,9 +171,9 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
           compressionCodec.orNull,
           partitionWriterHandle
         )
-      case _ =>
+      case other =>
         throw new UnsupportedOperationException(
-          s"Unsupported celeborn shuffle writer type: $shuffleWriterType")
+          s"Unsupported celeborn shuffle writer type: ${other.name}")
     }
 
     runtime
diff --git 
a/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/gluten/uniffle/UniffleShuffleManager.java
 
b/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/gluten/uniffle/UniffleShuffleManager.java
index 70c9621157..d63b3954dc 100644
--- 
a/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/gluten/uniffle/UniffleShuffleManager.java
+++ 
b/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/gluten/uniffle/UniffleShuffleManager.java
@@ -79,8 +79,7 @@ public class UniffleShuffleManager extends RssShuffleManager 
implements Supports
           shuffleWriteClient,
           rssHandle,
           this::markFailedTask,
-          context,
-          dependency.isSort());
+          context);
     } else {
       return super.getWriter(handle, mapId, context, metrics);
     }
diff --git 
a/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
 
b/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
index 96aaa5121a..9562dcbc4b 100644
--- 
a/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
+++ 
b/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.writer;
 import org.apache.gluten.backendsapi.BackendsApiManager;
 import org.apache.gluten.columnarbatch.ColumnarBatches;
 import org.apache.gluten.config.GlutenConfig;
+import org.apache.gluten.config.SortShuffleWriterType$;
 import org.apache.gluten.memory.memtarget.MemoryTarget;
 import org.apache.gluten.memory.memtarget.Spiller;
 import org.apache.gluten.memory.memtarget.Spillers;
@@ -80,7 +81,6 @@ public class VeloxUniffleColumnarShuffleWriter<K, V> extends 
RssShuffleWriter<K,
       ShuffleWriterJniWrapper.create(runtime);
   private final int nativeBufferSize = GlutenConfig.get().maxBatchSize();
   private final int bufferSize;
-  private final Boolean isSort;
   private final int numPartitions;
 
   private final ColumnarShuffleDependency<K, V, V> columnarDep;
@@ -103,8 +103,7 @@ public class VeloxUniffleColumnarShuffleWriter<K, V> 
extends RssShuffleWriter<K,
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, V> rssHandle,
       Function<String, Boolean> taskFailureCallback,
-      TaskContext context,
-      Boolean isSort) {
+      TaskContext context) {
     super(
         appId,
         shuffleId,
@@ -120,7 +119,6 @@ public class VeloxUniffleColumnarShuffleWriter<K, V> 
extends RssShuffleWriter<K,
     columnarDep = (ColumnarShuffleDependency<K, V, V>) 
rssHandle.getDependency();
     this.partitionId = partitionId;
     this.sparkConf = sparkConf;
-    this.isSort = isSort;
     this.numPartitions = columnarDep.nativePartitioning().getNumPartitions();
     bufferSize =
         (int)
@@ -166,7 +164,7 @@ public class VeloxUniffleColumnarShuffleWriter<K, V> 
extends RssShuffleWriter<K,
                   bufferSize,
                   partitionPusher);
 
-          if (isSort) {
+          if 
(columnarDep.shuffleWriterType().equals(SortShuffleWriterType$.MODULE$)) {
             nativeShuffleWriter =
                 shuffleWriterJniWrapper.createSortShuffleWriter(
                     numPartitions,
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
index 00c6ddb587..f74e6c08be 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
@@ -17,6 +17,7 @@
 package org.apache.gluten.backendsapi.velox
 
 import org.apache.gluten.backendsapi.MetricsApi
+import org.apache.gluten.config.{HashShuffleWriterType, 
RssSortShuffleWriterType, ShuffleWriterType, SortShuffleWriterType}
 import org.apache.gluten.metrics._
 import org.apache.gluten.substrait.{AggregationParams, JoinParams}
 
@@ -292,7 +293,7 @@ class VeloxMetricsApi extends MetricsApi with Logging {
 
   override def genColumnarShuffleExchangeMetrics(
       sparkContext: SparkContext,
-      isSort: Boolean): Map[String, SQLMetric] = {
+      shuffleWriterType: ShuffleWriterType): Map[String, SQLMetric] = {
     val baseMetrics = Map(
       "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of 
partitions"),
       "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
@@ -314,18 +315,25 @@ class VeloxMetricsApi extends MetricsApi with Logging {
       // row buffer + sort buffer size.
       "peakBytes" -> SQLMetrics.createSizeMetric(sparkContext, "peak bytes 
allocated")
     )
-    if (isSort) {
-      baseMetrics ++ Map(
-        "sortTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time to 
shuffle sort"),
-        "c2rTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time to 
shuffle c2r")
-      )
-    } else {
-      baseMetrics ++ Map(
-        "splitTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time 
to split"),
-        "avgDictionaryFields" -> SQLMetrics
-          .createAverageMetric(sparkContext, "avg dictionary fields"),
-        "dictionarySize" -> SQLMetrics.createSizeMetric(sparkContext, 
"dictionary size")
-      )
+    shuffleWriterType match {
+      case HashShuffleWriterType =>
+        baseMetrics ++ Map(
+          "splitTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time 
to split"),
+          "avgDictionaryFields" -> SQLMetrics
+            .createAverageMetric(sparkContext, "avg dictionary fields"),
+          "dictionarySize" -> SQLMetrics.createSizeMetric(sparkContext, 
"dictionary size")
+        )
+      case SortShuffleWriterType =>
+        baseMetrics ++ Map(
+          "sortTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time 
to shuffle sort"),
+          "c2rTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time 
to shuffle c2r")
+        )
+      case RssSortShuffleWriterType =>
+        baseMetrics ++ Map(
+          "sortTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time 
to shuffle sort")
+        )
+      case _ =>
+        baseMetrics
     }
   }
 
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 72fb3e3b13..111efd946f 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -17,7 +17,7 @@
 package org.apache.gluten.backendsapi.velox
 
 import org.apache.gluten.backendsapi.SparkPlanExecApi
-import org.apache.gluten.config.{GlutenConfig, ReservedKeys, VeloxConfig}
+import org.apache.gluten.config.{GlutenConfig, HashShuffleWriterType, 
ReservedKeys, RssSortShuffleWriterType, ShuffleWriterType, 
SortShuffleWriterType, VeloxConfig}
 import org.apache.gluten.exception.GlutenNotSupportException
 import org.apache.gluten.execution._
 import org.apache.gluten.expression._
@@ -534,7 +534,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
       serializer: Serializer,
       writeMetrics: Map[String, SQLMetric],
       metrics: Map[String, SQLMetric],
-      isSort: Boolean): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = 
{
+      shuffleWriterType: ShuffleWriterType)
+      : ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
     // scalastyle:on argcount
     ExecUtil.genShuffleDependency(
       rdd,
@@ -543,19 +544,39 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
       serializer,
       writeMetrics,
       metrics,
-      isSort)
+      shuffleWriterType)
   }
   // scalastyle:on argcount
 
   /** Determine whether to use sort-based shuffle based on shuffle 
partitioning and output. */
-  override def useSortBasedShuffle(partitioning: Partitioning, output: 
Seq[Attribute]): Boolean = {
+  override def getShuffleWriterType(
+      partitioning: Partitioning,
+      output: Seq[Attribute]): ShuffleWriterType = {
     val conf = GlutenConfig.get
-    lazy val isCelebornSortBasedShuffle = conf.isUseCelebornShuffleManager &&
-      conf.celebornShuffleWriterType == ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER
-    partitioning != SinglePartition &&
-    (partitioning.numPartitions >= 
GlutenConfig.get.columnarShuffleSortPartitionsThreshold ||
-      output.size >= GlutenConfig.get.columnarShuffleSortColumnsThreshold) ||
-    isCelebornSortBasedShuffle
+    if (conf.isUseCelebornShuffleManager) {
+      if (conf.celebornShuffleWriterType == 
ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER) {
+        if (conf.useCelebornRssSort) {
+          RssSortShuffleWriterType
+        } else if (partitioning != SinglePartition) {
+          SortShuffleWriterType
+        } else {
+          // If not using rss sort, we still use hash shuffle writer for 
single partitioning.
+          HashShuffleWriterType
+        }
+      } else {
+        HashShuffleWriterType
+      }
+    } else {
+      if (
+        partitioning != SinglePartition &&
+        (partitioning.numPartitions >= 
GlutenConfig.get.columnarShuffleSortPartitionsThreshold ||
+          output.size >= GlutenConfig.get.columnarShuffleSortColumnsThreshold)
+      ) {
+        SortShuffleWriterType
+      } else {
+        HashShuffleWriterType
+      }
+    }
   }
 
   /**
@@ -603,7 +624,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
   override def createColumnarBatchSerializer(
       schema: StructType,
       metrics: Map[String, SQLMetric],
-      isSort: Boolean): Serializer = {
+      shuffleWriterType: ShuffleWriterType): Serializer = {
     val numOutputRows = metrics("numOutputRows")
     val deserializeTime = metrics("deserializeTime")
     val readBatchNumRows = metrics("avgReadBatchNumRows")
@@ -611,8 +632,14 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
     if (GlutenConfig.get.isUseCelebornShuffleManager) {
       val clazz = 
ClassUtils.getClass("org.apache.spark.shuffle.CelebornColumnarBatchSerializer")
       val constructor =
-        clazz.getConstructor(classOf[StructType], classOf[SQLMetric], 
classOf[SQLMetric])
-      constructor.newInstance(schema, readBatchNumRows, 
numOutputRows).asInstanceOf[Serializer]
+        clazz.getConstructor(
+          classOf[StructType],
+          classOf[SQLMetric],
+          classOf[SQLMetric],
+          classOf[ShuffleWriterType])
+      constructor
+        .newInstance(schema, readBatchNumRows, numOutputRows, 
shuffleWriterType)
+        .asInstanceOf[Serializer]
     } else {
       new ColumnarBatchSerializer(
         schema,
@@ -620,7 +647,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
         numOutputRows,
         deserializeTime,
         decompressTime,
-        isSort)
+        shuffleWriterType)
     }
   }
 
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala
index e6ae9f1e54..aea311c01d 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.gluten.extension
 
+import org.apache.gluten.config.HashShuffleWriterType
 import org.apache.gluten.config.VeloxConfig
 import org.apache.gluten.execution.VeloxResizeBatchesExec
 
@@ -32,7 +33,7 @@ case class AppendBatchResizeForShuffleInputAndOutput() 
extends Rule[SparkPlan] {
     val range = VeloxConfig.get.veloxResizeBatchesShuffleInputOutputRange
     plan.transformUp {
       case shuffle: ColumnarShuffleExchangeExec
-          if !shuffle.useSortBasedShuffle &&
+          if shuffle.shuffleWriterType == HashShuffleWriterType &&
             VeloxConfig.get.veloxResizeBatchesShuffleInput =>
         val appendBatches =
           VeloxResizeBatchesExec(shuffle.child, range.min, range.max)
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
index 82cb64d959..53354f432b 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala
@@ -17,8 +17,7 @@
 package org.apache.gluten.vectorized
 
 import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.config.GlutenConfig
-import org.apache.gluten.config.ReservedKeys
+import org.apache.gluten.config.{GlutenConfig, ShuffleWriterType}
 import org.apache.gluten.iterator.ClosableIterator
 import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
 import org.apache.gluten.runtime.Runtimes
@@ -51,13 +50,10 @@ class ColumnarBatchSerializer(
     numOutputRows: SQLMetric,
     deserializeTime: SQLMetric,
     decompressTime: SQLMetric,
-    isSort: Boolean)
+    shuffleWriterType: ShuffleWriterType)
   extends Serializer
   with Serializable {
 
-  private val shuffleWriterType =
-    if (isSort) ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER else 
ReservedKeys.GLUTEN_HASH_SHUFFLE_WRITER
-
   /** Creates a new [[SerializerInstance]]. */
   override def newInstance(): SerializerInstance = {
     new ColumnarBatchSerializerInstance(
@@ -78,7 +74,7 @@ private class ColumnarBatchSerializerInstance(
     numOutputRows: SQLMetric,
     deserializeTime: SQLMetric,
     decompressTime: SQLMetric,
-    shuffleWriterType: String)
+    shuffleWriterType: ShuffleWriterType)
   extends SerializerInstance
   with Logging {
 
@@ -111,7 +107,7 @@ private class ColumnarBatchSerializerInstance(
       batchSize,
       readerBufferSize,
       deserializerBufferSize,
-      shuffleWriterType)
+      shuffleWriterType.name)
     // Close shuffle reader instance as lately as the end of task processing,
     // since the native reader could hold a reference to memory pool that
     // was used to create all buffers read from shuffle reader. The pool
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
 
b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
index 6f5348ac0b..34895ceca3 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
@@ -18,7 +18,7 @@ package org.apache.spark.shuffle
 
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.columnarbatch.ColumnarBatches
-import org.apache.gluten.config.GlutenConfig
+import org.apache.gluten.config.{GlutenConfig, HashShuffleWriterType, 
SortShuffleWriterType}
 import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers}
 import org.apache.gluten.runtime.Runtimes
 import org.apache.gluten.vectorized._
@@ -42,7 +42,17 @@ class ColumnarShuffleWriter[K, V](
   with Logging {
 
   private val dep = 
handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]
-  protected val isSort: Boolean = dep.isSort
+
+  dep.shuffleWriterType match {
+    case HashShuffleWriterType | SortShuffleWriterType =>
+    // Valid shuffle writer types
+    case _ =>
+      throw new IllegalArgumentException(
+        s"Unsupported shuffle writer type: ${dep.shuffleWriterType.name}, " +
+          s"expected one of: ${HashShuffleWriterType.name}, 
${SortShuffleWriterType.name}")
+  }
+
+  protected val isSort: Boolean = dep.shuffleWriterType == 
SortShuffleWriterType
 
   private val numPartitions: Int = dep.partitioner.numPartitions
 
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala
index e3a84d4f07..dcfa0ee525 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.utils
 
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches}
+import org.apache.gluten.config.ShuffleWriterType
 import org.apache.gluten.iterator.Iterators
 import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
 import org.apache.gluten.runtime.Runtimes
@@ -88,7 +89,8 @@ object ExecUtil {
       serializer: Serializer,
       writeMetrics: Map[String, SQLMetric],
       metrics: Map[String, SQLMetric],
-      isSort: Boolean): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = 
{
+      shuffleWriterType: ShuffleWriterType)
+      : ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
     metrics("numPartitions").set(newPartitioning.numPartitions)
     val executionId = 
rdd.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
     SQLMetrics.postDriverMetricUpdates(
@@ -209,7 +211,7 @@ object ExecUtil {
         shuffleWriterProcessor = 
ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
         nativePartitioning = nativePartitioning,
         metrics = metrics,
-        isSort = isSort
+        shuffleWriterType = shuffleWriterType
       )
 
     dependency
diff --git a/cpp/core/shuffle/rss/RssPartitionWriter.cc 
b/cpp/core/shuffle/rss/RssPartitionWriter.cc
index 996c137671..3ded2fc0ba 100644
--- a/cpp/core/shuffle/rss/RssPartitionWriter.cc
+++ b/cpp/core/shuffle/rss/RssPartitionWriter.cc
@@ -36,10 +36,9 @@ arrow::Status RssPartitionWriter::stop(ShuffleWriterMetrics* 
metrics) {
       compressTime_ = compressedOs_->compressTime();
       spillTime_ -= compressTime_;
     }
-    RETURN_NOT_OK(rssOs_->Flush());
-    ARROW_ASSIGN_OR_RAISE(const auto evicted, rssOs_->Tell());
-    bytesEvicted_[lastEvictedPartitionId_] += evicted;
-    RETURN_NOT_OK(rssOs_->Close());
+    ARROW_ASSIGN_OR_RAISE(const auto buffer, rssOs_->Finish());
+    bytesEvicted_[lastEvictedPartitionId_] +=
+        rssClient_->pushPartitionData(lastEvictedPartitionId_, 
buffer->data_as<char>(), buffer->size());
   }
 
   rssClient_->stop();
@@ -73,19 +72,18 @@ RssPartitionWriter::sortEvict(uint32_t partitionId, 
std::unique_ptr<InMemoryPayl
   ScopedTimer timer(&spillTime_);
   if (lastEvictedPartitionId_ != partitionId) {
     if (lastEvictedPartitionId_ != -1) {
-      GLUTEN_DCHECK(rssOs_ != nullptr && !rssOs_->closed(), 
"RssPartitionWriterOutputStream should not be null");
+      GLUTEN_DCHECK(rssOs_ != nullptr && !rssOs_->closed(), "rssOs_ should not 
be null");
       if (compressedOs_ != nullptr) {
         RETURN_NOT_OK(compressedOs_->Flush());
       }
-      RETURN_NOT_OK(rssOs_->Flush());
-      ARROW_ASSIGN_OR_RAISE(const auto evicted, rssOs_->Tell());
-      bytesEvicted_[lastEvictedPartitionId_] += evicted;
-      RETURN_NOT_OK(rssOs_->Close());
+
+      ARROW_ASSIGN_OR_RAISE(const auto buffer, rssOs_->Finish());
+      bytesEvicted_[lastEvictedPartitionId_] +=
+          rssClient_->pushPartitionData(lastEvictedPartitionId_, 
buffer->data_as<char>(), buffer->size());
     }
 
-    rssOs_ =
-        std::make_shared<RssPartitionWriterOutputStream>(partitionId, 
rssClient_.get(), options_->pushBufferMaxSize);
-    RETURN_NOT_OK(rssOs_->init());
+    ARROW_ASSIGN_OR_RAISE(
+        rssOs_, 
arrow::io::BufferOutputStream::Create(options_->pushBufferMaxSize, 
arrow::default_memory_pool()));
     if (codec_ != nullptr) {
       ARROW_ASSIGN_OR_RAISE(
           compressedOs_,
diff --git a/cpp/core/shuffle/rss/RssPartitionWriter.h 
b/cpp/core/shuffle/rss/RssPartitionWriter.h
index 3e3a078aa5..0b4d740984 100644
--- a/cpp/core/shuffle/rss/RssPartitionWriter.h
+++ b/cpp/core/shuffle/rss/RssPartitionWriter.h
@@ -67,83 +67,8 @@ class RssPartitionWriter final : public PartitionWriter {
   std::vector<int64_t> rawPartitionLengths_;
 
   int32_t lastEvictedPartitionId_{-1};
-  std::shared_ptr<RssPartitionWriterOutputStream> rssOs_;
-  std::shared_ptr<ShuffleCompressedOutputStream> compressedOs_;
+  std::shared_ptr<arrow::io::BufferOutputStream> rssOs_{nullptr};
+  std::shared_ptr<ShuffleCompressedOutputStream> compressedOs_{nullptr};
 };
 
-class RssPartitionWriterOutputStream final : public arrow::io::OutputStream {
- public:
-  RssPartitionWriterOutputStream(int32_t partitionId, RssClient* rssClient, 
int64_t pushBufferSize)
-      : partitionId_(partitionId), rssClient_(rssClient), 
bufferSize_(pushBufferSize) {}
-
-  arrow::Status init() {
-    ARROW_ASSIGN_OR_RAISE(pushBuffer_, arrow::AllocateBuffer(bufferSize_, 
arrow::default_memory_pool()));
-    pushBufferPtr_ = pushBuffer_->mutable_data();
-    return arrow::Status::OK();
-  }
-
-  arrow::Status Close() override {
-    RETURN_NOT_OK(Flush());
-    pushBuffer_.reset();
-    return arrow::Status::OK();
-  }
-
-  bool closed() const override {
-    return pushBuffer_ == nullptr;
-  }
-
-  arrow::Result<int64_t> Tell() const override {
-    return bytesEvicted_ + bufferPos_;
-  }
-
-  arrow::Status Write(const void* data, int64_t nbytes) override {
-    auto dataPtr = static_cast<const char*>(data);
-    if (nbytes < 0) {
-      return arrow::Status::Invalid("write count should be >= 0");
-    }
-    if (nbytes == 0) {
-      return arrow::Status::OK();
-    }
-
-    if (nbytes + bufferPos_ <= bufferSize_) {
-      std::memcpy(pushBufferPtr_ + bufferPos_, dataPtr, nbytes);
-      bufferPos_ += nbytes;
-      return arrow::Status::OK();
-    }
-
-    int64_t bytesWritten = 0;
-    while (bytesWritten < nbytes) {
-      auto remaining = nbytes - bytesWritten;
-      if (remaining <= bufferSize_ - bufferPos_) {
-        std::memcpy(pushBufferPtr_ + bufferPos_, dataPtr + bytesWritten, 
remaining);
-        bufferPos_ += remaining;
-        return arrow::Status::OK();
-      }
-      auto toWrite = bufferSize_ - bufferPos_;
-      std::memcpy(pushBufferPtr_ + bufferPos_, dataPtr + bytesWritten, 
toWrite);
-      bytesWritten += toWrite;
-      bufferPos_ += toWrite;
-      RETURN_NOT_OK(Flush());
-    }
-    return arrow::Status::OK();
-  }
-
-  arrow::Status Flush() override {
-    if (bufferPos_ > 0) {
-      bytesEvicted_ += rssClient_->pushPartitionData(partitionId_, 
reinterpret_cast<char*>(pushBufferPtr_), bufferPos_);
-      bufferPos_ = 0;
-    }
-    return arrow::Status::OK();
-  }
-
- private:
-  int32_t partitionId_;
-  RssClient* rssClient_;
-  int64_t bufferSize_{kDefaultPushMemoryThreshold};
-
-  std::shared_ptr<arrow::Buffer> pushBuffer_;
-  uint8_t* pushBufferPtr_{nullptr};
-  int64_t bufferPos_{0};
-  int64_t bytesEvicted_{0};
-};
 } // namespace gluten
diff --git a/docs/Configuration.md b/docs/Configuration.md
index 63619aee90..cdf809ed42 100644
--- a/docs/Configuration.md
+++ b/docs/Configuration.md
@@ -107,6 +107,7 @@ nav_order: 15
 | spark.gluten.sql.columnar.scanOnly                                 | false   
          | When enabled, only scan and the filter after scan will be offloaded 
to native.                                                                      
                                                                                
                                                                                
                                                                                
              [...]
 | spark.gluten.sql.columnar.shuffle                                  | true    
          | Enable or disable columnar shuffle.                                 
                                                                                
                                                                                
                                                                                
                                                                                
              [...]
 | spark.gluten.sql.columnar.shuffle.celeborn.fallback.enabled        | true    
          | If enabled, fall back to ColumnarShuffleManager when celeborn 
service is unavailable.Otherwise, throw an exception.                           
                                                                                
                                                                                
                                                                                
                    [...]
+| spark.gluten.sql.columnar.shuffle.celeborn.useRssSort              | true    
          | If true, use RSS sort implementation for Celeborn sort-based 
shuffle.If false, use Gluten's row-based sort implementation. Only valid when 
`spark.celeborn.client.spark.shuffle.writer` is set to `sort`.                  
                                                                                
                                                                                
                       [...]
 | spark.gluten.sql.columnar.shuffle.codec                            | 
&lt;undefined&gt; | By default, the supported codecs are lz4 and zstd. When 
spark.gluten.sql.columnar.shuffle.codecBackend=qat,the supported codecs are 
gzip and zstd. When spark.gluten.sql.columnar.shuffle.codecBackend=iaa,the 
supported codec is gzip.                                                        
                                                                                
                                   [...]
 | spark.gluten.sql.columnar.shuffle.codecBackend                     | 
&lt;undefined&gt; |
 | spark.gluten.sql.columnar.shuffle.compression.threshold            | 100     
          | If number of rows in a batch falls below this threshold, will copy 
all buffers into one buffer to compress.                                        
                                                                                
                                                                                
                                                                                
               [...]
diff --git a/docs/get-started/Velox.md b/docs/get-started/Velox.md
index 7c742b9c60..a6d5f6a2bb 100644
--- a/docs/get-started/Velox.md
+++ b/docs/get-started/Velox.md
@@ -284,6 +284,14 @@ spark.celeborn.storage.hdfs.dir hdfs://<namenode>/celeborn
 spark.dynamicAllocation.enabled false
 ```
 
+Additionally, for sort-based shuffle, Celeborn supports two types of shuffle 
writers: the default row-based sort shuffle writer and the RSS sort shuffle 
writer.
+By default, Celeborn uses the RSS sort shuffle writer. You can switch to the 
default row-based sort shuffle writer
+by setting the following configuration:
+
+```
+spark.gluten.sql.columnar.shuffle.celeborn.useRssSort false
+```
+
 ## Uniffle support
 
 Uniffle with velox backend supports 
[Uniffle](https://github.com/apache/incubator-uniffle) as remote shuffle 
service. Currently, the supported Uniffle versions are `0.9.2`.
diff --git 
a/gluten-celeborn/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
 
b/gluten-celeborn/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
index 31a3c772f8..451a020f6f 100644
--- 
a/gluten-celeborn/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
+++ 
b/gluten-celeborn/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
@@ -17,7 +17,6 @@
 package org.apache.spark.shuffle
 
 import org.apache.gluten.config.GlutenConfig
-import org.apache.gluten.config.ReservedKeys
 
 import org.apache.spark._
 import org.apache.spark.internal.Logging
@@ -31,7 +30,6 @@ import org.apache.celeborn.client.ShuffleClient
 import org.apache.celeborn.common.CelebornConf
 
 import java.io.IOException
-import java.util.Locale
 
 abstract class CelebornColumnarShuffleWriter[K, V](
     shuffleId: Int,
@@ -71,11 +69,6 @@ abstract class CelebornColumnarShuffleWriter[K, V](
 
   protected val clientPushSortMemoryThreshold: Long = 
celebornConf.clientPushSortMemoryThreshold
 
-  protected val shuffleWriterType: String =
-    celebornConf.shuffleWriterMode.name
-      .toLowerCase(Locale.ROOT)
-      .replace(ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER, 
ReservedKeys.GLUTEN_RSS_SORT_SHUFFLE_WRITER)
-
   protected val celebornPartitionPusher = new CelebornPartitionPusher(
     shuffleId,
     numMappers,
diff --git 
a/gluten-delta/src-delta33/main/scala/org/apache/spark/sql/perf/DeltaOptimizedWriterTransformer.scala
 
b/gluten-delta/src-delta33/main/scala/org/apache/spark/sql/perf/DeltaOptimizedWriterTransformer.scala
index 9f601d8466..800e4d1766 100644
--- 
a/gluten-delta/src-delta33/main/scala/org/apache/spark/sql/perf/DeltaOptimizedWriterTransformer.scala
+++ 
b/gluten-delta/src-delta33/main/scala/org/apache/spark/sql/perf/DeltaOptimizedWriterTransformer.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql.perf
 
 import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.config.ShuffleWriterType
 import org.apache.gluten.execution.GlutenPlan
 import org.apache.gluten.extension.columnar.transition.Convention
 
@@ -70,8 +71,8 @@ case class DeltaOptimizedWriterTransformer(
   with GlutenPlan
   with DeltaLogging {
 
-  lazy val useSortBasedShuffle: Boolean =
-    
BackendsApiManager.getSparkPlanExecApiInstance.useSortBasedShuffle(outputPartitioning,
 output)
+  lazy val shuffleWriterType: ShuffleWriterType =
+    
BackendsApiManager.getSparkPlanExecApiInstance.getShuffleWriterType(outputPartitioning,
 output)
 
   override def output: Seq[Attribute] = child.output
 
@@ -86,7 +87,7 @@ case class DeltaOptimizedWriterTransformer(
     BackendsApiManager.getMetricsApiInstance
       .genColumnarShuffleExchangeMetrics(
         sparkContext,
-        useSortBasedShuffle) ++ readMetrics ++ writeMetrics
+        shuffleWriterType) ++ readMetrics ++ writeMetrics
 
   @transient lazy val inputColumnarRDD: RDD[ColumnarBatch] = 
child.executeColumnar()
 
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
index 453cfab4e4..a5944f50d8 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.gluten.backendsapi
 
+import org.apache.gluten.config.ShuffleWriterType
 import org.apache.gluten.metrics.{IMetrics, MetricsUpdater}
 import org.apache.gluten.substrait.{AggregationParams, JoinParams}
 
@@ -84,7 +85,7 @@ trait MetricsApi extends Serializable {
 
   def genColumnarShuffleExchangeMetrics(
       sparkContext: SparkContext,
-      isSort: Boolean): Map[String, SQLMetric]
+      shuffleWriterType: ShuffleWriterType): Map[String, SQLMetric]
 
   def genWindowTransformerMetrics(sparkContext: SparkContext): Map[String, 
SQLMetric]
 
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 2dd22de760..53d6f9211e 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.gluten.backendsapi
 
+import org.apache.gluten.config.{HashShuffleWriterType, ShuffleWriterType}
 import org.apache.gluten.exception.GlutenNotSupportException
 import org.apache.gluten.execution._
 import org.apache.gluten.expression._
@@ -345,10 +346,14 @@ trait SparkPlanExecApi {
       serializer: Serializer,
       writeMetrics: Map[String, SQLMetric],
       metrics: Map[String, SQLMetric],
-      isSort: Boolean): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch]
+      shuffleWriterType: ShuffleWriterType): ShuffleDependency[Int, 
ColumnarBatch, ColumnarBatch]
 
   /** Determine whether to use sort-based shuffle based on shuffle 
partitioning and output. */
-  def useSortBasedShuffle(partitioning: Partitioning, output: Seq[Attribute]): 
Boolean
+  def getShuffleWriterType(
+      partitioning: Partitioning,
+      output: Seq[Attribute]): ShuffleWriterType = {
+    HashShuffleWriterType
+  }
 
   /**
    * Generate ColumnarShuffleWriter for ColumnarShuffleManager.
@@ -366,7 +371,7 @@ trait SparkPlanExecApi {
   def createColumnarBatchSerializer(
       schema: StructType,
       metrics: Map[String, SQLMetric],
-      isSort: Boolean): Serializer
+      shuffleWriterType: ShuffleWriterType): Serializer
 
   /** Create broadcast relation for BroadcastExchangeExec */
   def createBroadcastRelation(
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala 
b/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
index 1c094faa4c..01c4730d8c 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
@@ -34,6 +34,22 @@ case class GlutenNumaBindingInfo(
     totalCoreRange: Array[String] = null,
     numCoresPerExecutor: Int = -1) {}
 
+trait ShuffleWriterType {
+  val name: String
+}
+
+case object HashShuffleWriterType extends ShuffleWriterType {
+  override val name: String = ReservedKeys.GLUTEN_HASH_SHUFFLE_WRITER
+}
+
+case object SortShuffleWriterType extends ShuffleWriterType {
+  override val name: String = ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER
+}
+
+case object RssSortShuffleWriterType extends ShuffleWriterType {
+  override val name: String = ReservedKeys.GLUTEN_RSS_SORT_SHUFFLE_WRITER
+}
+
 class GlutenConfig(conf: SQLConf) extends GlutenCoreConfig(conf) {
   import GlutenConfig._
 
@@ -334,6 +350,8 @@ class GlutenConfig(conf: SQLConf) extends 
GlutenCoreConfig(conf) {
 
   def enableCelebornFallback: Boolean = getConf(CELEBORN_FALLBACK_ENABLED)
 
+  def useCelebornRssSort: Boolean = getConf(CELEBORN_USE_RSS_SORT)
+
   def enableHdfsViewfs: Boolean = getConf(HDFS_VIEWFS_ENABLED)
 
   def parquetEncryptionValidationEnabled: Boolean = 
getConf(ENCRYPTED_PARQUET_FALLBACK_ENABLED)
@@ -1526,6 +1544,16 @@ object GlutenConfig {
       .booleanConf
       .createWithDefault(true)
 
+  val CELEBORN_USE_RSS_SORT =
+    buildConf("spark.gluten.sql.columnar.shuffle.celeborn.useRssSort")
+      .internal()
+      .doc(
+        "If true, use RSS sort implementation for Celeborn sort-based 
shuffle." +
+          "If false, use Gluten's row-based sort implementation. " +
+          "Only valid when `spark.celeborn.client.spark.shuffle.writer` is set 
to `sort`.")
+      .booleanConf
+      .createWithDefault(true)
+
   val HDFS_VIEWFS_ENABLED =
     buildStaticConf("spark.gluten.storage.hdfsViewfs.enabled")
       .internal()
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarCollectLimitBaseExec.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarCollectLimitBaseExec.scala
index 8941b98ff8..da1999a24c 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarCollectLimitBaseExec.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarCollectLimitBaseExec.scala
@@ -17,6 +17,7 @@
 package org.apache.gluten.execution
 
 import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.config.ShuffleWriterType
 import org.apache.gluten.extension.columnar.transition.Convention
 
 import org.apache.spark.rdd.RDD
@@ -46,17 +47,17 @@ abstract class ColumnarCollectLimitBaseExec(
   private lazy val readMetrics =
     
SQLColumnarShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
 
-  private lazy val useSortBasedShuffle: Boolean =
+  private lazy val shuffleWriterType: ShuffleWriterType =
     BackendsApiManager.getSparkPlanExecApiInstance
-      .useSortBasedShuffle(outputPartitioning, child.output)
+      .getShuffleWriterType(outputPartitioning, child.output)
 
   @transient private lazy val serializer: Serializer =
     BackendsApiManager.getSparkPlanExecApiInstance
-      .createColumnarBatchSerializer(child.schema, metrics, 
useSortBasedShuffle)
+      .createColumnarBatchSerializer(child.schema, metrics, shuffleWriterType)
 
   @transient override lazy val metrics: Map[String, SQLMetric] =
     BackendsApiManager.getMetricsApiInstance
-      .genColumnarShuffleExchangeMetrics(sparkContext, useSortBasedShuffle) ++
+      .genColumnarShuffleExchangeMetrics(sparkContext, shuffleWriterType) ++
       readMetrics ++ writeMetrics
 
   override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
@@ -99,7 +100,7 @@ abstract class ColumnarCollectLimitBaseExec(
         serializer,
         writeMetrics,
         metrics,
-        useSortBasedShuffle
+        shuffleWriterType
       ),
       readMetrics
     )
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarCollectTailBaseExec.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarCollectTailBaseExec.scala
index f88f156708..56f9ce69ee 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarCollectTailBaseExec.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarCollectTailBaseExec.scala
@@ -17,6 +17,7 @@
 package org.apache.gluten.execution
 
 import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.config.ShuffleWriterType
 import org.apache.gluten.extension.columnar.transition.Convention
 
 import org.apache.spark.rdd.RDD
@@ -44,17 +45,17 @@ abstract class ColumnarCollectTailBaseExec(
   private lazy val readMetrics =
     
SQLColumnarShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
 
-  private lazy val useSortBasedShuffle: Boolean =
+  private lazy val shuffleWriterType: ShuffleWriterType =
     BackendsApiManager.getSparkPlanExecApiInstance
-      .useSortBasedShuffle(outputPartitioning, child.output)
+      .getShuffleWriterType(outputPartitioning, child.output)
 
   @transient private lazy val serializer: Serializer =
     BackendsApiManager.getSparkPlanExecApiInstance
-      .createColumnarBatchSerializer(child.schema, metrics, 
useSortBasedShuffle)
+      .createColumnarBatchSerializer(child.schema, metrics, shuffleWriterType)
 
   @transient override lazy val metrics: Map[String, SQLMetric] =
     BackendsApiManager.getMetricsApiInstance
-      .genColumnarShuffleExchangeMetrics(sparkContext, useSortBasedShuffle) ++
+      .genColumnarShuffleExchangeMetrics(sparkContext, shuffleWriterType) ++
       readMetrics ++ writeMetrics
 
   override def rowType0(): Convention.RowType = Convention.RowType.None
@@ -98,7 +99,7 @@ abstract class ColumnarCollectTailBaseExec(
         serializer,
         writeMetrics,
         metrics,
-        useSortBasedShuffle
+        shuffleWriterType
       ),
       readMetrics
     )
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala
index 9f9f867ff6..7ce554d335 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.shuffle
 
+import org.apache.gluten.config.{HashShuffleWriterType, ShuffleWriterType}
 import org.apache.gluten.vectorized.NativePartitioning
 
 import org.apache.spark.{Aggregator, Partitioner, ShuffleDependency, SparkEnv}
@@ -59,7 +60,7 @@ class ColumnarShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
     override val shuffleWriterProcessor: ShuffleWriteProcessor = new 
ShuffleWriteProcessor,
     val nativePartitioning: NativePartitioning,
     val metrics: Map[String, SQLMetric],
-    val isSort: Boolean = false)
+    val shuffleWriterType: ShuffleWriterType = HashShuffleWriterType)
   extends ShuffleDependency[K, V, C](
     _rdd,
     partitioner,
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
index 5c48a5c045..8611410592 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
@@ -17,7 +17,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.config.ReservedKeys
+import org.apache.gluten.config.ShuffleWriterType
 import org.apache.gluten.execution.ValidatablePlan
 import org.apache.gluten.execution.ValidationResult
 import org.apache.gluten.extension.columnar.transition.Convention
@@ -53,15 +53,15 @@ case class ColumnarShuffleExchangeExec(
   private[sql] lazy val readMetrics =
     
SQLColumnarShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
 
-  val useSortBasedShuffle: Boolean =
-    
BackendsApiManager.getSparkPlanExecApiInstance.useSortBasedShuffle(outputPartitioning,
 output)
+  val shuffleWriterType: ShuffleWriterType =
+    
BackendsApiManager.getSparkPlanExecApiInstance.getShuffleWriterType(outputPartitioning,
 output)
 
   // Note: "metrics" is made transient to avoid sending driver-side metrics to 
tasks.
   @transient override lazy val metrics =
     BackendsApiManager.getMetricsApiInstance
       .genColumnarShuffleExchangeMetrics(
         sparkContext,
-        useSortBasedShuffle) ++ readMetrics ++ writeMetrics
+        shuffleWriterType) ++ readMetrics ++ writeMetrics
 
   @transient lazy val inputColumnarRDD: RDD[ColumnarBatch] = 
child.executeColumnar()
 
@@ -89,12 +89,12 @@ case class ColumnarShuffleExchangeExec(
       serializer,
       writeMetrics,
       metrics,
-      useSortBasedShuffle)
+      shuffleWriterType)
   }
 
   // super.stringArgs ++ Iterator(output.map(o => 
s"${o}#${o.dataType.simpleString}"))
   val serializer: Serializer = BackendsApiManager.getSparkPlanExecApiInstance
-    .createColumnarBatchSerializer(schema, metrics, useSortBasedShuffle)
+    .createColumnarBatchSerializer(schema, metrics, shuffleWriterType)
 
   var cachedShuffleRDD: ShuffledColumnarBatchRDD = _
 
@@ -126,11 +126,7 @@ case class ColumnarShuffleExchangeExec(
   }
 
   override def stringArgs: Iterator[Any] = {
-    val shuffleWriterType = {
-      if (useSortBasedShuffle) ReservedKeys.GLUTEN_SORT_SHUFFLE_WRITER
-      else ReservedKeys.GLUTEN_HASH_SHUFFLE_WRITER
-    }
-    super.stringArgs ++ Iterator(s"[shuffle_writer_type=$shuffleWriterType]")
+    super.stringArgs ++ 
Iterator(s"[shuffle_writer_type=${shuffleWriterType.name}]")
   }
 
   override def batchType(): Convention.BatchType = 
BackendsApiManager.getSettings.primaryBatchType


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to