This is an automated email from the ASF dual-hosted git repository.
feiwang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 4205f83da [CELEBORN-1995] Optimize memory usage for push failed batches
4205f83da is described below
commit 4205f83da33036c587c6ecfa407462d2f492e136
Author: mingji <[email protected]>
AuthorDate: Sun May 18 07:19:26 2025 -0700
[CELEBORN-1995] Optimize memory usage for push failed batches
### What changes were proposed in this pull request?
Aggregate push failed batch for the same map ID and attempt ID.
### Why are the changes needed?
To reduce memory usage.
### Does this PR introduce _any_ user-facing change?
NO.
### How was this patch tested?
GA and cluster run.
Closes #3253 from FMX/b1995.
Authored-by: mingji <[email protected]>
Signed-off-by: Wang, Fei <[email protected]>
---
.../apache/celeborn/client/DummyShuffleClient.java | 5 +-
.../org/apache/celeborn/client/ShuffleClient.java | 5 +-
.../apache/celeborn/client/ShuffleClientImpl.java | 23 ++---
.../celeborn/client/read/CelebornInputStream.java | 27 ++---
.../org/apache/celeborn/client/CommitManager.scala | 4 +-
.../apache/celeborn/client/LifecycleManager.scala | 16 ++-
.../celeborn/client/commit/CommitHandler.scala | 4 +-
.../client/commit/MapPartitionCommitHandler.scala | 4 +-
.../commit/ReducePartitionCommitHandler.scala | 26 ++---
.../common/write/LocationPushFailedBatches.java | 94 +++++++++++++++++
.../celeborn/common/write/PushFailedBatch.java | 84 ---------------
.../apache/celeborn/common/write/PushState.java | 12 +--
common/src/main/proto/TransportMessages.proto | 14 ++-
.../common/protocol/message/ControlMessages.scala | 15 +--
.../apache/celeborn/common/util/PbSerDeUtils.scala | 50 ++++-----
.../org/apache/celeborn/common/util/Utils.scala | 11 ++
.../write/LocationPushFailedBatchesSuiteJ.java | 114 +++++++++++++++++++++
.../common/write/PushFailedBatchSuiteJ.java | 79 --------------
.../celeborn/common/util/PbSerDeUtilsTest.scala | 19 ++--
....scala => LocationPushFailedBatchesSuite.scala} | 8 +-
20 files changed, 328 insertions(+), 286 deletions(-)
diff --git
a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
index dd1a032c8..648309489 100644
--- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -28,7 +28,6 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
@@ -48,7 +47,7 @@ import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.util.ExceptionMaker;
import org.apache.celeborn.common.util.JavaUtils;
-import org.apache.celeborn.common.write.PushFailedBatch;
+import org.apache.celeborn.common.write.LocationPushFailedBatches;
import org.apache.celeborn.common.write.PushState;
public class DummyShuffleClient extends ShuffleClient {
@@ -144,7 +143,7 @@ public class DummyShuffleClient extends ShuffleClient {
ExceptionMaker exceptionMaker,
ArrayList<PartitionLocation> locations,
ArrayList<PbStreamHandler> streamHandlers,
- Map<String, Set<PushFailedBatch>> failedBatchSetMap,
+ Map<String, LocationPushFailedBatches> failedBatchSetMap,
Map<String, Pair<Integer, Integer>> chunksRange,
int[] mapAttempts,
MetricsCallback metricsCallback)
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index bf0192e4a..6363e9004 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -21,7 +21,6 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;
-import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.BiFunction;
@@ -46,7 +45,7 @@ import
org.apache.celeborn.common.protocol.message.ControlMessages;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.util.CelebornHadoopUtils;
import org.apache.celeborn.common.util.ExceptionMaker;
-import org.apache.celeborn.common.write.PushFailedBatch;
+import org.apache.celeborn.common.write.LocationPushFailedBatches;
import org.apache.celeborn.common.write.PushState;
/**
@@ -269,7 +268,7 @@ public abstract class ShuffleClient {
ExceptionMaker exceptionMaker,
ArrayList<PartitionLocation> locations,
ArrayList<PbStreamHandler> streamHandlers,
- Map<String, Set<PushFailedBatch>> failedBatchSetMap,
+ Map<String, LocationPushFailedBatches> failedBatchSetMap,
Map<String, Pair<Integer, Integer>> chunksRange,
int[] mapAttempts,
MetricsCallback metricsCallback)
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 81329e8d1..6e80f90c2 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -68,7 +68,7 @@ import org.apache.celeborn.common.rpc.RpcEnv;
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.*;
import org.apache.celeborn.common.write.DataBatches;
-import org.apache.celeborn.common.write.PushFailedBatch;
+import org.apache.celeborn.common.write.LocationPushFailedBatches;
import org.apache.celeborn.common.write.PushState;
public class ShuffleClientImpl extends ShuffleClient {
@@ -150,7 +150,7 @@ public class ShuffleClientImpl extends ShuffleClient {
public static class ReduceFileGroups {
public Map<Integer, Set<PartitionLocation>> partitionGroups;
- public Map<String, Set<PushFailedBatch>> pushFailedBatches;
+ public Map<String, LocationPushFailedBatches> pushFailedBatches;
public int[] mapAttempts;
public Set<Integer> partitionIds;
@@ -158,7 +158,7 @@ public class ShuffleClientImpl extends ShuffleClient {
Map<Integer, Set<PartitionLocation>> partitionGroups,
int[] mapAttempts,
Set<Integer> partitionIds,
- Map<String, Set<PushFailedBatch>> pushFailedBatches) {
+ Map<String, LocationPushFailedBatches> pushFailedBatches) {
this.partitionGroups = partitionGroups;
this.mapAttempts = mapAttempts;
this.partitionIds = partitionIds;
@@ -1122,8 +1122,8 @@ public class ShuffleClientImpl extends ShuffleClient {
partitionId,
nextBatchId);
if (dataPushFailureTrackingEnabled && pushReplicateEnabled) {
- pushState.addFailedBatch(
- latest.getUniqueId(), new PushFailedBatch(mapId,
attemptId, nextBatchId));
+ pushState.recordFailedBatch(
+ latest.getUniqueId(), mapId, attemptId, nextBatchId);
}
ReviveRequest reviveRequest =
new ReviveRequest(
@@ -1192,8 +1192,7 @@ public class ShuffleClientImpl extends ShuffleClient {
@Override
public void onFailure(Throwable e) {
if (dataPushFailureTrackingEnabled) {
- pushState.addFailedBatch(
- latest.getUniqueId(), new PushFailedBatch(mapId,
attemptId, nextBatchId));
+ pushState.recordFailedBatch(latest.getUniqueId(), mapId,
attemptId, nextBatchId);
}
if (pushState.exception.get() != null) {
return;
@@ -1563,9 +1562,8 @@ public class ShuffleClientImpl extends ShuffleClient {
} else {
if (dataPushFailureTrackingEnabled && pushReplicateEnabled) {
for (DataBatches.DataBatch resubmitBatch :
batchesNeedResubmit) {
- pushState.addFailedBatch(
- resubmitBatch.loc.getUniqueId(),
- new PushFailedBatch(mapId, attemptId,
resubmitBatch.batchId));
+ pushState.recordFailedBatch(
+ resubmitBatch.loc.getUniqueId(), mapId, attemptId,
resubmitBatch.batchId);
}
}
ReviveRequest[] requests =
@@ -1625,8 +1623,7 @@ public class ShuffleClientImpl extends ShuffleClient {
public void onFailure(Throwable e) {
if (dataPushFailureTrackingEnabled) {
for (int i = 0; i < numBatches; i++) {
- pushState.addFailedBatch(
- partitionUniqueIds[i], new PushFailedBatch(mapId,
attemptId, batchIds[i]));
+ pushState.recordFailedBatch(partitionUniqueIds[i], mapId,
attemptId, batchIds[i]);
}
}
if (pushState.exception.get() != null) {
@@ -1915,7 +1912,7 @@ public class ShuffleClientImpl extends ShuffleClient {
ExceptionMaker exceptionMaker,
ArrayList<PartitionLocation> locations,
ArrayList<PbStreamHandler> streamHandlers,
- Map<String, Set<PushFailedBatch>> failedBatchSetMap,
+ Map<String, LocationPushFailedBatches> failedBatchSetMap,
Map<String, Pair<Integer, Integer>> chunksRange,
int[] mapAttempts,
MetricsCallback metricsCallback)
diff --git
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
index ffcc50dde..5bed692b4 100644
---
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
+++
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
@@ -48,7 +48,7 @@ import org.apache.celeborn.common.protocol.*;
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.ExceptionMaker;
import org.apache.celeborn.common.util.Utils;
-import org.apache.celeborn.common.write.PushFailedBatch;
+import org.apache.celeborn.common.write.LocationPushFailedBatches;
public abstract class CelebornInputStream extends InputStream {
private static final Logger logger =
LoggerFactory.getLogger(CelebornInputStream.class);
@@ -60,7 +60,7 @@ public abstract class CelebornInputStream extends InputStream
{
ArrayList<PartitionLocation> locations,
ArrayList<PbStreamHandler> streamHandlers,
int[] attempts,
- Map<String, Set<PushFailedBatch>> failedBatchSetMap,
+ Map<String, LocationPushFailedBatches> failedBatchSetMap,
Map<String, Pair<Integer, Integer>> chunksRange,
int attemptNumber,
long taskId,
@@ -176,7 +176,7 @@ public abstract class CelebornInputStream extends
InputStream {
private Map<Integer, Set<Integer>> batchesRead = new HashMap<>();
- private final Map<String, Set<PushFailedBatch>> failedBatches;
+ private final Map<String, LocationPushFailedBatches> failedBatches;
private byte[] compressedBuf;
private byte[] rawDataBuf;
@@ -223,7 +223,7 @@ public abstract class CelebornInputStream extends
InputStream {
ArrayList<PartitionLocation> locations,
ArrayList<PbStreamHandler> streamHandlers,
int[] attempts,
- Map<String, Set<PushFailedBatch>> failedBatchSet,
+ Map<String, LocationPushFailedBatches> failedBatchSet,
int attemptNumber,
long taskId,
Map<String, Pair<Integer, Integer>> partitionLocationToChunkRange,
@@ -266,7 +266,7 @@ public abstract class CelebornInputStream extends
InputStream {
ArrayList<PartitionLocation> locations,
ArrayList<PbStreamHandler> streamHandlers,
int[] attempts,
- Map<String, Set<PushFailedBatch>> failedBatchSet,
+ Map<String, LocationPushFailedBatches> failedBatchSet,
int attemptNumber,
long taskId,
int startMapIndex,
@@ -758,7 +758,7 @@ public abstract class CelebornInputStream extends
InputStream {
return false;
}
- PushFailedBatch failedBatch = new PushFailedBatch(-1, -1, -1);
+ LocationPushFailedBatches failedBatch = new
LocationPushFailedBatches();
boolean hasData = false;
while (currentChunk.isReadable() || moveToNextChunk()) {
currentChunk.readBytes(sizeBuf);
@@ -784,14 +784,15 @@ public abstract class CelebornInputStream extends
InputStream {
// de-duplicate
if (attemptId == attempts[mapId]) {
if (readSkewPartitionWithoutMapRange) {
- Set<PushFailedBatch> failedBatchSet =
+ LocationPushFailedBatches locationPushFailedBatches =
this.failedBatches.get(currentReader.getLocation().getUniqueId());
- if (null != failedBatchSet) {
- failedBatch.setMapId(mapId);
- failedBatch.setAttemptId(attemptId);
- failedBatch.setBatchId(batchId);
- if (failedBatchSet.contains(failedBatch)) {
- logger.warn("Skip duplicated batch: {}.", failedBatch);
+ if (null != locationPushFailedBatches) {
+ if (locationPushFailedBatches.contains(mapId, attemptId,
batchId)) {
+ logger.warn(
+ "Skip duplicated batch: mapId={}, attemptId={},
batchId={}",
+ mapId,
+ attemptId,
+ batchId);
continue;
}
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index bffd05430..702441dce 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -42,7 +42,7 @@ import org.apache.celeborn.common.rpc.RpcCallContext
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.JavaUtils
import org.apache.celeborn.common.util.ThreadUtils
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
case class ShuffleCommittedInfo(
// partition id -> unique partition ids
@@ -219,7 +219,7 @@ class CommitManager(appUniqueId: String, val conf:
CelebornConf, lifecycleManage
attemptId: Int,
numMappers: Int,
partitionId: Int = -1,
- pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] =
Collections.emptyMap())
+ pushFailedBatches: util.Map[String, LocationPushFailedBatches] =
Collections.emptyMap())
: (Boolean, Boolean) = {
getCommitHandler(shuffleId).finishMapperAttempt(
shuffleId,
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 20e1099d2..bcb018570 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -59,7 +59,7 @@ import org.apache.celeborn.common.util.{JavaUtils,
PbSerDeUtils, ThreadUtils, Ut
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.ThreadUtils.awaitResult
import org.apache.celeborn.common.util.Utils.UNKNOWN_APP_SHUFFLE_ID
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
object LifecycleManager {
// shuffle id -> partition id -> partition locations
@@ -67,7 +67,7 @@ object LifecycleManager {
ConcurrentHashMap[Int, ConcurrentHashMap[Integer,
util.Set[PartitionLocation]]]
// shuffle id -> partition uniqueId -> PushFailedBatch set
type ShufflePushFailedBatches =
- ConcurrentHashMap[Int, util.HashMap[String, util.Set[PushFailedBatch]]]
+ ConcurrentHashMap[Int, util.HashMap[String, LocationPushFailedBatches]]
type ShuffleAllocatedWorkers =
ConcurrentHashMap[Int, ConcurrentHashMap[String,
ShufflePartitionLocationInfo]]
type ShuffleFailedWorkers = ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]
@@ -825,7 +825,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
mapId: Int,
attemptId: Int,
numMappers: Int,
- pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]]): Unit = {
+ pushFailedBatches: util.Map[String, LocationPushFailedBatches]): Unit = {
val (mapperAttemptFinishedSuccess, allMapperFinished) =
commitManager.finishMapperAttempt(
@@ -1118,12 +1118,8 @@ class LifecycleManager(val appUniqueId: String, val
conf: CelebornConf) extends
}
}
- val (mapperAttemptFinishedSuccess, _) = commitManager.finishMapperAttempt(
- shuffleId,
- mapId,
- attemptId,
- numMappers,
- partitionId)
+ val (mapperAttemptFinishedSuccess, _) =
+ commitManager.finishMapperAttempt(shuffleId, mapId, attemptId,
numMappers, partitionId)
reply(mapperAttemptFinishedSuccess)
}
@@ -1817,7 +1813,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
}
}
- // Once a partition is released, it will be never needed anymore
+ // Once a partition is released, it will never be needed anymore
def releasePartition(shuffleId: Int, partitionId: Int): Unit = {
commitManager.releasePartitionResource(shuffleId, partitionId)
val partitionLocation = latestPartitionLocation.get(shuffleId)
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index c8b563372..42ea0379f 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -43,7 +43,7 @@ import org.apache.celeborn.common.util.{CollectionUtils,
JavaUtils, Utils}
// Can Remove this if celeborn don't support scala211 in future
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.ThreadUtils.awaitResult
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
case class CommitFilesParam(
worker: WorkerInfo,
@@ -206,7 +206,7 @@ abstract class CommitHandler(
attemptId: Int,
numMappers: Int,
partitionId: Int,
- pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
+ pushFailedBatches: util.Map[String, LocationPushFailedBatches],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean)
def registerShuffle(
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index 4f31018e5..715950531 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -40,7 +40,7 @@ import org.apache.celeborn.common.rpc.RpcCallContext
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.JavaUtils
import org.apache.celeborn.common.util.Utils
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
/**
* This commit handler is for MapPartition ShuffleType, which means that a Map
Partition contains all data produced
@@ -186,7 +186,7 @@ class MapPartitionCommitHandler(
attemptId: Int,
numMappers: Int,
partitionId: Int,
- pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
+ pushFailedBatches: util.Map[String, LocationPushFailedBatches],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) =
{
val inProcessingPartitionIds =
inProcessMapPartitionEndIds.computeIfAbsent(
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index a65f7e1eb..45c371ee5 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -41,7 +41,7 @@ import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc.RpcCallContext
import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext,
RemoteNettyRpcCallContext}
import org.apache.celeborn.common.util.JavaUtils
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
/**
* This commit handler is for ReducePartition ShuffleType, which means that a
Reduce Partition contains all data
@@ -94,18 +94,18 @@ class ReducePartitionCommitHandler(
.build().asInstanceOf[Cache[Int, ByteBuffer]]
private val newShuffleId2PushFailedBatchMapFunc
- : function.Function[Int, util.HashMap[String,
util.Set[PushFailedBatch]]] =
- new util.function.Function[Int, util.HashMap[String,
util.Set[PushFailedBatch]]]() {
- override def apply(s: Int): util.HashMap[String,
util.Set[PushFailedBatch]] = {
- new util.HashMap[String, util.Set[PushFailedBatch]]()
+ : function.Function[Int, util.HashMap[String,
LocationPushFailedBatches]] =
+ new util.function.Function[Int, util.HashMap[String,
LocationPushFailedBatches]]() {
+ override def apply(s: Int): util.HashMap[String,
LocationPushFailedBatches] = {
+ new util.HashMap[String, LocationPushFailedBatches]()
}
}
private val uniqueId2PushFailedBatchMapFunc
- : function.Function[String, util.Set[PushFailedBatch]] =
- new util.function.Function[String, util.Set[PushFailedBatch]]() {
- override def apply(s: String): util.Set[PushFailedBatch] = {
- Sets.newHashSet[PushFailedBatch]()
+ : function.Function[String, LocationPushFailedBatches] =
+ new util.function.Function[String, LocationPushFailedBatches]() {
+ override def apply(s: String): LocationPushFailedBatches = {
+ new LocationPushFailedBatches()
}
}
@@ -267,7 +267,7 @@ class ReducePartitionCommitHandler(
attemptId: Int,
numMappers: Int,
partitionId: Int,
- pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
+ pushFailedBatches: util.Map[String, LocationPushFailedBatches],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) =
{
shuffleMapperAttempts.synchronized {
if (getMapperAttempts(shuffleId) == null) {
@@ -282,11 +282,11 @@ class ReducePartitionCommitHandler(
val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent(
shuffleId,
newShuffleId2PushFailedBatchMapFunc)
- for ((partitionUniqId, pushFailedBatchSet) <-
pushFailedBatches.asScala) {
+ for ((partitionUniqId, locationPushFailedBatches) <-
pushFailedBatches.asScala) {
val partitionPushFailedBatches =
pushFailedBatchesMap.computeIfAbsent(
partitionUniqId,
uniqueId2PushFailedBatchMapFunc)
- partitionPushFailedBatches.addAll(pushFailedBatchSet)
+ partitionPushFailedBatches.merge(locationPushFailedBatches)
}
}
// Mapper with this attemptId finished, also check all other mapper
finished or not.
@@ -361,7 +361,7 @@ class ReducePartitionCommitHandler(
pushFailedBatches =
shufflePushFailedBatches.getOrDefault(
shuffleId,
- new util.HashMap[String, util.Set[PushFailedBatch]]()),
+ new util.HashMap[String, LocationPushFailedBatches]()),
serdeVersion = serdeVersion)
val serializedMsg =
diff --git
a/common/src/main/java/org/apache/celeborn/common/write/LocationPushFailedBatches.java
b/common/src/main/java/org/apache/celeborn/common/write/LocationPushFailedBatches.java
new file mode 100644
index 000000000..2273a3698
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/write/LocationPushFailedBatches.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.write;
+
+import java.io.Serializable;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.commons.lang3.StringUtils;
+
+import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.common.util.Utils;
+
+public class LocationPushFailedBatches implements Serializable {
+
+ // map-attempt-id -> failed batch ids
+ private final Map<String, Set<Integer>> failedBatches =
JavaUtils.newConcurrentHashMap();
+
+ public Map<String, Set<Integer>> getFailedBatches() {
+ return failedBatches;
+ }
+
+ public boolean contains(int mapId, int attemptId, int batchId) {
+ Set<Integer> batches = failedBatches.get(Utils.makeAttemptKey(mapId,
attemptId));
+ return batches != null && batches.contains(batchId);
+ }
+
+ public void merge(LocationPushFailedBatches batches) {
+ Map<String, Set<Integer>> otherFailedBatchesMap =
batches.getFailedBatches();
+ otherFailedBatchesMap.forEach(
+ (k, v) -> {
+ Set<Integer> failedBatches =
+ this.failedBatches.computeIfAbsent(k, (s) ->
ConcurrentHashMap.newKeySet());
+ failedBatches.addAll(v);
+ });
+ }
+
+ public void addFailedBatch(int mapId, int attemptId, int batchId) {
+ String attemptKey = Utils.makeAttemptKey(mapId, attemptId);
+ Set<Integer> failedBatches =
+ this.failedBatches.computeIfAbsent(attemptKey, (s) ->
ConcurrentHashMap.newKeySet());
+ failedBatches.add(batchId);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ LocationPushFailedBatches that = (LocationPushFailedBatches) o;
+ if (that.failedBatches.size() != failedBatches.size()) {
+ return false;
+ }
+
+ return failedBatches.entrySet().stream()
+ .allMatch(
+ item -> {
+ Set<Integer> failedBatchesSet =
that.failedBatches.get(item.getKey());
+ if (failedBatchesSet == null) return false;
+ return failedBatchesSet.equals(item.getValue());
+ });
+ }
+
+ @Override
+ public int hashCode() {
+ return failedBatches.entrySet().hashCode();
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder stringBuilder = new StringBuilder();
+ failedBatches.forEach(
+ (attemptKey, value) -> {
+ stringBuilder.append("failed attemptKey:");
+ stringBuilder.append(attemptKey);
+ stringBuilder.append(" fail batch Ids:");
+ stringBuilder.append(StringUtils.join(value, ","));
+ });
+ return "LocationPushFailedBatches{" + "failedBatches=" + stringBuilder +
'}';
+ }
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java
b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java
deleted file mode 100644
index ccee8bf11..000000000
--- a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.celeborn.common.write;
-
-import java.io.Serializable;
-
-import com.google.common.base.Objects;
-import org.apache.commons.lang3.builder.ToStringBuilder;
-import org.apache.commons.lang3.builder.ToStringStyle;
-
-public class PushFailedBatch implements Serializable {
-
- private int mapId;
- private int attemptId;
- private int batchId;
-
- public PushFailedBatch(int mapId, int attemptId, int batchId) {
- this.mapId = mapId;
- this.attemptId = attemptId;
- this.batchId = batchId;
- }
-
- public int getMapId() {
- return mapId;
- }
-
- public void setMapId(int mapId) {
- this.mapId = mapId;
- }
-
- public int getAttemptId() {
- return attemptId;
- }
-
- public void setAttemptId(int attemptId) {
- this.attemptId = attemptId;
- }
-
- public int getBatchId() {
- return batchId;
- }
-
- public void setBatchId(int batchId) {
- this.batchId = batchId;
- }
-
- @Override
- public boolean equals(Object other) {
- if (!(other instanceof PushFailedBatch)) {
- return false;
- }
- PushFailedBatch o = (PushFailedBatch) other;
- return mapId == o.mapId && attemptId == o.attemptId && batchId ==
o.batchId;
- }
-
- @Override
- public int hashCode() {
- return Objects.hashCode(mapId, attemptId, batchId);
- }
-
- @Override
- public String toString() {
- return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
- .append("mapId", mapId)
- .append("attemptId", attemptId)
- .append("batchId", batchId)
- .toString();
- }
-}
diff --git
a/common/src/main/java/org/apache/celeborn/common/write/PushState.java
b/common/src/main/java/org/apache/celeborn/common/write/PushState.java
index 3977e5d54..afa22bb8a 100644
--- a/common/src/main/java/org/apache/celeborn/common/write/PushState.java
+++ b/common/src/main/java/org/apache/celeborn/common/write/PushState.java
@@ -19,11 +19,9 @@ package org.apache.celeborn.common.write;
import java.io.IOException;
import java.util.Map;
-import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
-import com.google.common.collect.Sets;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.celeborn.common.CelebornConf;
@@ -36,7 +34,7 @@ public class PushState {
public AtomicReference<IOException> exception = new AtomicReference<>();
private final InFlightRequestTracker inFlightRequestTracker;
- private final Map<String, Set<PushFailedBatch>> failedBatchMap;
+ private final Map<String, LocationPushFailedBatches> failedBatchMap;
public PushState(CelebornConf conf) {
pushBufferMaxSize = conf.clientPushBufferMaxSize();
@@ -95,13 +93,13 @@ public class PushState {
return inFlightRequestTracker.remainingAllowPushes(hostAndPushPort);
}
- public void addFailedBatch(String partitionId, PushFailedBatch failedBatch) {
+ public void recordFailedBatch(String partitionId, int mapId, int attemptId,
int batchId) {
this.failedBatchMap
- .computeIfAbsent(partitionId, (s) -> Sets.newConcurrentHashSet())
- .add(failedBatch);
+ .computeIfAbsent(partitionId, (s) -> new LocationPushFailedBatches())
+ .addFailedBatch(mapId, attemptId, batchId);
}
- public Map<String, Set<PushFailedBatch>> getFailedBatches() {
+ public Map<String, LocationPushFailedBatches> getFailedBatches() {
return this.failedBatchMap;
}
}
diff --git a/common/src/main/proto/TransportMessages.proto
b/common/src/main/proto/TransportMessages.proto
index 7b0d0bec2..c120b9952 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -356,17 +356,15 @@ message PbMapperEnd {
int32 attemptId = 3;
int32 numMappers = 4;
int32 partitionId = 5;
- map<string, PbPushFailedBatchSet> pushFailureBatches= 6;
+ map<string, PbLocationPushFailedBatches> pushFailureBatches= 6;
}
-message PbPushFailedBatchSet {
- repeated PbPushFailedBatch failureBatches = 1;
+message PbLocationPushFailedBatches {
+ map<string, PbFailedBatches> failedBatches = 1;
}
-message PbPushFailedBatch {
- int32 mapId = 1;
- int32 attemptId = 2;
- int32 batchId = 3;
+message PbFailedBatches{
+ repeated int32 failedBatches = 1;
}
message PbMapperEndResponse {
@@ -389,7 +387,7 @@ message PbGetReducerFileGroupResponse {
// only map partition mode has succeed partitionIds
repeated int32 partitionIds = 4;
- map<string, PbPushFailedBatchSet> pushFailedBatches = 5;
+ map<string, PbLocationPushFailedBatches> pushFailedBatches = 5;
bytes broadcast = 6;
}
diff --git
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 57831f213..ccd5d79a2 100644
---
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -33,7 +33,7 @@ import org.apache.celeborn.common.protocol._
import org.apache.celeborn.common.protocol.MessageType._
import org.apache.celeborn.common.quota.ResourceConsumption
import org.apache.celeborn.common.util.{PbSerDeUtils, Utils}
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
sealed trait Message extends Serializable
@@ -274,7 +274,7 @@ object ControlMessages extends Logging {
attemptId: Int,
numMappers: Int,
partitionId: Int,
- failedBatchSet: util.Map[String, util.Set[PushFailedBatch]])
+ failedBatchSet: util.Map[String, LocationPushFailedBatches])
extends MasterMessage
case class MapperEndResponse(status: StatusCode) extends MasterMessage
@@ -292,7 +292,8 @@ object ControlMessages extends Logging {
fileGroup: util.Map[Integer, util.Set[PartitionLocation]] =
Collections.emptyMap(),
attempts: Array[Int] = Array.emptyIntArray,
partitionIds: util.Set[Integer] = Collections.emptySet[Integer](),
- pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] =
Collections.emptyMap(),
+ pushFailedBatches: util.Map[String, LocationPushFailedBatches] =
+ Collections.emptyMap(),
broadcast: Array[Byte] = Array.emptyByteArray,
serdeVersion: SerdeVersion = SerdeVersion.V1)
extends MasterMessage
@@ -732,7 +733,7 @@ object ControlMessages extends Logging {
case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId,
pushFailedBatch) =>
val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) =>
- val resultValue = PbSerDeUtils.toPbPushFailedBatchSet(v)
+ val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v)
(k, resultValue)
}.toMap.asJava
val payload = PbMapperEnd.newBuilder()
@@ -781,7 +782,7 @@ object ControlMessages extends Logging {
builder.putAllPushFailedBatches(
failedBatches.asScala.map {
case (uniqueId, pushFailedBatchSet) =>
- (uniqueId, PbSerDeUtils.toPbPushFailedBatchSet(pushFailedBatchSet))
+ (uniqueId,
PbSerDeUtils.toPbLocationPushFailedBatches(pushFailedBatchSet))
}.asJava)
builder.setBroadcast(ByteString.copyFrom(broadcast))
val payload = builder.build().toByteArray
@@ -1171,7 +1172,7 @@ object ControlMessages extends Logging {
pbMapperEnd.getPartitionId,
pbMapperEnd.getPushFailureBatchesMap.asScala.map {
case (partitionId, pushFailedBatchSet) =>
- (partitionId,
PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet))
+ (partitionId,
PbSerDeUtils.fromPbLocationPushFailedBatches(pushFailedBatchSet))
}.toMap.asJava)
case MAPPER_END_RESPONSE_VALUE =>
@@ -1211,7 +1212,7 @@ object ControlMessages extends Logging {
val partitionIds = new
util.HashSet(pbGetReducerFileGroupResponse.getPartitionIdsList)
val pushFailedBatches =
pbGetReducerFileGroupResponse.getPushFailedBatchesMap.asScala.map {
case (uniqueId, pushFailedBatchSet) =>
- (uniqueId,
PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet))
+ (uniqueId,
PbSerDeUtils.fromPbLocationPushFailedBatches(pushFailedBatchSet))
}.toMap.asJava
val broadcast = pbGetReducerFileGroupResponse.getBroadcast.toByteArray
GetReducerFileGroupResponse(
diff --git
a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
index 3e0d1afd9..6ba547684 100644
--- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
@@ -32,7 +32,7 @@ import
org.apache.celeborn.common.protocol.PartitionLocation.Mode
import
org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource
import org.apache.celeborn.common.quota.ResourceConsumption
import org.apache.celeborn.common.util.{CollectionUtils =>
localCollectionUtils}
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
object PbSerDeUtils {
@@ -477,7 +477,7 @@ object PbSerDeUtils {
}
def fromPbApplicationMeta(pbApplicationMeta: PbApplicationMeta):
ApplicationMeta = {
- new ApplicationMeta(pbApplicationMeta.getAppId,
pbApplicationMeta.getSecret)
+ ApplicationMeta(pbApplicationMeta.getAppId, pbApplicationMeta.getSecret)
}
def toPbWorkerStatus(workerStatus: WorkerStatus): PbWorkerStatus = {
@@ -501,7 +501,7 @@ object PbSerDeUtils {
def fromPbWorkerEventInfo(pbWorkerEventInfo: PbWorkerEventInfo):
WorkerEventInfo = {
new WorkerEventInfo(
pbWorkerEventInfo.getWorkerEventType.getNumber,
- pbWorkerEventInfo.getEventStartTime())
+ pbWorkerEventInfo.getEventStartTime)
}
private def toPackedPartitionLocation(
@@ -695,34 +695,36 @@ object PbSerDeUtils {
}.asJava
}
- def toPbPushFailedBatch(pushFailedBatch: PushFailedBatch): PbPushFailedBatch
= {
- PbPushFailedBatch.newBuilder()
- .setMapId(pushFailedBatch.getMapId)
- .setAttemptId(pushFailedBatch.getAttemptId)
- .setBatchId(pushFailedBatch.getBatchId)
+ def toPbFailedBatches(failedBatches: util.Set[Int]): PbFailedBatches = {
+ PbFailedBatches.newBuilder()
+ .addAllFailedBatches(failedBatches.asScala.map(new Integer(_)).asJava)
.build()
}
- def fromPbPushFailedBatch(pbPushFailedBatch: PbPushFailedBatch):
PushFailedBatch = {
- new PushFailedBatch(
- pbPushFailedBatch.getMapId,
- pbPushFailedBatch.getAttemptId,
- pbPushFailedBatch.getBatchId)
+ def fromPbFailedBatches(failedBatches: PbFailedBatches): util.Set[Int] = {
+ failedBatches.getFailedBatchesList.asScala.map(_.intValue()).toSet.asJava
}
- def toPbPushFailedBatchSet(failedBatchSet: util.Set[PushFailedBatch]):
PbPushFailedBatchSet = {
- val builder = PbPushFailedBatchSet.newBuilder()
- failedBatchSet.asScala.foreach(batch =>
builder.addFailureBatches(toPbPushFailedBatch(batch)))
-
+ def toPbLocationPushFailedBatches(locationPushFailedBatches:
LocationPushFailedBatches)
+ : PbLocationPushFailedBatches = {
+ val builder = PbLocationPushFailedBatches.newBuilder()
+ builder.putAllFailedBatches(
+ locationPushFailedBatches
+ .getFailedBatches
+ .asScala
+ .map(item =>
+ (item._1,
toPbFailedBatches(item._2.asScala.map(_.intValue()).toSet.asJava))).asJava)
builder.build()
}
- def fromPbPushFailedBatchSet(pbFailedBatchSet: PbPushFailedBatchSet)
- : util.Set[PushFailedBatch] = {
- val failedBatchSet = new util.HashSet[PushFailedBatch]()
- pbFailedBatchSet.getFailureBatchesList.asScala.foreach(batch =>
- failedBatchSet.add(fromPbPushFailedBatch(batch)))
-
- failedBatchSet
+ def fromPbLocationPushFailedBatches(pbLocationPushFailedBatches:
PbLocationPushFailedBatches)
+ : LocationPushFailedBatches = {
+ val batches = new LocationPushFailedBatches()
+ pbLocationPushFailedBatches.getFailedBatchesMap.asScala.foreach { case
(key, value) =>
+ batches.getFailedBatches.put(
+ key,
+ fromPbFailedBatches(value).asScala.map(new Integer(_)).asJava)
+ }
+ batches
}
}
diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
index e825abce3..73f070ab9 100644
--- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
@@ -717,6 +717,17 @@ object Utils extends Logging {
s"$shuffleId-$mapId-$attemptId"
}
+ def makeAttemptKey(mapId: Int, attemptId: Int): String = {
+ s"$mapId-$attemptId"
+ }
+
+ def splitAttemptKey(attemptKey: String): (Int, Int) = {
+ val splits = attemptKey.split("-")
+ val mapId = splits(0).toInt
+ val attemptId = splits(1).toInt
+ (mapId, attemptId)
+ }
+
def shuffleKeyPrefix(shuffleKey: String): String = {
shuffleKey + "-"
}
diff --git
a/common/src/test/java/org/apache/celeborn/common/write/LocationPushFailedBatchesSuiteJ.java
b/common/src/test/java/org/apache/celeborn/common/write/LocationPushFailedBatchesSuiteJ.java
new file mode 100644
index 000000000..6fc58eb17
--- /dev/null
+++
b/common/src/test/java/org/apache/celeborn/common/write/LocationPushFailedBatchesSuiteJ.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.write;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class LocationPushFailedBatchesSuiteJ {
+
+ @Test
+ public void equalsReturnsTrueForIdenticalBatches() {
+ LocationPushFailedBatches batch1 = new LocationPushFailedBatches();
+ batch1.addFailedBatch(1, 2, 3);
+ LocationPushFailedBatches batch2 = new LocationPushFailedBatches();
+ batch2.addFailedBatch(1, 2, 3);
+ Assert.assertEquals(batch1, batch2);
+ }
+
+ @Test
+ public void equalsReturnsFalseForDifferentBatches() {
+ LocationPushFailedBatches batch1 = new LocationPushFailedBatches();
+ batch1.addFailedBatch(1, 2, 3);
+ LocationPushFailedBatches batch2 = new LocationPushFailedBatches();
+ batch2.addFailedBatch(4, 5, 6);
+ Assert.assertNotEquals(batch1, batch2);
+ }
+
+ @Test
+ public void hashCodeDiffersForDifferentBatches() {
+ LocationPushFailedBatches batch1 = new LocationPushFailedBatches();
+ batch1.addFailedBatch(1, 2, 3);
+ LocationPushFailedBatches batch2 = new LocationPushFailedBatches();
+ batch2.addFailedBatch(4, 5, 6);
+ Assert.assertNotEquals(batch1.hashCode(), batch2.hashCode());
+ }
+
+ @Test
+ public void hashCodeSameForIdenticalBatches() {
+ LocationPushFailedBatches batch1 = new LocationPushFailedBatches();
+ batch1.addFailedBatch(1, 2, 3);
+ LocationPushFailedBatches batch2 = new LocationPushFailedBatches();
+ batch2.addFailedBatch(1, 2, 3);
+ Assert.assertEquals(batch1.hashCode(), batch2.hashCode());
+ }
+
+ @Test
+ public void hashCodeIsConsistent() {
+ LocationPushFailedBatches batch = new LocationPushFailedBatches();
+ batch.addFailedBatch(1, 2, 3);
+ int hashCode1 = batch.hashCode();
+ int hashCode2 = batch.hashCode();
+ Assert.assertEquals(hashCode1, hashCode2);
+ }
+
+ @Test
+ public void toStringReturnsExpectedFormat() {
+ LocationPushFailedBatches batch = new LocationPushFailedBatches();
+ batch.addFailedBatch(1, 2, 3);
+ String expected =
+ "LocationPushFailedBatches{failedBatches=failed attemptKey:1-2 fail
batch Ids:3}";
+ Assert.assertEquals(expected, batch.toString());
+ }
+
+ @Test
+ public void hashCodeAndEqualsWorkInSet() {
+ Set<LocationPushFailedBatches> set = new HashSet<>();
+ LocationPushFailedBatches batch1 = new LocationPushFailedBatches();
+ batch1.addFailedBatch(1, 2, 3);
+ LocationPushFailedBatches batch2 = new LocationPushFailedBatches();
+ batch2.addFailedBatch(1, 2, 3);
+ set.add(batch1);
+ assertTrue(set.contains(batch2));
+ }
+
+ @Test
+ public void concurrentAddAndGetShouldNotConflict() throws
InterruptedException {
+ LocationPushFailedBatches batches = new LocationPushFailedBatches();
+ ExecutorService executor = Executors.newFixedThreadPool(4);
+
+ int totalFailedBatches = 1000;
+ for (int i = 0; i < totalFailedBatches; i++) {
+ final int tIdx = i;
+ executor.submit(() -> batches.addFailedBatch(tIdx % 10, tIdx % 5, tIdx));
+ }
+ executor.shutdown();
+ executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS);
+ assertTrue(!batches.getFailedBatches().isEmpty());
+ assertEquals(
+ totalFailedBatches,
batches.getFailedBatches().values().stream().mapToInt(Set::size).sum());
+ }
+}
diff --git
a/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java
b/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java
deleted file mode 100644
index fcfc6b799..000000000
---
a/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.celeborn.common.write;
-
-import java.util.HashSet;
-import java.util.Set;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class PushFailedBatchSuiteJ {
-
- @Test
- public void equalsReturnsTrueForIdenticalBatches() {
- PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
- PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3);
- Assert.assertEquals(batch1, batch2);
- }
-
- @Test
- public void equalsReturnsFalseForDifferentBatches() {
- PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
- PushFailedBatch batch2 = new PushFailedBatch(4, 5, 6);
- Assert.assertNotEquals(batch1, batch2);
- }
-
- @Test
- public void hashCodeDiffersForDifferentBatches() {
- PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
- PushFailedBatch batch2 = new PushFailedBatch(4, 5, 6);
- Assert.assertNotEquals(batch1.hashCode(), batch2.hashCode());
- }
-
- @Test
- public void hashCodeSameForIdenticalBatches() {
- PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
- PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3);
- Assert.assertEquals(batch1.hashCode(), batch2.hashCode());
- }
-
- @Test
- public void hashCodeIsConsistent() {
- PushFailedBatch batch = new PushFailedBatch(1, 2, 3);
- int hashCode1 = batch.hashCode();
- int hashCode2 = batch.hashCode();
- Assert.assertEquals(hashCode1, hashCode2);
- }
-
- @Test
- public void toStringReturnsExpectedFormat() {
- PushFailedBatch batch = new PushFailedBatch(1, 2, 3);
- String expected = "PushFailedBatch[mapId=1,attemptId=2,batchId=3]";
- Assert.assertEquals(expected, batch.toString());
- }
-
- @Test
- public void hashCodeAndEqualsWorkInSet() {
- Set<PushFailedBatch> set = new HashSet<>();
- PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
- PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3);
- set.add(batch1);
- Assert.assertTrue(set.contains(batch2));
- }
-}
diff --git
a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
index 3136584b1..9d4059ff9 100644
---
a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
+++
b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random
-import com.google.common.collect.{Lists, Sets}
+import com.google.common.collect.Lists
import org.apache.hadoop.shaded.org.apache.commons.lang3.RandomStringUtils
import org.apache.celeborn.CelebornFunSuite
@@ -36,7 +36,7 @@ import
org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode}
import
org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse,
WorkerResource}
import org.apache.celeborn.common.quota.ResourceConsumption
import
org.apache.celeborn.common.util.PbSerDeUtils.{fromPbPackedPartitionLocationsPair,
toPbPackedPartitionLocationsPair}
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
class PbSerDeUtilsTest extends CelebornFunSuite {
@@ -664,19 +664,12 @@ class PbSerDeUtilsTest extends CelebornFunSuite {
}
test("fromAndToPushFailedBatch") {
- val failedBatch = new PushFailedBatch(1, 1, 2)
- val pbPushFailedBatch = PbSerDeUtils.toPbPushFailedBatch(failedBatch)
- val restoredFailedBatch =
PbSerDeUtils.fromPbPushFailedBatch(pbPushFailedBatch)
+ val failedBatch = new LocationPushFailedBatches()
+ failedBatch.addFailedBatch(1, 1, 2)
+ val pbPushFailedBatch =
PbSerDeUtils.toPbLocationPushFailedBatches(failedBatch)
+ val restoredFailedBatch =
PbSerDeUtils.fromPbLocationPushFailedBatches(pbPushFailedBatch)
assert(restoredFailedBatch.equals(failedBatch))
}
- test("fromAndToPushFailedBatchSet") {
- val failedBatchSet = Sets.newHashSet(new PushFailedBatch(1, 1, 2), new
PushFailedBatch(2, 2, 3))
- val pbPushFailedBatchSet =
PbSerDeUtils.toPbPushFailedBatchSet(failedBatchSet)
- val restoredFailedBatchSet =
PbSerDeUtils.fromPbPushFailedBatchSet(pbPushFailedBatchSet)
-
- assert(restoredFailedBatchSet.equals(failedBatchSet))
- }
-
}
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/LocationPushFailedBatchesSuite.scala
similarity index 93%
rename from
tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala
rename to
tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/LocationPushFailedBatchesSuite.scala
index 716ad4e36..0cb31db48 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/LocationPushFailedBatchesSuite.scala
@@ -27,10 +27,10 @@ import org.scalatest.funsuite.AnyFunSuite
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.protocol.ShuffleMode
-import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.common.write.LocationPushFailedBatches
import org.apache.celeborn.service.deploy.worker.PushDataHandler
-class PushFailedBatchSuite extends AnyFunSuite
+class LocationPushFailedBatchesSuite extends AnyFunSuite
with SparkTestBase
with BeforeAndAfterEach {
@@ -79,9 +79,11 @@ class PushFailedBatchSuite extends AnyFunSuite
// and PartitionLocation uniqueId will be 0-0
val pushFailedBatch =
manager.commitManager.getCommitHandler(0).getShuffleFailedBatches()
assert(!pushFailedBatch.isEmpty)
+ val failedBatchObj = new LocationPushFailedBatches()
+ failedBatchObj.addFailedBatch(0, 0, 1)
Assert.assertEquals(
pushFailedBatch.get(0).get("0-0"),
- Sets.newHashSet(new PushFailedBatch(0, 0, 1)))
+ failedBatchObj)
sc.stop()
}