This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 44eb4e5e5 [#134] improvement(spark3): Use taskId and attemptNo as 
taskAttemptId (#1529)
44eb4e5e5 is described below

commit 44eb4e5e57a062b4f427ee7ba737bf213fdb442f
Author: Enrico Minack <[email protected]>
AuthorDate: Tue Feb 20 09:19:50 2024 +0100

    [#134] improvement(spark3): Use taskId and attemptNo as taskAttemptId 
(#1529)
    
    ### What changes were proposed in this pull request?
    Use map index and task attempt number as the task attempt id in Spark3.
    
    This requires to rework the bits of the blockId to maximize bit utilization 
for Spark3:
    
https://github.com/apache/incubator-uniffle/blob/b924acacb0c555370a593f3a069187cf8b8081d7/common/src/main/java/org/apache/uniffle/common/util/Constants.java#L30-L35
    
    Ideally, the `TASK_ATTEMPT_ID_MAX_LENGTH` is set equal to 
`PARTITION_ID_MAX_LENGTH` + the number of bits required to store the largest 
task attempt number. The largest task attempt number is `maxFailures - 1`, or 
`maxFailures` if speculative execution is enabled (configured via 
`spark.speculation` and disabled by default). The `maxFailures` is configured 
via `spark.task.maxFailures` and defaults to 4. So by default, two bits are 
required to store the largest attempt number and `TASK_A [...]
    
    Example:
    
    - with `PARTITION_ID_MAX_LENGTH = 20`, Uniffle supports 1,048,576 partitions
    - requiring `TASK_ATTEMPT_ID_MAX_LENGTH = 22`
    - allowing for `ATOMIC_INT_MAX_LENGTH = 21`.
    
    ### Why are the changes needed?
    The map index (map partition id) is limited to the number of partitions of 
a shuffle. The task attempt number is limited by the max number of failures 
configured by `spark.task.maxFailures`, which defaults to 4. This provides us 
an id that is unique per shuffe while not growing arbitrarily large as 
`context.taskAttemptId` does.
    
    Fix: #134
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Unit and integration tests.
---
 .../apache/spark/shuffle/RssShuffleManager.java    |  73 +++++-
 .../spark/shuffle/RssShuffleManagerTest.java       | 250 +++++++++++++++++++++
 .../org/apache/uniffle/test/FailingTasksTest.java  | 102 +++++++++
 .../test/RSSStageDynamicServerReWriteTest.java     |  13 +-
 .../apache/uniffle/test/RSSStageResubmitTest.java  |  13 +-
 .../test/SparkTaskFailureIntegrationTestBase.java  |  37 +++
 6 files changed, 465 insertions(+), 23 deletions(-)

diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 625089118..cb35ce3a2 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -78,6 +78,7 @@ import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.rpc.GrpcServer;
+import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.common.util.RssUtils;
@@ -112,6 +113,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
   private boolean dynamicConfEnabled = false;
   private final ShuffleDataDistributionType dataDistributionType;
   private final int maxConcurrencyPerPartitionToWrite;
+  private final int maxFailures;
+  private final boolean speculation;
   private String user;
   private String uuid;
   private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
@@ -182,6 +185,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     this.dataDistributionType = getDataDistributionType(sparkConf);
     this.maxConcurrencyPerPartitionToWrite =
         
RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
+    this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+    this.speculation = sparkConf.getBoolean("spark.speculation", false);
     long retryIntervalMax = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
     int heartBeatThreadNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
     this.dataTransferPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
@@ -307,6 +312,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         
RssSparkConfig.toRssConf(sparkConf).get(RssClientConf.DATA_DISTRIBUTION_TYPE);
     this.maxConcurrencyPerPartitionToWrite =
         
RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
+    this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+    this.speculation = sparkConf.getBoolean("spark.speculation", false);
     this.heartbeatInterval = 
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
     this.heartbeatTimeout =
         sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), 
heartbeatInterval / 2);
@@ -503,11 +510,18 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     }
     String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
     LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), 
rssHandle.getShuffleId());
+    long taskAttemptId =
+        getTaskAttemptId(
+            context.partitionId(),
+            context.attemptNumber(),
+            maxFailures,
+            speculation,
+            Constants.TASK_ATTEMPT_ID_MAX_LENGTH);
     return new RssShuffleWriter<>(
         rssHandle.getAppId(),
         shuffleId,
         taskId,
-        context.taskAttemptId(),
+        taskAttemptId,
         writeMetrics,
         this,
         sparkConf,
@@ -518,6 +532,63 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         shuffleHandleInfo);
   }
 
+  /**
+   * Provides a task attempt id that is unique for a shuffle stage.
+   *
+   * <p>We are not using context.taskAttemptId() here as this is a 
monotonically increasing number
+   * that is unique across the entire Spark app which can reach very large 
numbers, which can
+   * practically reach LONG.MAX_VALUE. That would overflow the bits in the 
block id.
+   *
+   * <p>Here we use the map index or task id, appended by the attempt number 
per task. The map index
+   * is limited by the number of partitions of a stage. The attempt number per 
task is limited /
+   * configured by spark.task.maxFailures (default: 4).
+   *
+   * @return a task attempt id unique for a shuffle stage
+   */
+  @VisibleForTesting
+  protected static long getTaskAttemptId(
+      int mapIndex, int attemptNo, int maxFailures, boolean speculation, int 
maxTaskAttemptIdBits) {
+    // attempt number is zero based: 0, 1, …, maxFailures-1
+    // max maxFailures < 1 is not allowed but for safety, we interpret that as 
maxFailures == 1
+    int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
+
+    // with speculative execution enabled we could observe +1 attempts
+    if (speculation) {
+      maxAttemptNo++;
+    }
+
+    if (attemptNo > maxAttemptNo) {
+      // this should never happen, if it does, our assumptions are wrong,
+      // and we risk overflowing the attempt number bits
+      throw new RssException(
+          "Observing attempt number "
+              + attemptNo
+              + " while maxFailures is set to "
+              + maxFailures
+              + (speculation ? " with speculation enabled" : "")
+              + ".");
+    }
+
+    int attemptBits = 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
+    int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
+    if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
+      throw new RssException(
+          "Observing mapIndex["
+              + mapIndex
+              + "] that would produce a taskAttemptId with "
+              + (mapIndexBits + attemptBits)
+              + " bits which is larger than the allowed "
+              + maxTaskAttemptIdBits
+              + " bits (maxFailures["
+              + maxFailures
+              + "], speculation["
+              + speculation
+              + "]). Please consider providing more bits for taskAttemptIds.");
+    }
+
+    return (long) mapIndex << attemptBits | attemptNo;
+  }
+
   public void setPusherAppId(RssShuffleHandle rssShuffleHandle) {
     // todo: this implement is tricky, we should refactor it
     if (id.get() == null) {
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
index 64bd6f902..9150d6d30 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.shuffle;
 
+import java.util.Arrays;
+
 import org.apache.spark.SparkConf;
 import org.apache.spark.sql.internal.SQLConf;
 import org.junit.jupiter.api.Test;
@@ -24,12 +26,14 @@ import org.junit.jupiter.api.Test;
 import org.apache.uniffle.client.util.RssClientConfig;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class RssShuffleManagerTest extends RssShuffleManagerTestBase {
@@ -84,6 +88,252 @@ public class RssShuffleManagerTest extends 
RssShuffleManagerTestBase {
     }
   }
 
+  private long bits(String string) {
+    return Long.parseLong(string.replaceAll("[|]", ""), 2);
+  }
+
+  @Test
+  public void testGetTaskAttemptIdWithoutSpeculation() {
+    // the expected bits("xy|z") represents the expected Long in bit notation 
where | is used to
+    // separate map index from attempt number, so merely for visualization 
purposes
+
+    // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
+    for (int maxFailures : Arrays.asList(-1, 0, 1)) {
+      assertEquals(
+          bits("0000|"),
+          RssShuffleManager.getTaskAttemptId(0, 0, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertEquals(
+          bits("0001|"),
+          RssShuffleManager.getTaskAttemptId(1, 0, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertEquals(
+          bits("0010|"),
+          RssShuffleManager.getTaskAttemptId(2, 0, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+    }
+
+    // maxFailures of 2
+    assertEquals(bits("000|0"), RssShuffleManager.getTaskAttemptId(0, 0, 2, 
false, 10));
+    assertEquals(bits("000|1"), RssShuffleManager.getTaskAttemptId(0, 1, 2, 
false, 10));
+    assertEquals(bits("001|0"), RssShuffleManager.getTaskAttemptId(1, 0, 2, 
false, 10));
+    assertEquals(bits("001|1"), RssShuffleManager.getTaskAttemptId(1, 1, 2, 
false, 10));
+    assertEquals(bits("010|0"), RssShuffleManager.getTaskAttemptId(2, 0, 2, 
false, 10));
+    assertEquals(bits("010|1"), RssShuffleManager.getTaskAttemptId(2, 1, 2, 
false, 10));
+    assertEquals(bits("011|0"), RssShuffleManager.getTaskAttemptId(3, 0, 2, 
false, 10));
+    assertEquals(bits("011|1"), RssShuffleManager.getTaskAttemptId(3, 1, 2, 
false, 10));
+
+    // maxFailures of 3
+    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 3, 
false, 10));
+    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 3, 
false, 10));
+    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 3, 
false, 10));
+    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 3, 
false, 10));
+    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 3, 
false, 10));
+    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 3, 
false, 10));
+    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 3, 
false, 10));
+    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 3, 
false, 10));
+    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 3, 
false, 10));
+    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 3, 
false, 10));
+    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 3, 
false, 10));
+    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 3, 
false, 10));
+
+    // maxFailures of 4
+    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 4, 
false, 10));
+    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 4, 
false, 10));
+    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 4, 
false, 10));
+    assertEquals(bits("00|11"), RssShuffleManager.getTaskAttemptId(0, 3, 4, 
false, 10));
+    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 4, 
false, 10));
+    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 4, 
false, 10));
+    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 4, 
false, 10));
+    assertEquals(bits("01|11"), RssShuffleManager.getTaskAttemptId(1, 3, 4, 
false, 10));
+    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 4, 
false, 10));
+    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 4, 
false, 10));
+    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 4, 
false, 10));
+    assertEquals(bits("10|11"), RssShuffleManager.getTaskAttemptId(2, 3, 4, 
false, 10));
+    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 4, 
false, 10));
+    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 4, 
false, 10));
+    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 4, 
false, 10));
+    assertEquals(bits("11|11"), RssShuffleManager.getTaskAttemptId(3, 3, 4, 
false, 10));
+
+    // maxFailures of 5
+    assertEquals(bits("0|000"), RssShuffleManager.getTaskAttemptId(0, 0, 5, 
false, 10));
+    assertEquals(bits("1|100"), RssShuffleManager.getTaskAttemptId(1, 4, 5, 
false, 10));
+
+    // test with ints that overflow into signed int and long
+    assertEquals(
+        Integer.MAX_VALUE, 
RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, false, 31));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 1 | 1,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 1, 2, false, 
32));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 2 | 3,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 3, 4, false, 
33));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 3 | 7,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 7, 8, false, 
34));
+
+    // test with attemptNo >= maxFailures
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 1, -1, 
false, 10));
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 1, 0, 
false, 10));
+    for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures, 
maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 1, 
maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 2, 
maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      Exception e =
+          assertThrowsExactly(
+              RssException.class,
+              () ->
+                  RssShuffleManager.getTaskAttemptId(0, maxFailures + 128, 
maxFailures, false, 10),
+              String.valueOf(maxFailures));
+      assertEquals(
+          "Observing attempt number "
+              + (maxFailures + 128)
+              + " while maxFailures is set to "
+              + maxFailures
+              + ".",
+          e.getMessage());
+    }
+
+    // test with mapIndex that would require more than maxTaskAttemptBits
+    Exception e =
+        assertThrowsExactly(
+            RssException.class, () -> RssShuffleManager.getTaskAttemptId(256, 
0, 3, true, 10));
+    assertEquals(
+        "Observing mapIndex[256] that would produce a taskAttemptId with 11 
bits "
+            + "which is larger than the allowed 10 bits (maxFailures[3], 
speculation[true]). "
+            + "Please consider providing more bits for taskAttemptIds.",
+        e.getMessage());
+    // check that a lower mapIndex works as expected
+    assertEquals(bits("11111111|00"), RssShuffleManager.getTaskAttemptId(255, 
0, 3, true, 10));
+  }
+
+  @Test
+  public void testGetTaskAttemptIdWithSpeculation() {
+    // with speculation, we expect maxFailures+1 attempts
+
+    // the expected bits("xy|z") represents the expected Long in bit notation 
where | is used to
+    // separate map index from attempt number, so merely for visualization 
purposes
+
+    // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
+    for (int maxFailures : Arrays.asList(-1, 0, 1)) {
+      for (int attemptNo : Arrays.asList(0, 1)) {
+        assertEquals(
+            bits("0000|" + attemptNo),
+            RssShuffleManager.getTaskAttemptId(0, attemptNo, maxFailures, 
true, 10),
+            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+        assertEquals(
+            bits("0001|" + attemptNo),
+            RssShuffleManager.getTaskAttemptId(1, attemptNo, maxFailures, 
true, 10),
+            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+        assertEquals(
+            bits("0010|" + attemptNo),
+            RssShuffleManager.getTaskAttemptId(2, attemptNo, maxFailures, 
true, 10),
+            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+      }
+    }
+
+    // maxFailures of 2
+    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 2, 
true, 10));
+    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 2, 
true, 10));
+    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 2, 
true, 10));
+    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 2, 
true, 10));
+    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 2, 
true, 10));
+    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 2, 
true, 10));
+    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 2, 
true, 10));
+    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 2, 
true, 10));
+    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 2, 
true, 10));
+    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 2, 
true, 10));
+    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 2, 
true, 10));
+    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 2, 
true, 10));
+
+    // maxFailures of 3
+    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 3, 
true, 10));
+    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 3, 
true, 10));
+    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 3, 
true, 10));
+    assertEquals(bits("00|11"), RssShuffleManager.getTaskAttemptId(0, 3, 3, 
true, 10));
+    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 3, 
true, 10));
+    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 3, 
true, 10));
+    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 3, 
true, 10));
+    assertEquals(bits("01|11"), RssShuffleManager.getTaskAttemptId(1, 3, 3, 
true, 10));
+    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 3, 
true, 10));
+    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 3, 
true, 10));
+    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 3, 
true, 10));
+    assertEquals(bits("10|11"), RssShuffleManager.getTaskAttemptId(2, 3, 3, 
true, 10));
+    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 3, 
true, 10));
+    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 3, 
true, 10));
+    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 3, 
true, 10));
+    assertEquals(bits("11|11"), RssShuffleManager.getTaskAttemptId(3, 3, 3, 
true, 10));
+
+    // maxFailures of 4
+    assertEquals(bits("0|000"), RssShuffleManager.getTaskAttemptId(0, 0, 4, 
true, 10));
+    assertEquals(bits("1|100"), RssShuffleManager.getTaskAttemptId(1, 4, 4, 
true, 10));
+
+    // test with ints that overflow into signed int and long
+    assertEquals(
+        (long) Integer.MAX_VALUE << 1,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, true, 32));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 1 | 1,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 1, 1, true, 32));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 2 | 3,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 3, 3, true, 33));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 3 | 7,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 7, 7, true, 34));
+
+    // test with attemptNo > maxFailures (attemptNo == maxFailures allowed for 
speculation enabled)
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 2, -1, 
true, 10));
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 2, 0, 
true, 10));
+    for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 1, 
maxFailures, true, 10),
+          String.valueOf(maxFailures));
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 2, 
maxFailures, true, 10),
+          String.valueOf(maxFailures));
+      Exception e =
+          assertThrowsExactly(
+              RssException.class,
+              () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 128, 
maxFailures, true, 10),
+              String.valueOf(maxFailures));
+      assertEquals(
+          "Observing attempt number "
+              + (maxFailures + 128)
+              + " while maxFailures is set to "
+              + maxFailures
+              + " with speculation enabled.",
+          e.getMessage());
+    }
+
+    // test with mapIndex that would require more than maxTaskAttemptBits
+    Exception e =
+        assertThrowsExactly(
+            RssException.class, () -> RssShuffleManager.getTaskAttemptId(256, 
0, 4, false, 10));
+    assertEquals(
+        "Observing mapIndex[256] that would produce a taskAttemptId with 11 
bits "
+            + "which is larger than the allowed 10 bits (maxFailures[4], 
speculation[false]). "
+            + "Please consider providing more bits for taskAttemptIds.",
+        e.getMessage());
+    // check that a lower mapIndex works as expected
+    assertEquals(bits("11111111|00"), RssShuffleManager.getTaskAttemptId(255, 
0, 4, false, 10));
+  }
+
   @Test
   public void testCreateShuffleManagerServer() {
     setupMockedRssShuffleUtils(StatusCode.SUCCESS);
diff --git 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/FailingTasksTest.java
 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/FailingTasksTest.java
new file mode 100644
index 000000000..e9a0818aa
--- /dev/null
+++ 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/FailingTasksTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import com.google.common.collect.Maps;
+import org.apache.spark.TaskContext;
+import org.apache.spark.api.java.function.MapPartitionsFunction;
+import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.SparkSession;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.storage.util.StorageType;
+
+// This test has all tasks fail twice, the third attempt succeeds.
+// The failing attempts all provide zeros to the shuffle step, while the 
succeeding attempts
+// provide the actual non-zero integers (actually only one zero). If blocks 
from the failing
+// attempts leak into the read shuffle data, we would see those zeros and fail 
when comparing
+// to without RSS.
+public class FailingTasksTest extends SparkTaskFailureIntegrationTestBase {
+
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    shutdownServers();
+    CoordinatorConf coordinatorConf = getCoordinatorConf();
+    Map<String, String> dynamicConf = Maps.newHashMap();
+    dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), 
HDFS_URI + "rss/test");
+    dynamicConf.put(
+        RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.MEMORY_LOCALFILE_HDFS.name());
+    addDynamicConf(coordinatorConf, dynamicConf);
+    createCoordinatorServer(coordinatorConf);
+    ShuffleServerConf shuffleServerConf = getShuffleServerConf();
+    createShuffleServer(shuffleServerConf);
+    startServers();
+  }
+
+  @Override
+  Map runTest(SparkSession spark, String fileName) throws Exception {
+    int n = 1000000;
+    return spark.range(0, n, 1, 4)
+        .mapPartitions(
+            (MapPartitionsFunction<Long, Long>)
+                it ->
+                    new Iterator<Long>() {
+                      final TaskContext context = TaskContext.get();
+
+                      @Override
+                      public boolean hasNext() {
+                        // the first two attempts fail in the end
+                        return context.attemptNumber() < 2 || it.hasNext();
+                      }
+
+                      @Override
+                      public Long next() {
+                        if (it.hasNext()) {
+                          Long next = it.next();
+                          // the failing attempt returns only zeros
+                          if (context.attemptNumber() < 2) {
+                            return 0L;
+                          } else {
+                            return next;
+                          }
+                        } else {
+                          throw new RuntimeException("let this task fail");
+                        }
+                      }
+                    },
+            Encoders.LONG())
+        .repartition(3, new Column("value"))
+        .mapPartitions((MapPartitionsFunction<Long, Long>) it -> it, 
Encoders.LONG())
+        .collectAsList().stream()
+        .collect(Collectors.toMap(v -> v, v -> v));
+  }
+
+  @Test
+  public void testFailedTasks() throws Exception {
+    run();
+  }
+}
diff --git 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java
 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java
index 3e739690f..e7fed902f 100644
--- 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java
+++ 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java
@@ -39,12 +39,10 @@ import org.apache.uniffle.server.ShuffleServer;
 import org.apache.uniffle.server.ShuffleServerConf;
 import org.apache.uniffle.storage.util.StorageType;
 
-public class RSSStageDynamicServerReWriteTest extends SparkIntegrationTestBase 
{
+public class RSSStageDynamicServerReWriteTest extends 
SparkTaskFailureIntegrationTestBase {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(RSSStageDynamicServerReWriteTest.class);
 
-  private static int maxTaskFailures = 3;
-
   @BeforeAll
   public static void setupServers(@TempDir File tmpDir) throws Exception {
     CoordinatorConf coordinatorConf = getCoordinatorConf();
@@ -96,18 +94,11 @@ public class RSSStageDynamicServerReWriteTest extends 
SparkIntegrationTestBase {
     return result;
   }
 
-  @Override
-  protected SparkConf createSparkConf() {
-    return new SparkConf()
-        .setAppName(this.getClass().getSimpleName())
-        .setMaster(String.format("local[4,%d]", maxTaskFailures));
-  }
-
   @Override
   public void updateSparkConfCustomer(SparkConf sparkConf) {
+    super.updateSparkConfCustomer(sparkConf);
     sparkConf.set(
         RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + 
RssClientConfig.RSS_RESUBMIT_STAGE, "true");
-    sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures));
   }
 
   @Test
diff --git 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
index 5e95bc009..497e232a2 100644
--- 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
+++ 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
@@ -35,9 +35,7 @@ import org.apache.uniffle.server.ShuffleServer;
 import org.apache.uniffle.server.ShuffleServerConf;
 import org.apache.uniffle.storage.util.StorageType;
 
-public class RSSStageResubmitTest extends SparkIntegrationTestBase {
-
-  private static int maxTaskFailures = 3;
+public class RSSStageResubmitTest extends SparkTaskFailureIntegrationTestBase {
 
   @BeforeAll
   public static void setupServers() throws Exception {
@@ -73,18 +71,11 @@ public class RSSStageResubmitTest extends 
SparkIntegrationTestBase {
     return result;
   }
 
-  @Override
-  protected SparkConf createSparkConf() {
-    return new SparkConf()
-        .setAppName(this.getClass().getSimpleName())
-        .setMaster(String.format("local[4,%d]", maxTaskFailures));
-  }
-
   @Override
   public void updateSparkConfCustomer(SparkConf sparkConf) {
+    super.updateSparkConfCustomer(sparkConf);
     sparkConf.set(
         RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + 
RssClientConfig.RSS_RESUBMIT_STAGE, "true");
-    sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures));
   }
 
   @Test
diff --git 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkTaskFailureIntegrationTestBase.java
 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkTaskFailureIntegrationTestBase.java
new file mode 100644
index 000000000..a1e0c37f4
--- /dev/null
+++ 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkTaskFailureIntegrationTestBase.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import org.apache.spark.SparkConf;
+
+public abstract class SparkTaskFailureIntegrationTestBase extends 
SparkIntegrationTestBase {
+
+  protected static final int maxTaskFailures = 3;
+
+  @Override
+  protected SparkConf createSparkConf() {
+    return new SparkConf()
+        .setAppName(this.getClass().getSimpleName())
+        .setMaster(String.format("local[4,%d]", maxTaskFailures));
+  }
+
+  @Override
+  public void updateSparkConfCustomer(SparkConf sparkConf) {
+    sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures));
+  }
+}

Reply via email to