This is an automated email from the ASF dual-hosted git repository.

binjieyang pushed a commit to branch CELEBORN-1768
in repository https://gitbox.apache.org/repos/asf/celeborn.git

commit f080f7ce1622d0c75fdfb8a704deec8b71849197
Author: binjie yang <[email protected]>
AuthorDate: Mon Dec 9 20:35:23 2024 +0800

    [CELEBORN-1768][WRITER] Refactoring Shuffle Writer to extract common methods
---
 .../celeborn/ColumnarHashBasedShuffleWriter.java   |   2 +-
 .../spark/shuffle/celeborn/BasedShuffleWriter.java | 223 +++++++++++++++++++
 .../shuffle/celeborn/HashBasedShuffleWriter.java   | 209 ++---------------
 .../shuffle/celeborn/SortBasedShuffleWriter.java   | 246 ++-------------------
 .../shuffle/celeborn/SparkShuffleManager.java      |   3 +-
 .../celeborn/SortBasedShuffleWriterSuiteJ.java     |  17 +-
 6 files changed, 279 insertions(+), 421 deletions(-)

diff --git 
a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
 
b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
index b09b1306c..d28673911 100644
--- 
a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
+++ 
b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
@@ -132,7 +132,7 @@ public class ColumnarHashBasedShuffleWriter<K, V, C> 
extends HashBasedShuffleWri
   }
 
   @Override
