This is an automated email from the ASF dual-hosted git repository.

feiwang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 4205f83da [CELEBORN-1995] Optimize memory usage for push failed batches
4205f83da is described below

commit 4205f83da33036c587c6ecfa407462d2f492e136
Author: mingji <[email protected]>
AuthorDate: Sun May 18 07:19:26 2025 -0700

    [CELEBORN-1995] Optimize memory usage for push failed batches
    
    ### What changes were proposed in this pull request?
    Aggregate push failed batch for the same map ID and attempt ID.
    
    ### Why are the changes needed?
    To reduce memory usage.
    
    ### Does this PR introduce _any_ user-facing change?
    NO.
    
    ### How was this patch tested?
    GA and cluster run.
    
    Closes #3253 from FMX/b1995.
    
    Authored-by: mingji <[email protected]>
    Signed-off-by: Wang, Fei <[email protected]>
---
 .../apache/celeborn/client/DummyShuffleClient.java |   5 +-
 .../org/apache/celeborn/client/ShuffleClient.java  |   5 +-
 .../apache/celeborn/client/ShuffleClientImpl.java  |  23 ++---
 .../celeborn/client/read/CelebornInputStream.java  |  27 ++---
 .../org/apache/celeborn/client/CommitManager.scala |   4 +-
 .../apache/celeborn/client/LifecycleManager.scala  |  16 ++-
 .../celeborn/client/commit/CommitHandler.scala     |   4 +-
 .../client/commit/MapPartitionCommitHandler.scala  |   4 +-
 .../commit/ReducePartitionCommitHandler.scala      |  26 ++---
 .../common/write/LocationPushFailedBatches.java    |  94 +++++++++++++++++
 .../celeborn/common/write/PushFailedBatch.java     |  84 ---------------
 .../apache/celeborn/common/write/PushState.java    |  12 +--
 common/src/main/proto/TransportMessages.proto      |  14 ++-
 .../common/protocol/message/ControlMessages.scala  |  15 +--
 .../apache/celeborn/common/util/PbSerDeUtils.scala |  50 ++++-----
 .../org/apache/celeborn/common/util/Utils.scala    |  11 ++
 .../write/LocationPushFailedBatchesSuiteJ.java     | 114 +++++++++++++++++++++
 .../common/write/PushFailedBatchSuiteJ.java        |  79 --------------
 .../celeborn/common/util/PbSerDeUtilsTest.scala    |  19 ++--
 ....scala => LocationPushFailedBatchesSuite.scala} |   8 +-
 20 files changed, 328 insertions(+), 286 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
index dd1a032c8..648309489 100644
--- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -28,7 +28,6 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -48,7 +47,7 @@ import org.apache.celeborn.common.protocol.PbStreamHandler;
 import org.apache.celeborn.common.rpc.RpcEndpointRef;
 import org.apache.celeborn.common.util.ExceptionMaker;
 import org.apache.celeborn.common.util.JavaUtils;
