AngersZhuuuu commented on code in PR #993:
URL: 
https://github.com/apache/incubator-celeborn/pull/993#discussion_r1031149784


##########
client/src/main/java/org/apache/celeborn/client/write/PushState.java:
##########
@@ -30,14 +31,17 @@
 
 import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.util.Utils;
 
 public class PushState {
   private static final Logger logger = 
LoggerFactory.getLogger(PushState.class);
 
   private int pushBufferMaxSize;
 
   public final AtomicInteger batchId = new AtomicInteger();
-  public final ConcurrentHashMap<Integer, PartitionLocation> inFlightBatches =
+  private final ConcurrentHashMap<Integer, PartitionLocation> inFlightBatches =
+      new ConcurrentHashMap<>();
+  private final ConcurrentHashMap<String, Set<Integer>> batchIdPerAddressPair =
       new ConcurrentHashMap<>();

Review Comment:
   Why not just 
   ```
   // master location's hostToPushport -> batchs
   private final ConcurrentHashMap<String, Set<Integer>> batchIdsPerAddressPair 
=
         new ConcurrentHashMap<>();
   ```
   `



##########
client/src/main/java/org/apache/celeborn/client/write/PushState.java:
##########
@@ -88,4 +100,21 @@ public boolean addBatchData(String addressPair, 
PartitionLocation loc, int batch
   public DataBatches takeDataBatches(String addressPair) {
     return batchesMap.remove(addressPair);
   }
+
+  public void addFlightBatches(int batchId, PartitionLocation loc) {
+    String addressPair = Utils.genAddressPair(loc);
+    Set<Integer> batchIdSetPerPair =
+        batchIdPerAddressPair.computeIfAbsent(addressPair, id -> new 
HashSet<>());
+    batchIdSetPerPair.add(batchId);
+    inFlightBatches.put(batchId, loc);
+  }
+
+  public void removeFlightBatches(int batchId) {
+    PartitionLocation loc = inFlightBatches.remove(batchId);
+    String addressPair = Utils.genAddressPair(loc);

Review Comment:
   Celeborn just push data to master location then master location's 
corresponding worker push data to slave location's worker, so here seems we 
don't need to check if it's peer is null?



##########
client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java:
##########
@@ -379,13 +371,52 @@ private void limitMaxInFlight(String mapKey, PushState 
pushState, int limit) thr
 
     if (times <= 0) {
       logger.error(
-          "After waiting for {} ms, there are still {} batches in flight for 
map {}, "
+          "After waiting for {} ms, there are still {} batches in flight for 
map {} and addressPair {}, "
               + "which exceeds the limit {}.",
           timeoutMs,
-          inFlightBatches.size(),
+          batchIdSet.size(),
           mapKey,
+          addressPair,
           limit);
-      logger.error("Map: {} in flight batches: {}", mapKey, inFlightBatches);
+      logger.error(
+          "Map: {} with addressPair {} in flight batches: {}", mapKey, 
addressPair, batchIdSet);
+      throw new IOException("wait timeout for task " + mapKey, 
pushState.exception.get());
+    }
+    if (pushState.exception.get() != null) {
+      throw pushState.exception.get();
+    }
+  }
+
+  private void limitZeroInFlight(String mapKey, PushState pushState) throws 
IOException {
+    if (pushState.exception.get() != null) {
+      throw pushState.exception.get();
+    }
+    long timeoutMs = conf.pushLimitInFlightTimeoutMs();
+    long delta = conf.pushLimitInFlightSleepDeltaMs();
+    long times = timeoutMs / delta;
+    ConcurrentHashMap<Integer, PartitionLocation> inFlightBatches = 
pushState.getInFlightBatches();
+
+    try {
+      while (times > 0) {
+        if (inFlightBatches.size() == 0) {

Review Comment:
   `pushState.batchIdsPerAddress.values.map(_.asScala.sum).asScala.sum`?



##########
client/src/main/java/org/apache/celeborn/client/write/PushState.java:
##########
@@ -30,14 +31,17 @@
 
 import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.util.Utils;
 
 public class PushState {
   private static final Logger logger = 
LoggerFactory.getLogger(PushState.class);
 
   private int pushBufferMaxSize;
 
   public final AtomicInteger batchId = new AtomicInteger();
-  public final ConcurrentHashMap<Integer, PartitionLocation> inFlightBatches =
+  private final ConcurrentHashMap<Integer, PartitionLocation> inFlightBatches =

Review Comment:
   I think we can remove this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to