-  protected void closeWrite() throws IOException {
+  protected void closeWrite() throws IOException, InterruptedException {
     if (canUseFastWrite() && isColumnarShuffle) {
       closeColumnarWrite();
     } else {
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java
new file mode 100644
index 000000000..f732a68ad
--- /dev/null
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java
@@ -0,0 +1,223 @@
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+import java.util.concurrent.atomic.LongAdder;
+
+import scala.Option;
+import scala.Product2;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.spark.Partitioner;
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.sql.execution.UnsafeRowSerializer;
+import org.apache.spark.storage.BlockManagerId;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public abstract class BasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
+
+  protected static final ClassTag<Object> OBJECT_CLASS_TAG = 
ClassTag$.MODULE$.Object();
+  protected static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
+
+  protected final int PUSH_BUFFER_INIT_SIZE;
+  protected final int PUSH_BUFFER_MAX_SIZE;
+  protected final ShuffleDependency<K, V, C> dep;
+  protected final Partitioner partitioner;
+  protected final ShuffleWriteMetricsReporter writeMetrics;
+  protected final int shuffleId;
+  protected final int mapId;
+  protected final int encodedAttemptId;
+  protected final TaskContext taskContext;
+  protected final ShuffleClient shuffleClient;
+  protected final int numMappers;
+  protected final int numPartitions;
+  protected final OpenByteArrayOutputStream serBuffer;
+  protected final SerializationStream serOutputStream;
+  private final boolean unsafeRowFastWrite;
+
+  protected final LongAdder[] mapStatusLengths;
+
+  /**
+   * Are we in the process of stopping? Because map tasks can call stop() with 
success = true and
+   * then call stop() with success = false if they get an exception, we want 
to make sure we don't
+   * try deleting files, etc. twice.
+   */
+  private volatile boolean stopping = false;
+
+  protected long peakMemoryUsedBytes = 0;
+  protected long tmpRecordsWritten = 0;
+
+  public BasedShuffleWriter(
+      int shuffleId,
+      CelebornShuffleHandle<K, V, C> handle,
+      TaskContext taskContext,
+      CelebornConf conf,
+      ShuffleClient client,
+      ShuffleWriteMetricsReporter metrics) {
+    PUSH_BUFFER_INIT_SIZE = conf.clientPushBufferInitialSize();
+    PUSH_BUFFER_MAX_SIZE = conf.clientPushBufferMaxSize();
+    this.dep = handle.dependency();
+    this.partitioner = dep.partitioner();
+    this.writeMetrics = metrics;
+    this.shuffleId = shuffleId;
+    this.mapId = taskContext.partitionId();
+    // [CELEBORN-1496] using the encoded attempt number instead of task 
attempt number
+    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
+    this.taskContext = taskContext;
+    this.shuffleClient = client;
+    this.numMappers = handle.numMappers();
+    this.numPartitions = dep.partitioner().numPartitions();
+    SerializerInstance serializer = dep.serializer().newInstance();
+    serBuffer = new OpenByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
+    serOutputStream = serializer.serializeStream(serBuffer);
+    unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite();
+
+    mapStatusLengths = new LongAdder[numPartitions];
+    for (int i = 0; i < numPartitions; i++) {
+      mapStatusLengths[i] = new LongAdder();
+    }
+  }
+
+  protected void doWrite(scala.collection.Iterator<Product2<K, V>> records)
+      throws IOException, InterruptedException {
+    if (canUseFastWrite()) {
+      fastWrite0(records);
+    } else if (dep.mapSideCombine()) {
+      if (dep.aggregator().isEmpty()) {
+        throw new UnsupportedOperationException(
+            "When using map side combine, an aggregator must be specified.");
+      }
+      write0(dep.aggregator().get().combineValuesByKey(records, taskContext));
+    } else {
+      write0(records);
+    }
+  }
+
+  @Override
+  public void write(Iterator<Product2<K, V>> records) throws IOException {
+    boolean needCleanupPusher = true;
+    try {
+      doWrite(records);
+      close();
+      needCleanupPusher = false;
+    } catch (InterruptedException e) {
+      TaskInterruptedHelper.throwTaskKillException();
+    } finally {
+      if (needCleanupPusher) {
+        cleanupPusher();
+      }
+    }
+  }
+
+  @Override
+  public Option<MapStatus> stop(boolean success) {
+    try {
+      
taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes());
+
+      if (stopping) {
+        return Option.empty();
+      } else {
+        stopping = true;
+        if (success) {
+          BlockManagerId bmId = 
SparkEnv.get().blockManager().shuffleServerId();
+          MapStatus mapStatus =
+              SparkUtils.createMapStatus(
+                  bmId, SparkUtils.unwrap(mapStatusLengths), 
taskContext.taskAttemptId());
+          if (mapStatus == null) {
+            throw new IllegalStateException("Cannot call stop(true) without 
having called write()");
+          }
+          return Option.apply(mapStatus);
+        } else {
+          return Option.empty();
+        }
+      }
+    } finally {
+      shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
+    }
+  }
+
+  // Added in SPARK-32917, for Spark 3.2 and above
+  @SuppressWarnings("MissingOverride")
+  public long[] getPartitionLengths() {
+    throw new UnsupportedOperationException(
+        "Celeborn is not compatible with Spark push mode, please set 
spark.shuffle.push.enabled to false");
+  }
+
+  abstract void fastWrite0(scala.collection.Iterator iterator)
+      throws IOException, InterruptedException;
+
+  abstract void write0(scala.collection.Iterator iterator) throws IOException, 
InterruptedException;
+
+  abstract void updatePeakMemoryUsed();
+
+  abstract void cleanupPusher() throws IOException;
+
+  abstract void closeWrite() throws IOException, InterruptedException;
+
+  @VisibleForTesting
+  boolean canUseFastWrite() {
+    boolean keyIsPartitionId = false;
+    if (unsafeRowFastWrite && dep.serializer() instanceof UnsafeRowSerializer) 
{
+      // SPARK-39391 renames PartitionIdPassthrough's package
+      String partitionerClassName = partitioner.getClass().getSimpleName();
+      keyIsPartitionId = "PartitionIdPassthrough".equals(partitionerClassName);
+    }
+    return keyIsPartitionId;
+  }
+
+  /** Return the peak memory used so far, in bytes. */
+  public long getPeakMemoryUsedBytes() {
+    updatePeakMemoryUsed();
+    return peakMemoryUsedBytes;
+  }
+
+  protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) 
throws IOException {
+    int bytesWritten =
+        shuffleClient.pushData(
+            shuffleId,
+            mapId,
+            encodedAttemptId,
+            partitionId,
+            buffer,
+            0,
+            numBytes,
+            numMappers,
+            numPartitions);
+    mapStatusLengths[partitionId].add(bytesWritten);
+    writeMetrics.incBytesWritten(bytesWritten);
+  }
+
+  /**
+   * This method will push the remaining data and close these pushers.
+   * It's important, will send Mapper End RPC to LifecycleManager to update
+   * the attempt of the corresponding task.
+   * We should only call this method when the task is successfully completed.
+   */
+  protected void close() throws IOException, InterruptedException {
+    long pushMergedDataTime = System.nanoTime();
+    closeWrite();
+    shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
+    writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
+    updateRecordsWrittenMetrics();
+
+    long waitStartTime = System.nanoTime();
+    shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
+    writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
+  }
+
+  protected void updateRecordsWrittenMetrics() {
+    writeMetrics.incRecordsWritten(tmpRecordsWritten);
+    tmpRecordsWritten = 0;
+  }
+}
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 4c5e6739b..6a0dc24b1 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -19,30 +19,15 @@ package org.apache.spark.shuffle.celeborn;
 
 import java.io.IOException;
 import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.atomic.LongAdder;
 
-import javax.annotation.Nullable;
-
-import scala.Option;
 import scala.Product2;
-import scala.reflect.ClassTag;
-import scala.reflect.ClassTag$;
 
-import com.google.common.annotations.VisibleForTesting;
-import org.apache.spark.Partitioner;
-import org.apache.spark.ShuffleDependency;
-import org.apache.spark.SparkEnv;
 import org.apache.spark.TaskContext;
 import org.apache.spark.annotation.Private;
-import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.serializer.SerializationStream;
-import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
-import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 import org.apache.spark.sql.execution.UnsafeRowSerializer;
 import org.apache.spark.sql.execution.metric.SQLMetric;
-import org.apache.spark.storage.BlockManagerId;
 import org.apache.spark.unsafe.Platform;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -54,50 +39,14 @@ import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.util.Utils;
 
 @Private
-public class HashBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
+public class HashBasedShuffleWriter<K, V, C> extends BasedShuffleWriter<K, V, 
C> {
 
   private static final Logger logger = 
LoggerFactory.getLogger(HashBasedShuffleWriter.class);
 
-  private static final ClassTag<Object> OBJECT_CLASS_TAG = 
ClassTag$.MODULE$.Object();
-  private static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
-
-  private final int PUSH_BUFFER_INIT_SIZE;
-  private final int PUSH_BUFFER_MAX_SIZE;
-  private final ShuffleDependency<K, V, C> dep;
-  private final Partitioner partitioner;
-  private final ShuffleWriteMetricsReporter writeMetrics;
-  private final int shuffleId;
-  private final int mapId;
-  private final int encodedAttemptId;
-  private final TaskContext taskContext;
-  private final ShuffleClient shuffleClient;
-  private final int numMappers;
-  private final int numPartitions;
-
-  @Nullable private MapStatus mapStatus;
-  private long peakMemoryUsedBytes = 0;
-
-  private final OpenByteArrayOutputStream serBuffer;
-  private final SerializationStream serOutputStream;
-
   private byte[][] sendBuffers;
   private int[] sendOffsets;
-
-  private final LongAdder[] mapStatusLengths;
-  protected long tmpRecordsWritten = 0;
-
-  private final SendBufferPool sendBufferPool;
-
-  /**
-   * Are we in the process of stopping? Because map tasks can call stop() with 
success = true and
-   * then call stop() with success = false if they get an exception, we want 
to make sure we don't
-   * try deleting files, etc. twice.
-   */
-  private volatile boolean stopping = false;
-
   private DataPusher dataPusher;
-
-  private final boolean unsafeRowFastWrite;
+  private final SendBufferPool sendBufferPool;
 
   // In order to facilitate the writing of unit test code, ShuffleClient needs 
to be passed in as
   // parameters. By the way, simplify the passed parameters.
@@ -110,31 +59,9 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       ShuffleWriteMetricsReporter metrics,
       SendBufferPool sendBufferPool)
       throws IOException {
-    this.mapId = taskContext.partitionId();
-    this.dep = handle.dependency();
-    this.shuffleId = shuffleId;
-    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
-    SerializerInstance serializer = dep.serializer().newInstance();
-    this.partitioner = dep.partitioner();
-    this.writeMetrics = metrics;
-    this.taskContext = taskContext;
-    this.numMappers = handle.numMappers();
-    this.numPartitions = dep.partitioner().numPartitions();
-    this.shuffleClient = client;
-
-    unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite();
-    serBuffer = new OpenByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
-    serOutputStream = serializer.serializeStream(serBuffer);
-
-    mapStatusLengths = new LongAdder[numPartitions];
-    for (int i = 0; i < numPartitions; i++) {
-      mapStatusLengths[i] = new LongAdder();
-    }
-
-    PUSH_BUFFER_INIT_SIZE = conf.clientPushBufferInitialSize();
-    PUSH_BUFFER_MAX_SIZE = conf.clientPushBufferMaxSize();
-
+    super(shuffleId, handle, taskContext, conf, client, metrics);
     this.sendBufferPool = sendBufferPool;
+
     sendBuffers = sendBufferPool.acquireBuffer(numPartitions);
     sendOffsets = new int[numPartitions];
 
@@ -159,42 +86,6 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   }
 
   @Override
-  public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
-    boolean needCleanupPusher = true;
-    try {
-      if (canUseFastWrite()) {
-        fastWrite0(records);
-      } else if (dep.mapSideCombine()) {
-        if (dep.aggregator().isEmpty()) {
-          throw new UnsupportedOperationException(
-              "When using map side combine, an aggregator must be specified.");
-        }
-        write0(dep.aggregator().get().combineValuesByKey(records, 
taskContext));
-      } else {
-        write0(records);
-      }
-      close();
-      needCleanupPusher = false;
-    } catch (InterruptedException e) {
-      TaskInterruptedHelper.throwTaskKillException();
-    } finally {
-      if (needCleanupPusher) {
-        cleanupPusher();
-      }
-    }
-  }
-
-  @VisibleForTesting
-  boolean canUseFastWrite() {
-    boolean keyIsPartitionId = false;
-    if (unsafeRowFastWrite && dep.serializer() instanceof UnsafeRowSerializer) 
{
-      // SPARK-39391 renames PartitionIdPassthrough's package
-      String partitionerClassName = partitioner.getClass().getSimpleName();
-      keyIsPartitionId = "PartitionIdPassthrough".equals(partitionerClassName);
-    }
-    return keyIsPartitionId;
-  }
-
   protected void fastWrite0(scala.collection.Iterator iterator)
       throws IOException, InterruptedException {
     final scala.collection.Iterator<Product2<Integer, UnsafeRow>> records = 
iterator;
@@ -238,7 +129,9 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     }
   }
 
