This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 44eb4e5e5 [#134] improvement(spark3): Use taskId and attemptNo as
taskAttemptId (#1529)
44eb4e5e5 is described below
commit 44eb4e5e57a062b4f427ee7ba737bf213fdb442f
Author: Enrico Minack <[email protected]>
AuthorDate: Tue Feb 20 09:19:50 2024 +0100
[#134] improvement(spark3): Use taskId and attemptNo as taskAttemptId
(#1529)
### What changes were proposed in this pull request?
Use map index and task attempt number as the task attempt id in Spark3.
This requires to rework the bits of the blockId to maximize bit utilization
for Spark3:
https://github.com/apache/incubator-uniffle/blob/b924acacb0c555370a593f3a069187cf8b8081d7/common/src/main/java/org/apache/uniffle/common/util/Constants.java#L30-L35
Ideally, the `TASK_ATTEMPT_ID_MAX_LENGTH` is set equal to
`PARTITION_ID_MAX_LENGTH` + the number of bits required to store the largest
task attempt number. The largest task attempt number is `maxFailures - 1`, or
`maxFailures` if speculative execution is enabled (configured via
`spark.speculation` and disabled by default). The `maxFailures` is configured
via `spark.task.maxFailures` and defaults to 4. So by default, two bits are
required to store the largest attempt number and `TASK_A [...]
Example:
- with `PARTITION_ID_MAX_LENGTH = 20`, Uniffle supports 1,048,576 partitions
- requiring `TASK_ATTEMPT_ID_MAX_LENGTH = 22`
- allowing for `ATOMIC_INT_MAX_LENGTH = 21`.
### Why are the changes needed?
The map index (map partition id) is limited to the number of partitions of
a shuffle. The task attempt number is limited by the max number of failures
configured by `spark.task.maxFailures`, which defaults to 4. This provides us
an id that is unique per shuffe while not growing arbitrarily large as
`context.taskAttemptId` does.
Fix: #134
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit and integration tests.
---
.../apache/spark/shuffle/RssShuffleManager.java | 73 +++++-
.../spark/shuffle/RssShuffleManagerTest.java | 250 +++++++++++++++++++++
.../org/apache/uniffle/test/FailingTasksTest.java | 102 +++++++++
.../test/RSSStageDynamicServerReWriteTest.java | 13 +-
.../apache/uniffle/test/RSSStageResubmitTest.java | 13 +-
.../test/SparkTaskFailureIntegrationTestBase.java | 37 +++
6 files changed, 465 insertions(+), 23 deletions(-)
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 625089118..cb35ce3a2 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -78,6 +78,7 @@ import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.rpc.GrpcServer;
+import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
@@ -112,6 +113,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
private boolean dynamicConfEnabled = false;
private final ShuffleDataDistributionType dataDistributionType;
private final int maxConcurrencyPerPartitionToWrite;
+ private final int maxFailures;
+ private final boolean speculation;
private String user;
private String uuid;
private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
@@ -182,6 +185,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.dataDistributionType = getDataDistributionType(sparkConf);
this.maxConcurrencyPerPartitionToWrite =
RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
+ this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+ this.speculation = sparkConf.getBoolean("spark.speculation", false);
long retryIntervalMax =
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
int heartBeatThreadNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
this.dataTransferPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
@@ -307,6 +312,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
RssSparkConfig.toRssConf(sparkConf).get(RssClientConf.DATA_DISTRIBUTION_TYPE);
this.maxConcurrencyPerPartitionToWrite =
RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
+ this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+ this.speculation = sparkConf.getBoolean("spark.speculation", false);
this.heartbeatInterval =
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
this.heartbeatTimeout =
sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(),
heartbeatInterval / 2);
@@ -503,11 +510,18 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(),
rssHandle.getShuffleId());
+ long taskAttemptId =
+ getTaskAttemptId(
+ context.partitionId(),
+ context.attemptNumber(),
+ maxFailures,
+ speculation,
+ Constants.TASK_ATTEMPT_ID_MAX_LENGTH);
return new RssShuffleWriter<>(
rssHandle.getAppId(),
shuffleId,
taskId,
- context.taskAttemptId(),
+ taskAttemptId,
writeMetrics,
this,
sparkConf,
@@ -518,6 +532,63 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleHandleInfo);
}
+ /**
+ * Provides a task attempt id that is unique for a shuffle stage.
+ *
+ * <p>We are not using context.taskAttemptId() here as this is a
monotonically increasing number
+ * that is unique across the entire Spark app which can reach very large
numbers, which can
+ * practically reach LONG.MAX_VALUE. That would overflow the bits in the
block id.
+ *
+ * <p>Here we use the map index or task id, appended by the attempt number
per task. The map index
+ * is limited by the number of partitions of a stage. The attempt number per
task is limited /
+ * configured by spark.task.maxFailures (default: 4).
+ *
+ * @return a task attempt id unique for a shuffle stage
+ */
+ @VisibleForTesting
+ protected static long getTaskAttemptId(
+ int mapIndex, int attemptNo, int maxFailures, boolean speculation, int
maxTaskAttemptIdBits) {
+ // attempt number is zero based: 0, 1, …, maxFailures-1
+ // max maxFailures < 1 is not allowed but for safety, we interpret that as
maxFailures == 1
+ int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
+
+ // with speculative execution enabled we could observe +1 attempts
+ if (speculation) {
+ maxAttemptNo++;
+ }
+
+ if (attemptNo > maxAttemptNo) {
+ // this should never happen, if it does, our assumptions are wrong,
+ // and we risk overflowing the attempt number bits
+ throw new RssException(
+ "Observing attempt number "
+ + attemptNo
+ + " while maxFailures is set to "
+ + maxFailures
+ + (speculation ? " with speculation enabled" : "")
+ + ".");
+ }
+
+ int attemptBits = 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
+ int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
+ if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
+ throw new RssException(
+ "Observing mapIndex["
+ + mapIndex
+ + "] that would produce a taskAttemptId with "
+ + (mapIndexBits + attemptBits)
+ + " bits which is larger than the allowed "
+ + maxTaskAttemptIdBits
+ + " bits (maxFailures["
+ + maxFailures
+ + "], speculation["
+ + speculation
+ + "]). Please consider providing more bits for taskAttemptIds.");
+ }
+
+ return (long) mapIndex << attemptBits | attemptNo;
+ }
+
public void setPusherAppId(RssShuffleHandle rssShuffleHandle) {
// todo: this implement is tricky, we should refactor it
if (id.get() == null) {
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
index 64bd6f902..9150d6d30 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle;
+import java.util.Arrays;
+
import org.apache.spark.SparkConf;
import org.apache.spark.sql.internal.SQLConf;
import org.junit.jupiter.api.Test;
@@ -24,12 +26,14 @@ import org.junit.jupiter.api.Test;
import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.storage.util.StorageType;
import static
org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class RssShuffleManagerTest extends RssShuffleManagerTestBase {
@@ -84,6 +88,252 @@ public class RssShuffleManagerTest extends
RssShuffleManagerTestBase {
}
}
+ private long bits(String string) {
+ return Long.parseLong(string.replaceAll("[|]", ""), 2);
+ }
+
+ @Test
+ public void testGetTaskAttemptIdWithoutSpeculation() {
+ // the expected bits("xy|z") represents the expected Long in bit notation
where | is used to
+ // separate map index from attempt number, so merely for visualization
purposes
+
+ // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
+ for (int maxFailures : Arrays.asList(-1, 0, 1)) {
+ assertEquals(
+ bits("0000|"),
+ RssShuffleManager.getTaskAttemptId(0, 0, maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ assertEquals(
+ bits("0001|"),
+ RssShuffleManager.getTaskAttemptId(1, 0, maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ assertEquals(
+ bits("0010|"),
+ RssShuffleManager.getTaskAttemptId(2, 0, maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ }
+
+ // maxFailures of 2
+ assertEquals(bits("000|0"), RssShuffleManager.getTaskAttemptId(0, 0, 2,
false, 10));
+ assertEquals(bits("000|1"), RssShuffleManager.getTaskAttemptId(0, 1, 2,
false, 10));
+ assertEquals(bits("001|0"), RssShuffleManager.getTaskAttemptId(1, 0, 2,
false, 10));
+ assertEquals(bits("001|1"), RssShuffleManager.getTaskAttemptId(1, 1, 2,
false, 10));
+ assertEquals(bits("010|0"), RssShuffleManager.getTaskAttemptId(2, 0, 2,
false, 10));
+ assertEquals(bits("010|1"), RssShuffleManager.getTaskAttemptId(2, 1, 2,
false, 10));
+ assertEquals(bits("011|0"), RssShuffleManager.getTaskAttemptId(3, 0, 2,
false, 10));
+ assertEquals(bits("011|1"), RssShuffleManager.getTaskAttemptId(3, 1, 2,
false, 10));
+
+ // maxFailures of 3
+ assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 3,
false, 10));
+ assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 3,
false, 10));
+ assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 3,
false, 10));
+ assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 3,
false, 10));
+ assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 3,
false, 10));
+ assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 3,
false, 10));
+ assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 3,
false, 10));
+ assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 3,
false, 10));
+ assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 3,
false, 10));
+ assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 3,
false, 10));
+ assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 3,
false, 10));
+ assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 3,
false, 10));
+
+ // maxFailures of 4
+ assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 4,
false, 10));
+ assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 4,
false, 10));
+ assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 4,
false, 10));
+ assertEquals(bits("00|11"), RssShuffleManager.getTaskAttemptId(0, 3, 4,
false, 10));
+ assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 4,
false, 10));
+ assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 4,
false, 10));
+ assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 4,
false, 10));
+ assertEquals(bits("01|11"), RssShuffleManager.getTaskAttemptId(1, 3, 4,
false, 10));
+ assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 4,
false, 10));
+ assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 4,
false, 10));
+ assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 4,
false, 10));
+ assertEquals(bits("10|11"), RssShuffleManager.getTaskAttemptId(2, 3, 4,
false, 10));
+ assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 4,
false, 10));
+ assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 4,
false, 10));
+ assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 4,
false, 10));
+ assertEquals(bits("11|11"), RssShuffleManager.getTaskAttemptId(3, 3, 4,
false, 10));
+
+ // maxFailures of 5
+ assertEquals(bits("0|000"), RssShuffleManager.getTaskAttemptId(0, 0, 5,
false, 10));
+ assertEquals(bits("1|100"), RssShuffleManager.getTaskAttemptId(1, 4, 5,
false, 10));
+
+ // test with ints that overflow into signed int and long
+ assertEquals(
+ Integer.MAX_VALUE,
RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, false, 31));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 1 | 1,
+ RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 1, 2, false,
32));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 2 | 3,
+ RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 3, 4, false,
33));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 3 | 7,
+ RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 7, 8, false,
34));
+
+ // test with attemptNo >= maxFailures
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 1, -1,
false, 10));
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 1, 0,
false, 10));
+ for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
+ assertThrowsExactly(
+ RssException.class,
+ () -> RssShuffleManager.getTaskAttemptId(0, maxFailures,
maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ assertThrowsExactly(
+ RssException.class,
+ () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 1,
maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ assertThrowsExactly(
+ RssException.class,
+ () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 2,
maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ Exception e =
+ assertThrowsExactly(
+ RssException.class,
+ () ->
+ RssShuffleManager.getTaskAttemptId(0, maxFailures + 128,
maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ assertEquals(
+ "Observing attempt number "
+ + (maxFailures + 128)
+ + " while maxFailures is set to "
+ + maxFailures
+ + ".",
+ e.getMessage());
+ }
+
+ // test with mapIndex that would require more than maxTaskAttemptBits
+ Exception e =
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManager.getTaskAttemptId(256,
0, 3, true, 10));
+ assertEquals(
+ "Observing mapIndex[256] that would produce a taskAttemptId with 11
bits "
+ + "which is larger than the allowed 10 bits (maxFailures[3],
speculation[true]). "
+ + "Please consider providing more bits for taskAttemptIds.",
+ e.getMessage());
+ // check that a lower mapIndex works as expected
+ assertEquals(bits("11111111|00"), RssShuffleManager.getTaskAttemptId(255,
0, 3, true, 10));
+ }
+
+ @Test
+ public void testGetTaskAttemptIdWithSpeculation() {
+ // with speculation, we expect maxFailures+1 attempts
+
+ // the expected bits("xy|z") represents the expected Long in bit notation
where | is used to
+ // separate map index from attempt number, so merely for visualization
purposes
+
+ // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
+ for (int maxFailures : Arrays.asList(-1, 0, 1)) {
+ for (int attemptNo : Arrays.asList(0, 1)) {
+ assertEquals(
+ bits("0000|" + attemptNo),
+ RssShuffleManager.getTaskAttemptId(0, attemptNo, maxFailures,
true, 10),
+ "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+ assertEquals(
+ bits("0001|" + attemptNo),
+ RssShuffleManager.getTaskAttemptId(1, attemptNo, maxFailures,
true, 10),
+ "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+ assertEquals(
+ bits("0010|" + attemptNo),
+ RssShuffleManager.getTaskAttemptId(2, attemptNo, maxFailures,
true, 10),
+ "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+ }
+ }
+
+ // maxFailures of 2
+ assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 2,
true, 10));
+ assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 2,
true, 10));
+ assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 2,
true, 10));
+ assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 2,
true, 10));
+ assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 2,
true, 10));
+ assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 2,
true, 10));
+ assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 2,
true, 10));
+ assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 2,
true, 10));
+ assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 2,
true, 10));
+ assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 2,
true, 10));
+ assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 2,
true, 10));
+ assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 2,
true, 10));
+
+ // maxFailures of 3
+ assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 3,
true, 10));
+ assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 3,
true, 10));
+ assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 3,
true, 10));
+ assertEquals(bits("00|11"), RssShuffleManager.getTaskAttemptId(0, 3, 3,
true, 10));
+ assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 3,
true, 10));
+ assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 3,
true, 10));
+ assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 3,
true, 10));
+ assertEquals(bits("01|11"), RssShuffleManager.getTaskAttemptId(1, 3, 3,
true, 10));
+ assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 3,
true, 10));
+ assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 3,
true, 10));
+ assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 3,
true, 10));
+ assertEquals(bits("10|11"), RssShuffleManager.getTaskAttemptId(2, 3, 3,
true, 10));
+ assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 3,
true, 10));
+ assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 3,
true, 10));
+ assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 3,
true, 10));
+ assertEquals(bits("11|11"), RssShuffleManager.getTaskAttemptId(3, 3, 3,
true, 10));
+
+ // maxFailures of 4
+ assertEquals(bits("0|000"), RssShuffleManager.getTaskAttemptId(0, 0, 4,
true, 10));
+ assertEquals(bits("1|100"), RssShuffleManager.getTaskAttemptId(1, 4, 4,
true, 10));
+
+ // test with ints that overflow into signed int and long
+ assertEquals(
+ (long) Integer.MAX_VALUE << 1,
+ RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, true, 32));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 1 | 1,
+ RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 1, 1, true, 32));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 2 | 3,
+ RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 3, 3, true, 33));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 3 | 7,
+ RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 7, 7, true, 34));
+
+ // test with attemptNo > maxFailures (attemptNo == maxFailures allowed for
speculation enabled)
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 2, -1,
true, 10));
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 2, 0,
true, 10));
+ for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
+ assertThrowsExactly(
+ RssException.class,
+ () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 1,
maxFailures, true, 10),
+ String.valueOf(maxFailures));
+ assertThrowsExactly(
+ RssException.class,
+ () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 2,
maxFailures, true, 10),
+ String.valueOf(maxFailures));
+ Exception e =
+ assertThrowsExactly(
+ RssException.class,
+ () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 128,
maxFailures, true, 10),
+ String.valueOf(maxFailures));
+ assertEquals(
+ "Observing attempt number "
+ + (maxFailures + 128)
+ + " while maxFailures is set to "
+ + maxFailures
+ + " with speculation enabled.",
+ e.getMessage());
+ }
+
+ // test with mapIndex that would require more than maxTaskAttemptBits
+ Exception e =
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManager.getTaskAttemptId(256,
0, 4, false, 10));
+ assertEquals(
+ "Observing mapIndex[256] that would produce a taskAttemptId with 11
bits "
+ + "which is larger than the allowed 10 bits (maxFailures[4],
speculation[false]). "
+ + "Please consider providing more bits for taskAttemptIds.",
+ e.getMessage());
+ // check that a lower mapIndex works as expected
+ assertEquals(bits("11111111|00"), RssShuffleManager.getTaskAttemptId(255,
0, 4, false, 10));
+ }
+
@Test
public void testCreateShuffleManagerServer() {
setupMockedRssShuffleUtils(StatusCode.SUCCESS);
diff --git
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/FailingTasksTest.java
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/FailingTasksTest.java
new file mode 100644
index 000000000..e9a0818aa
--- /dev/null
+++
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/FailingTasksTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import com.google.common.collect.Maps;
+import org.apache.spark.TaskContext;
+import org.apache.spark.api.java.function.MapPartitionsFunction;
+import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.SparkSession;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.storage.util.StorageType;
+
+// This test has all tasks fail twice, the third attempt succeeds.
+// The failing attempts all provide zeros to the shuffle step, while the
succeeding attempts
+// provide the actual non-zero integers (actually only one zero). If blocks
from the failing
+// attempts leak into the read shuffle data, we would see those zeros and fail
when comparing
+// to without RSS.
+public class FailingTasksTest extends SparkTaskFailureIntegrationTestBase {
+
+ @BeforeAll
+ public static void setupServers() throws Exception {
+ shutdownServers();
+ CoordinatorConf coordinatorConf = getCoordinatorConf();
+ Map<String, String> dynamicConf = Maps.newHashMap();
+ dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(),
HDFS_URI + "rss/test");
+ dynamicConf.put(
+ RssSparkConfig.RSS_STORAGE_TYPE.key(),
StorageType.MEMORY_LOCALFILE_HDFS.name());
+ addDynamicConf(coordinatorConf, dynamicConf);
+ createCoordinatorServer(coordinatorConf);
+ ShuffleServerConf shuffleServerConf = getShuffleServerConf();
+ createShuffleServer(shuffleServerConf);
+ startServers();
+ }
+
+ @Override
+ Map runTest(SparkSession spark, String fileName) throws Exception {
+ int n = 1000000;
+ return spark.range(0, n, 1, 4)
+ .mapPartitions(
+ (MapPartitionsFunction<Long, Long>)
+ it ->
+ new Iterator<Long>() {
+ final TaskContext context = TaskContext.get();
+
+ @Override
+ public boolean hasNext() {
+ // the first two attempts fail in the end
+ return context.attemptNumber() < 2 || it.hasNext();
+ }
+
+ @Override
+ public Long next() {
+ if (it.hasNext()) {
+ Long next = it.next();
+ // the failing attempt returns only zeros
+ if (context.attemptNumber() < 2) {
+ return 0L;
+ } else {
+ return next;
+ }
+ } else {
+ throw new RuntimeException("let this task fail");
+ }
+ }
+ },
+ Encoders.LONG())
+ .repartition(3, new Column("value"))
+ .mapPartitions((MapPartitionsFunction<Long, Long>) it -> it,
Encoders.LONG())
+ .collectAsList().stream()
+ .collect(Collectors.toMap(v -> v, v -> v));
+ }
+
+ @Test
+ public void testFailedTasks() throws Exception {
+ run();
+ }
+}
diff --git
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java
index 3e739690f..e7fed902f 100644
---
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java
+++
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java
@@ -39,12 +39,10 @@ import org.apache.uniffle.server.ShuffleServer;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.storage.util.StorageType;
-public class RSSStageDynamicServerReWriteTest extends SparkIntegrationTestBase
{
+public class RSSStageDynamicServerReWriteTest extends
SparkTaskFailureIntegrationTestBase {
private static final Logger LOG =
LoggerFactory.getLogger(RSSStageDynamicServerReWriteTest.class);
- private static int maxTaskFailures = 3;
-
@BeforeAll
public static void setupServers(@TempDir File tmpDir) throws Exception {
CoordinatorConf coordinatorConf = getCoordinatorConf();
@@ -96,18 +94,11 @@ public class RSSStageDynamicServerReWriteTest extends
SparkIntegrationTestBase {
return result;
}
- @Override
- protected SparkConf createSparkConf() {
- return new SparkConf()
- .setAppName(this.getClass().getSimpleName())
- .setMaster(String.format("local[4,%d]", maxTaskFailures));
- }
-
@Override
public void updateSparkConfCustomer(SparkConf sparkConf) {
+ super.updateSparkConfCustomer(sparkConf);
sparkConf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX +
RssClientConfig.RSS_RESUBMIT_STAGE, "true");
- sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures));
}
@Test
diff --git
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
index 5e95bc009..497e232a2 100644
---
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
+++
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
@@ -35,9 +35,7 @@ import org.apache.uniffle.server.ShuffleServer;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.storage.util.StorageType;
-public class RSSStageResubmitTest extends SparkIntegrationTestBase {
-
- private static int maxTaskFailures = 3;
+public class RSSStageResubmitTest extends SparkTaskFailureIntegrationTestBase {
@BeforeAll
public static void setupServers() throws Exception {
@@ -73,18 +71,11 @@ public class RSSStageResubmitTest extends
SparkIntegrationTestBase {
return result;
}
- @Override
- protected SparkConf createSparkConf() {
- return new SparkConf()
- .setAppName(this.getClass().getSimpleName())
- .setMaster(String.format("local[4,%d]", maxTaskFailures));
- }
-
@Override
public void updateSparkConfCustomer(SparkConf sparkConf) {
+ super.updateSparkConfCustomer(sparkConf);
sparkConf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX +
RssClientConfig.RSS_RESUBMIT_STAGE, "true");
- sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures));
}
@Test
diff --git
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkTaskFailureIntegrationTestBase.java
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkTaskFailureIntegrationTestBase.java
new file mode 100644
index 000000000..a1e0c37f4
--- /dev/null
+++
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkTaskFailureIntegrationTestBase.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import org.apache.spark.SparkConf;
+
+public abstract class SparkTaskFailureIntegrationTestBase extends
SparkIntegrationTestBase {
+
+ protected static final int maxTaskFailures = 3;
+
+ @Override
+ protected SparkConf createSparkConf() {
+ return new SparkConf()
+ .setAppName(this.getClass().getSimpleName())
+ .setMaster(String.format("local[4,%d]", maxTaskFailures));
+ }
+
+ @Override
+ public void updateSparkConfCustomer(SparkConf sparkConf) {
+ sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures));
+ }
+}