This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 30bddd0c9 [CELEBORN] CHCelebornColumnarShuffleWriter supports 
celeborn.client.spark.shuffle.writer to use memory sort shuffle in ClickHouse 
backend (#6432)
30bddd0c9 is described below

commit 30bddd0c9ba57c828202283a48349b6d1f11b230
Author: Nicholas Jiang <[email protected]>
AuthorDate: Mon Jul 15 10:29:06 2024 +0800

    [CELEBORN] CHCelebornColumnarShuffleWriter supports 
celeborn.client.spark.shuffle.writer to use memory sort shuffle in ClickHouse 
backend (#6432)
---
 .../shuffle/CHCelebornColumnarShuffleWriter.scala  | 88 ++++++++++-----------
 .../shuffle/CelebornColumnarShuffleWriter.scala    | 36 +++++++--
 .../VeloxCelebornColumnarShuffleWriter.scala       | 91 ++++++++--------------
 3 files changed, 103 insertions(+), 112 deletions(-)

diff --git 
a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
 
b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
index 7276e0f2c..e5cd3d22f 100644
--- 
a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
+++ 
b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import org.apache.celeborn.client.ShuffleClient
 import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.protocol.ShuffleMode
 
 import java.io.IOException
 import java.util.Locale
@@ -55,61 +56,16 @@ class CHCelebornColumnarShuffleWriter[K, V](
 
   private var splitResult: CHSplitResult = _
 
-  private val nativeBufferSize: Int = 
GlutenConfig.getConf.shuffleWriterBufferSize
-
   @throws[IOException]
   override def internalWrite(records: Iterator[Product2[K, V]]): Unit = {
-    if (!records.hasNext) {
-      handleEmptyIterator()
-      return
-    }
-
-    if (nativeShuffleWriter == -1L) {
-      nativeShuffleWriter = jniWrapper.makeForRSS(
-        dep.nativePartitioning,
-        shuffleId,
-        mapId,
-        nativeBufferSize,
-        customizedCompressCodec,
-        GlutenConfig.getConf.chColumnarShuffleSpillThreshold,
-        CHBackendSettings.shuffleHashAlgorithm,
-        celebornPartitionPusher,
-        GlutenConfig.getConf.chColumnarThrowIfMemoryExceed,
-        GlutenConfig.getConf.chColumnarFlushBlockBufferBeforeEvict,
-        GlutenConfig.getConf.chColumnarForceExternalSortShuffle,
-        GlutenConfig.getConf.chColumnarForceMemorySortShuffle
-      )
-      CHNativeMemoryAllocators.createSpillable(
-        "CelebornShuffleWriter",
-        new Spiller() {
-          override def spill(self: MemoryTarget, phase: Spiller.Phase, size: 
Long): Long = {
-            if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
-              return 0L
-            }
-            if (nativeShuffleWriter == -1L) {
-              throw new IllegalStateException(
-                "Fatal: spill() called before a celeborn shuffle writer " +
-                  "is created. This behavior should be" +
-                  "optimized by moving memory " +
-                  "allocations from make() to split()")
-            }
-            logInfo(s"Gluten shuffle writer: Trying to push $size bytes of 
data")
-            val spilled = jniWrapper.evict(nativeShuffleWriter)
-            logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of 
data")
-            spilled
-          }
-        }
-      )
-    }
     while (records.hasNext) {
       val cb = records.next()._2.asInstanceOf[ColumnarBatch]
       if (cb.numRows == 0 || cb.numCols == 0) {
         logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} 
cols")
       } else {
+        initShuffleWriter()
         val col = cb.column(0).asInstanceOf[CHColumnVector]
-        val block = col.getBlockAddress
-        jniWrapper
-          .split(nativeShuffleWriter, block)
+        jniWrapper.split(nativeShuffleWriter, col.getBlockAddress)
         dep.metrics("numInputRows").add(cb.numRows)
         dep.metrics("inputBatches").add(1)
         // This metric is important, AQE use it to decide if EliminateLimit
@@ -117,6 +73,7 @@ class CHCelebornColumnarShuffleWriter[K, V](
       }
     }
 
+    assert(nativeShuffleWriter != -1L)
     splitResult = jniWrapper.stop(nativeShuffleWriter)
 
     dep.metrics("splitTime").add(splitResult.getSplitTime)
@@ -135,6 +92,43 @@ class CHCelebornColumnarShuffleWriter[K, V](
     mapStatus = MapStatus(blockManager.shuffleServerId, 
splitResult.getRawPartitionLengths, mapId)
   }
 
+  override def createShuffleWriter(columnarBatch: ColumnarBatch): Unit = {
+    nativeShuffleWriter = jniWrapper.makeForRSS(
+      dep.nativePartitioning,
+      shuffleId,
+      mapId,
+      nativeBufferSize,
+      customizedCompressCodec,
+      GlutenConfig.getConf.chColumnarShuffleSpillThreshold,
+      CHBackendSettings.shuffleHashAlgorithm,
+      celebornPartitionPusher,
+      GlutenConfig.getConf.chColumnarThrowIfMemoryExceed,
+      GlutenConfig.getConf.chColumnarFlushBlockBufferBeforeEvict,
+      GlutenConfig.getConf.chColumnarForceExternalSortShuffle,
+      GlutenConfig.getConf.chColumnarForceMemorySortShuffle
+        || ShuffleMode.SORT.name.equalsIgnoreCase(shuffleWriterType)
+    )
+    CHNativeMemoryAllocators.createSpillable(
+      "CelebornShuffleWriter",
+      new Spiller() {
+        override def spill(self: MemoryTarget, phase: Spiller.Phase, size: 
Long): Long = {
+          if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
+            return 0L
+          }
+          if (nativeShuffleWriter == -1L) {
+            throw new IllegalStateException(
+              "Fatal: spill() called before a celeborn shuffle writer is 
created. " +
+                "This behavior should be optimized by moving memory 
allocations from make() to split()")
+          }
+          logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data")
+          val spilled = jniWrapper.evict(nativeShuffleWriter)
+          logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of 
data")
+          spilled
+        }
+      }
+    )
+  }
+
   override def closeShuffleWriter(): Unit = {
     jniWrapper.close(nativeShuffleWriter)
   }
diff --git 
a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
 
b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
index d58eeb195..f5ed8c3d8 100644
--- 
a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
+++ 
b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
@@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.SHUFFLE_COMPRESS
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle
+import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.storage.BlockManager
 
 import org.apache.celeborn.client.ShuffleClient
@@ -52,12 +53,23 @@ abstract class CelebornColumnarShuffleWriter[K, V](
 
   protected val mapId: Int = context.partitionId()
 
+  protected lazy val nativeBufferSize: Int = {
+    val bufferSize = GlutenConfig.getConf.shuffleWriterBufferSize
+    val maxBatchSize = GlutenConfig.getConf.maxBatchSize
+    if (bufferSize > maxBatchSize) {
+      logInfo(
+        s"${GlutenConfig.SHUFFLE_WRITER_BUFFER_SIZE.key} ($bufferSize) exceeds 
max " +
+          s" batch size. Limited to 
${GlutenConfig.COLUMNAR_MAX_BATCH_SIZE.key} ($maxBatchSize).")
+      maxBatchSize
+    } else {
+      bufferSize
+    }
+  }
+
   protected val clientPushBufferMaxSize: Int = 
celebornConf.clientPushBufferMaxSize
 
   protected val clientPushSortMemoryThreshold: Long = 
celebornConf.clientPushSortMemoryThreshold
 
-  protected val clientSortMemoryMaxSize: Long = 
celebornConf.clientPushSortMemoryThreshold
-
   protected val shuffleWriterType: String =
     celebornConf.shuffleWriterMode.name.toLowerCase(Locale.ROOT)
 
@@ -96,6 +108,12 @@ abstract class CelebornColumnarShuffleWriter[K, V](
 
   @throws[IOException]
   final override def write(records: Iterator[Product2[K, V]]): Unit = {
+    if (!records.hasNext) {
+      partitionLengths = new Array[Long](dep.partitioner.numPartitions)
+      client.mapperEnd(shuffleId, mapId, context.attemptNumber, numMappers)
+      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, 
mapId)
+      return
+    }
     internalWrite(records)
   }
 
@@ -122,10 +140,18 @@ abstract class CelebornColumnarShuffleWriter[K, V](
     }
   }
 
+  def createShuffleWriter(columnarBatch: ColumnarBatch): Unit = {}
+
   def closeShuffleWriter(): Unit = {}
 
   def getPartitionLengths: Array[Long] = partitionLengths
 
+  def initShuffleWriter(columnarBatch: ColumnarBatch): Unit = {
+    if (nativeShuffleWriter == -1L) {
+      createShuffleWriter(columnarBatch)
+    }
+  }
+
   def pushMergedDataToCeleborn(): Unit = {
     val pushMergedDataTime = System.nanoTime
     client.prepareForMergeData(shuffleId, mapId, context.attemptNumber())
@@ -133,10 +159,4 @@ abstract class CelebornColumnarShuffleWriter[K, V](
     client.mapperEnd(shuffleId, mapId, context.attemptNumber, numMappers)
     writeMetrics.incWriteTime(System.nanoTime - pushMergedDataTime)
   }
-
-  def handleEmptyIterator(): Unit = {
-    partitionLengths = new Array[Long](dep.partitioner.numPartitions)
-    client.mapperEnd(shuffleId, mapId, context.attemptNumber, numMappers)
-    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, 
mapId)
-  }
 }
diff --git 
a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
 
b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
index c93255eaa..baf61b8a1 100644
--- 
a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
+++ 
b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
@@ -55,25 +55,6 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
 
   private var splitResult: SplitResult = _
 
-  private lazy val nativeBufferSize = {
-    val bufferSize = GlutenConfig.getConf.shuffleWriterBufferSize
-    val maxBatchSize = GlutenConfig.getConf.maxBatchSize
-    if (bufferSize > maxBatchSize) {
-      logInfo(
-        s"${GlutenConfig.SHUFFLE_WRITER_BUFFER_SIZE.key} ($bufferSize) exceeds 
max " +
-          s" batch size. Limited to 
${GlutenConfig.COLUMNAR_MAX_BATCH_SIZE.key} ($maxBatchSize).")
-      maxBatchSize
-    } else {
-      bufferSize
-    }
-  }
-
-  private val memoryLimit: Long = if ("sort".equals(shuffleWriterType)) {
-    Math.min(clientSortMemoryMaxSize, clientPushBufferMaxSize * numPartitions)
-  } else {
-    availableOffHeapPerTask()
-  }
-
   private def availableOffHeapPerTask(): Long = {
     val perTask =
       SparkMemoryUtil.getCurrentAvailableOffHeapMemory / 
SparkResourceUtil.getTaskSlots(conf)
@@ -82,49 +63,13 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
 
   @throws[IOException]
   override def internalWrite(records: Iterator[Product2[K, V]]): Unit = {
-    if (!records.hasNext) {
-      handleEmptyIterator()
-      return
-    }
-
     while (records.hasNext) {
       val cb = records.next()._2.asInstanceOf[ColumnarBatch]
       if (cb.numRows == 0 || cb.numCols == 0) {
         logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} 
cols")
       } else {
+        initShuffleWriter(cb)
         val handle = ColumnarBatches.getNativeHandle(cb)
-        if (nativeShuffleWriter == -1L) {
-          nativeShuffleWriter = jniWrapper.makeForRSS(
-            dep.nativePartitioning,
-            nativeBufferSize,
-            customizedCompressionCodec,
-            compressionLevel,
-            bufferCompressThreshold,
-            GlutenConfig.getConf.columnarShuffleCompressionMode,
-            clientPushBufferMaxSize,
-            clientPushSortMemoryThreshold,
-            celebornPartitionPusher,
-            handle,
-            context.taskAttemptId(),
-            GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, 
context.partitionId),
-            "celeborn",
-            shuffleWriterType,
-            GlutenConfig.getConf.columnarShuffleReallocThreshold
-          )
-          runtime.addSpiller(new Spiller() {
-            override def spill(self: MemoryTarget, phase: Spiller.Phase, size: 
Long): Long = {
-              if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
-                return 0L
-              }
-              logInfo(s"Gluten shuffle writer: Trying to push $size bytes of 
data")
-              // fixme pass true when being called by self
-              val pushed =
-                jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
-              logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of 
data")
-              pushed
-            }
-          })
-        }
         val startTime = System.nanoTime()
         jniWrapper.write(nativeShuffleWriter, cb.numRows, handle, 
availableOffHeapPerTask())
         dep.metrics("splitTime").add(System.nanoTime() - startTime)
@@ -135,8 +80,8 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
       }
     }
 
-    val startTime = System.nanoTime()
     assert(nativeShuffleWriter != -1L)
+    val startTime = System.nanoTime()
     splitResult = jniWrapper.stop(nativeShuffleWriter)
 
     dep
@@ -155,6 +100,38 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, 
mapId)
   }
 
+  override def createShuffleWriter(columnarBatch: ColumnarBatch): Unit = {
+    nativeShuffleWriter = jniWrapper.makeForRSS(
+      dep.nativePartitioning,
+      nativeBufferSize,
+      customizedCompressionCodec,
+      compressionLevel,
+      bufferCompressThreshold,
+      GlutenConfig.getConf.columnarShuffleCompressionMode,
+      clientPushBufferMaxSize,
+      clientPushSortMemoryThreshold,
+      celebornPartitionPusher,
+      ColumnarBatches.getNativeHandle(columnarBatch),
+      context.taskAttemptId(),
+      GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, 
context.partitionId),
+      "celeborn",
+      shuffleWriterType,
+      GlutenConfig.getConf.columnarShuffleReallocThreshold
+    )
+    runtime.addSpiller(new Spiller() {
+      override def spill(self: MemoryTarget, phase: Spiller.Phase, size: 
Long): Long = {
+        if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
+          return 0L
+        }
+        logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data")
+        // fixme pass true when being called by self
+        val pushed = jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
+        logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of data")
+        pushed
+      }
+    })
+  }
+
   override def closeShuffleWriter(): Unit = {
     jniWrapper.close(nativeShuffleWriter)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to