-  private void write0(scala.collection.Iterator iterator) throws IOException, 
InterruptedException {
+  @Override
+  protected void write0(scala.collection.Iterator iterator)
+      throws IOException, InterruptedException {
     final scala.collection.Iterator<Product2<K, ?>> records = iterator;
 
     while (records.hasNext()) {
@@ -265,6 +158,11 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     }
   }
 
+  @Override
+  void updatePeakMemoryUsed() {
+    // do nothing, hash shuffle writer always update this used peak memory
+  }
+
   private byte[] getOrCreateBuffer(int partitionId) {
     byte[] buffer = sendBuffers[partitionId];
     if (buffer == null) {
@@ -275,23 +173,6 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     return buffer;
   }
 
-  protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) 
throws IOException {
-    logger.debug("Push giant record, size {}.", numBytes);
-    int bytesWritten =
-        shuffleClient.pushData(
-            shuffleId,
-            mapId,
-            encodedAttemptId,
-            partitionId,
-            buffer,
-            0,
-            numBytes,
-            numMappers,
-            numPartitions);
-    mapStatusLengths[partitionId].add(bytesWritten);
-    writeMetrics.incBytesWritten(bytesWritten);
-  }
-
   private int getOrUpdateOffset(int partitionId, int serializedRecordSize)
       throws IOException, InterruptedException {
     int offset = sendOffsets[partitionId];
@@ -322,7 +203,12 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     writeMetrics.incWriteTime(System.nanoTime() - start);
   }
 
-  protected void closeWrite() throws IOException {
+  @Override
+  protected void closeWrite() throws IOException, InterruptedException {
+    // here we wait for all the in-flight batches to return which sent by 
dataPusher thread
+    dataPusher.waitOnTermination();
+    sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());
+    shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
     // merge and push residual data to reduce network traffic
     // NB: since dataPusher thread have no in-flight data at this point,
     //     we now push merged data by task thread will not introduce any 
contention
@@ -356,7 +242,8 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     writeMetrics.incBytesWritten(bytesWritten);
   }
 
-  private void cleanupPusher() throws IOException {
+  @Override
+  protected void cleanupPusher() throws IOException {
     try {
       dataPusher.waitOnTermination();
       sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());
@@ -364,60 +251,4 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       TaskInterruptedHelper.throwTaskKillException();
     }
   }
-
-  private void close() throws IOException, InterruptedException {
-    // here we wait for all the in-flight batches to return which sent by 
dataPusher thread
-    long pushMergedDataTime = System.nanoTime();
-    dataPusher.waitOnTermination();
-    sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());
-    shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
-    closeWrite();
-    shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
-    writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
-    updateRecordsWrittenMetrics();
-
-    long waitStartTime = System.nanoTime();
-    shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
-    writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
-
-    BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
-    mapStatus =
-        SparkUtils.createMapStatus(
-            bmId, SparkUtils.unwrap(mapStatusLengths), 
taskContext.taskAttemptId());
-  }
-
-  private void updateRecordsWrittenMetrics() {
-    writeMetrics.incRecordsWritten(tmpRecordsWritten);
-    tmpRecordsWritten = 0;
-  }
-
-  @Override
-  public Option<MapStatus> stop(boolean success) {
-    try {
-      taskContext.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes);
-
-      if (stopping) {
-        return Option.empty();
-      } else {
-        stopping = true;
-        if (success) {
-          if (mapStatus == null) {
-            throw new IllegalStateException("Cannot call stop(true) without 
having called write()");
-          }
-          return Option.apply(mapStatus);
-        } else {
-          return Option.empty();
-        }
-      }
-    } finally {
-      shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
-    }
-  }
-
-  // Added in SPARK-32917, for Spark 3.2 and above
-  @SuppressWarnings("MissingOverride")
-  public long[] getPartitionLengths() {
-    throw new UnsupportedOperationException(
-        "Celeborn is not compatible with Spark push mode, please set 
spark.shuffle.push.enabled to false");
-  }
 }
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 5717910ee..b6fd2f407 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -18,28 +18,15 @@
 package org.apache.spark.shuffle.celeborn;
 
 import java.io.IOException;
