This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new bdade5c8 [#986] [Improvement][tez] Optimize the method of obtain the
vertex id. (#990)
bdade5c8 is described below
commit bdade5c83e98c3d59188a283e66cc323fd38b00c
Author: zhengchenyu <[email protected]>
AuthorDate: Thu Jul 6 18:45:56 2023 +0800
[#986] [Improvement][tez] Optimize the method of obtain the vertex id.
(#990)
### What changes were proposed in this pull request?
Optimize the method of obtain the vertex id.
### Why are the changes needed?
For now, vertex id is extract from vertex name. This way only support the
vertex name like "Map 0", "Reduce 1", generally generated from hive.
For tez examples, the vertex name is arbitrary, so we can't get the vertex
id. So we need a new way to get vertex id.
Fix: #986
### How was this patch tested?
integration test, unit test, test in yarn cluster, test in tez local mode.
---
.../org/apache/tez/common/InputContextUtils.java | 12 ----------
.../java/org/apache/tez/common/RssTezConfig.java | 4 ++++
.../java/org/apache/tez/common/RssTezUtils.java | 22 +++++-------------
.../org/apache/tez/dag/app/RssDAGAppMaster.java | 8 +++++++
.../common/shuffle/impl/RssShuffleManager.java | 7 +++---
.../common/shuffle/impl/RssTezFetcherTask.java | 12 ++++------
.../common/shuffle/orderedgrouped/RssShuffle.java | 5 +++--
.../orderedgrouped/RssShuffleScheduler.java | 7 +++---
.../library/input/RssOrderedGroupedKVInput.java | 21 ++++++++++++++++-
.../runtime/library/input/RssUnorderedKVInput.java | 20 ++++++++++++++++-
.../output/RssOrderedPartitionedKVOutput.java | 10 +++++++--
.../library/output/RssUnorderedKVOutput.java | 10 +++++++--
.../output/RssUnorderedPartitionedKVOutput.java | 10 +++++++--
.../apache/tez/common/InputContextUtilsTest.java | 11 ---------
.../org/apache/tez/common/RssTezUtilsTest.java | 6 ++---
.../apache/tez/dag/app/RssDAGAppMasterTest.java | 26 +++++++++++++++-------
.../common/shuffle/impl/RssShuffleManagerTest.java | 2 +-
.../orderedgrouped/RssShuffleSchedulerTest.java | 2 +-
.../shuffle/orderedgrouped/RssShuffleTest.java | 11 +++++++--
.../common/sort/buffer/WriteBufferManagerTest.java | 10 ++++-----
.../input/RssOrderedGroupedKVInputTest.java | 4 ++++
21 files changed, 137 insertions(+), 83 deletions(-)
diff --git
a/client-tez/src/main/java/org/apache/tez/common/InputContextUtils.java
b/client-tez/src/main/java/org/apache/tez/common/InputContextUtils.java
index 88cdacf7..ce5adea6 100644
--- a/client-tez/src/main/java/org/apache/tez/common/InputContextUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/InputContextUtils.java
@@ -29,16 +29,4 @@ public class InputContextUtils {
String uniqueIdentifier = inputContext.getUniqueIdentifier();
return TezTaskAttemptID.fromString(uniqueIdentifier.substring(0,
uniqueIdentifier.length() - 6));
}
-
- /**
- *
- * @param inputContext
- * @return Compute shuffle id using InputContext
- */
- public static int computeShuffleId(InputContext inputContext) {
- int dagIdentifier = inputContext.getDagIdentifier();
- String sourceVertexName = inputContext.getSourceVertexName();
- String taskVertexName = inputContext.getTaskVertexName();
- return RssTezUtils.computeShuffleId(dagIdentifier, sourceVertexName,
taskVertexName);
- }
}
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
index b575f2fc..67f181f5 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
@@ -175,6 +175,10 @@ public class RssTezConfig {
public static final Set<String> RSS_MANDATORY_CLUSTER_CONF =
ImmutableSet.of(RSS_STORAGE_TYPE, RSS_REMOTE_STORAGE_PATH);
+ public static final String RSS_SHUFFLE_SOURCE_VERTEX_ID =
TEZ_RSS_CONFIG_PREFIX + "rss.shuffle.source.vertex.id";
+ public static final String RSS_SHUFFLE_DESTINATION_VERTEX_ID =
+ TEZ_RSS_CONFIG_PREFIX + "rss.shuffle.destination.vertex.id";
+
public static RssConf toRssConf(Configuration jobConf) {
RssConf rssConf = new RssConf();
for (Map.Entry<String, String> entry : jobConf) {
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
index 81b334d3..a1439b8e 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
@@ -200,26 +200,16 @@ public class RssTezUtils {
/**
*
* @param tezDagID Get from tez InputContext, represent dag id.
- * @param upVertexName Up stream vertex name of the task, like "Map 1" or
"Reducer 2".
- * @param downVertexName The vertex name of task, like "Map 1" or "Reducer
2".
- * @return The shuffle id. First convert upVertexName of String type to int,
by invoke mapVertexId() method,
- * Then convert downVertexName of String type to int, by invoke
mapVertexId() method.
- * Finally compute shuffle id by pass tezDagID, upVertexId, downVertexId and
invoke computeShuffleId() method.
- * By map vertex name of String type to int type, we can compute shuffle id.
+ * @param upVertexId Up stream vertex id of the task.
+ * @param downVertexId The vertex id of task.
+ * @return The shuffle id.
*/
- public static int computeShuffleId(int tezDagID, String upVertexName, String
downVertexName) {
- int upVertexId = mapVertexId(upVertexName);
- int downVertexId = mapVertexId(downVertexName);
- int shuffleId = computeShuffleId(tezDagID, upVertexId, downVertexId);
- LOG.info("Compute Shuffle Id, upVertexName:{}, id:{}, downVertexName:{},
id:{}, shuffleId:{}",
- upVertexName, upVertexId, downVertexName, downVertexId, shuffleId);
+ public static int computeShuffleId(int tezDagID, int upVertexId, int
downVertexId) {
+ int shuffleId = tezDagID * (SHUFFLE_ID_MAGIC * SHUFFLE_ID_MAGIC) +
upVertexId * SHUFFLE_ID_MAGIC + downVertexId;
+ LOG.info("Compute Shuffle Id:{}, up vertex id:{}, down vertex id:{}",
shuffleId, upVertexId, downVertexId);
return shuffleId;
}
- private static int computeShuffleId(int tezDagID, int upTezVertexID, int
downTezVertexID) {
- return tezDagID * (SHUFFLE_ID_MAGIC * SHUFFLE_ID_MAGIC) + upTezVertexID *
SHUFFLE_ID_MAGIC + downTezVertexID;
- }
-
/**
*
* @param vertexName: vertex name, like "Map 1" or "Reducer 2"
diff --git
a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
index 6e024354..f7d90c65 100644
--- a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
+++ b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
@@ -73,6 +73,8 @@ import static
org.apache.log4j.LogManager.CONFIGURATOR_CLASS_KEY;
import static org.apache.log4j.LogManager.DEFAULT_CONFIGURATION_KEY;
import static
org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS;
import static org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT;
+import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
public class RssDAGAppMaster extends DAGAppMaster {
private static final Logger LOG =
LoggerFactory.getLogger(RssDAGAppMaster.class);
@@ -339,10 +341,14 @@ public class RssDAGAppMaster extends DAGAppMaster {
Map<String, Edge> edges = (Map<String, Edge>) getPrivateField(dag,
"edges");
for (Map.Entry<String, Edge> entry : edges.entrySet()) {
Edge edge = entry.getValue();
+ int sourceVertexId =
dag.getVertex(edge.getSourceVertexName()).getVertexId().getId();
+ int destinationVertexId =
dag.getVertex(edge.getDestinationVertexName()).getVertexId().getId();
// add user defined config to edge source conf
Configuration edgeSourceConf =
TezUtils.createConfFromUserPayload(edge.getEdgeProperty().getEdgeSource().getUserPayload());
+ edgeSourceConf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, sourceVertexId);
+ edgeSourceConf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID,
destinationVertexId);
edgeSourceConf.set(RSS_AM_SHUFFLE_MANAGER_ADDRESS,
this.appMaster.getTezRemoteShuffleManager().getAddress().getHostName());
edgeSourceConf.setInt(RSS_AM_SHUFFLE_MANAGER_PORT,
@@ -363,6 +369,8 @@ public class RssDAGAppMaster extends DAGAppMaster {
// add user defined config to edge destination conf
Configuration edgeDestinationConf =
TezUtils.createConfFromUserPayload(edge.getEdgeProperty().getEdgeSource().getUserPayload());
+ edgeDestinationConf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID,
sourceVertexId);
+ edgeDestinationConf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID,
destinationVertexId);
edgeDestinationConf.set(RSS_AM_SHUFFLE_MANAGER_ADDRESS,
this.appMaster.getTezRemoteShuffleManager().getAddress().getHostName());
edgeDestinationConf.setInt(RSS_AM_SHUFFLE_MANAGER_PORT,
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
index 1c53f8c0..7de9a1d0 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
@@ -112,6 +112,7 @@ public class RssShuffleManager extends ShuffleManager {
private final InputContext inputContext;
private final int numInputs;
+ private final int shuffleId;
private final DecimalFormat mbpsFormat = new DecimalFormat("0.00");
@@ -227,12 +228,13 @@ public class RssShuffleManager extends ShuffleManager {
public RssShuffleManager(InputContext inputContext, Configuration conf, int
numInputs,
int bufferSize, boolean ifileReadAheadEnabled, int
ifileReadAheadLength,
- CompressionCodec codec, FetchedInputAllocator inputAllocator) throws
IOException {
+ CompressionCodec codec, FetchedInputAllocator inputAllocator, int
shuffleId) throws IOException {
super(inputContext, conf, numInputs, bufferSize, ifileReadAheadEnabled,
ifileReadAheadLength, codec,
inputAllocator);
this.inputContext = inputContext;
this.conf = conf;
this.numInputs = numInputs;
+ this.shuffleId = shuffleId;
this.shuffledInputsCounter =
inputContext.getCounters().findCounter(TaskCounter.NUM_SHUFFLED_INPUTS);
this.failedShufflesCounter =
inputContext.getCounters().findCounter(TaskCounter.NUM_FAILED_SHUFFLE_INPUTS);
@@ -343,7 +345,6 @@ public class RssShuffleManager extends ShuffleManager {
@Override
public void run() throws IOException {
- int shuffleId = InputContextUtils.computeShuffleId(this.inputContext);
TezTaskAttemptID tezTaskAttemptId =
InputContextUtils.getTezTaskAttemptID(this.inputContext);
this.partitionToServers = UmbilicalUtils.requestShuffleServer(
this.inputContext.getApplicationId(), this.conf, tezTaskAttemptId,
shuffleId);
@@ -501,7 +502,7 @@ public class RssShuffleManager extends ShuffleManager {
partition, partitionToServers.get(partition),
partitionToServers);
RssTezFetcherTask fetcher = new
RssTezFetcherTask(RssShuffleManager.this, inputContext,
- conf, inputManager, partition,
partitionToInput.get(partition),
+ conf, inputManager, partition, shuffleId,
partitionToInput.get(partition),
new
HashSet<ShuffleServerInfo>(partitionToServers.get(partition)),
rssAllBlockIdBitmapMap, rssSuccessBlockIdBitmapMap,
numInputs, partitionToServers.size());
rssRunningFetchers.add(fetcher);
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
index 35cedf72..46d435f6 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
@@ -28,7 +28,6 @@ import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapred.JobConf;
import org.apache.tez.common.CallableWithNdc;
import org.apache.tez.common.IdUtils;
-import org.apache.tez.common.InputContextUtils;
import org.apache.tez.common.RssTezConfig;
import org.apache.tez.common.RssTezUtils;
import org.apache.tez.runtime.api.InputContext;
@@ -73,10 +72,10 @@ public class RssTezFetcherTask extends
CallableWithNdc<FetchResult> {
private final int readBufferSize;
private final int partitionNumPerRange;
private final int partitionNum;
-
+ private final int shuffleId;
public RssTezFetcherTask(FetcherCallback fetcherCallback, InputContext
inputContext, Configuration conf,
- FetchedInputAllocator inputManager, int partition,
+ FetchedInputAllocator inputManager, int partition, int shuffleId,
List<InputAttemptIdentifier> inputs, Set<ShuffleServerInfo>
serverInfoList,
Map<Integer, Roaring64NavigableMap> rssAllBlockIdBitmapMap,
Map<Integer, Roaring64NavigableMap> rssSuccessBlockIdBitmapMap,
@@ -88,6 +87,7 @@ public class RssTezFetcherTask extends
CallableWithNdc<FetchResult> {
this.inputManager = inputManager;
this.partition = partition; // partition id to fetch
this.inputs = inputs;
+ this.shuffleId = shuffleId;
this.serverInfoSet = serverInfoList;
this.rssAllBlockIdBitmapMap = rssAllBlockIdBitmapMap;
@@ -116,10 +116,6 @@ public class RssTezFetcherTask extends
CallableWithNdc<FetchResult> {
@Override
protected FetchResult callInternal() throws Exception {
- // get assigned RSS servers
- // just get blockIds from RSS servers
- int shuffleId = InputContextUtils.computeShuffleId(inputContext);
-
ShuffleWriteClient writeClient =
RssTezUtils.createShuffleClient(this.conf);
LOG.info("WriteClient getShuffleResult, clientType:{}, serverInfoSet:{},
appId:{}, shuffleId:{}, partition:{}",
clientType, serverInfoSet, appId, shuffleId, partition);
@@ -147,7 +143,7 @@ public class RssTezFetcherTask extends
CallableWithNdc<FetchResult> {
boolean expectedTaskIdsBitmapFilterEnable = serverInfoSet.size() > 1;
CreateShuffleReadClientRequest request = new
CreateShuffleReadClientRequest(
appId,
- InputContextUtils.computeShuffleId(inputContext),
+ shuffleId,
partition,
basePath,
partitionNumPerRange,
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffle.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffle.java
index 38345f37..5a1891f6 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffle.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffle.java
@@ -107,7 +107,7 @@ public class RssShuffle implements ExceptionReporter {
Usage: Create instance, RssShuffle
*/
public RssShuffle(InputContext inputContext, Configuration conf, int
numInputs,
- long initialMemoryAvailable) throws IOException {
+ long initialMemoryAvailable, int shuffleId) throws
IOException {
this.inputContext = inputContext;
this.conf = conf;
@@ -178,7 +178,8 @@ public class RssShuffle implements ExceptionReporter {
codec,
ifileReadAhead,
ifileReadAheadLength,
- srcNameTrimmed);
+ srcNameTrimmed,
+ shuffleId);
this.mergePhaseTime =
inputContext.getCounters().findCounter(TaskCounter.MERGE_PHASE_TIME);
this.shufflePhaseTime =
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_PHASE_TIME);
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
index 5af487b1..6ce9d637 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
@@ -245,6 +245,7 @@ class RssShuffleScheduler extends ShuffleScheduler {
private final int dagId;
private final boolean asyncHttp;
private final boolean sslShuffle;
+ private final int shuffleId;
private final TezCounter ioErrsCounter;
private final TezCounter wrongLengthErrsCounter;
@@ -300,7 +301,8 @@ class RssShuffleScheduler extends ShuffleScheduler {
CompressionCodec codec,
boolean ifileReadAhead,
int ifileReadAheadLength,
- String srcNameTrimmed) throws IOException {
+ String srcNameTrimmed,
+ int shuffleId) throws IOException {
super(inputContext, conf, numberOfInputs, exceptionReporter, mergeManager,
allocator, startTime, codec,
ifileReadAhead, ifileReadAheadLength, srcNameTrimmed);
this.inputContext = inputContext;
@@ -324,6 +326,7 @@ class RssShuffleScheduler extends ShuffleScheduler {
this.ifileReadAhead = ifileReadAhead;
this.ifileReadAheadLength = ifileReadAheadLength;
this.srcNameTrimmed = srcNameTrimmed;
+ this.shuffleId = shuffleId;
this.codec = codec;
int configuredNumFetchers =
conf.getInt(
@@ -483,7 +486,6 @@ class RssShuffleScheduler extends ShuffleScheduler {
@Override
public void start() throws Exception {
- int shuffleId = InputContextUtils.computeShuffleId(this.inputContext);
TezTaskAttemptID tezTaskAttemptID =
InputContextUtils.getTezTaskAttemptID(this.inputContext);
this.partitionToServers = UmbilicalUtils.requestShuffleServer(
inputContext.getApplicationId(), conf, tezTaskAttemptID, shuffleId);
@@ -1575,7 +1577,6 @@ class RssShuffleScheduler extends ShuffleScheduler {
ShuffleWriteClient writeClient = RssTezUtils.createShuffleClient(conf);
String clientType = "";
- int shuffleId = InputContextUtils.computeShuffleId(inputContext);
Roaring64NavigableMap blockIdBitmap = writeClient.getShuffleResult(
clientType, shuffleServerInfoSet, applicationId, shuffleId,
mapHost.getPartitionId());
writeClient.close();
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/input/RssOrderedGroupedKVInput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/input/RssOrderedGroupedKVInput.java
index d18c6c48..10b19281 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/input/RssOrderedGroupedKVInput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/input/RssOrderedGroupedKVInput.java
@@ -33,12 +33,16 @@ import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.RawComparator;
+import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezRuntimeFrameworkConfigs;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.counters.TaskCounter;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.TezException;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.AbstractLogicalInput;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.InputContext;
@@ -57,6 +61,9 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.exception.RssException;
+import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
+
/**
* {@link RssOrderedGroupedKVInput} in a {@link AbstractLogicalInput} which
shuffles
* intermediate sorted data, merges them and provides key/<values> to the
@@ -79,6 +86,7 @@ public class RssOrderedGroupedKVInput extends
AbstractLogicalInput {
protected Configuration conf;
protected RssShuffle shuffle;
protected MemoryUpdateCallbackHandler memoryUpdateCallbackHandler;
+ private int shuffleId;
private final BlockingQueue<Event> pendingEvents = new
LinkedBlockingQueue<>();
private long firstEventReceivedTime = -1;
@SuppressWarnings("rawtypes")
@@ -116,6 +124,16 @@ public class RssOrderedGroupedKVInput extends
AbstractLogicalInput {
this.inputValueCounter =
getContext().getCounters().findCounter(TaskCounter.REDUCE_INPUT_RECORDS);
this.shuffledInputs =
getContext().getCounters().findCounter(TaskCounter.NUM_SHUFFLED_INPUTS);
this.conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS,
getContext().getWorkDirs());
+
+ TezTaskAttemptID taskAttemptId = TezTaskAttemptID.fromString(
+
RssTezUtils.uniqueIdentifierToAttemptId(getContext().getUniqueIdentifier()));
+ TezVertexID tezVertexID = taskAttemptId.getTaskID().getVertexID();
+ TezDAGID tezDAGID = tezVertexID.getDAGId();
+ int sourceVertexId = this.conf.getInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, -1);
+ int destinationVertexId =
this.conf.getInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, -1);
+ assert sourceVertexId != -1;
+ assert destinationVertexId != -1;
+ this.shuffleId = RssTezUtils.computeShuffleId(tezDAGID.getId(),
sourceVertexId, destinationVertexId);
return Collections.emptyList();
}
@@ -141,7 +159,8 @@ public class RssOrderedGroupedKVInput extends
AbstractLogicalInput {
@VisibleForTesting
RssShuffle createRssShuffle() throws IOException {
- return new RssShuffle(getContext(), conf, getNumPhysicalInputs(),
memoryUpdateCallbackHandler.getMemoryAssigned());
+ return new RssShuffle(getContext(), conf, getNumPhysicalInputs(),
memoryUpdateCallbackHandler.getMemoryAssigned(),
+ shuffleId);
}
/**
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/input/RssUnorderedKVInput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/input/RssUnorderedKVInput.java
index c1ac3119..67f52fca 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/input/RssUnorderedKVInput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/input/RssUnorderedKVInput.java
@@ -35,12 +35,16 @@ import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.io.compress.DefaultCodec;
import org.apache.hadoop.util.ReflectionUtils;
+import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezRuntimeFrameworkConfigs;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.TezUtilsInternal;
import org.apache.tez.common.counters.TaskCounter;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.AbstractLogicalInput;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.InputContext;
@@ -62,6 +66,9 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.exception.RssException;
+import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
+
/**
* {@link RssUnorderedKVInput} provides unordered key value input by
* bringing together (shuffling) a set of distributed data and providing a
@@ -85,6 +92,7 @@ public class RssUnorderedKVInput extends AbstractLogicalInput
{
private SimpleFetchedInputAllocator inputManager;
private ShuffleEventHandler inputEventHandler;
+ private int shuffleId;
public RssUnorderedKVInput(InputContext inputContext, int numPhysicalInputs)
{
super(inputContext, numPhysicalInputs);
@@ -112,6 +120,16 @@ public class RssUnorderedKVInput extends
AbstractLogicalInput {
this.conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS,
getContext().getWorkDirs());
this.inputRecordCounter = getContext().getCounters().findCounter(
TaskCounter.INPUT_RECORDS_PROCESSED);
+
+ TezTaskAttemptID taskAttemptId = TezTaskAttemptID.fromString(
+
RssTezUtils.uniqueIdentifierToAttemptId(getContext().getUniqueIdentifier()));
+ TezVertexID tezVertexID = taskAttemptId.getTaskID().getVertexID();
+ TezDAGID tezDAGID = tezVertexID.getDAGId();
+ int sourceVertexId = this.conf.getInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, -1);
+ int destinationVertexId =
this.conf.getInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, -1);
+ assert sourceVertexId != -1;
+ assert destinationVertexId != -1;
+ this.shuffleId = RssTezUtils.computeShuffleId(tezDAGID.getId(),
sourceVertexId, destinationVertexId);
return Collections.emptyList();
}
@@ -154,7 +172,7 @@ public class RssUnorderedKVInput extends
AbstractLogicalInput {
memoryUpdateCallbackHandler.getMemoryAssigned());
this.rssShuffleManager = new RssShuffleManager(getContext(), conf,
getNumPhysicalInputs(), ifileBufferSize,
- ifileReadAhead, ifileReadAheadLength, codec, inputManager);
+ ifileReadAhead, ifileReadAheadLength, codec, inputManager,
shuffleId);
this.inputEventHandler = new ShuffleInputEventHandlerImpl(getContext(),
rssShuffleManager,
inputManager, codec, ifileReadAhead, ifileReadAheadLength,
compositeFetch);
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
index ac7a4a62..3b1bc77e 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
@@ -65,6 +65,8 @@ import org.apache.uniffle.common.ShuffleServerInfo;
import static
org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS;
import static org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT;
+import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
/**
* {@link RssOrderedPartitionedKVOutput} is an {@link AbstractLogicalOutput}
which
@@ -132,7 +134,7 @@ public class RssOrderedPartitionedKVOutput extends
AbstractLogicalOutput {
UserGroupInformation taskOwner =
UserGroupInformation.createRemoteUser(this.applicationId.toString());
- TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
+ final TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
.doAs(new
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
@Override
public TezRemoteShuffleUmbilicalProtocol run() throws Exception {
@@ -143,7 +145,11 @@ public class RssOrderedPartitionedKVOutput extends
AbstractLogicalOutput {
});
TezVertexID tezVertexID = taskAttemptId.getTaskID().getVertexID();
TezDAGID tezDAGID = tezVertexID.getDAGId();
- this.shuffleId = RssTezUtils.computeShuffleId(tezDAGID.getId(),
this.taskVertexName, this.destinationVertexName);
+ int sourceVertexId = this.conf.getInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, -1);
+ int destinationVertexId =
this.conf.getInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, -1);
+ assert sourceVertexId != -1;
+ assert destinationVertexId != -1;
+ this.shuffleId = RssTezUtils.computeShuffleId(tezDAGID.getId(),
sourceVertexId, destinationVertexId);
GetShuffleServerRequest request = new
GetShuffleServerRequest(this.taskAttemptId, this.mapNum,
this.numOutputs, this.shuffleId);
GetShuffleServerResponse response =
umbilical.getShuffleAssignments(request);
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
index b40ef8e7..94a1a024 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
@@ -65,6 +65,8 @@ import org.apache.uniffle.common.ShuffleServerInfo;
import static
org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS;
import static org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT;
+import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
/**
* {@link RssUnorderedKVOutput} is an {@link AbstractLogicalOutput} which
@@ -133,7 +135,7 @@ public class RssUnorderedKVOutput extends
AbstractLogicalOutput {
UserGroupInformation taskOwner =
UserGroupInformation.createRemoteUser(this.applicationId.toString());
- TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
+ final TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
.doAs(new
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
@Override
public TezRemoteShuffleUmbilicalProtocol run() throws Exception {
@@ -145,7 +147,11 @@ public class RssUnorderedKVOutput extends
AbstractLogicalOutput {
});
TezVertexID tezVertexID = taskAttemptId.getTaskID().getVertexID();
TezDAGID tezDAGID = tezVertexID.getDAGId();
- this.shuffleId = RssTezUtils.computeShuffleId(tezDAGID.getId(),
this.taskVertexName, this.destinationVertexName);
+ int sourceVertexId = this.conf.getInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, -1);
+ int destinationVertexId =
this.conf.getInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, -1);
+ assert sourceVertexId != -1;
+ assert destinationVertexId != -1;
+ this.shuffleId = RssTezUtils.computeShuffleId(tezDAGID.getId(),
sourceVertexId, destinationVertexId);
GetShuffleServerRequest request = new
GetShuffleServerRequest(this.taskAttemptId, this.mapNum,
this.numOutputs, this.shuffleId);
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
index 6cd8c1d2..2b2262e8 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
@@ -65,6 +65,8 @@ import org.apache.uniffle.common.ShuffleServerInfo;
import static
org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS;
import static org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT;
+import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
/**
* {@link RssUnorderedPartitionedKVOutput} is an {@link AbstractLogicalOutput}
which
@@ -132,7 +134,7 @@ public class RssUnorderedPartitionedKVOutput extends
AbstractLogicalOutput {
final InetSocketAddress address = NetUtils.createSocketAddrForHost(host,
port);
UserGroupInformation taskOwner =
UserGroupInformation.createRemoteUser(this.applicationId.toString());
- TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
+ final TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
.doAs(new
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
@Override
public TezRemoteShuffleUmbilicalProtocol run() throws Exception {
@@ -144,7 +146,11 @@ public class RssUnorderedPartitionedKVOutput extends
AbstractLogicalOutput {
});
TezVertexID tezVertexID = taskAttemptId.getTaskID().getVertexID();
TezDAGID tezDAGID = tezVertexID.getDAGId();
- this.shuffleId = RssTezUtils.computeShuffleId(tezDAGID.getId(),
this.taskVertexName, this.destinationVertexName);
+ int sourceVertexId = this.conf.getInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, -1);
+ int destinationVertexId =
this.conf.getInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, -1);
+ assert sourceVertexId != -1;
+ assert destinationVertexId != -1;
+ this.shuffleId = RssTezUtils.computeShuffleId(tezDAGID.getId(),
sourceVertexId, destinationVertexId);
GetShuffleServerRequest request = new
GetShuffleServerRequest(this.taskAttemptId, this.mapNum,
this.numOutputs, this.shuffleId);
GetShuffleServerResponse response =
umbilical.getShuffleAssignments(request);
diff --git
a/client-tez/src/test/java/org/apache/tez/common/InputContextUtilsTest.java
b/client-tez/src/test/java/org/apache/tez/common/InputContextUtilsTest.java
index 350291d5..1b9f1ff6 100644
--- a/client-tez/src/test/java/org/apache/tez/common/InputContextUtilsTest.java
+++ b/client-tez/src/test/java/org/apache/tez/common/InputContextUtilsTest.java
@@ -35,15 +35,4 @@ public class InputContextUtilsTest {
TezTaskAttemptID rightTaskAttemptID =
TezTaskAttemptID.fromString("attempt_1685094627632_0157_1_01_000000_0");
assertEquals(rightTaskAttemptID,
InputContextUtils.getTezTaskAttemptID(inputContext));
}
-
-
- @Test
- public void testComputeShuffleId() {
- InputContext inputContext = mock(InputContext.class);
- when(inputContext.getDagIdentifier()).thenReturn(1);
- when(inputContext.getSourceVertexName()).thenReturn("Map 1");
- when(inputContext.getTaskVertexName()).thenReturn("Reducer 1");
-
- assertEquals(1001601, InputContextUtils.computeShuffleId(inputContext));
- }
}
diff --git
a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
index a888baec..361c05bd 100644
--- a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
+++ b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
@@ -147,9 +147,9 @@ public class RssTezUtilsTest {
@Test
public void testComputeShuffleId() {
int dagId = 1;
- String upVertexName = "Map 1";
- String downVertexName = "Reducer 2";
- assertEquals(1001602, RssTezUtils.computeShuffleId(dagId, upVertexName,
downVertexName));
+ int upVertexId = 1;
+ int downVertexID = 2;
+ assertEquals(1001002, RssTezUtils.computeShuffleId(dagId, upVertexId,
downVertexID));
}
@Test
diff --git
a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
index 535ba953..ef2bb5de 100644
--- a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
+++ b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
@@ -76,6 +76,8 @@ import org.apache.uniffle.storage.util.StorageType;
import static
org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS;
import static org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT;
+import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
import static org.apache.tez.common.RssTezConfig.RSS_STORAGE_TYPE;
import static
org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_READAHEAD_BYTES;
import static org.awaitility.Awaitility.await;
@@ -148,15 +150,16 @@ public class RssDAGAppMasterTest {
await().atMost(2, TimeUnit.SECONDS).until(() ->
dagImpl.getState().equals(DAGState.INITED));
// 8 verify I/O for vertexImpl
- verfiyOutput(dagImpl, "vertex1",
RssOrderedPartitionedKVOutput.class.getName());
- verfiyInput(dagImpl, "vertex2", RssOrderedGroupedKVInput.class.getName());
- verfiyOutput(dagImpl, "vertex2", RssUnorderedKVOutput.class.getName());
- verfiyInput(dagImpl, "vertex3", RssUnorderedKVInput.class.getName());
- verfiyOutput(dagImpl, "vertex3",
RssUnorderedPartitionedKVOutput.class.getName());
- verfiyInput(dagImpl, "vertex4", RssUnorderedKVInput.class.getName());
+ verfiyOutput(dagImpl, "vertex1",
RssOrderedPartitionedKVOutput.class.getName(), 0, 1);
+ verfiyInput(dagImpl, "vertex2", RssOrderedGroupedKVInput.class.getName(),
0, 1);
+ verfiyOutput(dagImpl, "vertex2", RssUnorderedKVOutput.class.getName(), 1,
2);
+ verfiyInput(dagImpl, "vertex3", RssUnorderedKVInput.class.getName(), 1, 2);
+ verfiyOutput(dagImpl, "vertex3",
RssUnorderedPartitionedKVOutput.class.getName(), 2, 3);
+ verfiyInput(dagImpl, "vertex4", RssUnorderedKVInput.class.getName(), 2, 3);
}
- public static void verfiyInput(DAGImpl dag, String name, String
expectedInputClassName) throws Exception {
+ public static void verfiyInput(DAGImpl dag, String name, String
expectedInputClassName,
+ int expectedSourceVertexId, int
expectedDestinationVertexId) throws Exception {
// 1 verfiy rename rss io class name
List<InputSpec> inputSpecs = dag.getVertex(name).getInputSpecList(0);
Assertions.assertEquals(1, inputSpecs.size());
@@ -176,9 +179,13 @@ public class RssDAGAppMasterTest {
// should not deliver to Input/Output.
Assertions.assertEquals(12345,
conf.getInt(TEZ_RUNTIME_IFILE_READAHEAD_BYTES, -1));
Assertions.assertNull(conf.get("tez.config.from.client"));
+ // 4 verfiy vertex id
+ Assertions.assertEquals(expectedSourceVertexId,
conf.getInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, -1));
+ Assertions.assertEquals(expectedDestinationVertexId,
conf.getInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, -1));
}
- public static void verfiyOutput(DAGImpl dag, String name, String
expectedOutputClassName) throws Exception {
+ public static void verfiyOutput(DAGImpl dag, String name, String
expectedOutputClassName,
+ int expectedSourceVertexId, int
expectedDestinationVertexId) throws Exception {
// 1 verfiy rename rss io class name
List<OutputSpec> outputSpecs = dag.getVertex(name).getOutputSpecList(0);
Assertions.assertEquals(1, outputSpecs.size());
@@ -193,6 +200,9 @@ public class RssDAGAppMasterTest {
Assertions.assertEquals("value1", conf.get("tez.config1"));
Assertions.assertEquals("value3", conf.get("tez.config3"));
Assertions.assertNull(conf.get("tez.config2"));
+ // 4 verfiy vertex id
+ Assertions.assertEquals(expectedSourceVertexId,
conf.getInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, -1));
+ Assertions.assertEquals(expectedDestinationVertexId,
conf.getInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, -1));
}
private static DAG createDAG(String dageName, Configuration conf) {
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManagerTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManagerTest.java
index 8d49cf15..3bd376d3 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManagerTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManagerTest.java
@@ -261,7 +261,7 @@ public class RssShuffleManagerTest {
boolean ifileReadAheadEnabled, int ifileReadAheadLength,
CompressionCodec codec,
FetchedInputAllocator inputAllocator) throws IOException {
super(inputContext, conf, numInputs, bufferSize, ifileReadAheadEnabled,
- ifileReadAheadLength, codec, inputAllocator);
+ ifileReadAheadLength, codec, inputAllocator, 0);
}
@Override
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleSchedulerTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleSchedulerTest.java
index e64645be..ae84ac7e 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleSchedulerTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleSchedulerTest.java
@@ -886,7 +886,7 @@ public class RssShuffleSchedulerTest {
boolean ifileReadAhead, int ifileReadAheadLength,
String srcNameTrimmed, boolean fetcherShouldWait) throws
IOException {
super(inputContext, conf, numberOfInputs, shuffle, mergeManager,
allocator, startTime, codec,
- ifileReadAhead, ifileReadAheadLength, srcNameTrimmed);
+ ifileReadAhead, ifileReadAheadLength, srcNameTrimmed, 0);
this.fetcherShouldWait = fetcherShouldWait;
this.reporter = shuffle;
this.inputContext = inputContext;
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleTest.java
index a345c0dd..7d0255a2 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleTest.java
@@ -34,6 +34,10 @@ import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.common.security.JobTokenIdentifier;
import org.apache.tez.common.security.JobTokenSecretManager;
import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.ExecutionContext;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.api.TaskFailureType;
@@ -86,7 +90,7 @@ public class RssShuffleTest {
InputContext inputContext = createTezInputContext();
TezConfiguration conf = new TezConfiguration();
conf.setLong(Constants.TEZ_RUNTIME_TASK_MEMORY, 300000L);
- RssShuffle shuffle = new RssShuffle(inputContext, conf, 1, 3000000L);
+ RssShuffle shuffle = new RssShuffle(inputContext, conf, 1, 3000000L,
0);
try {
shuffle.run();
ShuffleScheduler scheduler = shuffle.rssScheduler;
@@ -127,7 +131,7 @@ public class RssShuffleTest {
InputContext inputContext = createTezInputContext();
TezConfiguration conf = new TezConfiguration();
conf.setLong(Constants.TEZ_RUNTIME_TASK_MEMORY, 300000L);
- RssShuffle shuffle = new RssShuffle(inputContext, conf, 1, 3000000L);
+ RssShuffle shuffle = new RssShuffle(inputContext, conf, 1, 3000000L,
0);
try {
shuffle.run();
ShuffleScheduler scheduler = shuffle.rssScheduler;
@@ -160,6 +164,9 @@ public class RssShuffleTest {
doReturn(applicationId).when(inputContext).getApplicationId();
doReturn("Map 1").when(inputContext).getSourceVertexName();
doReturn("Reducer 1").when(inputContext).getTaskVertexName();
+ String uniqueId = String.format("%s_%05d", TezTaskAttemptID.getInstance(
+
TezTaskID.getInstance(TezVertexID.getInstance(TezDAGID.getInstance(applicationId,
1), 1), 1), 1), 1);
+ doReturn(uniqueId).when(inputContext).getUniqueIdentifier();
when(inputContext.getCounters()).thenReturn(new TezCounters());
ExecutionContext executionContext = new ExecutionContextImpl("localhost");
doReturn(executionContext).when(inputContext).getExecutionContext();
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index 2b581656..23ca1bfa 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -83,7 +83,7 @@ public class WriteBufferManagerTest {
long sendCheckInterval = 500L;
long sendCheckTimeout = 5;
int bitmapSplitNum = 1;
- int shuffleId = getShuffleId(tezTaskAttemptID, "Map 1", "Reducer 2");
+ int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
new WriteBufferManager(tezTaskAttemptID, maxMemSize, appId,
@@ -138,7 +138,7 @@ public class WriteBufferManagerTest {
long sendCheckInterval = 500L;
long sendCheckTimeout = 60 * 1000 * 10L;
int bitmapSplitNum = 1;
- int shuffleId = getShuffleId(tezTaskAttemptID, "Map 1", "Reducer 2");
+ int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
new WriteBufferManager(tezTaskAttemptID, maxMemSize, appId,
@@ -200,7 +200,7 @@ public class WriteBufferManagerTest {
long sendCheckInterval = 500L;
long sendCheckTimeout = 60 * 1000 * 10L;
int bitmapSplitNum = 1;
- int shuffleId = getShuffleId(tezTaskAttemptID, "Map 1", "Reducer 2");
+ int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
new WriteBufferManager(tezTaskAttemptID, maxMemSize, appId,
@@ -227,9 +227,9 @@ public class WriteBufferManagerTest {
writeClient.mockedShuffleServer.getFlushBlockSize());
}
- private int getShuffleId(TezTaskAttemptID tezTaskAttemptID, String
upVertexName, String downVertexName) {
+ private int getShuffleId(TezTaskAttemptID tezTaskAttemptID, int upVertexId,
int downVertexId) {
TezVertexID tezVertexID = tezTaskAttemptID.getTaskID().getVertexID();
- int shuffleId =
RssTezUtils.computeShuffleId(tezVertexID.getDAGId().getId(), upVertexName,
downVertexName);
+ int shuffleId =
RssTezUtils.computeShuffleId(tezVertexID.getDAGId().getId(), upVertexId,
downVertexId);
return shuffleId;
}
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/input/RssOrderedGroupedKVInputTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/input/RssOrderedGroupedKVInputTest.java
index 1502c95d..5db3ee8d 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/input/RssOrderedGroupedKVInputTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/input/RssOrderedGroupedKVInputTest.java
@@ -39,6 +39,8 @@ import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
+import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
@@ -86,6 +88,8 @@ public class RssOrderedGroupedKVInputTest {
doReturn(executionContext).when(inputContext).getExecutionContext();
Configuration conf = new TezConfiguration();
+ conf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, 1);
+ conf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, 2);
UserPayload payLoad = TezUtils.createUserPayloadFromConf(conf);
String[] workingDirs = new String[]{"workDir1"};
TezCounters counters = new TezCounters();