This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f068226 [CELEBORN-119] Add timeout for pushdata (#1097)
2f068226 is described below

commit 2f0682265e2044604ccfbe8185124081b0002568
Author: Keyong Zhou <[email protected]>
AuthorDate: Tue Dec 20 20:40:42 2022 +0800

    [CELEBORN-119] Add timeout for pushdata (#1097)
---
 .../apache/celeborn/client/ShuffleClientImpl.java  | 75 ++++++++++---------
 .../apache/celeborn/client/write/PushState.java    | 83 +++++++++++++++++-----
 .../common/network/client/TransportClient.java     |  2 +-
 .../common/protocol/message/StatusCode.java        |  6 +-
 .../org/apache/celeborn/common/CelebornConf.scala  | 15 +++-
 docs/configuration/client.md                       |  2 +-
 docs/configuration/worker.md                       |  1 +
 .../celeborn/tests/spark/PushdataTimeoutTest.scala | 80 +++++++++++++++++++++
 .../service/deploy/worker/PushDataHandler.scala    | 20 ++++++
 9 files changed, 221 insertions(+), 63 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index a7193d03..65c65cec 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -178,7 +178,7 @@ public class ShuffleClientImpl extends ShuffleClient {
     } else if (mapperEnded(shuffleId, mapId, attemptId)) {
       logger.debug(
           "Retrying push data, but the mapper(map {} attempt {}) has ended.", 
mapId, attemptId);
-      pushState.inFlightBatches.remove(batchId);
+      pushState.removeBatch(batchId);
     } else {
       PartitionLocation newLoc = 
reducePartitionMap.get(shuffleId).get(partitionId);
       logger.info("Revive success, new location for reduce {} is {}.", 
partitionId, newLoc);
@@ -191,7 +191,7 @@ public class ShuffleClientImpl extends ShuffleClient {
         PushData newPushData =
             new PushData(MASTER_MODE, shuffleKey, newLoc.getUniqueId(), 
newBuffer);
         ChannelFuture future = client.pushData(newPushData, callback);
-        pushState.addFuture(batchId, future);
+        pushState.pushStarted(batchId, future, callback);
       } catch (Exception ex) {
         logger.warn(
             "Exception raised while pushing data for shuffle {} map {} attempt 
{}" + " batch {}.",
@@ -256,7 +256,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           pushState,
           true);
     }
-    pushState.inFlightBatches.remove(oldGroupedBatchId);
+    pushState.removeBatch(oldGroupedBatchId);
   }
 
   private String genAddressPair(PartitionLocation loc) {
@@ -360,18 +360,20 @@ public class ShuffleClientImpl extends ShuffleClient {
       throw pushState.exception.get();
     }
 
-    ConcurrentHashMap<Integer, PartitionLocation> inFlightBatches = 
pushState.inFlightBatches;
     long timeoutMs = conf.pushLimitInFlightTimeoutMs();
     long delta = conf.pushLimitInFlightSleepDeltaMs();
     long times = timeoutMs / delta;
     try {
       while (times > 0) {
-        if (inFlightBatches.size() <= limit) {
+        if (pushState.inflightBatchCount() <= limit) {
           break;
         }
         if (pushState.exception.get() != null) {
           throw pushState.exception.get();
         }
+
+        pushState.failExpiredBatch();
+
         Thread.sleep(delta);
         times--;
       }
@@ -384,10 +386,9 @@ public class ShuffleClientImpl extends ShuffleClient {
           "After waiting for {} ms, there are still {} batches in flight for 
map {}, "
               + "which exceeds the limit {}.",
           timeoutMs,
-          inFlightBatches.size(),
+          pushState.inflightBatchCount(),
           mapKey,
           limit);
-      logger.error("Map: {} in flight batches: {}", mapKey, inFlightBatches);
       throw new IOException("wait timeout for task " + mapKey, 
pushState.exception.get());
     }
     if (pushState.exception.get() != null) {
@@ -507,7 +508,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           attemptId);
       PushState pushState = pushStates.get(mapKey);
       if (pushState != null) {
-        pushState.cancelFutures();
+        pushState.cleanup();
       }
       return 0;
     }
@@ -546,7 +547,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           attemptId);
       PushState pushState = pushStates.get(mapKey);
       if (pushState != null) {
-        pushState.cancelFutures();
+        pushState.cleanup();
       }
       return 0;
     }
@@ -593,7 +594,7 @@ public class ShuffleClientImpl extends ShuffleClient {
       limitMaxInFlight(mapKey, pushState, currentMaxReqsInFlight);
 
       // add inFlight requests
-      pushState.inFlightBatches.put(nextBatchId, loc);
+      pushState.addBatch(nextBatchId);
 
       // build PushData request
       NettyManagedBuffer buffer = new 
NettyManagedBuffer(Unpooled.wrappedBuffer(body));
@@ -604,7 +605,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           new RpcResponseCallback() {
             @Override
             public void onSuccess(ByteBuffer response) {
-              pushState.inFlightBatches.remove(nextBatchId);
+              pushState.removeBatch(nextBatchId);
               // TODO Need to adjust maxReqsInFlight if server response is 
congested, see
               // CELEBORN-62
               if (response.remaining() > 0 && response.get() == 
StatusCode.STAGE_ENDED.getValue()) {
@@ -612,7 +613,6 @@ public class ShuffleClientImpl extends ShuffleClient {
                     .computeIfAbsent(shuffleId, (id) -> 
ConcurrentHashMap.newKeySet())
                     .add(mapKey);
               }
-              pushState.removeFuture(nextBatchId);
               logger.debug(
                   "Push data to {}:{} success for map {} attempt {} batch {}.",
                   loc.getHost(),
@@ -626,7 +626,6 @@ public class ShuffleClientImpl extends ShuffleClient {
             public void onFailure(Throwable e) {
               pushState.exception.compareAndSet(
                   null, new IOException("Revived PushData failed!", e));
-              pushState.removeFuture(nextBatchId);
               logger.error(
                   "Push data to {}:{} failed for map {} attempt {} batch {}.",
                   loc.getHost(),
@@ -679,6 +678,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                       attemptId,
                       nextBatchId);
                   congestionControl();
+                  callback.onSuccess(response);
                 } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_SLAVE_CONGESTED.getValue()) {
                   logger.debug(
                       "Push data split for map {} attempt {} batch {} return 
slave congested.",
@@ -686,6 +686,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                       attemptId,
                       nextBatchId);
                   congestionControl();
+                  callback.onSuccess(response);
                 } else {
                   response.rewind();
                   slowStart();
@@ -726,7 +727,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                             pushState,
                             getPushDataFailCause(e.getMessage())));
               } else {
-                pushState.inFlightBatches.remove(nextBatchId);
+                pushState.removeBatch(nextBatchId);
                 logger.info(
                     "Mapper shuffleId:{} mapId:{} attempt:{} already ended, 
remove batchId:{}.",
                     shuffleId,
@@ -742,7 +743,7 @@ public class ShuffleClientImpl extends ShuffleClient {
         TransportClient client =
             dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), 
partitionId);
         ChannelFuture future = client.pushData(pushData, wrappedCallback);
-        pushState.addFuture(nextBatchId, future);
+        pushState.pushStarted(nextBatchId, future, wrappedCallback);
       } catch (Exception e) {
         logger.warn("PushData failed", e);
         wrappedCallback.onFailure(
@@ -898,7 +899,7 @@ public class ShuffleClientImpl extends ShuffleClient {
     final int port = Integer.parseInt(splits[1]);
 
     int groupedBatchId = pushState.batchId.addAndGet(1);
-    pushState.inFlightBatches.put(groupedBatchId, batches.get(0).loc);
+    pushState.addBatch(groupedBatchId);
 
     final int numBatches = batches.size();
     final String[] partitionUniqueIds = new String[numBatches];
@@ -928,7 +929,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                 mapId,
                 attemptId,
                 groupedBatchId);
-            pushState.inFlightBatches.remove(groupedBatchId);
+            pushState.removeBatch(groupedBatchId);
             // TODO Need to adjust maxReqsInFlight if server response is 
congested, see CELEBORN-62
             if (response.remaining() > 0 && response.get() == 
StatusCode.STAGE_ENDED.getValue()) {
               mapperEndMap
@@ -995,6 +996,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                     attemptId,
                     Arrays.toString(batchIds));
                 congestionControl();
+                callback.onSuccess(response);
               } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_SLAVE_CONGESTED.getValue()) {
                 logger.debug(
                     "Push data split for map {} attempt {} batchs {} return 
slave congested.",
@@ -1002,6 +1004,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                     attemptId,
                     Arrays.toString(batchIds));
                 congestionControl();
+                callback.onSuccess(response);
               } else {
                 // Should not happen in current architecture.
                 response.rewind();
@@ -1056,7 +1059,8 @@ public class ShuffleClientImpl extends ShuffleClient {
     // do push merged data
     try {
       TransportClient client = dataClientFactory.createClient(host, port);
-      client.pushMergedData(mergedData, wrappedCallback);
+      ChannelFuture future = client.pushMergedData(mergedData, 
wrappedCallback);
+      pushState.pushStarted(groupedBatchId, future, wrappedCallback);
     } catch (Exception e) {
       logger.warn("PushMergedData failed", e);
       wrappedCallback.onFailure(new 
Exception(getPushDataFailCause(e.getMessage()).toString(), e));
@@ -1114,7 +1118,7 @@ public class ShuffleClientImpl extends ShuffleClient {
     PushState pushState = pushStates.remove(mapKey);
     if (pushState != null) {
       pushState.exception.compareAndSet(null, new IOException("Cleaned Up"));
-      pushState.cancelFutures();
+      pushState.cleanup();
     }
   }
 
@@ -1303,6 +1307,8 @@ public class ShuffleClientImpl extends ShuffleClient {
     } else if (StatusCode.PUSH_DATA_FAIL_MASTER.getMessage().equals(message)
         || connectFail(message)) {
       cause = StatusCode.PUSH_DATA_FAIL_MASTER;
+    } else if (StatusCode.PUSH_DATA_TIMEOUT.getMessage().equals(message)) {
+      cause = StatusCode.PUSH_DATA_TIMEOUT;
     } else {
       cause = StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE;
     }
@@ -1338,7 +1344,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           attemptId);
       PushState pushState = pushStates.get(mapKey);
       if (pushState != null) {
-        pushState.cancelFutures();
+        pushState.cleanup();
       }
       return 0;
     }
@@ -1367,7 +1373,7 @@ public class ShuffleClientImpl extends ShuffleClient {
     limitMaxInFlight(mapKey, pushState, maxInFlight);
 
     // add inFlight requests
-    pushState.inFlightBatches.put(nextBatchId, location);
+    pushState.addBatch(nextBatchId);
 
     // build PushData request
     NettyManagedBuffer buffer = new NettyManagedBuffer(data);
@@ -1379,8 +1385,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           @Override
           public void onSuccess(ByteBuffer response) {
             closeCallBack.getAsBoolean();
-            pushState.inFlightBatches.remove(nextBatchId);
-            pushState.removeFuture(nextBatchId);
+            pushState.removeBatch(nextBatchId);
             if (response.remaining() > 0) {
               byte reason = response.get();
               if (reason == StatusCode.STAGE_ENDED.getValue()) {
@@ -1401,8 +1406,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           @Override
           public void onFailure(Throwable e) {
             closeCallBack.getAsBoolean();
-            pushState.inFlightBatches.remove(nextBatchId);
-            pushState.removeFuture(nextBatchId);
+            pushState.removeBatch(nextBatchId);
             if (pushState.exception.get() != null) {
               return;
             }
@@ -1432,7 +1436,7 @@ public class ShuffleClientImpl extends ShuffleClient {
       TransportClient client =
           dataClientFactory.createClient(location.getHost(), 
location.getPushPort(), partitionId);
       ChannelFuture future = client.pushData(pushData, callback);
-      pushState.addFuture(nextBatchId, future);
+      pushState.pushStarted(nextBatchId, future, callback);
     } catch (Exception e) {
       logger.warn("PushData byteBuf failed", e);
       callback.onFailure(new 
Exception(getPushDataFailCause(e.getMessage()).toString(), e));
@@ -1454,7 +1458,6 @@ public class ShuffleClientImpl extends ShuffleClient {
         shuffleId,
         mapId,
         attemptId,
-        location,
         () -> {
           String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
           logger.info(
@@ -1473,7 +1476,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                   attemptId,
                   numPartitions,
                   bufferSize);
-          client.sendRpcSync(handShake.toByteBuffer(), 
conf.pushDataRpcTimeoutMs());
+          client.sendRpcSync(handShake.toByteBuffer(), 
conf.pushDataTimeoutMs());
           return null;
         });
   }
@@ -1492,7 +1495,6 @@ public class ShuffleClientImpl extends ShuffleClient {
         shuffleId,
         mapId,
         attemptId,
-        location,
         () -> {
           String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
           logger.info(
@@ -1512,7 +1514,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                   currentRegionIdx,
                   isBroadcast);
           ByteBuffer regionStartResponse =
-              client.sendRpcSync(regionStart.toByteBuffer(), 
conf.pushDataRpcTimeoutMs());
+              client.sendRpcSync(regionStart.toByteBuffer(), 
conf.pushDataTimeoutMs());
           if (regionStartResponse.hasRemaining()
               && regionStartResponse.get() == 
StatusCode.HARD_SPLIT.getValue()) {
             // if split then revive
@@ -1561,7 +1563,6 @@ public class ShuffleClientImpl extends ShuffleClient {
         shuffleId,
         mapId,
         attemptId,
-        location,
         () -> {
           final String shuffleKey = Utils.makeShuffleKey(applicationId, 
shuffleId);
           logger.info(
@@ -1574,17 +1575,13 @@ public class ShuffleClientImpl extends ShuffleClient {
               dataClientFactory.createClient(location.getHost(), 
location.getPushPort());
           RegionFinish regionFinish =
               new RegionFinish(MASTER_MODE, shuffleKey, 
location.getUniqueId(), attemptId);
-          client.sendRpcSync(regionFinish.toByteBuffer(), 
conf.pushDataRpcTimeoutMs());
+          client.sendRpcSync(regionFinish.toByteBuffer(), 
conf.pushDataTimeoutMs());
           return null;
         });
   }
 
   private <R> R sendMessageInternal(
-      int shuffleId,
-      int mapId,
-      int attemptId,
-      PartitionLocation location,
-      ThrowingExceptionSupplier<R, Exception> supplier)
+      int shuffleId, int mapId, int attemptId, ThrowingExceptionSupplier<R, 
Exception> supplier)
       throws IOException {
     PushState pushState = null;
     int batchId = 0;
@@ -1606,11 +1603,11 @@ public class ShuffleClientImpl extends ShuffleClient {
 
       // add inFlight requests
       batchId = pushState.batchId.incrementAndGet();
-      pushState.inFlightBatches.put(batchId, location);
+      pushState.addBatch(batchId);
       return retrySendMessage(supplier);
     } finally {
       if (pushState != null) {
-        pushState.inFlightBatches.remove(batchId);
+        pushState.removeBatch(batchId);
       }
     }
   }
diff --git 
a/client/src/main/java/org/apache/celeborn/client/write/PushState.java 
b/client/src/main/java/org/apache/celeborn/client/write/PushState.java
index 67a719aa..a8f96240 100644
--- a/client/src/main/java/org/apache/celeborn/client/write/PushState.java
+++ b/client/src/main/java/org/apache/celeborn/client/write/PushState.java
@@ -18,8 +18,6 @@
 package org.apache.celeborn.client.write;
 
 import java.io.IOException;
-import java.util.HashSet;
-import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
@@ -29,41 +27,90 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.message.StatusCode;
 
 public class PushState {
+  class BatchInfo {
+    ChannelFuture channelFuture;
+    long pushTime;
+    RpcResponseCallback callback;
+  }
+
   private static final Logger logger = 
LoggerFactory.getLogger(PushState.class);
 
   private int pushBufferMaxSize;
+  private long pushTimeout;
 
   public final AtomicInteger batchId = new AtomicInteger();
-  public final ConcurrentHashMap<Integer, PartitionLocation> inFlightBatches =
+  private final ConcurrentHashMap<Integer, BatchInfo> inflightBatchInfos =
       new ConcurrentHashMap<>();
-  public final ConcurrentHashMap<Integer, ChannelFuture> futures = new 
ConcurrentHashMap<>();
   public AtomicReference<IOException> exception = new AtomicReference<>();
 
   public PushState(CelebornConf conf) {
     pushBufferMaxSize = conf.pushBufferMaxSize();
+    pushTimeout = conf.pushDataTimeoutMs();
+  }
+
+  public void addBatch(int batchId) {
+    inflightBatchInfos.computeIfAbsent(batchId, id -> new BatchInfo());
+  }
+
+  public void removeBatch(int batchId) {
+    BatchInfo info = inflightBatchInfos.remove(batchId);
+    if (info != null && info.channelFuture != null) {
+      info.channelFuture.cancel(true);
+    }
   }
 
-  public void addFuture(int batchId, ChannelFuture future) {
-    futures.put(batchId, future);
+  public int inflightBatchCount() {
+    return inflightBatchInfos.size();
   }
 
-  public void removeFuture(int batchId) {
-    futures.remove(batchId);
+  public synchronized void failExpiredBatch() {
+    long currentTime = System.currentTimeMillis();
+    inflightBatchInfos
+        .values()
+        .forEach(
+            info -> {
+              if (currentTime - info.pushTime > pushTimeout) {
+                if (info.callback != null) {
+                  info.channelFuture.cancel(true);
+                  info.callback.onFailure(
+                      new 
IOException(StatusCode.PUSH_DATA_TIMEOUT.getMessage()));
+                  info.channelFuture = null;
+                  info.callback = null;
+                }
+              }
+            });
+  }
+
+  public void pushStarted(int batchId, ChannelFuture future, 
RpcResponseCallback callback) {
+    BatchInfo info = inflightBatchInfos.get(batchId);
+    // In rare cases info could be null. For example, a speculative task has 
one thread pushing,
+    // and other thread retry-pushing. At time 1 thread 1 find StageEnded, 
then it cleans up
+    // PushState, at the same time thread 2 pushes data and calles pushStarted,
+    // at this time info will be null
+    if (info != null) {
+      info.pushTime = System.currentTimeMillis();
+      info.channelFuture = future;
+      info.callback = callback;
+    }
   }
 
-  public synchronized void cancelFutures() {
-    if (!futures.isEmpty()) {
-      Set<Integer> keys = new HashSet<>(futures.keySet());
-      logger.debug("Cancel all {} futures.", keys.size());
-      for (Integer batchId : keys) {
-        ChannelFuture future = futures.remove(batchId);
-        if (future != null) {
-          future.cancel(true);
-        }
-      }
+  public void cleanup() {
+    if (!inflightBatchInfos.isEmpty()) {
+      logger.debug("Cancel all {} futures.", inflightBatchInfos.size());
+      inflightBatchInfos
+          .values()
+          .forEach(
+              entry -> {
+                if (entry.channelFuture != null) {
+                  entry.channelFuture.cancel(true);
+                }
+              });
+      inflightBatchInfos.clear();
     }
   }
 
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
index 4ffc579d..88e8c42a 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
@@ -288,7 +288,7 @@ public class TransportClient implements Closeable {
       } else {
         String errorMsg =
             String.format(
-                "Failed to send RPC %s to %s: %s, channel will be closed",
+                "Failed to send request %s to %s: %s, channel will be closed",
                 requestId, NettyUtils.getRemoteAddress(channel), 
future.cause());
         logger.warn(errorMsg);
         channel.close();
diff --git 
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
 
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
index 3b3f067d..0fb46765 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
@@ -66,7 +66,9 @@ public enum StatusCode {
   REGION_START_FAIL_SLAVE(34),
   REGION_START_FAIL_MASTER(35),
   REGION_FINISH_FAIL_SLAVE(36),
-  REGION_FINISH_FAIL_MASTER(37);
+  REGION_FINISH_FAIL_MASTER(37),
+
+  PUSH_DATA_TIMEOUT(38);
 
   private final byte value;
 
@@ -103,6 +105,8 @@ public enum StatusCode {
       msg = "RegionFinishFailMaster";
     } else if (value == REGION_FINISH_FAIL_SLAVE.getValue()) {
       msg = "RegionFinishFailSlave";
+    } else if (value == PUSH_DATA_TIMEOUT.getValue()) {
+      msg = "PushDataTimeout";
     }
 
     return msg;
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 16b078ba..0dfbc3b8 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -556,6 +556,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
   // //////////////////////////////////////////////////////
   def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE)
   def testRetryCommitFiles: Boolean = get(TEST_RETRY_COMMIT_FILE)
+  def testPushDataTimeout: Boolean = get(TEST_PUSHDATA_TIMEOUT)
 
   def masterHost: String = get(MASTER_HOST)
 
@@ -679,7 +680,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
   def rpcCacheSize: Int = get(RPC_CACHE_SIZE)
   def rpcCacheConcurrencyLevel: Int = get(RPC_CACHE_CONCURRENCY_LEVEL)
   def rpcCacheExpireTime: Long = get(RPC_CACHE_EXPIRE_TIME)
-  def pushDataRpcTimeoutMs = get(PUSH_DATA_RPC_TIMEOUT)
+  def pushDataTimeoutMs = get(PUSH_DATA_TIMEOUT)
 
   def registerShuffleRpcAskTimeout: RpcTimeout =
     new RpcTimeout(
@@ -2192,8 +2193,8 @@ object CelebornConf extends Logging {
       .timeConf(TimeUnit.MILLISECONDS)
       .createWithDefaultString("5s")
 
-  val PUSH_DATA_RPC_TIMEOUT: ConfigEntry[Long] =
-    buildConf("celeborn.push.data.rpc.timeout")
+  val PUSH_DATA_TIMEOUT: ConfigEntry[Long] =
+    buildConf("celeborn.push.data.timeout")
       .withAlternative("rss.push.data.rpc.timeout")
       .categories("client")
       .version("0.2.0")
@@ -2201,6 +2202,14 @@ object CelebornConf extends Logging {
       .timeConf(TimeUnit.MILLISECONDS)
       .createWithDefaultString("120s")
 
+  val TEST_PUSHDATA_TIMEOUT: ConfigEntry[Boolean] =
+    buildConf("celeborn.test.pushdataTimeout")
+      .categories("worker")
+      .version("0.2.0")
+      .doc("Wheter to test pushdata timeout")
+      .booleanConf
+      .createWithDefault(false)
+
   val REGISTER_SHUFFLE_RPC_ASK_TIMEOUT: OptionalConfigEntry[Long] =
     buildConf("celeborn.rpc.registerShuffle.askTimeout")
       .categories("client")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 8ec3aafc..b58fac84 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -27,7 +27,7 @@ license: |
 | celeborn.master.endpoints | &lt;localhost&gt;:9097 | Endpoints of master 
nodes for celeborn client to connect, allowed pattern is: 
`<host1>:<port1>[,<host2>:<port2>]*`, e.g. `clb1:9097,clb2:9098,clb3:9099`. If 
the port is omitted, 9097 will be used. | 0.2.0 | 
 | celeborn.push.buffer.initial.size | 8k |  | 0.2.0 | 
 | celeborn.push.buffer.max.size | 64k | Max size of reducer partition buffer 
memory for shuffle hash writer. The pushed data will be buffered in memory 
before sending to Celeborn worker. For performance consideration keep this 
buffer size higher than 32K. Example: If reducer amount is 2000, buffer size is 
64K, then each task will consume up to `64KiB * 2000 = 125MiB` heap memory. | 
0.2.0 | 
-| celeborn.push.data.rpc.timeout | 120s | Timeout for a task to push data rpc 
message. | 0.2.0 | 
+| celeborn.push.data.timeout | 120s | Timeout for a task to push data rpc 
message. | 0.2.0 | 
 | celeborn.push.limit.inFlight.sleepInterval | 50ms | Sleep interval when 
check netty in-flight requests to be done. | 0.2.0 | 
 | celeborn.push.limit.inFlight.timeout | 240s | Timeout for netty in-flight 
requests to be done. | 0.2.0 | 
 | celeborn.push.maxReqsInFlight | 32 | Amount of Netty in-flight requests. The 
maximum memory is `celeborn.push.maxReqsInFlight` * 
`celeborn.push.buffer.max.size` * compression ratio(1 in worst case), default: 
64Kib * 32 = 2Mib | 0.2.0 | 
diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md
index 0ae29ef9..44b62816 100644
--- a/docs/configuration/worker.md
+++ b/docs/configuration/worker.md
@@ -29,6 +29,7 @@ license: |
 | celeborn.shuffle.chuck.size | 8m | Max chunk size of reducer's merged 
shuffle data. For example, if a reducer's shuffle data is 128M and the data 
will need 16 fetch chunk requests to fetch. | 0.2.0 | 
 | celeborn.shuffle.minPartitionSizeToEstimate | 8mb | Ignore partition size 
smaller than this configuration of partition size for estimation. | 0.2.0 | 
 | celeborn.storage.hdfs.dir | &lt;undefined&gt; | HDFS dir configuration for 
Celeborn to access HDFS. | 0.2.0 | 
+| celeborn.test.pushdataTimeout | false | Wheter to test pushdata timeout | 
0.2.0 | 
 | celeborn.worker.closeIdleConnections | false | Whether worker will close 
idle connections. | 0.2.0 | 
 | celeborn.worker.commit.threads | 32 | Thread number of worker to commit 
shuffle data files asynchronously. | 0.2.0 | 
 | celeborn.worker.directMemoryRatioForMemoryShuffleStorage | 0.1 | Max ratio 
of direct memory to store shuffle data | 0.2.0 | 
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala
new file mode 100644
index 00000000..ed876524
--- /dev/null
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.spark
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.ShuffleClient
+
+class PushdataTimeoutTest extends AnyFunSuite
+  with SparkTestBase
+  with BeforeAndAfterAll
+  with BeforeAndAfterEach {
+
+  override def beforeAll(): Unit = {
+    logInfo("test initialized , setup rss mini cluster")
+    val workerConf = Map(
+      "celeborn.test.pushdataTimeout" -> s"true")
+    tuple = setupRssMiniClusterSpark(masterConfs = null, workerConfs = 
workerConf)
+  }
+
+  override def afterAll(): Unit = {
+    logInfo("all test complete , stop rss mini cluster")
+    clearMiniCluster(tuple)
+  }
+
+  override def beforeEach(): Unit = {
+    ShuffleClient.reset()
+  }
+
+  override def afterEach(): Unit = {
+    System.gc()
+  }
+
+  test("celeborn spark integration test - pushdata timeout") {
+    val sparkConf = new 
SparkConf().setAppName("rss-demo").setMaster("local[4]")
+      .set("spark.celeborn.push.data.timeout", "10s")
+    val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
+    val combineResult = combine(sparkSession)
+    val groupbyResult = groupBy(sparkSession)
+    val repartitionResult = repartition(sparkSession)
+    val sqlResult = runsql(sparkSession)
+
+    Thread.sleep(3000L)
+    sparkSession.stop()
+
+    val rssSparkSession = SparkSession.builder()
+      .config(updateSparkConf(sparkConf, false)).getOrCreate()
+    val rssCombineResult = combine(rssSparkSession)
+    val rssGroupbyResult = groupBy(rssSparkSession)
+    val rssRepartitionResult = repartition(rssSparkSession)
+    val rssSqlResult = runsql(rssSparkSession)
+
+    assert(combineResult.equals(rssCombineResult))
+    assert(groupbyResult.equals(rssGroupbyResult))
+    assert(repartitionResult.equals(rssRepartitionResult))
+    assert(combineResult.equals(rssCombineResult))
+    assert(sqlResult.equals(rssSqlResult))
+
+    rssSparkSession.stop()
+
+  }
+}
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index 933f3b6b..d91c55c9 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -24,6 +24,7 @@ import java.util.concurrent.atomic.{AtomicBoolean, 
AtomicIntegerArray}
 import com.google.common.base.Throwables
 import io.netty.buffer.ByteBuf
 
+import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.exception.AlreadyClosedException
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
@@ -55,6 +56,8 @@ class PushDataHandler extends BaseMessageHandler with Logging 
{
   var partitionSplitMinimumSize: Long = _
   var shutdown: AtomicBoolean = _
   var storageManager: StorageManager = _
+  var conf: CelebornConf = _
+  @volatile var pushDataTimeoutTested = false
 
   def init(worker: Worker): Unit = {
     workerSource = worker.workerSource
@@ -71,6 +74,7 @@ class PushDataHandler extends BaseMessageHandler with Logging 
{
     partitionSplitMinimumSize = worker.conf.partitionSplitMinimumSize
     storageManager = worker.storageManager
     shutdown = worker.shutdown
+    conf = worker.conf
 
     logInfo(s"diskReserveSize $diskReserveSize")
   }
@@ -122,6 +126,12 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
     val body = pushData.body.asInstanceOf[NettyManagedBuffer].getBuf
     val isMaster = mode == PartitionLocation.Mode.MASTER
 
+    // For test
+    if (conf.testPushDataTimeout && !pushDataTimeoutTested) {
+      pushDataTimeoutTested = true
+      return
+    }
+
     val key = s"${pushData.requestId}"
     if (isMaster) {
       workerSource.startTimer(WorkerSource.MasterPushDataTime, key)
@@ -315,6 +325,12 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
       workerSource.startTimer(WorkerSource.SlavePushDataTime, key)
     }
 
+    // For test
+    if (conf.testPushDataTimeout && !PushDataHandler.pushDataTimeoutTested) {
+      PushDataHandler.pushDataTimeoutTested = true
+      return
+    }
+
     val wrappedCallback = new RpcResponseCallback() {
       override def onSuccess(response: ByteBuffer): Unit = {
         if (isMaster) {
@@ -897,3 +913,7 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
     (PackedPartitionId.getRawPartitionId(id), 
PackedPartitionId.getAttemptId(id))
   }
 }
+
+object PushDataHandler {
+  @volatile var pushDataTimeoutTested = false
+}

Reply via email to