-import java.util.concurrent.atomic.LongAdder;
 
-import scala.Option;
 import scala.Product2;
-import scala.reflect.ClassTag;
-import scala.reflect.ClassTag$;
 
-import com.google.common.annotations.VisibleForTesting;
-import org.apache.spark.Partitioner;
-import org.apache.spark.ShuffleDependency;
-import org.apache.spark.SparkEnv;
 import org.apache.spark.TaskContext;
 import org.apache.spark.annotation.Private;
-import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.serializer.SerializationStream;
-import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
-import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 import org.apache.spark.sql.execution.UnsafeRowSerializer;
 import org.apache.spark.sql.execution.metric.SQLMetric;
-import org.apache.spark.storage.BlockManagerId;
 import org.apache.spark.unsafe.Platform;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -47,96 +34,40 @@ import org.slf4j.LoggerFactory;
 import org.apache.celeborn.client.ShuffleClient;
 import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.exception.CelebornIOException;
-import org.apache.celeborn.common.util.Utils;
 
 @Private
-public class SortBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
+public class SortBasedShuffleWriter<K, V, C> extends BasedShuffleWriter<K, V, 
C> {
 
   private static final Logger logger = 
LoggerFactory.getLogger(SortBasedShuffleWriter.class);
-
-  private static final ClassTag<Object> OBJECT_CLASS_TAG = 
ClassTag$.MODULE$.Object();
-  private static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
-
-  private final ShuffleDependency<K, V, C> dep;
-  private final Partitioner partitioner;
-  private final ShuffleWriteMetricsReporter writeMetrics;
-  private final int shuffleId;
-  private final int mapId;
-  private final int encodedAttemptId;
-  private final TaskContext taskContext;
-  private final ShuffleClient shuffleClient;
-  private final int numMappers;
-  private final int numPartitions;
-
-  private final long pushBufferMaxSize;
+  private final SendBufferPool sendBufferPool;
 
   private final SortBasedPusher pusher;
-  private long peakMemoryUsedBytes = 0;
-
-  private final OpenByteArrayOutputStream serBuffer;
-  private final SerializationStream serOutputStream;
-
-  private final LongAdder[] mapStatusLengths;
-  private long tmpRecordsWritten = 0;
-
-  /**
-   * Are we in the process of stopping? Because map tasks can call stop() with 
success = true and
-   * then call stop() with success = false if they get an exception, we want 
to make sure we don't
-   * try deleting files, etc. twice.
-   */
-  private volatile boolean stopping = false;
-
-  private final boolean unsafeRowFastWrite;
 
   public SortBasedShuffleWriter(
       int shuffleId,
-      ShuffleDependency<K, V, C> dep,
-      int numMappers,
+      CelebornShuffleHandle<K, V, C> handle,
       TaskContext taskContext,
       CelebornConf conf,
       ShuffleClient client,
       ShuffleWriteMetricsReporter metrics,
       SendBufferPool sendBufferPool)
       throws IOException {
-    this(shuffleId, dep, numMappers, taskContext, conf, client, metrics, 
sendBufferPool, null);
+    this(shuffleId, handle, taskContext, conf, client, metrics, 
sendBufferPool, null);
   }
 
   // In order to facilitate the writing of unit test code, ShuffleClient needs 
to be passed in as
   // parameters. By the way, simplify the passed parameters.
   public SortBasedShuffleWriter(
       int shuffleId,
-      ShuffleDependency<K, V, C> dep,
-      int numMappers,
+      CelebornShuffleHandle<K, V, C> handle,
       TaskContext taskContext,
       CelebornConf conf,
       ShuffleClient client,
       ShuffleWriteMetricsReporter metrics,
       SendBufferPool sendBufferPool,
-      SortBasedPusher pusher)
-      throws IOException {
-    this.mapId = taskContext.partitionId();
-    this.dep = dep;
-    this.shuffleId = shuffleId;
-    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
-    SerializerInstance serializer = dep.serializer().newInstance();
-    this.partitioner = dep.partitioner();
-    this.writeMetrics = metrics;
-    this.taskContext = taskContext;
-    this.numMappers = numMappers;
-    this.numPartitions = dep.partitioner().numPartitions();
-    this.shuffleClient = client;
-    unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite();
-
-    serBuffer = new OpenByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
-    serOutputStream = serializer.serializeStream(serBuffer);
-
-    this.mapStatusLengths = new LongAdder[numPartitions];
-    for (int i = 0; i < numPartitions; i++) {
-      this.mapStatusLengths[i] = new LongAdder();
-    }
-
-    pushBufferMaxSize = conf.clientPushBufferMaxSize();
-
+      SortBasedPusher pusher) {
+    super(shuffleId, handle, taskContext, conf, client, metrics);
+    this.sendBufferPool = sendBufferPool;
     if (pusher == null) {
       this.pusher =
           new SortBasedPusher(
@@ -159,99 +90,16 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     }
   }
 
-  public SortBasedShuffleWriter(
-      CelebornShuffleHandle<K, V, C> handle,
-      TaskContext taskContext,
-      CelebornConf conf,
-      ShuffleClient client,
-      ShuffleWriteMetricsReporter metrics,
-      SendBufferPool sendBufferPool)
-      throws IOException {
-    this(
-        SparkUtils.celebornShuffleId(client, handle, taskContext, true),
-        handle.dependency(),
-        handle.numMappers(),
-        taskContext,
-        conf,
-        client,
-        metrics,
-        sendBufferPool);
-  }
-
-  public SortBasedShuffleWriter(
-      CelebornShuffleHandle<K, V, C> handle,
-      TaskContext taskContext,
-      CelebornConf conf,
-      ShuffleClient client,
-      ShuffleWriteMetricsReporter metrics,
-      SendBufferPool sendBufferPool,
-      SortBasedPusher pusher)
-      throws IOException {
-    this(
-        SparkUtils.celebornShuffleId(client, handle, taskContext, true),
-        handle.dependency(),
-        handle.numMappers(),
-        taskContext,
-        conf,
-        client,
-        metrics,
-        sendBufferPool,
-        pusher);
-  }
-
-  private void updatePeakMemoryUsed() {
+  @Override
+  protected void updatePeakMemoryUsed() {
     long mem = pusher.getPeakMemoryUsedBytes();
     if (mem > peakMemoryUsedBytes) {
       peakMemoryUsedBytes = mem;
     }
   }
 
-  /** Return the peak memory used so far, in bytes. */
-  public long getPeakMemoryUsedBytes() {
-    updatePeakMemoryUsed();
-    return peakMemoryUsedBytes;
-  }
-
-  void doWrite(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
-    if (canUseFastWrite()) {
-      fastWrite0(records);
-    } else if (dep.mapSideCombine()) {
-      if (dep.aggregator().isEmpty()) {
-        throw new UnsupportedOperationException(
-            "When using map side combine, an aggregator must be specified.");
-      }
-      write0(dep.aggregator().get().combineValuesByKey(records, taskContext));
-    } else {
-      write0(records);
-    }
-  }
-
   @Override
-  public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
-    boolean needCleanupPusher = true;
-    try {
-      doWrite(records);
-      close();
-      needCleanupPusher = false;
-    } finally {
-      if (needCleanupPusher) {
-        cleanupPusher();
-      }
-    }
-  }
-
-  @VisibleForTesting
-  boolean canUseFastWrite() {
-    boolean keyIsPartitionId = false;
-    if (unsafeRowFastWrite && dep.serializer() instanceof UnsafeRowSerializer) 
{
-      // SPARK-39391 renames PartitionIdPassthrough's package
-      String partitionerClassName = partitioner.getClass().getSimpleName();
-      keyIsPartitionId = "PartitionIdPassthrough".equals(partitionerClassName);
-    }
-    return keyIsPartitionId;
-  }
-
-  private void fastWrite0(scala.collection.Iterator iterator) throws 
IOException {
+  protected void fastWrite0(scala.collection.Iterator iterator) throws 
IOException {
     final scala.collection.Iterator<Product2<Integer, UnsafeRow>> records = 
iterator;
 
     SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer) 
dep.serializer());
@@ -267,7 +115,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         dataSize.add(serializedRecordSize);
       }
 
-      if (serializedRecordSize > pushBufferMaxSize) {
+      if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) {
         byte[] giantBuffer = new byte[serializedRecordSize];
         Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, 
Integer.reverseBytes(rowSize));
         Platform.copyMemory(
@@ -301,7 +149,8 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     writeMetrics.incWriteTime(System.nanoTime() - start);
   }
 
-  private void write0(scala.collection.Iterator iterator) throws IOException {
+  @Override
+  protected void write0(scala.collection.Iterator iterator) throws IOException 
{
     final scala.collection.Iterator<Product2<K, ?>> records = iterator;
 
     while (records.hasNext()) {
@@ -316,7 +165,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       final int serializedRecordSize = serBuffer.size();
       assert (serializedRecordSize > 0);
 
-      if (serializedRecordSize > pushBufferMaxSize) {
+      if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) {
         pushGiantRecord(partitionId, serBuffer.getBuf(), serializedRecordSize);
       } else {
         boolean success =
@@ -344,75 +193,18 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     }
   }
 
-  private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) 
throws IOException {
-    logger.debug("Push giant record, size {}.", Utils.bytesToString(numBytes));
-    int bytesWritten =
-        shuffleClient.pushData(
-            shuffleId,
-            mapId,
-            encodedAttemptId,
-            partitionId,
-            buffer,
-            0,
-            numBytes,
-            numMappers,
-            numPartitions);
-    mapStatusLengths[partitionId].add(bytesWritten);
-    writeMetrics.incBytesWritten(bytesWritten);
-  }
-
-  private void cleanupPusher() throws IOException {
+  @Override
+  protected void cleanupPusher() throws IOException {
     if (pusher != null) {
       pusher.close(false);
     }
   }
 
-  private void close() throws IOException {
-    logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
+  @Override
+  protected void closeWrite() throws IOException, InterruptedException {
     long pushStartTime = System.nanoTime();
     pusher.pushData(false);
     pusher.close(true);
-
-    shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
     writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
-    writeMetrics.incRecordsWritten(tmpRecordsWritten);
-
-    long waitStartTime = System.nanoTime();
-    shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
-    writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
-  }
-
-  @Override
-  public Option<MapStatus> stop(boolean success) {
-    try {
-      
taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes());
-
-      if (stopping) {
-        return Option.empty();
-      } else {
-        stopping = true;
-        if (success) {
-          BlockManagerId bmId = 
SparkEnv.get().blockManager().shuffleServerId();
-          MapStatus mapStatus =
-              SparkUtils.createMapStatus(
-                  bmId, SparkUtils.unwrap(mapStatusLengths), 
taskContext.taskAttemptId());
-          if (mapStatus == null) {
-            throw new IllegalStateException("Cannot call stop(true) without 
having called write()");
-          }
-          return Option.apply(mapStatus);
-        } else {
-          return Option.empty();
-        }
-      }
-    } finally {
-      shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
-    }
-  }
-
-  // Added in SPARK-32917, for Spark 3.2 and above
-  @SuppressWarnings("MissingOverride")
-  public long[] getPartitionLengths() {
-    throw new UnsupportedOperationException(
-        "Celeborn is not compatible with push-based shuffle, please set 
spark.shuffle.push.enabled to false");
   }
 }
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index af3c400ec..8e4190335 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -277,8 +277,7 @@ public class SparkShuffleManager implements ShuffleManager {
         if (ShuffleMode.SORT.equals(shuffleMode)) {
           return new SortBasedShuffleWriter<>(
               shuffleId,
-              h.dependency(),
-              h.numMappers(),
+              h,
               context,
               celebornConf,
               shuffleClient,
diff --git 
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
 
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
index 0963737c0..c0d44007c 100644
--- 
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
+++ 
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
@@ -64,7 +64,13 @@ public class SortBasedShuffleWriterSuiteJ extends 
CelebornShuffleWriterSuiteBase
       ShuffleWriteMetricsReporter metrics)
       throws IOException {
     return new SortBasedShuffleWriter<Integer, String, String>(
-        handle, context, conf, client, metrics, SendBufferPool.get(4, 30, 60));
+        SparkUtils.celebornShuffleId(client, handle, taskContext, true),
+        handle,
+        context,
+        conf,
+        client,
+        metrics,
+        SendBufferPool.get(4, 30, 60));
   }
 
   private SortBasedShuffleWriter<Integer, String, String> 
createShuffleWriterWithPusher(
@@ -76,7 +82,14 @@ public class SortBasedShuffleWriterSuiteJ extends 
CelebornShuffleWriterSuiteBase
       SortBasedPusher pusher)
       throws Exception {
     return new SortBasedShuffleWriter<Integer, String, String>(
-        handle, context, conf, client, metrics, SendBufferPool.get(4, 30, 60), 
pusher);
+        SparkUtils.celebornShuffleId(client, handle, taskContext, true),
+        handle,
+        context,
+        conf,
+        client,
+        metrics,
+        SendBufferPool.get(4, 30, 60),
+        pusher);
   }
 
   private SortBasedPusher createSortBasedPusher(


Reply via email to