FMX commented on code in PR #3306:
URL: https://github.com/apache/celeborn/pull/3306#discussion_r2125994833


##########
client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java:
##########
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+import java.util.concurrent.atomic.LongAdder;
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.execution.metric.SQLMetric;
+import org.apache.spark.unsafe.Platform;
+import scala.Option;
+import scala.Product2;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.spark.Partitioner;
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.sql.execution.UnsafeRowSerializer;
+import org.apache.spark.storage.BlockManagerId;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public abstract class BasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
+
+  protected static final ClassTag<Object> OBJECT_CLASS_TAG = 
ClassTag$.MODULE$.Object();
+  protected static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
+
+  protected final int PUSH_BUFFER_INIT_SIZE;
+  protected final int PUSH_BUFFER_MAX_SIZE;
+  protected final ShuffleDependency<K, V, C> dep;
+  protected final Partitioner partitioner;
+  protected final ShuffleWriteMetricsReporter writeMetrics;
+  protected final int shuffleId;
+  protected final int mapId;
+  protected final int encodedAttemptId;
+  protected final TaskContext taskContext;
+  protected final ShuffleClient shuffleClient;
+  protected final int numMappers;
+  protected final int numPartitions;
+  protected final OpenByteArrayOutputStream serBuffer;
+  protected final SerializationStream serOutputStream;
+  private final boolean unsafeRowFastWrite;
+
+  protected final LongAdder[] mapStatusLengths;
+
+  /**
+   * Are we in the process of stopping? Because map tasks can call stop() with 
success = true and
+   * then call stop() with success = false if they get an exception, we want 
to make sure we don't
+   * try deleting files, etc. twice.
+   */
+  private volatile boolean stopping = false;
+
+  protected long peakMemoryUsedBytes = 0;
+  protected long tmpRecordsWritten = 0;
+
+  public BasedShuffleWriter(
+      int shuffleId,
+      CelebornShuffleHandle<K, V, C> handle,
+      TaskContext taskContext,
+      CelebornConf conf,
+      ShuffleClient client,
+      ShuffleWriteMetricsReporter metrics) {
+    PUSH_BUFFER_INIT_SIZE = conf.clientPushBufferInitialSize();
+    PUSH_BUFFER_MAX_SIZE = conf.clientPushBufferMaxSize();
+    this.dep = handle.dependency();
+    this.partitioner = dep.partitioner();
+    this.writeMetrics = metrics;
+    this.shuffleId = shuffleId;
+    this.mapId = taskContext.partitionId();
+    // [CELEBORN-1496] using the encoded attempt number instead of task 
attempt number
+    this.encodedAttemptId = 
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
+    this.taskContext = taskContext;
+    this.shuffleClient = client;
+    this.numMappers = handle.numMappers();
+    this.numPartitions = dep.partitioner().numPartitions();
+    SerializerInstance serializer = dep.serializer().newInstance();
+    serBuffer = new OpenByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
+    serOutputStream = serializer.serializeStream(serBuffer);
+    unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite();
+
+    mapStatusLengths = new LongAdder[numPartitions];
+    for (int i = 0; i < numPartitions; i++) {
+      mapStatusLengths[i] = new LongAdder();
+    }
+  }
+
+  protected void doWrite(scala.collection.Iterator<Product2<K, V>> records)
+      throws IOException, InterruptedException {
+    if (canUseFastWrite()) {
+      fastWrite0(records);
+    } else if (dep.mapSideCombine()) {
+      if (dep.aggregator().isEmpty()) {
+        throw new UnsupportedOperationException(
+            "When using map side combine, an aggregator must be specified.");
+      }
+      write0(dep.aggregator().get().combineValuesByKey(records, taskContext));
+    } else {
+      write0(records);
+    }
+  }
+
+  @Override
+  public void write(Iterator<Product2<K, V>> records) throws IOException {
+    boolean needCleanupPusher = true;
+    try {
+      doWrite(records);
+      close();
+      needCleanupPusher = false;
+    } catch (InterruptedException e) {
+      TaskInterruptedHelper.throwTaskKillException();
+    } finally {
+      if (needCleanupPusher) {
+        cleanupPusher();
+      }
+    }
+  }
+
+  protected void fastWrite0(scala.collection.Iterator iterator)
+      throws IOException, InterruptedException {
+    final scala.collection.Iterator<Product2<Integer, UnsafeRow>> records = 
iterator;
+
+    SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer) 
dep.serializer());
+    while (records.hasNext()) {
+      final Product2<Integer, UnsafeRow> record = records.next();
+      final int partitionId = record._1();
+      final UnsafeRow row = record._2();
+
+      final int rowSize = row.getSizeInBytes();
+      final int serializedRecordSize = 4 + rowSize;
+
+      if (dataSize != null) {
+        dataSize.add(serializedRecordSize);

Review Comment:
   I think we should use the `serializedRecordSize` in both the hash-based 
shuffle writer and the sort-based shuffle writer. After checking the earliest 
commit logs, I believe that this difference is unintentional.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to