This is an automated email from the ASF dual-hosted git repository.
zhengchenyu pushed a commit to branch branch-0.9
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/branch-0.9 by this push:
new 664c4f248 [#1398] fix(mr)(tez): Make attempId computable and move it
to taskAttemptId in BlockId layout. (#2027)
664c4f248 is described below
commit 664c4f248c7274660b2f55c9eb981ee9caf3487e
Author: QI Jiale <[email protected]>
AuthorDate: Mon Aug 12 17:41:59 2024 +0800
[#1398] fix(mr)(tez): Make attempId computable and move it to taskAttemptId
in BlockId layout. (#2027)
### What changes were proposed in this pull request?
Before this PR, in MR and TEZ engine:
1. attemptId is in sequenceNo of BlockId instead of taskAttemptId.
2. attempId is fixed 6 bit.
After this PR:
1. attemptId is in taskAttemptId. This is more reasonable.
2. attempId is calculated from max num of allowed failures and whether
speculative execution is enabled.
### Why are the changes needed?
Fix: #1398
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing UT and integrated tests.
---
.../hadoop/mapred/RssMapOutputCollector.java | 4 +-
.../org/apache/hadoop/mapreduce/RssMRUtils.java | 87 +++++++++++----------
.../mapreduce/task/reduce/RssEventFetcher.java | 2 +-
.../hadoop/mapred/SortWriteBufferManagerTest.java | 12 +--
.../apache/hadoop/mapreduce/RssMRUtilsTest.java | 17 ++--
.../mapreduce/task/reduce/EventFetcherTest.java | 28 +++----
.../hadoop/mapreduce/task/reduce/FetcherTest.java | 5 +-
.../shuffle/manager/RssShuffleManagerBase.java | 30 ++-----
.../shuffle/manager/RssShuffleManagerBaseTest.java | 40 ----------
.../java/org/apache/tez/common/RssTezUtils.java | 91 ++++++++++++----------
.../common/shuffle/impl/RssShuffleManager.java | 6 +-
.../common/shuffle/impl/RssTezFetcherTask.java | 8 +-
.../orderedgrouped/RssShuffleScheduler.java | 6 +-
.../library/common/sort/impl/RssSorter.java | 4 +-
.../library/common/sort/impl/RssUnSorter.java | 4 +-
.../output/RssOrderedPartitionedKVOutput.java | 4 +-
.../library/output/RssUnorderedKVOutput.java | 4 +-
.../output/RssUnorderedPartitionedKVOutput.java | 4 +-
.../org/apache/tez/common/RssTezUtilsTest.java | 10 +--
.../library/common/sort/impl/RssSorterTest.java | 5 +-
.../library/common/sort/impl/RssUnSorterTest.java | 5 +-
.../apache/uniffle/client/util/ClientUtils.java | 17 ++++
.../org/apache/uniffle/client/ClientUtilsTest.java | 42 ++++++++++
.../org/apache/uniffle/common/util/BlockId.java | 4 +-
.../uniffle/test/TezWordCountWithFailuresTest.java | 2 +-
.../uniffle/server/buffer/BufferTestBase.java | 2 +-
.../handler/impl/HadoopShuffleReadHandlerTest.java | 2 +-
27 files changed, 247 insertions(+), 198 deletions(-)
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
index 3acf0b417..30f98c52e 100644
---
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
+++
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
@@ -100,8 +100,8 @@ public class RssMapOutputCollector<K extends Object, V
extends Object>
ApplicationAttemptId applicationAttemptId =
RssMRUtils.getApplicationAttemptId();
String appId = applicationAttemptId.toString();
long taskAttemptId =
- RssMRUtils.convertTaskAttemptIdToLong(
- mapTask.getTaskID(), applicationAttemptId.getAttemptId());
+ RssMRUtils.createRssTaskAttemptId(
+ mapTask.getTaskID(), applicationAttemptId.getAttemptId(),
mrJobConf);
double sendThreshold =
RssMRUtils.getDouble(
rssJobConf,
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
index f1220486b..9012e618e 100644
--- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
+++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
@@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;
@@ -44,37 +45,61 @@ public class RssMRUtils {
private static final Logger LOG = LoggerFactory.getLogger(RssMRUtils.class);
private static final BlockIdLayout LAYOUT = BlockIdLayout.DEFAULT;
- private static final int MAX_ATTEMPT_LENGTH = 6;
- private static final int MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1;
- private static final int MAX_SEQUENCE_NO =
- (1 << (LAYOUT.sequenceNoBits - MAX_ATTEMPT_LENGTH)) - 1;
-
- // Class TaskAttemptId have two field id and mapId, rss taskAttemptID have
21 bits,
- // mapId is 19 bits, id is 2 bits. MR have a trick logic, taskAttemptId will
increase
- // 1000 * (appAttemptId - 1), so we will decrease it.
- public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID,
int appAttemptId) {
- int lowBytes = taskAttemptID.getTaskID().getId();
- if (lowBytes > LAYOUT.maxTaskAttemptId) {
- throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " +
lowBytes + " exceed");
- }
+
+ // Class TaskAttemptId have two field id and mapId. MR have a trick logic,
taskAttemptId will
+ // increase 1000 * (appAttemptId - 1), so we will decrease it.
+ public static int createRssTaskAttemptId(
+ TaskAttemptID taskAttemptID, int appAttemptId, int maxAttemptNo) {
+ int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
+
if (appAttemptId < 1) {
throw new RssException("appAttemptId " + appAttemptId + " is wrong");
}
- int highBytes = taskAttemptID.getId() - (appAttemptId - 1) * 1000;
- if (highBytes > MAX_ATTEMPT_ID || highBytes < 0) {
+ int attemptId = taskAttemptID.getId() - (appAttemptId - 1) * 1000;
+ if (attemptId > maxAttemptNo || attemptId < 0) {
throw new RssException(
- "TaskAttempt " + taskAttemptID + " high bytes " + highBytes + "
exceed");
+ "TaskAttempt " + taskAttemptID + " attemptId " + attemptId + "
exceed " + maxAttemptNo);
}
- return LAYOUT.getBlockId(highBytes, 0, lowBytes);
+ int taskId = taskAttemptID.getTaskID().getId();
+
+ int mapIndexBits = ClientUtils.getNumberOfSignificantBits(taskId);
+ if (mapIndexBits + attemptBits > LAYOUT.taskAttemptIdBits) {
+ throw new RssException(
+ "Observing taskId["
+ + taskId
+ + "] that would produce a taskAttemptId with "
+ + (mapIndexBits + attemptBits)
+ + " bits which is larger than the allowed "
+ + LAYOUT.taskAttemptIdBits
+ + "]). Please consider providing more bits for taskAttemptIds.");
+ }
+
+ return (taskId << attemptBits) | attemptId;
+ }
+
+ public static int createRssTaskAttemptId(
+ TaskAttemptID taskAttemptID, int appAttemptId, int maxFailures, boolean
speculation) {
+ int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation);
+ return createRssTaskAttemptId(taskAttemptID, appAttemptId, maxAttemptNo);
+ }
+
+ public static int createRssTaskAttemptId(
+ TaskAttemptID taskAttemptID, int appAttemptId, Configuration conf) {
+ int maxFailures = conf.getInt(MRJobConfig.MAP_MAX_ATTEMPTS, 4);
+ boolean speculation = conf.getBoolean(MRJobConfig.MAP_SPECULATIVE, true);
+ return createRssTaskAttemptId(taskAttemptID, appAttemptId, maxFailures,
speculation);
}
public static TaskAttemptID createMRTaskAttemptId(
- JobID jobID, TaskType taskType, long rssTaskAttemptId, int appAttemptId)
{
+ JobID jobID, TaskType taskType, long rssTaskAttemptId, int appAttemptId,
int maxAttemptNo) {
+ int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
if (appAttemptId < 1) {
throw new RssException("appAttemptId " + appAttemptId + " is wrong");
}
- TaskID taskID = new TaskID(jobID, taskType,
LAYOUT.getTaskAttemptId(rssTaskAttemptId));
- int id = LAYOUT.getSequenceNo(rssTaskAttemptId) + 1000 * (appAttemptId -
1);
+ int task = (int) rssTaskAttemptId >> attemptBits;
+ int attempt = (int) rssTaskAttemptId & ((1 << attemptBits) - 1);
+ TaskID taskID = new TaskID(jobID, taskType, task);
+ int id = attempt + 1000 * (appAttemptId - 1);
return new TaskAttemptID(taskID, id);
}
@@ -228,27 +253,11 @@ public class RssMRUtils {
}
public static long getBlockId(int partitionId, long taskAttemptId, int
nextSeqNo) {
- long attemptId = taskAttemptId >> (LAYOUT.partitionIdBits +
LAYOUT.taskAttemptIdBits);
- if (attemptId < 0 || attemptId > MAX_ATTEMPT_ID) {
- throw new RssException(
- "Can't support attemptId [" + attemptId + "], the max value should
be " + MAX_ATTEMPT_ID);
- }
- if (nextSeqNo < 0 || nextSeqNo > MAX_SEQUENCE_NO) {
- throw new RssException(
- "Can't support sequence [" + nextSeqNo + "], the max value should be
" + MAX_SEQUENCE_NO);
- }
-
- int atomicInt = (int) ((nextSeqNo << MAX_ATTEMPT_LENGTH) + attemptId);
- long taskId =
- taskAttemptId - (attemptId << (LAYOUT.partitionIdBits +
LAYOUT.taskAttemptIdBits));
-
- return LAYOUT.getBlockId(atomicInt, partitionId, taskId);
+ return LAYOUT.getBlockId(nextSeqNo, partitionId, taskAttemptId);
}
- public static long getTaskAttemptId(long blockId) {
- int mapId = LAYOUT.getTaskAttemptId(blockId);
- int attemptId = LAYOUT.getSequenceNo(blockId) & MAX_ATTEMPT_ID;
- return LAYOUT.getBlockId(attemptId, 0, mapId);
+ public static int getTaskAttemptId(long blockId) {
+ return LAYOUT.getTaskAttemptId(blockId);
}
public static int estimateTaskConcurrency(JobConf jobConf) {
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java
index 397d45fb2..5fb1f0fe1 100644
---
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java
+++
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java
@@ -75,7 +75,7 @@ public class RssEventFetcher<K, V> {
String errMsg = "TaskAttemptIDs are inconsistent with map tasks";
for (TaskAttemptID taskAttemptID : successMaps) {
if (!obsoleteMaps.contains(taskAttemptID)) {
- long rssTaskId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID,
appAttemptId);
+ long rssTaskId = RssMRUtils.createRssTaskAttemptId(taskAttemptID,
appAttemptId, jobConf);
int mapIndex = taskAttemptID.getTaskID().getId();
// There can be multiple successful attempts on same map task.
// So we only need to accept one of them.
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 69a8215d5..17350b9b9 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -75,7 +75,7 @@ public class SortWriteBufferManagerTest {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
- 1L,
+ 1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
@@ -139,7 +139,7 @@ public class SortWriteBufferManagerTest {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
100,
- 1L,
+ 1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
@@ -191,7 +191,7 @@ public class SortWriteBufferManagerTest {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
- 1L,
+ 1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
@@ -243,7 +243,7 @@ public class SortWriteBufferManagerTest {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
- 1L,
+ 1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
@@ -310,7 +310,7 @@ public class SortWriteBufferManagerTest {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
- 1L,
+ 1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
@@ -389,7 +389,7 @@ public class SortWriteBufferManagerTest {
SortWriteBufferManager<Text, IntWritable> manager =
new SortWriteBufferManager<Text, IntWritable>(
10240,
- 1L,
+ 1,
10,
keySerializer,
valueSerializer,
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java
index 2d3710ca1..31e3590d2 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java
@@ -45,20 +45,21 @@ public class RssMRUtilsTest {
TaskAttemptID mrTaskAttemptId = new TaskAttemptID(taskId, 3);
boolean isException = false;
try {
- RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
+ RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
} catch (RssException e) {
isException = true;
}
assertTrue(isException);
- taskAttemptId = (1 << 20) + 0x123;
- mrTaskAttemptId = RssMRUtils.createMRTaskAttemptId(new JobID(),
TaskType.MAP, taskAttemptId, 1);
- long testId = RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
+ taskAttemptId = (0x123 << 3) + 1;
+ mrTaskAttemptId =
+ RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP,
taskAttemptId, 1, 4);
+ int testId = RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
assertEquals(taskAttemptId, testId);
- TaskID taskID = new TaskID(new org.apache.hadoop.mapred.JobID(),
TaskType.MAP, (int) (1 << 21));
+ TaskID taskID = new TaskID(new org.apache.hadoop.mapred.JobID(),
TaskType.MAP, 1 << 21);
mrTaskAttemptId = new TaskAttemptID(taskID, 2);
isException = false;
try {
- RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
+ RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
} catch (RssException e) {
isException = true;
}
@@ -70,7 +71,7 @@ public class RssMRUtilsTest {
JobID jobID = new JobID();
TaskID taskId = new TaskID(jobID, TaskType.MAP, 233);
TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1);
- long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID,
1);
+ long taskAttemptId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, 1,
4);
long blockId = RssMRUtils.getBlockId(1, taskAttemptId, 0);
long newTaskAttemptId = RssMRUtils.getTaskAttemptId(blockId);
assertEquals(taskAttemptId, newTaskAttemptId);
@@ -85,7 +86,7 @@ public class RssMRUtilsTest {
JobID jobID = new JobID();
TaskID taskId = new TaskID(jobID, TaskType.MAP, 233);
TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1);
- long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID,
1);
+ long taskAttemptId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, 1,
4);
long mask = (1L << layout.partitionIdBits) - 1;
for (int partitionId = 0; partitionId <= 3000; partitionId++) {
for (int seqNo = 0; seqNo <= 10; seqNo++) {
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java
index 8e8ab8f4e..6dfb81f45 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java
@@ -59,8 +59,8 @@ public class EventFetcherTest {
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
- RssMRUtils.convertTaskAttemptIdToLong(
- new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+ RssMRUtils.createRssTaskAttemptId(
+ new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
@@ -89,8 +89,8 @@ public class EventFetcherTest {
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
- RssMRUtils.convertTaskAttemptIdToLong(
- new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+ RssMRUtils.createRssTaskAttemptId(
+ new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
@@ -121,8 +121,8 @@ public class EventFetcherTest {
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
- RssMRUtils.convertTaskAttemptIdToLong(
- new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+ RssMRUtils.createRssTaskAttemptId(
+ new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
Roaring64NavigableMap taskIdBitmap = ef.fetchAllRssTaskIds();
@@ -146,8 +146,8 @@ public class EventFetcherTest {
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
- RssMRUtils.convertTaskAttemptIdToLong(
- new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+ RssMRUtils.createRssTaskAttemptId(
+ new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
IllegalStateException ex =
@@ -172,8 +172,8 @@ public class EventFetcherTest {
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
- RssMRUtils.convertTaskAttemptIdToLong(
- new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+ RssMRUtils.createRssTaskAttemptId(
+ new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
IllegalStateException ex =
@@ -205,14 +205,14 @@ public class EventFetcherTest {
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
if (!tipFailed.contains(mapIndex) && !obsoleted.contains(mapIndex)) {
long rssTaskId =
- RssMRUtils.convertTaskAttemptIdToLong(
- new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+ RssMRUtils.createRssTaskAttemptId(
+ new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1,
4);
expected.addLong(rssTaskId);
}
if (obsoleted.contains(mapIndex)) {
long rssTaskId =
- RssMRUtils.convertTaskAttemptIdToLong(
- new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 1), 1);
+ RssMRUtils.createRssTaskAttemptId(
+ new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 1), 1,
4);
expected.addLong(rssTaskId);
}
}
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index eeae036e0..8f695052e 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -290,7 +290,8 @@ public class FetcherTest {
null,
new Progress(),
new MROutputFiles());
- TaskAttemptID taskAttemptID = RssMRUtils.createMRTaskAttemptId(new
JobID(), TaskType.MAP, 1, 1);
+ TaskAttemptID taskAttemptID =
+ RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, 1, 1, 4);
byte[] buffer = new byte[10];
MapOutput mapOutput1 = merger.reserve(taskAttemptID, 10, 1);
RssBypassWriter.write(mapOutput1, buffer);
@@ -349,7 +350,7 @@ public class FetcherTest {
SortWriteBufferManager<Text, Text> manager =
new SortWriteBufferManager(
10240,
- 1L,
+ 1,
10,
serializationFactory.getSerializer(Text.class),
serializationFactory.getSerializer(Text.class),
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index 23922644a..651286923 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -45,6 +45,7 @@ import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.request.RssFetchClientConfRequest;
import org.apache.uniffle.client.response.RssFetchClientConfResponse;
+import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.config.ConfigOption;
@@ -112,8 +113,10 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
+ maxPartitions);
}
- int attemptIdBits = getAttemptIdBits(getMaxAttemptNo(maxFailures,
speculation));
- int partitionIdBits = 32 - Integer.numberOfLeadingZeros(maxPartitions -
1); // [1..31]
+ int attemptIdBits =
+ ClientUtils.getNumberOfSignificantBits(
+ ClientUtils.getMaxAttemptNo(maxFailures, speculation));
+ int partitionIdBits = ClientUtils.getNumberOfSignificantBits(maxPartitions
- 1); // [1..31]
int taskAttemptIdBits = partitionIdBits + attemptIdBits; //
[1+attemptIdBits..31+attemptIdBits]
int sequenceNoBits = 63 - partitionIdBits - taskAttemptIdBits; //
[1-attemptIdBits..61]
@@ -252,23 +255,6 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
}
}
- protected static int getMaxAttemptNo(int maxFailures, boolean speculation) {
- // attempt number is zero based: 0, 1, …, maxFailures-1
- // max maxFailures < 1 is not allowed but for safety, we interpret that as
maxFailures == 1
- int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
-
- // with speculative execution enabled we could observe +1 attempts
- if (speculation) {
- maxAttemptNo++;
- }
-
- return maxAttemptNo;
- }
-
- protected static int getAttemptIdBits(int maxAttemptNo) {
- return 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
- }
-
/** See static overload of this method. */
public abstract long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo);
@@ -287,8 +273,8 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
*/
protected static long getTaskAttemptIdForBlockId(
int mapIndex, int attemptNo, int maxFailures, boolean speculation, int
maxTaskAttemptIdBits) {
- int maxAttemptNo = getMaxAttemptNo(maxFailures, speculation);
- int attemptBits = getAttemptIdBits(maxAttemptNo);
+ int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation);
+ int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
if (attemptNo > maxAttemptNo) {
// this should never happen, if it does, our assumptions are wrong,
@@ -302,7 +288,7 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
+ ".");
}
- int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
+ int mapIndexBits = ClientUtils.getNumberOfSignificantBits(mapIndex);
if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
throw new RssException(
"Observing mapIndex["
diff --git
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
index fffb7af3f..610b42c8c 100644
---
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
+++
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
@@ -370,46 +370,6 @@ public class RssShuffleManagerBaseTest {
assertTrue(e.getMessage().startsWith("All block id bit config keys must be
provided "));
}
- @Test
- public void testGetMaxAttemptNo() {
- // without speculation
- assertEquals(0, RssShuffleManagerBase.getMaxAttemptNo(-1, false));
- assertEquals(0, RssShuffleManagerBase.getMaxAttemptNo(0, false));
- assertEquals(0, RssShuffleManagerBase.getMaxAttemptNo(1, false));
- assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(2, false));
- assertEquals(2, RssShuffleManagerBase.getMaxAttemptNo(3, false));
- assertEquals(3, RssShuffleManagerBase.getMaxAttemptNo(4, false));
- assertEquals(4, RssShuffleManagerBase.getMaxAttemptNo(5, false));
- assertEquals(1023, RssShuffleManagerBase.getMaxAttemptNo(1024, false));
-
- // with speculation
- assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(-1, true));
- assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(0, true));
- assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(1, true));
- assertEquals(2, RssShuffleManagerBase.getMaxAttemptNo(2, true));
- assertEquals(3, RssShuffleManagerBase.getMaxAttemptNo(3, true));
- assertEquals(4, RssShuffleManagerBase.getMaxAttemptNo(4, true));
- assertEquals(5, RssShuffleManagerBase.getMaxAttemptNo(5, true));
- assertEquals(1024, RssShuffleManagerBase.getMaxAttemptNo(1024, true));
- }
-
- @Test
- public void testGetAttemptIdBits() {
- assertEquals(0, RssShuffleManagerBase.getAttemptIdBits(0));
- assertEquals(1, RssShuffleManagerBase.getAttemptIdBits(1));
- assertEquals(2, RssShuffleManagerBase.getAttemptIdBits(2));
- assertEquals(2, RssShuffleManagerBase.getAttemptIdBits(3));
- assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(4));
- assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(5));
- assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(6));
- assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(7));
- assertEquals(4, RssShuffleManagerBase.getAttemptIdBits(8));
- assertEquals(4, RssShuffleManagerBase.getAttemptIdBits(9));
- assertEquals(10, RssShuffleManagerBase.getAttemptIdBits(1023));
- assertEquals(11, RssShuffleManagerBase.getAttemptIdBits(1024));
- assertEquals(11, RssShuffleManagerBase.getAttemptIdBits(1025));
- }
-
private long bits(String string) {
return Long.parseLong(string.replaceAll("[|]", ""), 2);
}
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
index c9a762fa9..2545822b9 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
@@ -24,6 +24,7 @@ import java.util.Set;
import com.google.common.base.Preconditions;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
@@ -51,6 +52,7 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;
@@ -60,10 +62,6 @@ public class RssTezUtils {
private static final Logger LOG = LoggerFactory.getLogger(RssTezUtils.class);
private static final BlockIdLayout LAYOUT = BlockIdLayout.DEFAULT;
- private static final int MAX_ATTEMPT_LENGTH = 6;
- private static final int MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1;
- private static final int MAX_SEQUENCE_NO =
- (1 << (LAYOUT.sequenceNoBits - MAX_ATTEMPT_LENGTH)) - 1;
public static final String HOST_NAME = "hostname";
@@ -159,32 +157,11 @@ public class RssTezUtils {
}
public static long getBlockId(int partitionId, long taskAttemptId, int
nextSeqNo) {
- LOG.info(
- "GetBlockId, partitionId:{}, taskAttemptId:{}, nextSeqNo:{}",
- partitionId,
- taskAttemptId,
- nextSeqNo);
- long attemptId = taskAttemptId >> (LAYOUT.partitionIdBits +
LAYOUT.taskAttemptIdBits);
- if (attemptId < 0 || attemptId > MAX_ATTEMPT_ID) {
- throw new RssException(
- "Can't support attemptId [" + attemptId + "], the max value should
be " + MAX_ATTEMPT_ID);
- }
- if (nextSeqNo < 0 || nextSeqNo > MAX_SEQUENCE_NO) {
- throw new RssException(
- "Can't support sequence [" + nextSeqNo + "], the max value should be
" + MAX_SEQUENCE_NO);
- }
-
- int atomicInt = (int) ((nextSeqNo << MAX_ATTEMPT_LENGTH) + attemptId);
- long taskId =
- taskAttemptId - (attemptId << (LAYOUT.partitionIdBits +
LAYOUT.taskAttemptIdBits));
-
- return LAYOUT.getBlockId(atomicInt, partitionId, taskId);
+ return LAYOUT.getBlockId(nextSeqNo, partitionId, taskAttemptId);
}
- public static long getTaskAttemptId(long blockId) {
- int mapId = LAYOUT.getTaskAttemptId(blockId);
- int attemptId = LAYOUT.getSequenceNo(blockId) & MAX_ATTEMPT_ID;
- return LAYOUT.getBlockId(attemptId, 0, mapId);
+ public static int getTaskAttemptId(long blockId) {
+ return LAYOUT.getTaskAttemptId(blockId);
}
public static int estimateTaskConcurrency(Configuration jobConf, int mapNum,
int reduceNum) {
@@ -276,23 +253,55 @@ public class RssTezUtils {
}
}
- public static long convertTaskAttemptIdToLong(TezTaskAttemptID
taskAttemptID) {
- int lowBytes = taskAttemptID.getTaskID().getId();
- if (lowBytes > LAYOUT.maxTaskAttemptId) {
- throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " +
lowBytes + " exceed");
+ public static int createRssTaskAttemptId(TezTaskAttemptID taskAttemptId, int
maxAttemptNo) {
+ int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
+
+ int attemptId = taskAttemptId.getId();
+ if (attemptId > maxAttemptNo || attemptId < 0) {
+ throw new RssException(
+ "TaskAttempt " + taskAttemptId + " attemptId " + attemptId + "
exceed");
}
- int highBytes = taskAttemptID.getId();
- if (highBytes > MAX_ATTEMPT_ID || highBytes < 0) {
+ int taskId = taskAttemptId.getTaskID().getId();
+
+ int mapIndexBits = ClientUtils.getNumberOfSignificantBits(taskId);
+ if (mapIndexBits + attemptBits > LAYOUT.taskAttemptIdBits) {
throw new RssException(
- "TaskAttempt " + taskAttemptID + " high bytes " + highBytes + "
exceed.");
+ "Observing taskId["
+ + taskId
+ + "] that would produce a taskAttemptId with "
+ + (mapIndexBits + attemptBits)
+ + " bits which is larger than the allowed "
+ + LAYOUT.taskAttemptIdBits
+ + "]). Please consider providing more bits for taskAttemptIds.");
}
- long id = LAYOUT.getBlockId(highBytes, 0, lowBytes);
- LOG.info("ConvertTaskAttemptIdToLong taskAttemptID:{}, id is {}, .",
taskAttemptID, id);
+
+ int id = (taskId << attemptBits) | attemptId;
+ LOG.info("createRssTaskAttemptId taskAttemptId:{}, id is {}, .",
taskAttemptId, id);
return id;
}
+ public static int createRssTaskAttemptId(TezTaskAttemptID taskAttemptId,
Configuration conf) {
+ int maxAttemptNo = getMaxAttemptNo(conf);
+ return createRssTaskAttemptId(taskAttemptId, maxAttemptNo);
+ }
+
+ public static int getMaxAttemptNo(Configuration conf) {
+ int maxFailures =
+ conf.getInt(
+ TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS,
+ TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS_DEFAULT);
+ boolean speculation =
+ conf.getBoolean(
+ TezConfiguration.TEZ_AM_SPECULATION_ENABLED,
+ TezConfiguration.TEZ_AM_SPECULATION_ENABLED_DEFAULT);
+ return ClientUtils.getMaxAttemptNo(maxFailures, speculation);
+ }
+
public static Roaring64NavigableMap fetchAllRssTaskIds(
- Set<InputAttemptIdentifier> successMapTaskAttempts, int totalMapsCount,
int appAttemptId) {
+ Set<InputAttemptIdentifier> successMapTaskAttempts,
+ int totalMapsCount,
+ int appAttemptId,
+ int maxAttemptNo) {
String errMsg = "TaskAttemptIDs are inconsistent with map tasks";
Roaring64NavigableMap rssTaskIdBitmap = Roaring64NavigableMap.bitmapOf();
Roaring64NavigableMap mapTaskIdBitmap = Roaring64NavigableMap.bitmapOf();
@@ -301,9 +310,9 @@ public class RssTezUtils {
for (InputAttemptIdentifier inputAttemptIdentifier :
successMapTaskAttempts) {
String pathComponent = inputAttemptIdentifier.getPathComponent();
- TezTaskAttemptID mapTaskAttemptID =
IdUtils.convertTezTaskAttemptID(pathComponent);
- long rssTaskId =
RssTezUtils.convertTaskAttemptIdToLong(mapTaskAttemptID);
- long mapTaskId = mapTaskAttemptID.getTaskID().getId();
+ TezTaskAttemptID mapTaskAttemptId =
IdUtils.convertTezTaskAttemptID(pathComponent);
+ long rssTaskId = RssTezUtils.createRssTaskAttemptId(mapTaskAttemptId,
maxAttemptNo);
+ long mapTaskId = mapTaskAttemptId.getTaskID().getId();
LOG.info(
"FetchAllRssTaskIds, pathComponent: {}, mapTaskId:{}, rssTaskId:{},
is contains:{}",
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
index a734b102b..236ffe615 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
@@ -68,6 +68,7 @@ import org.apache.hadoop.util.Time;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.tez.common.CallableWithNdc;
import org.apache.tez.common.InputContextUtils;
+import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezRuntimeFrameworkConfigs;
import org.apache.tez.common.TezUtilsInternal;
import org.apache.tez.common.UmbilicalUtils;
@@ -589,6 +590,8 @@ public class RssShuffleManager extends ShuffleManager {
partitionToServers.get(partition),
partitionToServers);
+ int maxAttemptNo = RssTezUtils.getMaxAttemptNo(conf);
+
RssTezFetcherTask fetcher =
new RssTezFetcherTask(
RssShuffleManager.this,
@@ -603,7 +606,8 @@ public class RssShuffleManager extends ShuffleManager {
rssAllBlockIdBitmapMap,
rssSuccessBlockIdBitmapMap,
numInputs,
- partitionToServers.size());
+ partitionToServers.size(),
+ maxAttemptNo);
rssRunningFetchers.add(fetcher);
if (isShutdown.get()) {
LOG.info(
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
index f1dd85b8b..8015b43ac 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
@@ -74,6 +74,7 @@ public class RssTezFetcherTask extends
CallableWithNdc<FetchResult> {
private final int partitionNum;
private final int shuffleId;
private final ApplicationAttemptId applicationAttemptId;
+ private final int maxAttemptNo;
public RssTezFetcherTask(
FetcherCallback fetcherCallback,
@@ -88,7 +89,8 @@ public class RssTezFetcherTask extends
CallableWithNdc<FetchResult> {
Map<Integer, Roaring64NavigableMap> rssAllBlockIdBitmapMap,
Map<Integer, Roaring64NavigableMap> rssSuccessBlockIdBitmapMap,
int numPhysicalInputs,
- int partitionNum) {
+ int partitionNum,
+ int maxAttemptNo) {
assert (inputs != null && inputs.size() > 0);
this.fetcherCallback = fetcherCallback;
this.inputContext = inputContext;
@@ -131,6 +133,7 @@ public class RssTezFetcherTask extends
CallableWithNdc<FetchResult> {
conf.getInt(
RssTezConfig.RSS_PARTITION_NUM_PER_RANGE,
RssTezConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE);
+ this.maxAttemptNo = maxAttemptNo;
LOG.info(
"RssTezFetcherTask fetch partition:{}, with inputs:{},
readBufferSize:{}, partitionNumPerRange:{}.",
this.partition,
@@ -159,7 +162,8 @@ public class RssTezFetcherTask extends
CallableWithNdc<FetchResult> {
// final RssEventFetcher eventFetcher = new RssEventFetcher(inputs,
numPhysicalInputs);
int appAttemptId = applicationAttemptId.getAttemptId();
Roaring64NavigableMap taskIdBitmap =
- RssTezUtils.fetchAllRssTaskIds(new HashSet<>(inputs),
numPhysicalInputs, appAttemptId);
+ RssTezUtils.fetchAllRssTaskIds(
+ new HashSet<>(inputs), numPhysicalInputs, appAttemptId,
this.maxAttemptNo);
LOG.info(
"Inputs:{}, num input:{}, appAttemptId:{}, taskIdBitmap:{}",
inputs,
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
index 0528037d3..d584ebe65 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
@@ -282,6 +282,8 @@ class RssShuffleScheduler extends ShuffleScheduler {
private RemoteStorageInfo remoteStorageInfo;
private int indexReadLimit;
+ private final int maxAttemptNo;
+
RssShuffleScheduler(
InputContext inputContext,
Configuration conf,
@@ -538,6 +540,7 @@ class RssShuffleScheduler extends ShuffleScheduler {
this.basePath = this.conf.get(RssTezConfig.RSS_REMOTE_STORAGE_PATH);
String remoteStorageConf =
this.conf.get(RssTezConfig.RSS_REMOTE_STORAGE_CONF);
this.remoteStorageInfo = new RemoteStorageInfo(basePath,
remoteStorageConf);
+ this.maxAttemptNo = RssTezUtils.getMaxAttemptNo(conf);
LOG.info(
"RSSShuffleScheduler running for sourceVertex: "
@@ -1832,7 +1835,8 @@ class RssShuffleScheduler extends ShuffleScheduler {
RssTezUtils.fetchAllRssTaskIds(
partitionIdToSuccessMapTaskAttempts.get(mapHost.getPartitionId()),
this.numInputs,
- appAttemptId);
+ appAttemptId,
+ maxAttemptNo);
LOG.info(
"In reduce: {}, RSS Tez client has fetched blockIds and taskIds
successfully, partitionId:{}.",
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
index 94c9aa90d..fe4f11e13 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
@@ -61,7 +61,8 @@ public class RssSorter extends ExternalSorter {
long initialMemoryAvailable,
int shuffleId,
ApplicationAttemptId applicationAttemptId,
- Map<Integer, List<ShuffleServerInfo>> partitionToServers)
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+ long taskAttemptId)
throws IOException {
super(outputContext, conf, numOutputs, initialMemoryAvailable);
this.partitionToServers = partitionToServers;
@@ -81,7 +82,6 @@ public class RssSorter extends ExternalSorter {
conf.getDouble(
RssTezConfig.RSS_CLIENT_SORT_MEMORY_USE_THRESHOLD,
RssTezConfig.RSS_CLIENT_DEFAULT_SORT_MEMORY_USE_THRESHOLD);
- long taskAttemptId =
RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptID);
long maxSegmentSize =
conf.getLong(
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
index 0248bb8a2..94e87a40e 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
@@ -60,7 +60,8 @@ public class RssUnSorter extends ExternalSorter {
long initialMemoryAvailable,
int shuffleId,
ApplicationAttemptId applicationAttemptId,
- Map<Integer, List<ShuffleServerInfo>> partitionToServers)
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+ long taskAttemptId)
throws IOException {
super(outputContext, conf, numOutputs, initialMemoryAvailable);
this.partitionToServers = partitionToServers;
@@ -80,7 +81,6 @@ public class RssUnSorter extends ExternalSorter {
conf.getDouble(
RssTezConfig.RSS_CLIENT_SORT_MEMORY_USE_THRESHOLD,
RssTezConfig.RSS_CLIENT_DEFAULT_SORT_MEMORY_USE_THRESHOLD);
- long taskAttemptId =
RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptID);
long maxSegmentSize =
conf.getLong(
RssTezConfig.RSS_CLIENT_MAX_BUFFER_SIZE,
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
index fc163e90f..f926dc0e4 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
@@ -212,6 +212,7 @@ public class RssOrderedPartitionedKVOutput extends
AbstractLogicalOutput {
public void start() throws Exception {
if (!isStarted.get()) {
memoryUpdateCallbackHandler.validateUpdateReceived();
+ long rssTaskAttemptId =
RssTezUtils.createRssTaskAttemptId(taskAttemptId, conf);
sorter =
new RssSorter(
taskAttemptId,
@@ -222,7 +223,8 @@ public class RssOrderedPartitionedKVOutput extends
AbstractLogicalOutput {
memoryUpdateCallbackHandler.getMemoryAssigned(),
shuffleId,
applicationAttemptId,
- partitionToServers);
+ partitionToServers,
+ rssTaskAttemptId);
LOG.info("Initialized RssSorter.");
isStarted.set(true);
}
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
index 39cfc4693..5d903c3e9 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
@@ -217,6 +217,7 @@ public class RssUnorderedKVOutput extends
AbstractLogicalOutput {
public void start() throws Exception {
if (!isStarted.get()) {
memoryUpdateCallbackHandler.validateUpdateReceived();
+ long rssTaskAttemptId =
RssTezUtils.createRssTaskAttemptId(taskAttemptId, conf);
sorter =
new RssUnSorter(
taskAttemptId,
@@ -227,7 +228,8 @@ public class RssUnorderedKVOutput extends
AbstractLogicalOutput {
memoryUpdateCallbackHandler.getMemoryAssigned(),
shuffleId,
applicationAttemptId,
- partitionToServers);
+ partitionToServers,
+ rssTaskAttemptId);
LOG.info("Initialized RssUnSorter.");
isStarted.set(true);
}
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
index ef73b640b..9f8d818a7 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
@@ -215,6 +215,7 @@ public class RssUnorderedPartitionedKVOutput extends
AbstractLogicalOutput {
public void start() throws Exception {
if (!isStarted.get()) {
memoryUpdateCallbackHandler.validateUpdateReceived();
+ long rssTaskAttemptId =
RssTezUtils.createRssTaskAttemptId(taskAttemptId, conf);
sorter =
new RssUnSorter(
taskAttemptId,
@@ -225,7 +226,8 @@ public class RssUnorderedPartitionedKVOutput extends
AbstractLogicalOutput {
memoryUpdateCallbackHandler.getMemoryAssigned(),
shuffleId,
applicationAttemptId,
- partitionToServers);
+ partitionToServers,
+ rssTaskAttemptId);
LOG.info("Initialized RssUnSorter.");
isStarted.set(true);
}
diff --git
a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
index c71dbefcf..12404a14f 100644
--- a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
+++ b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
@@ -56,17 +56,17 @@ public class RssTezUtilsTest {
boolean isException = false;
try {
- RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId);
+ RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId, 3);
} catch (RssException e) {
isException = true;
}
assertTrue(isException);
- taskId = TezTaskID.getInstance(vId, (int) (1 << 21));
+ taskId = TezTaskID.getInstance(vId, 1 << 21);
tezTaskAttemptId = TezTaskAttemptID.getInstance(taskId, 2);
isException = false;
try {
- RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId);
+ RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId, 3);
} catch (RssException e) {
isException = true;
}
@@ -80,7 +80,7 @@ public class RssTezUtilsTest {
TezVertexID vId = TezVertexID.getInstance(dagId, 35);
TezTaskID tId = TezTaskID.getInstance(vId, 389);
TezTaskAttemptID tezTaskAttemptId = TezTaskAttemptID.getInstance(tId, 2);
- long taskAttemptId =
RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId);
+ long taskAttemptId = RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId,
3);
long blockId = RssTezUtils.getBlockId(1, taskAttemptId, 0);
long newTaskAttemptId = RssTezUtils.getTaskAttemptId(blockId);
assertEquals(taskAttemptId, newTaskAttemptId);
@@ -97,7 +97,7 @@ public class RssTezUtilsTest {
TezVertexID vId = TezVertexID.getInstance(dagId, 35);
TezTaskID tId = TezTaskID.getInstance(vId, 389);
TezTaskAttemptID tezTaskAttemptId = TezTaskAttemptID.getInstance(tId, 2);
- long taskAttemptId =
RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId);
+ long taskAttemptId = RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId,
3);
long mask = (1L << layout.partitionIdBits) - 1;
for (int partitionId = 0; partitionId <= 3000; partitionId++) {
for (int seqNo = 0; seqNo <= 10; seqNo++) {
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
index 588abf32d..db9f9cea5 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
@@ -31,6 +31,7 @@ import org.apache.hadoop.io.Text;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezRuntimeFrameworkConfigs;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.OutputContext;
@@ -87,6 +88,7 @@ public class RssSorterTest {
long initialMemoryAvailable = 10240000;
int shuffleId = 1001;
+ long rssTaskAttemptId =
RssTezUtils.createRssTaskAttemptId(tezTaskAttemptID, 3);
RssSorter rssSorter =
new RssSorter(
@@ -98,7 +100,8 @@ public class RssSorterTest {
initialMemoryAvailable,
shuffleId,
applicationAttemptId,
- partitionToServers);
+ partitionToServers,
+ rssTaskAttemptId);
rssSorter.collect(new Text("0"), new Text("0"), 0);
rssSorter.collect(new Text("0"), new Text("1"), 0);
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java
index e54a37fd7..57f3f7626 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java
@@ -27,6 +27,7 @@ import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezRuntimeFrameworkConfigs;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.OutputContext;
@@ -82,6 +83,7 @@ public class RssUnSorterTest {
long initialMemoryAvailable = 10240000;
int shuffleId = 1001;
+ long rssTaskAttemptId =
RssTezUtils.createRssTaskAttemptId(tezTaskAttemptID, 3);
RssUnSorter rssSorter =
new RssUnSorter(
@@ -93,7 +95,8 @@ public class RssUnSorterTest {
initialMemoryAvailable,
shuffleId,
APPATTEMPT_ID,
- partitionToServers);
+ partitionToServers,
+ rssTaskAttemptId);
rssSorter.collect(new Text("0"), new Text("0"), 0);
rssSorter.collect(new Text("0"), new Text("1"), 0);
diff --git
a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
index b3d40dcde..29fc4b241 100644
--- a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
+++ b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
@@ -112,4 +112,21 @@ public class ClientUtils {
String.format("The value of %s should be one of %s", clientType,
types));
}
}
+
+ public static int getMaxAttemptNo(int maxFailures, boolean speculation) {
+ // attempt number is zero based: 0, 1, …, maxFailures-1
+ // max maxFailures < 1 is not allowed but for safety, we interpret that as
maxFailures == 1
+ int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
+
+ // with speculative execution enabled we could observe +1 attempts
+ if (speculation) {
+ maxAttemptNo++;
+ }
+
+ return maxAttemptNo;
+ }
+
+ public static int getNumberOfSignificantBits(int number) {
+ return 32 - Integer.numberOfLeadingZeros(number);
+ }
}
diff --git
a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
index dd9ef62b1..611a46b44 100644
--- a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
+++ b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
@@ -36,6 +36,8 @@ import org.apache.uniffle.client.util.DefaultIdHelper;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.RssUtils;
+import static org.apache.uniffle.client.util.ClientUtils.getMaxAttemptNo;
+import static
org.apache.uniffle.client.util.ClientUtils.getNumberOfSignificantBits;
import static org.apache.uniffle.client.util.ClientUtils.waitUntilDoneOrFail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
@@ -134,4 +136,44 @@ public class ClientUtilsTest {
// Ignore
}
}
+
+ @Test
+ public void testGetMaxAttemptNo() {
+ // without speculation
+ assertEquals(0, getMaxAttemptNo(-1, false));
+ assertEquals(0, getMaxAttemptNo(0, false));
+ assertEquals(0, getMaxAttemptNo(1, false));
+ assertEquals(1, getMaxAttemptNo(2, false));
+ assertEquals(2, getMaxAttemptNo(3, false));
+ assertEquals(3, getMaxAttemptNo(4, false));
+ assertEquals(4, getMaxAttemptNo(5, false));
+ assertEquals(1023, getMaxAttemptNo(1024, false));
+
+ // with speculation
+ assertEquals(1, getMaxAttemptNo(-1, true));
+ assertEquals(1, getMaxAttemptNo(0, true));
+ assertEquals(1, getMaxAttemptNo(1, true));
+ assertEquals(2, getMaxAttemptNo(2, true));
+ assertEquals(3, getMaxAttemptNo(3, true));
+ assertEquals(4, getMaxAttemptNo(4, true));
+ assertEquals(5, getMaxAttemptNo(5, true));
+ assertEquals(1024, getMaxAttemptNo(1024, true));
+ }
+
+ @Test
+ public void testGetNumberOfSignificantBits() {
+ assertEquals(0, getNumberOfSignificantBits(0));
+ assertEquals(1, getNumberOfSignificantBits(1));
+ assertEquals(2, getNumberOfSignificantBits(2));
+ assertEquals(2, getNumberOfSignificantBits(3));
+ assertEquals(3, getNumberOfSignificantBits(4));
+ assertEquals(3, getNumberOfSignificantBits(5));
+ assertEquals(3, getNumberOfSignificantBits(6));
+ assertEquals(3, getNumberOfSignificantBits(7));
+ assertEquals(4, getNumberOfSignificantBits(8));
+ assertEquals(4, getNumberOfSignificantBits(9));
+ assertEquals(10, getNumberOfSignificantBits(1023));
+ assertEquals(11, getNumberOfSignificantBits(1024));
+ assertEquals(11, getNumberOfSignificantBits(1025));
+ }
}
diff --git a/common/src/main/java/org/apache/uniffle/common/util/BlockId.java
b/common/src/main/java/org/apache/uniffle/common/util/BlockId.java
index 36025f66b..6c93e6345 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/BlockId.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/BlockId.java
@@ -32,10 +32,10 @@ public class BlockId {
public final BlockIdLayout layout;
public final int sequenceNo;
public final int partitionId;
- public final int taskAttemptId;
+ public final long taskAttemptId;
protected BlockId(
- long blockId, BlockIdLayout layout, int sequenceNo, int partitionId, int
taskAttemptId) {
+ long blockId, BlockIdLayout layout, int sequenceNo, int partitionId,
long taskAttemptId) {
this.blockId = blockId;
this.layout = layout;
this.sequenceNo = sequenceNo;
diff --git
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
index 438bbdf39..d31d842ab 100644
---
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
+++
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
@@ -362,7 +362,7 @@ public class TezWordCountWithFailuresTest extends
IntegrationTestBase {
// verifyMode is 0: avoid recompute succeeded task is true
Assertions.assertEquals(0,
progressMap.get("Tokenizer").getKilledTaskAttemptCount());
} else if (verifyMode == 1) {
- // verifyMode is 1: avoid recompute succeeded task is true
+ // verifyMode is 1: avoid recompute succeeded task is false
Assertions.assertTrue(progressMap.get("Tokenizer").getKilledTaskAttemptCount()
> 0);
}
return 0;
diff --git
a/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java
b/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java
index 5e53a80b2..d7deb7570 100644
--- a/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java
+++ b/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java
@@ -50,7 +50,7 @@ public abstract class BufferTestBase {
return createData(partitionId, 0, len);
}
- protected ShufflePartitionedData createData(int partitionId, int
taskAttemptId, int len) {
+ protected ShufflePartitionedData createData(int partitionId, long
taskAttemptId, int len) {
byte[] buf = new byte[len];
new Random().nextBytes(buf);
ShufflePartitionedBlock block =
diff --git
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
index c11fc27a7..d1b663f1f 100644
---
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
+++
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
@@ -108,7 +108,7 @@ public class HadoopShuffleReadHandlerTest extends
HadoopTestBase {
int totalBlockNum = 0;
int expectTotalBlockNum = 6;
int blockSize = 7;
- int taskAttemptId = 0;
+ long taskAttemptId = 0;
// write expectTotalBlockNum - 1 complete block
HadoopShuffleHandlerTestBase.writeTestData(