This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 90db37b73 [CELEBORN-1422] Remove tmpRecords array when collecting 
written count metrics
90db37b73 is described below

commit 90db37b73346cb067cc08bbd710705d7fd6d6389
Author: onebox-li <[email protected]>
AuthorDate: Thu May 16 09:43:34 2024 +0800

    [CELEBORN-1422] Remove tmpRecords array when collecting written count 
metrics
    
    ### What changes were proposed in this pull request?
    For spark3 client, use a long variable to help to count written records 
instead of a `tmpRecords` array.
    
    ### Why are the changes needed?
    There is no need to use a array for spark3.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Cluster test. Both shuffle writer count records correctly.
    
    Closes #2508 from onebox-li/remove_tmpRecords.
    
    Authored-by: onebox-li <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../celeborn/ColumnarHashBasedShuffleWriter.java   |  2 +-
 .../shuffle/celeborn/HashBasedShuffleWriter.java   | 26 +++++++---------------
 .../shuffle/celeborn/SortBasedShuffleWriter.java   | 17 ++++----------
 3 files changed, 13 insertions(+), 32 deletions(-)

diff --git 
a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
 
b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
index b468c5b96..b09b1306c 100644
--- 
a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
+++ 
b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
@@ -127,7 +127,7 @@ public class ColumnarHashBasedShuffleWriter<K, V, C> 
extends HashBasedShuffleWri
         }
         celebornBatchBuilders[partitionId].newBuilders();
       }
-      incRecordsWritten(partitionId);
+      tmpRecordsWritten++;
     }
   }
 
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index a3024bf00..34377ee4a 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -83,7 +83,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private int[] sendOffsets;
 
   private final LongAdder[] mapStatusLengths;
-  private final long[] tmpRecords;
+  protected long tmpRecordsWritten = 0;
 
   private final SendBufferPool sendBufferPool;
 
@@ -128,7 +128,6 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     for (int i = 0; i < numPartitions; i++) {
       mapStatusLengths[i] = new LongAdder();
     }
-    tmpRecords = new long[numPartitions];
 
     PUSH_BUFFER_INIT_SIZE = conf.clientPushBufferInitialSize();
     PUSH_BUFFER_MAX_SIZE = conf.clientPushBufferMaxSize();
@@ -227,14 +226,10 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             rowSize);
         sendOffsets[partitionId] = offset + serializedRecordSize;
       }
-      incRecordsWritten(partitionId);
+      tmpRecordsWritten++;
     }
   }
 
-  protected void incRecordsWritten(int partitionId) {
-    tmpRecords[partitionId] += 1;
-  }
-
   private void write0(scala.collection.Iterator iterator) throws IOException, 
InterruptedException {
     final scala.collection.Iterator<Product2<K, ?>> records = iterator;
 
@@ -258,7 +253,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         System.arraycopy(serBuffer.getBuf(), 0, buffer, offset, 
serializedRecordSize);
         sendOffsets[partitionId] = offset + serializedRecordSize;
       }
-      incRecordsWritten(partitionId);
+      tmpRecordsWritten++;
     }
   }
 
@@ -305,7 +300,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
     if ((buffer.length - offset) < serializedRecordSize) {
       flushSendBuffer(partitionId, buffer, offset);
-      updateMapStatus();
+      updateRecordsWrittenMetrics();
       offset = 0;
     }
     return offset;
@@ -362,8 +357,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     closeWrite();
     shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
     writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
-
-    updateMapStatus();
+    updateRecordsWrittenMetrics();
 
     long waitStartTime = System.nanoTime();
     shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), 
numMappers);
@@ -375,13 +369,9 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             bmId, SparkUtils.unwrap(mapStatusLengths), 
taskContext.taskAttemptId());
   }
 
-  private void updateMapStatus() {
-    long recordsWritten = 0;
-    for (int i = 0; i < partitioner.numPartitions(); i++) {
-      recordsWritten += tmpRecords[i];
-      tmpRecords[i] = 0;
-    }
-    writeMetrics.incRecordsWritten(recordsWritten);
+  private void updateRecordsWrittenMetrics() {
+    writeMetrics.incRecordsWritten(tmpRecordsWritten);
+    tmpRecordsWritten = 0;
   }
 
   @Override
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 9e198f54a..2b810b190 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -76,7 +76,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final SerializationStream serOutputStream;
 
   private final LongAdder[] mapStatusLengths;
-  private final long[] tmpRecords;
+  private long tmpRecordsWritten = 0;
 
   /**
    * Are we in the process of stopping? Because map tasks can call stop() with 
success = true and
@@ -132,7 +132,6 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     for (int i = 0; i < numPartitions; i++) {
       this.mapStatusLengths[i] = new LongAdder();
     }
-    tmpRecords = new long[numPartitions];
 
     pushBufferMaxSize = conf.clientPushBufferMaxSize();
 
@@ -282,7 +281,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
           }
         }
       }
-      tmpRecords[partitionId] += 1;
+      tmpRecordsWritten++;
     }
   }
 
@@ -331,7 +330,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
           }
         }
       }
-      tmpRecords[partitionId] += 1;
+      tmpRecordsWritten++;
     }
   }
 
@@ -360,21 +359,13 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
     shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
     writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
-
-    updateMapStatus();
+    writeMetrics.incRecordsWritten(tmpRecordsWritten);
 
     long waitStartTime = System.nanoTime();
     shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), 
numMappers);
     writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
   }
 
-  private void updateMapStatus() {
-    for (int i = 0; i < tmpRecords.length; i++) {
-      writeMetrics.incRecordsWritten(tmpRecords[i]);
-      tmpRecords[i] = 0;
-    }
-  }
-
   @Override
   public Option<MapStatus> stop(boolean success) {
     try {

Reply via email to