-import org.apache.celeborn.common.write.PushFailedBatch;
+import org.apache.celeborn.common.write.LocationPushFailedBatches;
 import org.apache.celeborn.common.write.PushState;
 
 public class DummyShuffleClient extends ShuffleClient {
@@ -144,7 +143,7 @@ public class DummyShuffleClient extends ShuffleClient {
       ExceptionMaker exceptionMaker,
       ArrayList<PartitionLocation> locations,
       ArrayList<PbStreamHandler> streamHandlers,
-      Map<String, Set<PushFailedBatch>> failedBatchSetMap,
+      Map<String, LocationPushFailedBatches> failedBatchSetMap,
       Map<String, Pair<Integer, Integer>> chunksRange,
       int[] mapAttempts,
       MetricsCallback metricsCallback)
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index bf0192e4a..6363e9004 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -21,7 +21,6 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Map;
 import java.util.Optional;
-import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.LongAdder;
 import java.util.function.BiFunction;
@@ -46,7 +45,7 @@ import 
org.apache.celeborn.common.protocol.message.ControlMessages;
 import org.apache.celeborn.common.rpc.RpcEndpointRef;
 import org.apache.celeborn.common.util.CelebornHadoopUtils;
 import org.apache.celeborn.common.util.ExceptionMaker;
-import org.apache.celeborn.common.write.PushFailedBatch;
+import org.apache.celeborn.common.write.LocationPushFailedBatches;
 import org.apache.celeborn.common.write.PushState;
 
 /**
@@ -269,7 +268,7 @@ public abstract class ShuffleClient {
       ExceptionMaker exceptionMaker,
       ArrayList<PartitionLocation> locations,
       ArrayList<PbStreamHandler> streamHandlers,
-      Map<String, Set<PushFailedBatch>> failedBatchSetMap,
+      Map<String, LocationPushFailedBatches> failedBatchSetMap,
       Map<String, Pair<Integer, Integer>> chunksRange,
       int[] mapAttempts,
       MetricsCallback metricsCallback)
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 81329e8d1..6e80f90c2 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -68,7 +68,7 @@ import org.apache.celeborn.common.rpc.RpcEnv;
 import org.apache.celeborn.common.unsafe.Platform;
 import org.apache.celeborn.common.util.*;
 import org.apache.celeborn.common.write.DataBatches;
-import org.apache.celeborn.common.write.PushFailedBatch;
+import org.apache.celeborn.common.write.LocationPushFailedBatches;
 import org.apache.celeborn.common.write.PushState;
 
 public class ShuffleClientImpl extends ShuffleClient {
@@ -150,7 +150,7 @@ public class ShuffleClientImpl extends ShuffleClient {
 
   public static class ReduceFileGroups {
     public Map<Integer, Set<PartitionLocation>> partitionGroups;
-    public Map<String, Set<PushFailedBatch>> pushFailedBatches;
+    public Map<String, LocationPushFailedBatches> pushFailedBatches;
     public int[] mapAttempts;
     public Set<Integer> partitionIds;
 
@@ -158,7 +158,7 @@ public class ShuffleClientImpl extends ShuffleClient {
         Map<Integer, Set<PartitionLocation>> partitionGroups,
         int[] mapAttempts,
         Set<Integer> partitionIds,
-        Map<String, Set<PushFailedBatch>> pushFailedBatches) {
+        Map<String, LocationPushFailedBatches> pushFailedBatches) {
       this.partitionGroups = partitionGroups;
       this.mapAttempts = mapAttempts;
       this.partitionIds = partitionIds;
@@ -1122,8 +1122,8 @@ public class ShuffleClientImpl extends ShuffleClient {
                       partitionId,
                       nextBatchId);
                   if (dataPushFailureTrackingEnabled && pushReplicateEnabled) {
-                    pushState.addFailedBatch(
-                        latest.getUniqueId(), new PushFailedBatch(mapId, 
attemptId, nextBatchId));
+                    pushState.recordFailedBatch(
+                        latest.getUniqueId(), mapId, attemptId, nextBatchId);
                   }
                   ReviveRequest reviveRequest =
                       new ReviveRequest(
@@ -1192,8 +1192,7 @@ public class ShuffleClientImpl extends ShuffleClient {
             @Override
             public void onFailure(Throwable e) {
               if (dataPushFailureTrackingEnabled) {
-                pushState.addFailedBatch(
-                    latest.getUniqueId(), new PushFailedBatch(mapId, 
attemptId, nextBatchId));
+                pushState.recordFailedBatch(latest.getUniqueId(), mapId, 
attemptId, nextBatchId);
               }
               if (pushState.exception.get() != null) {
                 return;
@@ -1563,9 +1562,8 @@ public class ShuffleClientImpl extends ShuffleClient {
               } else {
                 if (dataPushFailureTrackingEnabled && pushReplicateEnabled) {
                   for (DataBatches.DataBatch resubmitBatch : 
batchesNeedResubmit) {
-                    pushState.addFailedBatch(
-                        resubmitBatch.loc.getUniqueId(),
-                        new PushFailedBatch(mapId, attemptId, 
resubmitBatch.batchId));
+                    pushState.recordFailedBatch(
+                        resubmitBatch.loc.getUniqueId(), mapId, attemptId, 
resubmitBatch.batchId);
                   }
                 }
                 ReviveRequest[] requests =
@@ -1625,8 +1623,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           public void onFailure(Throwable e) {
             if (dataPushFailureTrackingEnabled) {
               for (int i = 0; i < numBatches; i++) {
-                pushState.addFailedBatch(
-                    partitionUniqueIds[i], new PushFailedBatch(mapId, 
attemptId, batchIds[i]));
+                pushState.recordFailedBatch(partitionUniqueIds[i], mapId, 
attemptId, batchIds[i]);
               }
             }
             if (pushState.exception.get() != null) {
@@ -1915,7 +1912,7 @@ public class ShuffleClientImpl extends ShuffleClient {
       ExceptionMaker exceptionMaker,
       ArrayList<PartitionLocation> locations,
       ArrayList<PbStreamHandler> streamHandlers,
-      Map<String, Set<PushFailedBatch>> failedBatchSetMap,
+      Map<String, LocationPushFailedBatches> failedBatchSetMap,
       Map<String, Pair<Integer, Integer>> chunksRange,
       int[] mapAttempts,
       MetricsCallback metricsCallback)
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java 
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
index ffcc50dde..5bed692b4 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
@@ -48,7 +48,7 @@ import org.apache.celeborn.common.protocol.*;
 import org.apache.celeborn.common.unsafe.Platform;
 import org.apache.celeborn.common.util.ExceptionMaker;
 import org.apache.celeborn.common.util.Utils;
-import org.apache.celeborn.common.write.PushFailedBatch;
+import org.apache.celeborn.common.write.LocationPushFailedBatches;
 
 public abstract class CelebornInputStream extends InputStream {
   private static final Logger logger = 
LoggerFactory.getLogger(CelebornInputStream.class);
@@ -60,7 +60,7 @@ public abstract class CelebornInputStream extends InputStream 
{
       ArrayList<PartitionLocation> locations,
       ArrayList<PbStreamHandler> streamHandlers,
       int[] attempts,
-      Map<String, Set<PushFailedBatch>> failedBatchSetMap,
+      Map<String, LocationPushFailedBatches> failedBatchSetMap,
       Map<String, Pair<Integer, Integer>> chunksRange,
       int attemptNumber,
       long taskId,
@@ -176,7 +176,7 @@ public abstract class CelebornInputStream extends 
InputStream {
 
     private Map<Integer, Set<Integer>> batchesRead = new HashMap<>();
 
-    private final Map<String, Set<PushFailedBatch>> failedBatches;
+    private final Map<String, LocationPushFailedBatches> failedBatches;
 
     private byte[] compressedBuf;
     private byte[] rawDataBuf;
@@ -223,7 +223,7 @@ public abstract class CelebornInputStream extends 
InputStream {
         ArrayList<PartitionLocation> locations,
         ArrayList<PbStreamHandler> streamHandlers,
         int[] attempts,
-        Map<String, Set<PushFailedBatch>> failedBatchSet,
+        Map<String, LocationPushFailedBatches> failedBatchSet,
         int attemptNumber,
         long taskId,
         Map<String, Pair<Integer, Integer>> partitionLocationToChunkRange,
@@ -266,7 +266,7 @@ public abstract class CelebornInputStream extends 
InputStream {
         ArrayList<PartitionLocation> locations,
         ArrayList<PbStreamHandler> streamHandlers,
         int[] attempts,
-        Map<String, Set<PushFailedBatch>> failedBatchSet,
+        Map<String, LocationPushFailedBatches> failedBatchSet,
         int attemptNumber,
         long taskId,
         int startMapIndex,
@@ -758,7 +758,7 @@ public abstract class CelebornInputStream extends 
InputStream {
           return false;
         }
 
-        PushFailedBatch failedBatch = new PushFailedBatch(-1, -1, -1);
+        LocationPushFailedBatches failedBatch = new 
LocationPushFailedBatches();
         boolean hasData = false;
         while (currentChunk.isReadable() || moveToNextChunk()) {
           currentChunk.readBytes(sizeBuf);
@@ -784,14 +784,15 @@ public abstract class CelebornInputStream extends 
InputStream {
           // de-duplicate
           if (attemptId == attempts[mapId]) {
             if (readSkewPartitionWithoutMapRange) {
-              Set<PushFailedBatch> failedBatchSet =
+              LocationPushFailedBatches locationPushFailedBatches =
                   
this.failedBatches.get(currentReader.getLocation().getUniqueId());
-              if (null != failedBatchSet) {
-                failedBatch.setMapId(mapId);
-                failedBatch.setAttemptId(attemptId);
-                failedBatch.setBatchId(batchId);
-                if (failedBatchSet.contains(failedBatch)) {
-                  logger.warn("Skip duplicated batch: {}.", failedBatch);
+              if (null != locationPushFailedBatches) {
+                if (locationPushFailedBatches.contains(mapId, attemptId, 
batchId)) {
+                  logger.warn(
+                      "Skip duplicated batch: mapId={}, attemptId={}, 
batchId={}",
+                      mapId,
+                      attemptId,
+                      batchId);
                   continue;
                 }
               }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index bffd05430..702441dce 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -42,7 +42,7 @@ import org.apache.celeborn.common.rpc.RpcCallContext
 import org.apache.celeborn.common.util.FunctionConverter._
 import org.apache.celeborn.common.util.JavaUtils
 import org.apache.celeborn.common.util.ThreadUtils
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
 
 case class ShuffleCommittedInfo(
     // partition id -> unique partition ids
@@ -219,7 +219,7 @@ class CommitManager(appUniqueId: String, val conf: 
CelebornConf, lifecycleManage
       attemptId: Int,
       numMappers: Int,
       partitionId: Int = -1,
-      pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = 
Collections.emptyMap())
+      pushFailedBatches: util.Map[String, LocationPushFailedBatches] = 
Collections.emptyMap())
       : (Boolean, Boolean) = {
     getCommitHandler(shuffleId).finishMapperAttempt(
       shuffleId,
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 20e1099d2..bcb018570 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -59,7 +59,7 @@ import org.apache.celeborn.common.util.{JavaUtils, 
PbSerDeUtils, ThreadUtils, Ut
 import org.apache.celeborn.common.util.FunctionConverter._
 import org.apache.celeborn.common.util.ThreadUtils.awaitResult
 import org.apache.celeborn.common.util.Utils.UNKNOWN_APP_SHUFFLE_ID
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
 
 object LifecycleManager {
   // shuffle id -> partition id -> partition locations
@@ -67,7 +67,7 @@ object LifecycleManager {
     ConcurrentHashMap[Int, ConcurrentHashMap[Integer, 
util.Set[PartitionLocation]]]
   // shuffle id -> partition uniqueId -> PushFailedBatch set
   type ShufflePushFailedBatches =
-    ConcurrentHashMap[Int, util.HashMap[String, util.Set[PushFailedBatch]]]
+    ConcurrentHashMap[Int, util.HashMap[String, LocationPushFailedBatches]]
   type ShuffleAllocatedWorkers =
     ConcurrentHashMap[Int, ConcurrentHashMap[String, 
ShufflePartitionLocationInfo]]
   type ShuffleFailedWorkers = ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]
@@ -825,7 +825,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       mapId: Int,
       attemptId: Int,
       numMappers: Int,
-      pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]]): Unit = {
+      pushFailedBatches: util.Map[String, LocationPushFailedBatches]): Unit = {
 
     val (mapperAttemptFinishedSuccess, allMapperFinished) =
       commitManager.finishMapperAttempt(
@@ -1118,12 +1118,8 @@ class LifecycleManager(val appUniqueId: String, val 
conf: CelebornConf) extends
       }
     }
 
-    val (mapperAttemptFinishedSuccess, _) = commitManager.finishMapperAttempt(
-      shuffleId,
-      mapId,
-      attemptId,
-      numMappers,
-      partitionId)
+    val (mapperAttemptFinishedSuccess, _) =
+      commitManager.finishMapperAttempt(shuffleId, mapId, attemptId, 
numMappers, partitionId)
     reply(mapperAttemptFinishedSuccess)
   }
 
@@ -1817,7 +1813,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     }
   }
 
-  // Once a partition is released, it will be never needed anymore
+  // Once a partition is released, it will never be needed anymore
   def releasePartition(shuffleId: Int, partitionId: Int): Unit = {
     commitManager.releasePartitionResource(shuffleId, partitionId)
     val partitionLocation = latestPartitionLocation.get(shuffleId)
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index c8b563372..42ea0379f 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -43,7 +43,7 @@ import org.apache.celeborn.common.util.{CollectionUtils, 
JavaUtils, Utils}
 // Can Remove this if celeborn don't support scala211 in future
 import org.apache.celeborn.common.util.FunctionConverter._
 import org.apache.celeborn.common.util.ThreadUtils.awaitResult
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
 
 case class CommitFilesParam(
     worker: WorkerInfo,
@@ -206,7 +206,7 @@ abstract class CommitHandler(
       attemptId: Int,
       numMappers: Int,
       partitionId: Int,
-      pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
+      pushFailedBatches: util.Map[String, LocationPushFailedBatches],
       recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean)
 
   def registerShuffle(
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index 4f31018e5..715950531 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -40,7 +40,7 @@ import org.apache.celeborn.common.rpc.RpcCallContext
 import org.apache.celeborn.common.util.FunctionConverter._
 import org.apache.celeborn.common.util.JavaUtils
 import org.apache.celeborn.common.util.Utils
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
 
 /**
  * This commit handler is for MapPartition ShuffleType, which means that a Map 
Partition contains all data produced
@@ -186,7 +186,7 @@ class MapPartitionCommitHandler(
       attemptId: Int,
       numMappers: Int,
       partitionId: Int,
-      pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
+      pushFailedBatches: util.Map[String, LocationPushFailedBatches],
       recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = 
{
     val inProcessingPartitionIds =
       inProcessMapPartitionEndIds.computeIfAbsent(
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index a65f7e1eb..45c371ee5 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -41,7 +41,7 @@ import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.rpc.RpcCallContext
 import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, 
RemoteNettyRpcCallContext}
 import org.apache.celeborn.common.util.JavaUtils
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
 
 /**
  * This commit handler is for ReducePartition ShuffleType, which means that a 
Reduce Partition contains all data
@@ -94,18 +94,18 @@ class ReducePartitionCommitHandler(
     .build().asInstanceOf[Cache[Int, ByteBuffer]]
 
   private val newShuffleId2PushFailedBatchMapFunc
-      : function.Function[Int, util.HashMap[String, 
util.Set[PushFailedBatch]]] =
-    new util.function.Function[Int, util.HashMap[String, 
util.Set[PushFailedBatch]]]() {
-      override def apply(s: Int): util.HashMap[String, 
util.Set[PushFailedBatch]] = {
-        new util.HashMap[String, util.Set[PushFailedBatch]]()
+      : function.Function[Int, util.HashMap[String, 
LocationPushFailedBatches]] =
+    new util.function.Function[Int, util.HashMap[String, 
LocationPushFailedBatches]]() {
+      override def apply(s: Int): util.HashMap[String, 
LocationPushFailedBatches] = {
+        new util.HashMap[String, LocationPushFailedBatches]()
       }
     }
 
   private val uniqueId2PushFailedBatchMapFunc
-      : function.Function[String, util.Set[PushFailedBatch]] =
-    new util.function.Function[String, util.Set[PushFailedBatch]]() {
-      override def apply(s: String): util.Set[PushFailedBatch] = {
-        Sets.newHashSet[PushFailedBatch]()
+      : function.Function[String, LocationPushFailedBatches] =
+    new util.function.Function[String, LocationPushFailedBatches]() {
+      override def apply(s: String): LocationPushFailedBatches = {
+        new LocationPushFailedBatches()
       }
     }
 
@@ -267,7 +267,7 @@ class ReducePartitionCommitHandler(
       attemptId: Int,
       numMappers: Int,
       partitionId: Int,
-      pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
+      pushFailedBatches: util.Map[String, LocationPushFailedBatches],
       recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = 
{
     shuffleMapperAttempts.synchronized {
       if (getMapperAttempts(shuffleId) == null) {
@@ -282,11 +282,11 @@ class ReducePartitionCommitHandler(
           val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent(
             shuffleId,
             newShuffleId2PushFailedBatchMapFunc)
-          for ((partitionUniqId, pushFailedBatchSet) <- 
pushFailedBatches.asScala) {
+          for ((partitionUniqId, locationPushFailedBatches) <- 
pushFailedBatches.asScala) {
             val partitionPushFailedBatches = 
pushFailedBatchesMap.computeIfAbsent(
               partitionUniqId,
               uniqueId2PushFailedBatchMapFunc)
-            partitionPushFailedBatches.addAll(pushFailedBatchSet)
+            partitionPushFailedBatches.merge(locationPushFailedBatches)
           }
         }
         // Mapper with this attemptId finished, also check all other mapper 
finished or not.
@@ -361,7 +361,7 @@ class ReducePartitionCommitHandler(
                 pushFailedBatches =
                   shufflePushFailedBatches.getOrDefault(
                     shuffleId,
-                    new util.HashMap[String, util.Set[PushFailedBatch]]()),
+                    new util.HashMap[String, LocationPushFailedBatches]()),
                 serdeVersion = serdeVersion)
 
               val serializedMsg =
diff --git 
a/common/src/main/java/org/apache/celeborn/common/write/LocationPushFailedBatches.java
 
b/common/src/main/java/org/apache/celeborn/common/write/LocationPushFailedBatches.java
new file mode 100644
index 000000000..2273a3698
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/write/LocationPushFailedBatches.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.write;
+
+import java.io.Serializable;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.commons.lang3.StringUtils;
+
+import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.common.util.Utils;
+
+public class LocationPushFailedBatches implements Serializable {
+
+  // map-attempt-id -> failed batch ids
+  private final Map<String, Set<Integer>> failedBatches = 
JavaUtils.newConcurrentHashMap();
+
+  public Map<String, Set<Integer>> getFailedBatches() {
+    return failedBatches;
+  }
+
+  public boolean contains(int mapId, int attemptId, int batchId) {
+    Set<Integer> batches = failedBatches.get(Utils.makeAttemptKey(mapId, 
attemptId));
+    return batches != null && batches.contains(batchId);
+  }
+
+  public void merge(LocationPushFailedBatches batches) {
+    Map<String, Set<Integer>> otherFailedBatchesMap = 
batches.getFailedBatches();
+    otherFailedBatchesMap.forEach(
+        (k, v) -> {
+          Set<Integer> failedBatches =
+              this.failedBatches.computeIfAbsent(k, (s) -> 
ConcurrentHashMap.newKeySet());
+          failedBatches.addAll(v);
+        });
+  }
+
+  public void addFailedBatch(int mapId, int attemptId, int batchId) {
+    String attemptKey = Utils.makeAttemptKey(mapId, attemptId);
+    Set<Integer> failedBatches =
+        this.failedBatches.computeIfAbsent(attemptKey, (s) -> 
ConcurrentHashMap.newKeySet());
+    failedBatches.add(batchId);
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (o == null || getClass() != o.getClass()) return false;
+    LocationPushFailedBatches that = (LocationPushFailedBatches) o;
+    if (that.failedBatches.size() != failedBatches.size()) {
+      return false;
+    }
+
+    return failedBatches.entrySet().stream()
+        .allMatch(
+            item -> {
+              Set<Integer> failedBatchesSet = 
that.failedBatches.get(item.getKey());
+              if (failedBatchesSet == null) return false;
+              return failedBatchesSet.equals(item.getValue());
+            });
+  }
+
+  @Override
+  public int hashCode() {
+    return failedBatches.entrySet().hashCode();
+  }
+
+  @Override
+  public String toString() {
+    StringBuilder stringBuilder = new StringBuilder();
+    failedBatches.forEach(
+        (attemptKey, value) -> {
+          stringBuilder.append("failed attemptKey:");
+          stringBuilder.append(attemptKey);
+          stringBuilder.append(" fail batch Ids:");
+          stringBuilder.append(StringUtils.join(value, ","));
+        });
+    return "LocationPushFailedBatches{" + "failedBatches=" + stringBuilder + 
'}';
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java 
b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java
deleted file mode 100644
index ccee8bf11..000000000
--- a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.celeborn.common.write;
-
-import java.io.Serializable;
-
-import com.google.common.base.Objects;
-import org.apache.commons.lang3.builder.ToStringBuilder;
-import org.apache.commons.lang3.builder.ToStringStyle;
-
-public class PushFailedBatch implements Serializable {
-
-  private int mapId;
-  private int attemptId;
-  private int batchId;
-
-  public PushFailedBatch(int mapId, int attemptId, int batchId) {
-    this.mapId = mapId;
-    this.attemptId = attemptId;
-    this.batchId = batchId;
-  }
-
-  public int getMapId() {
-    return mapId;
-  }
-
-  public void setMapId(int mapId) {
-    this.mapId = mapId;
-  }
-
-  public int getAttemptId() {
-    return attemptId;
-  }
-
-  public void setAttemptId(int attemptId) {
-    this.attemptId = attemptId;
-  }
-
-  public int getBatchId() {
-    return batchId;
-  }
-
-  public void setBatchId(int batchId) {
-    this.batchId = batchId;
-  }
-
-  @Override
-  public boolean equals(Object other) {
-    if (!(other instanceof PushFailedBatch)) {
-      return false;
-    }
-    PushFailedBatch o = (PushFailedBatch) other;
-    return mapId == o.mapId && attemptId == o.attemptId && batchId == 
o.batchId;
-  }
-
-  @Override
-  public int hashCode() {
-    return Objects.hashCode(mapId, attemptId, batchId);
-  }
-
-  @Override
-  public String toString() {
-    return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
-        .append("mapId", mapId)
-        .append("attemptId", attemptId)
-        .append("batchId", batchId)
-        .toString();
-  }
-}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/write/PushState.java 
b/common/src/main/java/org/apache/celeborn/common/write/PushState.java
index 3977e5d54..afa22bb8a 100644
--- a/common/src/main/java/org/apache/celeborn/common/write/PushState.java
+++ b/common/src/main/java/org/apache/celeborn/common/write/PushState.java
@@ -19,11 +19,9 @@ package org.apache.celeborn.common.write;
 
 import java.io.IOException;
 import java.util.Map;
-import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicReference;
 
-import com.google.common.collect.Sets;
 import org.apache.commons.lang3.tuple.Pair;
 
 import org.apache.celeborn.common.CelebornConf;
@@ -36,7 +34,7 @@ public class PushState {
   public AtomicReference<IOException> exception = new AtomicReference<>();
   private final InFlightRequestTracker inFlightRequestTracker;
 
-  private final Map<String, Set<PushFailedBatch>> failedBatchMap;
+  private final Map<String, LocationPushFailedBatches> failedBatchMap;
 
   public PushState(CelebornConf conf) {
     pushBufferMaxSize = conf.clientPushBufferMaxSize();
@@ -95,13 +93,13 @@ public class PushState {
     return inFlightRequestTracker.remainingAllowPushes(hostAndPushPort);
   }
 
-  public void addFailedBatch(String partitionId, PushFailedBatch failedBatch) {
+  public void recordFailedBatch(String partitionId, int mapId, int attemptId, 
int batchId) {
     this.failedBatchMap
-        .computeIfAbsent(partitionId, (s) -> Sets.newConcurrentHashSet())
-        .add(failedBatch);
+        .computeIfAbsent(partitionId, (s) -> new LocationPushFailedBatches())
+        .addFailedBatch(mapId, attemptId, batchId);
   }
 
-  public Map<String, Set<PushFailedBatch>> getFailedBatches() {
+  public Map<String, LocationPushFailedBatches> getFailedBatches() {
     return this.failedBatchMap;
   }
 }
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 7b0d0bec2..c120b9952 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -356,17 +356,15 @@ message PbMapperEnd {
   int32 attemptId = 3;
   int32 numMappers = 4;
   int32 partitionId = 5;
-  map<string, PbPushFailedBatchSet> pushFailureBatches= 6;
+  map<string, PbLocationPushFailedBatches> pushFailureBatches= 6;
 }
 
-message PbPushFailedBatchSet {
-  repeated PbPushFailedBatch failureBatches = 1;
+message PbLocationPushFailedBatches {
+  map<string, PbFailedBatches> failedBatches = 1;
 }
 
-message PbPushFailedBatch {
-  int32 mapId = 1;
-  int32 attemptId = 2;
-  int32 batchId = 3;
+message PbFailedBatches{
+  repeated int32 failedBatches = 1;
 }
 
 message PbMapperEndResponse {
@@ -389,7 +387,7 @@ message PbGetReducerFileGroupResponse {
   // only map partition mode has succeed partitionIds
   repeated int32 partitionIds = 4;
 
-  map<string, PbPushFailedBatchSet> pushFailedBatches = 5;
+  map<string, PbLocationPushFailedBatches> pushFailedBatches = 5;
 
   bytes broadcast = 6;
 }
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 57831f213..ccd5d79a2 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -33,7 +33,7 @@ import org.apache.celeborn.common.protocol._
 import org.apache.celeborn.common.protocol.MessageType._
 import org.apache.celeborn.common.quota.ResourceConsumption
 import org.apache.celeborn.common.util.{PbSerDeUtils, Utils}
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
 
 sealed trait Message extends Serializable
 
@@ -274,7 +274,7 @@ object ControlMessages extends Logging {
       attemptId: Int,
       numMappers: Int,
       partitionId: Int,
-      failedBatchSet: util.Map[String, util.Set[PushFailedBatch]])
+      failedBatchSet: util.Map[String, LocationPushFailedBatches])
     extends MasterMessage
 
   case class MapperEndResponse(status: StatusCode) extends MasterMessage
@@ -292,7 +292,8 @@ object ControlMessages extends Logging {
       fileGroup: util.Map[Integer, util.Set[PartitionLocation]] = 
Collections.emptyMap(),
       attempts: Array[Int] = Array.emptyIntArray,
       partitionIds: util.Set[Integer] = Collections.emptySet[Integer](),
-      pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = 
Collections.emptyMap(),
+      pushFailedBatches: util.Map[String, LocationPushFailedBatches] =
+        Collections.emptyMap(),
       broadcast: Array[Byte] = Array.emptyByteArray,
       serdeVersion: SerdeVersion = SerdeVersion.V1)
     extends MasterMessage
@@ -732,7 +733,7 @@ object ControlMessages extends Logging {
 
     case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, 
pushFailedBatch) =>
       val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) =>
-        val resultValue = PbSerDeUtils.toPbPushFailedBatchSet(v)
+        val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v)
         (k, resultValue)
       }.toMap.asJava
       val payload = PbMapperEnd.newBuilder()
@@ -781,7 +782,7 @@ object ControlMessages extends Logging {
       builder.putAllPushFailedBatches(
         failedBatches.asScala.map {
           case (uniqueId, pushFailedBatchSet) =>
-            (uniqueId, PbSerDeUtils.toPbPushFailedBatchSet(pushFailedBatchSet))
+            (uniqueId, 
PbSerDeUtils.toPbLocationPushFailedBatches(pushFailedBatchSet))
         }.asJava)
       builder.setBroadcast(ByteString.copyFrom(broadcast))
       val payload = builder.build().toByteArray
@@ -1171,7 +1172,7 @@ object ControlMessages extends Logging {
           pbMapperEnd.getPartitionId,
           pbMapperEnd.getPushFailureBatchesMap.asScala.map {
             case (partitionId, pushFailedBatchSet) =>
-              (partitionId, 
PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet))
+              (partitionId, 
PbSerDeUtils.fromPbLocationPushFailedBatches(pushFailedBatchSet))
           }.toMap.asJava)
 
       case MAPPER_END_RESPONSE_VALUE =>
@@ -1211,7 +1212,7 @@ object ControlMessages extends Logging {
         val partitionIds = new 
util.HashSet(pbGetReducerFileGroupResponse.getPartitionIdsList)
         val pushFailedBatches = 
pbGetReducerFileGroupResponse.getPushFailedBatchesMap.asScala.map {
           case (uniqueId, pushFailedBatchSet) =>
-            (uniqueId, 
PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet))
+            (uniqueId, 
PbSerDeUtils.fromPbLocationPushFailedBatches(pushFailedBatchSet))
         }.toMap.asJava
         val broadcast = pbGetReducerFileGroupResponse.getBroadcast.toByteArray
         GetReducerFileGroupResponse(
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala 
b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
index 3e0d1afd9..6ba547684 100644
--- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
@@ -32,7 +32,7 @@ import 
org.apache.celeborn.common.protocol.PartitionLocation.Mode
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource
 import org.apache.celeborn.common.quota.ResourceConsumption
 import org.apache.celeborn.common.util.{CollectionUtils => 
localCollectionUtils}
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
 
 object PbSerDeUtils {
 
@@ -477,7 +477,7 @@ object PbSerDeUtils {
   }
 
   def fromPbApplicationMeta(pbApplicationMeta: PbApplicationMeta): 
ApplicationMeta = {
-    new ApplicationMeta(pbApplicationMeta.getAppId, 
pbApplicationMeta.getSecret)
+    ApplicationMeta(pbApplicationMeta.getAppId, pbApplicationMeta.getSecret)
   }
 
   def toPbWorkerStatus(workerStatus: WorkerStatus): PbWorkerStatus = {
@@ -501,7 +501,7 @@ object PbSerDeUtils {
   def fromPbWorkerEventInfo(pbWorkerEventInfo: PbWorkerEventInfo): 
WorkerEventInfo = {
     new WorkerEventInfo(
       pbWorkerEventInfo.getWorkerEventType.getNumber,
-      pbWorkerEventInfo.getEventStartTime())
+      pbWorkerEventInfo.getEventStartTime)
   }
 
   private def toPackedPartitionLocation(
@@ -695,34 +695,36 @@ object PbSerDeUtils {
     }.asJava
   }
 
-  def toPbPushFailedBatch(pushFailedBatch: PushFailedBatch): PbPushFailedBatch 
= {
-    PbPushFailedBatch.newBuilder()
-      .setMapId(pushFailedBatch.getMapId)
-      .setAttemptId(pushFailedBatch.getAttemptId)
-      .setBatchId(pushFailedBatch.getBatchId)
+  def toPbFailedBatches(failedBatches: util.Set[Int]): PbFailedBatches = {
+    PbFailedBatches.newBuilder()
+      .addAllFailedBatches(failedBatches.asScala.map(new Integer(_)).asJava)
       .build()
   }
 
-  def fromPbPushFailedBatch(pbPushFailedBatch: PbPushFailedBatch): 
PushFailedBatch = {
-    new PushFailedBatch(
-      pbPushFailedBatch.getMapId,
-      pbPushFailedBatch.getAttemptId,
-      pbPushFailedBatch.getBatchId)
+  def fromPbFailedBatches(failedBatches: PbFailedBatches): util.Set[Int] = {
+    failedBatches.getFailedBatchesList.asScala.map(_.intValue()).toSet.asJava
   }
 
-  def toPbPushFailedBatchSet(failedBatchSet: util.Set[PushFailedBatch]): 
PbPushFailedBatchSet = {
-    val builder = PbPushFailedBatchSet.newBuilder()
-    failedBatchSet.asScala.foreach(batch => 
builder.addFailureBatches(toPbPushFailedBatch(batch)))
-
+  def toPbLocationPushFailedBatches(locationPushFailedBatches: 
LocationPushFailedBatches)
+      : PbLocationPushFailedBatches = {
+    val builder = PbLocationPushFailedBatches.newBuilder()
+    builder.putAllFailedBatches(
+      locationPushFailedBatches
+        .getFailedBatches
+        .asScala
+        .map(item =>
+          (item._1, 
toPbFailedBatches(item._2.asScala.map(_.intValue()).toSet.asJava))).asJava)
     builder.build()
   }
 
-  def fromPbPushFailedBatchSet(pbFailedBatchSet: PbPushFailedBatchSet)
-      : util.Set[PushFailedBatch] = {
-    val failedBatchSet = new util.HashSet[PushFailedBatch]()
-    pbFailedBatchSet.getFailureBatchesList.asScala.foreach(batch =>
-      failedBatchSet.add(fromPbPushFailedBatch(batch)))
-
-    failedBatchSet
+  def fromPbLocationPushFailedBatches(pbLocationPushFailedBatches: 
PbLocationPushFailedBatches)
+      : LocationPushFailedBatches = {
+    val batches = new LocationPushFailedBatches()
+    pbLocationPushFailedBatches.getFailedBatchesMap.asScala.foreach { case 
(key, value) =>
+      batches.getFailedBatches.put(
+        key,
+        fromPbFailedBatches(value).asScala.map(new Integer(_)).asJava)
+    }
+    batches
   }
 }
diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala 
b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
index e825abce3..73f070ab9 100644
--- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
@@ -717,6 +717,17 @@ object Utils extends Logging {
     s"$shuffleId-$mapId-$attemptId"
   }
 
+  def makeAttemptKey(mapId: Int, attemptId: Int): String = {
+    s"$mapId-$attemptId"
+  }
+
+  def splitAttemptKey(attemptKey: String): (Int, Int) = {
+    val splits = attemptKey.split("-")
+    val mapId = splits(0).toInt
+    val attemptId = splits(1).toInt
+    (mapId, attemptId)
+  }
+
   def shuffleKeyPrefix(shuffleKey: String): String = {
     shuffleKey + "-"
   }
diff --git 
a/common/src/test/java/org/apache/celeborn/common/write/LocationPushFailedBatchesSuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/write/LocationPushFailedBatchesSuiteJ.java
new file mode 100644
index 000000000..6fc58eb17
--- /dev/null
+++ 
b/common/src/test/java/org/apache/celeborn/common/write/LocationPushFailedBatchesSuiteJ.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.write;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class LocationPushFailedBatchesSuiteJ {
+
+  @Test
+  public void equalsReturnsTrueForIdenticalBatches() {
+    LocationPushFailedBatches batch1 = new LocationPushFailedBatches();
+    batch1.addFailedBatch(1, 2, 3);
+    LocationPushFailedBatches batch2 = new LocationPushFailedBatches();
+    batch2.addFailedBatch(1, 2, 3);
+    Assert.assertEquals(batch1, batch2);
+  }
+
+  @Test
+  public void equalsReturnsFalseForDifferentBatches() {
+    LocationPushFailedBatches batch1 = new LocationPushFailedBatches();
+    batch1.addFailedBatch(1, 2, 3);
+    LocationPushFailedBatches batch2 = new LocationPushFailedBatches();
+    batch2.addFailedBatch(4, 5, 6);
+    Assert.assertNotEquals(batch1, batch2);
+  }
+
+  @Test
+  public void hashCodeDiffersForDifferentBatches() {
+    LocationPushFailedBatches batch1 = new LocationPushFailedBatches();
+    batch1.addFailedBatch(1, 2, 3);
+    LocationPushFailedBatches batch2 = new LocationPushFailedBatches();
+    batch2.addFailedBatch(4, 5, 6);
+    Assert.assertNotEquals(batch1.hashCode(), batch2.hashCode());
+  }
+
+  @Test
+  public void hashCodeSameForIdenticalBatches() {
+    LocationPushFailedBatches batch1 = new LocationPushFailedBatches();
+    batch1.addFailedBatch(1, 2, 3);
+    LocationPushFailedBatches batch2 = new LocationPushFailedBatches();
+    batch2.addFailedBatch(1, 2, 3);
+    Assert.assertEquals(batch1.hashCode(), batch2.hashCode());
+  }
+
+  @Test
+  public void hashCodeIsConsistent() {
+    LocationPushFailedBatches batch = new LocationPushFailedBatches();
+    batch.addFailedBatch(1, 2, 3);
+    int hashCode1 = batch.hashCode();
+    int hashCode2 = batch.hashCode();
+    Assert.assertEquals(hashCode1, hashCode2);
+  }
+
+  @Test
+  public void toStringReturnsExpectedFormat() {
+    LocationPushFailedBatches batch = new LocationPushFailedBatches();
+    batch.addFailedBatch(1, 2, 3);
+    String expected =
+        "LocationPushFailedBatches{failedBatches=failed attemptKey:1-2 fail 
batch Ids:3}";
+    Assert.assertEquals(expected, batch.toString());
+  }
+
+  @Test
+  public void hashCodeAndEqualsWorkInSet() {
+    Set<LocationPushFailedBatches> set = new HashSet<>();
+    LocationPushFailedBatches batch1 = new LocationPushFailedBatches();
+    batch1.addFailedBatch(1, 2, 3);
+    LocationPushFailedBatches batch2 = new LocationPushFailedBatches();
+    batch2.addFailedBatch(1, 2, 3);
+    set.add(batch1);
+    assertTrue(set.contains(batch2));
+  }
+
+  @Test
+  public void concurrentAddAndGetShouldNotConflict() throws 
InterruptedException {
+    LocationPushFailedBatches batches = new LocationPushFailedBatches();
+    ExecutorService executor = Executors.newFixedThreadPool(4);
+
+    int totalFailedBatches = 1000;
+    for (int i = 0; i < totalFailedBatches; i++) {
+      final int tIdx = i;
+      executor.submit(() -> batches.addFailedBatch(tIdx % 10, tIdx % 5, tIdx));
+    }
+    executor.shutdown();
+    executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS);
+    assertTrue(!batches.getFailedBatches().isEmpty());
+    assertEquals(
+        totalFailedBatches, 
batches.getFailedBatches().values().stream().mapToInt(Set::size).sum());
+  }
+}
diff --git 
a/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java
deleted file mode 100644
index fcfc6b799..000000000
--- 
a/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.celeborn.common.write;
-
-import java.util.HashSet;
-import java.util.Set;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class PushFailedBatchSuiteJ {
-
-  @Test
-  public void equalsReturnsTrueForIdenticalBatches() {
-    PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
-    PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3);
-    Assert.assertEquals(batch1, batch2);
-  }
-
-  @Test
-  public void equalsReturnsFalseForDifferentBatches() {
-    PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
-    PushFailedBatch batch2 = new PushFailedBatch(4, 5, 6);
-    Assert.assertNotEquals(batch1, batch2);
-  }
-
-  @Test
-  public void hashCodeDiffersForDifferentBatches() {
-    PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
-    PushFailedBatch batch2 = new PushFailedBatch(4, 5, 6);
-    Assert.assertNotEquals(batch1.hashCode(), batch2.hashCode());
-  }
-
-  @Test
-  public void hashCodeSameForIdenticalBatches() {
-    PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
-    PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3);
-    Assert.assertEquals(batch1.hashCode(), batch2.hashCode());
-  }
-
-  @Test
-  public void hashCodeIsConsistent() {
-    PushFailedBatch batch = new PushFailedBatch(1, 2, 3);
-    int hashCode1 = batch.hashCode();
-    int hashCode2 = batch.hashCode();
-    Assert.assertEquals(hashCode1, hashCode2);
-  }
-
-  @Test
-  public void toStringReturnsExpectedFormat() {
-    PushFailedBatch batch = new PushFailedBatch(1, 2, 3);
-    String expected = "PushFailedBatch[mapId=1,attemptId=2,batchId=3]";
-    Assert.assertEquals(expected, batch.toString());
-  }
-
-  @Test
-  public void hashCodeAndEqualsWorkInSet() {
-    Set<PushFailedBatch> set = new HashSet<>();
-    PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
-    PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3);
-    set.add(batch1);
-    Assert.assertTrue(set.contains(batch2));
-  }
-}
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala 
b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
index 3136584b1..9d4059ff9 100644
--- 
a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
+++ 
b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.util.Random
 
-import com.google.common.collect.{Lists, Sets}
+import com.google.common.collect.Lists
 import org.apache.hadoop.shaded.org.apache.commons.lang3.RandomStringUtils
 
 import org.apache.celeborn.CelebornFunSuite
@@ -36,7 +36,7 @@ import 
org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode}
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse,
 WorkerResource}
 import org.apache.celeborn.common.quota.ResourceConsumption
 import 
org.apache.celeborn.common.util.PbSerDeUtils.{fromPbPackedPartitionLocationsPair,
 toPbPackedPartitionLocationsPair}
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
 
 class PbSerDeUtilsTest extends CelebornFunSuite {
 
@@ -664,19 +664,12 @@ class PbSerDeUtilsTest extends CelebornFunSuite {
   }
 
   test("fromAndToPushFailedBatch") {
-    val failedBatch = new PushFailedBatch(1, 1, 2)
-    val pbPushFailedBatch = PbSerDeUtils.toPbPushFailedBatch(failedBatch)
-    val restoredFailedBatch = 
PbSerDeUtils.fromPbPushFailedBatch(pbPushFailedBatch)
+    val failedBatch = new LocationPushFailedBatches()
+    failedBatch.addFailedBatch(1, 1, 2)
+    val pbPushFailedBatch = 
PbSerDeUtils.toPbLocationPushFailedBatches(failedBatch)
+    val restoredFailedBatch = 
PbSerDeUtils.fromPbLocationPushFailedBatches(pbPushFailedBatch)
 
     assert(restoredFailedBatch.equals(failedBatch))
   }
 
-  test("fromAndToPushFailedBatchSet") {
-    val failedBatchSet = Sets.newHashSet(new PushFailedBatch(1, 1, 2), new 
PushFailedBatch(2, 2, 3))
-    val pbPushFailedBatchSet = 
PbSerDeUtils.toPbPushFailedBatchSet(failedBatchSet)
-    val restoredFailedBatchSet = 
PbSerDeUtils.fromPbPushFailedBatchSet(pbPushFailedBatchSet)
-
-    assert(restoredFailedBatchSet.equals(failedBatchSet))
-  }
-
 }
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/LocationPushFailedBatchesSuite.scala
similarity index 93%
rename from 
tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala
rename to 
tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/LocationPushFailedBatchesSuite.scala
index 716ad4e36..0cb31db48 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/LocationPushFailedBatchesSuite.scala
@@ -27,10 +27,10 @@ import org.scalatest.funsuite.AnyFunSuite
 import org.apache.celeborn.client.ShuffleClient
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.protocol.ShuffleMode
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
 import org.apache.celeborn.service.deploy.worker.PushDataHandler
 
-class PushFailedBatchSuite extends AnyFunSuite
+class LocationPushFailedBatchesSuite extends AnyFunSuite
   with SparkTestBase
   with BeforeAndAfterEach {
 
@@ -79,9 +79,11 @@ class PushFailedBatchSuite extends AnyFunSuite
     // and PartitionLocation uniqueId will be 0-0
     val pushFailedBatch = 
manager.commitManager.getCommitHandler(0).getShuffleFailedBatches()
     assert(!pushFailedBatch.isEmpty)
+    val failedBatchObj = new LocationPushFailedBatches()
+    failedBatchObj.addFailedBatch(0, 0, 1)
     Assert.assertEquals(
       pushFailedBatch.get(0).get("0-0"),
-      Sets.newHashSet(new PushFailedBatch(0, 0, 1)))
+      failedBatchObj)
 
     sc.stop()
   }

Reply via email to