This is an automated email from the ASF dual-hosted git repository.
zhengchenyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new be23f6137 [#1750] feat(remote merge): Support Tez. (#2160)
be23f6137 is described below
commit be23f6137fb5f478372328bf492923f115203cf6
Author: zhengchenyu <[email protected]>
AuthorDate: Thu Oct 10 10:39:36 2024 +0800
[#1750] feat(remote merge): Support Tez. (#2160)
### What changes were proposed in this pull request?
Tez support remote merge.
### Why are the changes needed?
Fix: #1750
### Does this PR introduce _any_ user-facing change?
Yes, I will refine documentation in other PR.
### How was this patch tested?
unit test, integration test, real job in cluster.
---
client-tez/pom.xml | 34 +-
.../apache/tez/common/GetShuffleServerRequest.java | 36 ++
.../java/org/apache/tez/common/RssTezConfig.java | 10 +
.../java/org/apache/tez/common/RssTezUtils.java | 17 +-
.../org/apache/tez/dag/app/RssDAGAppMaster.java | 9 +-
.../tez/dag/app/TezRemoteShuffleManager.java | 25 +-
.../shuffle/orderedgrouped/RMRssShuffle.java | 269 ++++++++++++
.../orderedgrouped/RMRssShuffleScheduler.java | 84 ++++
.../library/common/sort/buffer/WriteBuffer.java | 34 +-
.../common/sort/buffer/WriteBufferManager.java | 40 +-
.../library/common/sort/impl/RssSorter.java | 9 +-
.../library/common/sort/impl/RssUnSorter.java | 5 +-
.../library/input/RMRssOrderedGroupedKVInput.java | 320 ++++++++++++++
.../output/RssOrderedPartitionedKVOutput.java | 27 +-
.../shuffle/orderedgrouped/RMRssShuffleTest.java | 443 +++++++++++++++++++
.../common/sort/buffer/WriteBufferManagerTest.java | 151 ++++++-
.../common/sort/buffer/WriteBufferTest.java | 95 +++-
.../input/RMRssOrderedGroupedKVInputTest.java | 481 +++++++++++++++++++++
...untTest.java => RMTezOrderedWordCountTest.java} | 42 +-
.../uniffle/test/TezCartesianProductTest.java | 6 +
.../uniffle/test/TezIntegrationTestBase.java | 9 +
.../uniffle/test/TezJoinIntegrationTestBase.java | 6 +
.../uniffle/test/TezOrderedWordCountTest.java | 6 +
.../uniffle/test/TezSimpleSessionExampleTest.java | 6 +
.../org/apache/uniffle/test/TezWordCountTest.java | 6 +
25 files changed, 2121 insertions(+), 49 deletions(-)
diff --git a/client-tez/pom.xml b/client-tez/pom.xml
index 52430441f..3f7ba7a02 100644
--- a/client-tez/pom.xml
+++ b/client-tez/pom.xml
@@ -128,17 +128,29 @@
</exclusion>
</exclusions>
</dependency>
- <dependency>
- <groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-minicluster</artifactId>
- <scope>test</scope>
- <exclusions>
- <exclusion>
- <groupId>org.slf4j</groupId>
- <artifactId>slf4j-log4j12</artifactId>
- </exclusion>
- </exclusions>
- </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-minicluster</artifactId>
+ <scope>test</scope>
+ <exclusions>
+ <exclusion>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.uniffle</groupId>
+ <artifactId>rss-common</artifactId>
+ <scope>test</scope>
+ <type>test-jar</type>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.uniffle</groupId>
+ <artifactId>rss-client</artifactId>
+ <scope>test</scope>
+ <type>test-jar</type>
+ </dependency>
</dependencies>
<build>
diff --git
a/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
index e87b97406..dd9a6decc 100644
---
a/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
+++
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
@@ -22,6 +22,7 @@ import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableUtils;
import org.apache.tez.dag.records.TezTaskAttemptID;
public class GetShuffleServerRequest implements Writable {
@@ -29,15 +30,32 @@ public class GetShuffleServerRequest implements Writable {
private int startIndex;
private int partitionNum;
private int shuffleId;
+ private String keyClassName;
+ private String valueClassName;
+ private String comparatorClassName;
public GetShuffleServerRequest() {}
public GetShuffleServerRequest(
TezTaskAttemptID currentTaskAttemptID, int startIndex, int partitionNum,
int shuffleId) {
+ this(currentTaskAttemptID, startIndex, partitionNum, shuffleId, "", "",
"");
+ }
+
+ public GetShuffleServerRequest(
+ TezTaskAttemptID currentTaskAttemptID,
+ int startIndex,
+ int partitionNum,
+ int shuffleId,
+ String keyClassName,
+ String valueClassName,
+ String comparatorClassName) {
this.currentTaskAttemptID = currentTaskAttemptID;
this.startIndex = startIndex;
this.partitionNum = partitionNum;
this.shuffleId = shuffleId;
+ this.keyClassName = keyClassName;
+ this.valueClassName = valueClassName;
+ this.comparatorClassName = comparatorClassName;
}
@Override
@@ -51,6 +69,9 @@ public class GetShuffleServerRequest implements Writable {
} else {
output.writeBoolean(false);
}
+ WritableUtils.writeString(output, keyClassName);
+ WritableUtils.writeString(output, valueClassName);
+ WritableUtils.writeString(output, comparatorClassName);
}
@Override
@@ -63,6 +84,9 @@ public class GetShuffleServerRequest implements Writable {
currentTaskAttemptID = new TezTaskAttemptID();
currentTaskAttemptID.readFields(dataInput);
}
+ keyClassName = WritableUtils.readString(dataInput);
+ valueClassName = WritableUtils.readString(dataInput);
+ comparatorClassName = WritableUtils.readString(dataInput);
}
@Override
@@ -94,4 +118,16 @@ public class GetShuffleServerRequest implements Writable {
public int getShuffleId() {
return shuffleId;
}
+
+ public String getKeyClassName() {
+ return keyClassName;
+ }
+
+ public String getValueClassName() {
+ return valueClassName;
+ }
+
+ public String getComparatorClassName() {
+ return comparatorClassName;
+ }
}
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
index a1186534a..e73b04fb5 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
@@ -225,6 +225,16 @@ public class RssTezConfig {
public static final String RSS_SHUFFLE_MODE = TEZ_RSS_CONFIG_PREFIX +
"shuffle.mode";
public static final String DEFAULT_RSS_SHUFFLE_MODE = "remote";
+ public static final String RSS_REMOTE_MERGE_ENABLE =
+ TEZ_RSS_CONFIG_PREFIX + RssClientConfig.RSS_REMOTE_MERGE_ENABLE;
+ public static final boolean RSS_REMOTE_MERGE_ENABLE_DEFAULT = false;
+ public static final String RSS_MERGED_BLOCK_SZIE =
+ TEZ_RSS_CONFIG_PREFIX + RssClientConfig.RSS_MERGED_BLOCK_SZIE;
+ public static final int RSS_MERGED_BLOCK_SZIE_DEFAULT =
+ RssClientConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT;
+ public static final String RSS_REMOTE_MERGE_CLASS_LOADER =
+ TEZ_RSS_CONFIG_PREFIX + RssClientConfig.RSS_REMOTE_MERGE_CLASS_LOADER;
+
public static RssConf toRssConf(Configuration jobConf) {
RssConf rssConf = new RssConf();
for (Map.Entry<String, String> entry : jobConf) {
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
index 3df582bde..7b6d9ad2c 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
@@ -33,6 +33,7 @@ import
org.apache.tez.runtime.library.input.ConcatenatedMergedKeyValuesInput;
import org.apache.tez.runtime.library.input.OrderedGroupedInputLegacy;
import org.apache.tez.runtime.library.input.OrderedGroupedKVInput;
import org.apache.tez.runtime.library.input.OrderedGroupedMergedKVInput;
+import org.apache.tez.runtime.library.input.RMRssOrderedGroupedKVInput;
import org.apache.tez.runtime.library.input.RssConcatenatedMergedKeyValueInput;
import
org.apache.tez.runtime.library.input.RssConcatenatedMergedKeyValuesInput;
import org.apache.tez.runtime.library.input.RssOrderedGroupedInputLegacy;
@@ -122,7 +123,8 @@ public class RssTezUtils {
.replicaRead(replicaRead)
.replicaSkipEnabled(replicaSkipEnabled)
.dataTransferPoolSize(dataTransferPoolSize)
- .dataCommitPoolSize(dataCommitPoolSize));
+ .dataCommitPoolSize(dataCommitPoolSize)
+ .rssConf(RssTezConfig.toRssConf(conf)));
return client;
}
@@ -423,13 +425,14 @@ public class RssTezUtils {
}
}
- public static String replaceRssInputClassName(String className) {
+ public static String replaceRssInputClassName(String className, boolean
isRemoteMergeEnable) {
if (className.equals(OrderedGroupedKVInput.class.getName())) {
- LOG.info(
- "Input class name will transient from {} to {}",
- className,
- RssOrderedGroupedKVInput.class.getName());
- return RssOrderedGroupedKVInput.class.getName();
+ String orderedInputClasName =
+ isRemoteMergeEnable
+ ? RMRssOrderedGroupedKVInput.class.getName()
+ : RssOrderedGroupedKVInput.class.getName();
+ LOG.info("Input class name will transient from {} to {}", className,
orderedInputClasName);
+ return orderedInputClasName;
} else if (className.equals(OrderedGroupedMergedKVInput.class.getName())) {
LOG.info(
"Input class name will transient from {} to {}",
diff --git
a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
index ac1f9c28f..0fb41146c 100644
--- a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
+++ b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
@@ -608,7 +608,14 @@ public class RssDAGAppMaster extends DAGAppMaster {
outputDescriptor.getClass().getSuperclass().getDeclaredField("className");
inputClassNameField.setAccessible(true);
String inputClassName = (String)
outputClassNameField.get(inputDescriptor);
- String rssInputClassName =
RssTezUtils.replaceRssInputClassName(inputClassName);
+ String rssInputClassName =
+ RssTezUtils.replaceRssInputClassName(
+ inputClassName,
+ appMaster
+ .getConfig()
+ .getBoolean(
+ RssTezConfig.RSS_REMOTE_MERGE_ENABLE,
+ RssTezConfig.RSS_REMOTE_MERGE_ENABLE_DEFAULT));
outputClassNameField.set(inputDescriptor, rssInputClassName);
}
} catch (IOException | IllegalAccessException | NoSuchFieldException e) {
diff --git
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
index bf376814e..85a138c13 100644
---
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
+++
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
@@ -190,7 +190,13 @@ public class TezRemoteShuffleManager implements
ServicePluginLifecycle {
if (shuffleIdToShuffleAssignsInfo.containsKey(shuffleId)) {
shuffleAssignmentsInfo =
shuffleIdToShuffleAssignsInfo.get(shuffleId);
} else {
- shuffleAssignmentsInfo =
getShuffleWorks(request.getPartitionNum(), shuffleId);
+ shuffleAssignmentsInfo =
+ getShuffleWorks(
+ request.getPartitionNum(),
+ shuffleId,
+ request.getKeyClassName(),
+ request.getValueClassName(),
+ request.getComparatorClassName());
}
if (shuffleAssignmentsInfo == null) {
@@ -221,7 +227,12 @@ public class TezRemoteShuffleManager implements
ServicePluginLifecycle {
}
}
- private ShuffleAssignmentsInfo getShuffleWorks(int partitionNum, int
shuffleId) {
+ private ShuffleAssignmentsInfo getShuffleWorks(
+ int partitionNum,
+ int shuffleId,
+ String keyClassName,
+ String valueClassName,
+ String comparatorClassName) {
ShuffleAssignmentsInfo shuffleAssignmentsInfo;
int requiredAssignmentShuffleServersNum =
RssTezUtils.getRequiredShuffleServerNumber(conf, 200, partitionNum);
@@ -292,7 +303,15 @@ public class TezRemoteShuffleManager implements
ServicePluginLifecycle {
remoteStorage,
ShuffleDataDistributionType.NORMAL,
RssTezConfig.toRssConf(conf)
-
.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE)));
+
.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE),
+ 0,
+ keyClassName,
+ valueClassName,
+ comparatorClassName,
+ conf.getInt(
+
RssTezConfig.RSS_MERGED_BLOCK_SZIE,
+
RssTezConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT),
+
conf.get(RssTezConfig.RSS_REMOTE_MERGE_CLASS_LOADER)));
LOG.info(
"Finish register shuffle with "
+ (System.currentTimeMillis() - start)
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffle.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffle.java
new file mode 100644
index 000000000..542d251e7
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffle.java
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.hadoop.classification.InterfaceAudience.Private;
+import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.WritableComparator;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.tez.common.InputContextUtils;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.common.UmbilicalUtils;
+import org.apache.tez.common.counters.TaskCounter;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.library.common.ConfigUtils;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.record.reader.KeyValuesReader;
+import org.apache.uniffle.client.record.reader.RMRecordsReader;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+
+@Private
[email protected]
+public class RMRssShuffle implements ExceptionReporter {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(RMRssShuffle.class);
+
+ private final Configuration conf;
+ private final RssConf rssConf;
+ private final InputContext inputContext;
+ private final int numInputs;
+ private final int shuffleId;
+ private final ApplicationAttemptId applicationAttemptId;
+ private final String appId;
+ private ShuffleInputEventHandlerOrderedGrouped eventHandler;
+ private final TezTaskAttemptID tezTaskAttemptID;
+ private final String srcNameTrimmed;
+ private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+
+ private AtomicBoolean isShutDown = new AtomicBoolean(false);
+
+ final TezCounter skippedInputCounter;
+ final TezCounter inputRecordCounter;
+
+ final Map<Integer, Set<InputAttemptIdentifier>>
partitionIdToSuccessMapTaskAttempts =
+ new HashMap<>();
+ final Map<Integer, Set<TezTaskID>> partitionIdToSuccessTezTasks = new
HashMap<>();
+
+ final Set<Integer> partitionIds = new HashSet<>();
+ private RMRecordsReader reader = null;
+ private RMRssShuffleScheduler scheduler;
+
+ public RMRssShuffle(
+ InputContext inputContext,
+ Configuration conf,
+ int numInputs,
+ int shuffleId,
+ ApplicationAttemptId applicationAttemptId)
+ throws IOException {
+ this.inputContext = inputContext;
+ this.conf = conf;
+ this.rssConf = RssTezConfig.toRssConf(conf);
+ this.numInputs = numInputs;
+ this.shuffleId = shuffleId;
+ this.applicationAttemptId = applicationAttemptId;
+ this.appId = this.applicationAttemptId.toString();
+ this.srcNameTrimmed =
TezUtilsInternal.cleanVertexName(inputContext.getSourceVertexName());
+ LOG.info(srcNameTrimmed + ": Shuffle assigned with " + numInputs + "
inputs.");
+ this.skippedInputCounter =
+ inputContext.getCounters().findCounter(TaskCounter.NUM_SKIPPED_INPUTS);
+ this.inputRecordCounter =
+
inputContext.getCounters().findCounter(TaskCounter.INPUT_RECORDS_PROCESSED);
+
+ this.scheduler =
+ new RMRssShuffleScheduler(
+ this.inputContext,
+ this.conf,
+ numInputs,
+ this,
+ null,
+ null,
+ System.currentTimeMillis(),
+ null,
+ false,
+ 0,
+ srcNameTrimmed,
+ this);
+ this.eventHandler =
+ new ShuffleInputEventHandlerOrderedGrouped(
+ inputContext, scheduler, ShuffleUtils.isTezShuffleHandler(conf));
+ this.tezTaskAttemptID =
InputContextUtils.getTezTaskAttemptID(this.inputContext);
+ // When remote merge is enable, we use the reading-while-processing
method, so we set input
+ // ready directly.
+ inputContext.inputIsReady();
+ }
+
+ public void handleEvents(List<Event> events) throws IOException {
+ if (!isShutDown.get()) {
+ eventHandler.handleEvents(events);
+ } else {
+ LOG.info(
+ srcNameTrimmed
+ + ": Ignoring events since already shutdown. EventCount: "
+ + events.size());
+ }
+ }
+
+ public void run() throws IOException {
+ this.partitionToServers =
+ UmbilicalUtils.requestShuffleServer(
+ inputContext.getApplicationId(), conf, tezTaskAttemptID,
shuffleId);
+ }
+
+ public void shutdown() {
+ if (!isShutDown.getAndSet(true)) {
+ if (reader != null) {
+ reader.close();
+ }
+ LOG.info("Shutting down Shuffle for source: " + srcNameTrimmed);
+ }
+ }
+
+ public void waitForEvents() throws InterruptedException {
+ while (!allInputTaskAttemptDone()) {
+ Thread.sleep(100);
+ }
+ // report unique blocks
+ reportUniqueBlockIds();
+ if (partitionIds.size() > 0) {
+ reader = createRMRecordsReader(partitionIds);
+ reader.start();
+ }
+ }
+
+ private boolean allInputTaskAttemptDone() {
+ return (this.partitionIdToSuccessTezTasks.values().stream().mapToInt(s ->
s.size()).sum()
+ + skippedInputCounter.getValue())
+ == numInputs;
+ }
+
+ public void reportUniqueBlockIds() {
+ ShuffleWriteClient writeClient = RssTezUtils.createShuffleClient(conf);
+ for (int partitionId : partitionIds) {
+ Roaring64NavigableMap blockIdBitmap =
+ writeClient.getShuffleResult(
+ null,
+ new HashSet<>(partitionToServers.get(partitionId)),
+ appId,
+ shuffleId,
+ partitionId);
+ Roaring64NavigableMap taskIdBitmap =
+ RssTezUtils.fetchAllRssTaskIds(
+ partitionIdToSuccessMapTaskAttempts.get(partitionId),
+ numInputs,
+ applicationAttemptId.getAttemptId(),
+ RssTezUtils.getMaxAttemptNo(conf));
+ Roaring64NavigableMap uniqueBlockIdBitMap =
Roaring64NavigableMap.bitmapOf();
+ blockIdBitmap.forEach(
+ blockId -> {
+ long taId = RssTezUtils.getTaskAttemptId(blockId);
+ if (taskIdBitmap.contains(taId)) {
+ uniqueBlockIdBitMap.add(blockId);
+ }
+ });
+ writeClient.startSortMerge(
+ new HashSet<>(partitionToServers.get(partitionId)),
+ appId,
+ shuffleId,
+ partitionId,
+ uniqueBlockIdBitMap);
+ }
+ }
+
+ public KeyValuesReader getKeyValuesReader() {
+ if (reader == null) {
+ return new KeyValuesReader() {
+ @Override
+ public boolean next() {
+ return false;
+ }
+
+ @Override
+ public Object getCurrentKey() throws IOException {
+ throw new IOException("No data available");
+ }
+
+ @Override
+ public Iterable getCurrentValues() throws IOException {
+ throw new IOException("No data available");
+ }
+ };
+ }
+ return this.reader.keyValuesReader();
+ }
+
+ @VisibleForTesting
+ public RMRecordsReader createRMRecordsReader(Set partitionIds) {
+ Class keyClass = ConfigUtils.getIntermediateInputKeyClass(conf);
+ Class valueClass = ConfigUtils.getIntermediateInputValueClass(conf);
+ // For hive on tez, we use separate serializer and comparator, namely
+ // TezBytesWritableSerialization and TezBytesComparator. But in remote
+ // merge mode, we use separate serializers, so we should also use
+ // separate comparators.
+ RawComparator rawComparator = WritableComparator.get(keyClass);
+ return new RMRecordsReader(
+ appId,
+ shuffleId,
+ partitionIds,
+ partitionToServers,
+ rssConf,
+ keyClass,
+ valueClass,
+ rawComparator,
+ true,
+ null,
+ false,
+ (inc) -> {
+ inputRecordCounter.increment(inc);
+ });
+ }
+
+ @Override
+ public void reportException(Throwable t) {
+ throw new RssException("should never happen!");
+ }
+
+ @Override
+ public void killSelf(Exception exception, String message) {
+ throw new RssException("should never happen!");
+ }
+}
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleScheduler.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleScheduler.java
new file mode 100644
index 000000000..14732ec0a
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleScheduler.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.io.IOException;
+import java.util.HashSet;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.compress.CompressionCodec;
+import org.apache.tez.common.IdUtils;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.library.common.CompositeInputAttemptIdentifier;
+
+// Although RMRssShuffleScheduler is extended from ShuffleScheduler, it is
only used
+// to handle DME events. Because the protobuf-java version of tez is different
from
+// uniffle, it is not convenient to parse DME events. It means there is no
need to
+// start RMRssShuffleScheduler.
+public class RMRssShuffleScheduler extends ShuffleScheduler {
+
+ private RMRssShuffle rssShuffle;
+
+ public RMRssShuffleScheduler(
+ InputContext inputContext,
+ Configuration conf,
+ int numberOfInputs,
+ ExceptionReporter exceptionReporter,
+ MergeManager mergeManager,
+ FetchedInputAllocatorOrderedGrouped allocator,
+ long startTime,
+ CompressionCodec codec,
+ boolean ifileReadAhead,
+ int ifileReadAheadLength,
+ String srcNameTrimmed,
+ RMRssShuffle rssShuffle)
+ throws IOException {
+ super(
+ inputContext,
+ conf,
+ numberOfInputs,
+ exceptionReporter,
+ mergeManager,
+ allocator,
+ startTime,
+ codec,
+ ifileReadAhead,
+ ifileReadAheadLength,
+ srcNameTrimmed);
+ this.rssShuffle = rssShuffle;
+ }
+
+ @Override
+ public synchronized void addKnownMapOutput(
+ String inputHostName, int port, int partitionId,
CompositeInputAttemptIdentifier srcAttempt) {
+ super.addKnownMapOutput(inputHostName, port, partitionId, srcAttempt);
+ rssShuffle.partitionIds.add(partitionId);
+ if
(!rssShuffle.partitionIdToSuccessMapTaskAttempts.containsKey(partitionId)) {
+ rssShuffle.partitionIdToSuccessMapTaskAttempts.put(partitionId, new
HashSet<>());
+ }
+
rssShuffle.partitionIdToSuccessMapTaskAttempts.get(partitionId).add(srcAttempt);
+
+ String pathComponent = srcAttempt.getPathComponent();
+ TezTaskAttemptID tezTaskAttemptId =
IdUtils.convertTezTaskAttemptID(pathComponent);
+ if (!rssShuffle.partitionIdToSuccessTezTasks.containsKey(partitionId)) {
+ rssShuffle.partitionIdToSuccessTezTasks.put(partitionId, new
HashSet<>());
+ }
+
rssShuffle.partitionIdToSuccessTezTasks.get(partitionId).add(tezTaskAttemptId.getTaskID());
+ }
+}
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
index 43d784f1d..c1af1d006 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
@@ -17,6 +17,7 @@
package org.apache.tez.runtime.library.common.sort.buffer;
+import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Comparator;
@@ -29,6 +30,8 @@ import org.apache.hadoop.io.serializer.Serializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.common.serializer.SerializerInstance;
+
public class WriteBuffer<K, V> extends OutputStream {
private static final Logger LOG = LoggerFactory.getLogger(WriteBuffer.class);
@@ -47,30 +50,47 @@ public class WriteBuffer<K, V> extends OutputStream {
private final List<WrappedBuffer> buffers = Lists.newArrayList();
private final List<Record<K>> records = Lists.newArrayList();
+ private final boolean useUniffleSerializer;
+ private SerializerInstance serializerInstance;
+ private DataOutputStream dataOutputStream;
+
public WriteBuffer(
boolean isNeedSorted,
int partitionId,
RawComparator<K> comparator,
long maxSegmentSize,
+ boolean useUniffleSerializer,
Serializer<K> keySerializer,
- Serializer<V> valueSerializer) {
+ Serializer<V> valueSerializer,
+ SerializerInstance serializerInstance) {
this.partitionId = partitionId;
this.comparator = comparator;
+ this.useUniffleSerializer = useUniffleSerializer;
this.maxSegmentSize = maxSegmentSize;
this.keySerializer = keySerializer;
this.valSerializer = valueSerializer;
this.isNeedSorted = isNeedSorted;
+ this.serializerInstance = serializerInstance;
+ if (useUniffleSerializer) {
+ this.dataOutputStream = new DataOutputStream(this);
+ }
}
/** add records */
public int addRecord(K key, V value) throws IOException {
- keySerializer.open(this);
- valSerializer.open(this);
+ if (!useUniffleSerializer) {
+ keySerializer.open(this);
+ valSerializer.open(this);
+ }
int lastOffSet = currentOffset;
int lastIndex = currentIndex;
int lastDataLength = dataLength;
int keyIndex = lastIndex;
- keySerializer.serialize(key);
+ if (useUniffleSerializer) {
+ serializerInstance.serialize(key, this.dataOutputStream);
+ } else {
+ keySerializer.serialize(key);
+ }
int keyLength = dataLength - lastDataLength;
int keyOffset = lastOffSet;
if (compact(lastIndex, lastOffSet, keyLength)) {
@@ -78,7 +98,11 @@ public class WriteBuffer<K, V> extends OutputStream {
keyIndex = lastIndex;
}
lastDataLength = dataLength;
- valSerializer.serialize(value);
+ if (useUniffleSerializer) {
+ serializerInstance.serialize(value, this.dataOutputStream);
+ } else {
+ valSerializer.serialize(value);
+ }
int valueLength = dataLength - lastDataLength;
records.add(new Record<K>(keyIndex, keyOffset, keyLength, valueLength));
return keyLength + valueLength;
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
index 93735efa4..78c235447 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
@@ -38,6 +38,8 @@ import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Uninterruptibles;
import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.io.serializer.Serializer;
import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.counters.TezCounter;
@@ -54,6 +56,8 @@ import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.serializer.SerializerFactory;
+import org.apache.uniffle.common.serializer.SerializerInstance;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.common.util.ThreadUtils;
@@ -105,6 +109,9 @@ public class WriteBufferManager<K, V> {
private final TezCounter mapOutputByteCounter;
private final TezCounter mapOutputRecordCounter;
+ private final boolean useUniffleSerializer;
+ private SerializerInstance serializerInstance;
+
/** WriteBufferManager */
public WriteBufferManager(
TezTaskAttemptID tezTaskAttemptID,
@@ -133,7 +140,10 @@ public class WriteBufferManager<K, V> {
int shuffleId,
boolean isNeedSorted,
TezCounter mapOutputByteCounter,
- TezCounter mapOutputRecordCounter) {
+ TezCounter mapOutputRecordCounter,
+ boolean useUniffleSerializer,
+ Class<K> keyClass,
+ Class<V> valClass) {
this.tezTaskAttemptID = tezTaskAttemptID;
this.maxMemSize = maxMemSize;
this.appId = appId;
@@ -142,7 +152,14 @@ public class WriteBufferManager<K, V> {
this.successBlockIds = successBlockIds;
this.failedBlockIds = failedBlockIds;
this.shuffleWriteClient = shuffleWriteClient;
- this.comparator = comparator;
+ // For hive on tez, we use separate serializer and comparator, namely
+ // TezBytesWritableSerialization and TezBytesComparator. But in remote
+ // merge mode, we use separate serializers, so we should also use
+ // separate comparators.
+ this.comparator =
+ useUniffleSerializer
+ ? WritableComparator.get((Class<? extends WritableComparable>)
keyClass)
+ : comparator;
this.maxSegmentSize = maxSegmentSize;
this.keySerializer = keySerializer;
this.valSerializer = valSerializer;
@@ -162,6 +179,13 @@ public class WriteBufferManager<K, V> {
this.isNeedSorted = isNeedSorted;
this.mapOutputByteCounter = mapOutputByteCounter;
this.mapOutputRecordCounter = mapOutputRecordCounter;
+ this.useUniffleSerializer = useUniffleSerializer;
+ if (useUniffleSerializer) {
+ SerializerFactory factory = new SerializerFactory(rssConf);
+ org.apache.uniffle.common.serializer.Serializer serializer =
factory.getSerializer(keyClass);
+ this.serializerInstance = serializer.newInstance();
+ assert
factory.getSerializer(valClass).getClass().equals(serializer.getClass());
+ }
this.sendExecutorService =
Executors.newFixedThreadPool(sendThreadNum,
ThreadUtils.getThreadFactory("send-thread"));
}
@@ -188,7 +212,14 @@ public class WriteBufferManager<K, V> {
if (!buffers.containsKey(partitionId)) {
WriteBuffer<K, V> sortWriterBuffer =
new WriteBuffer(
- isNeedSorted, partitionId, comparator, maxSegmentSize,
keySerializer, valSerializer);
+ isNeedSorted,
+ partitionId,
+ comparator,
+ maxSegmentSize,
+ useUniffleSerializer,
+ keySerializer,
+ valSerializer,
+ serializerInstance);
buffers.putIfAbsent(partitionId, sortWriterBuffer);
waitSendBuffers.add(sortWriterBuffer);
}
@@ -371,7 +402,8 @@ public class WriteBufferManager<K, V> {
final int uncompressLength = data.length;
long start = System.currentTimeMillis();
- final byte[] compressed = codec.map(c -> c.compress(data)).orElse(data);
+ final byte[] compressed =
+ useUniffleSerializer ? data : codec.map(c ->
c.compress(data)).orElse(data);
final long crc32 = ChecksumUtils.getCrc32(compressed);
compressTime += System.currentTimeMillis() - start;
final long blockId =
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
index fe4f11e13..84a71680e 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
@@ -30,6 +30,7 @@ import org.apache.tez.common.RssTezConfig;
import org.apache.tez.common.RssTezUtils;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.library.common.ConfigUtils;
import org.apache.tez.runtime.library.common.sort.buffer.WriteBufferManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -139,6 +140,9 @@ public class RssSorter extends ExternalSorter {
LOG.info("applicationAttemptId is {}", applicationAttemptId.toString());
+ boolean isRemoteMergeEnable =
+ conf.getBoolean(
+ RssTezConfig.RSS_REMOTE_MERGE_ENABLE,
RssTezConfig.RSS_REMOTE_MERGE_ENABLE_DEFAULT);
bufferManager =
new WriteBufferManager(
tezTaskAttemptID,
@@ -167,7 +171,10 @@ public class RssSorter extends ExternalSorter {
shuffleId,
true,
mapOutputByteCounter,
- mapOutputRecordCounter);
+ mapOutputRecordCounter,
+ isRemoteMergeEnable,
+ ConfigUtils.getIntermediateOutputKeyClass(this.conf),
+ ConfigUtils.getIntermediateOutputValueClass(this.conf));
LOG.info("Initialized WriteBufferManager.");
}
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
index 94e87a40e..b553700a6 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
@@ -165,7 +165,10 @@ public class RssUnSorter extends ExternalSorter {
shuffleId,
false,
mapOutputByteCounter,
- mapOutputRecordCounter);
+ mapOutputRecordCounter,
+ false,
+ null,
+ null);
LOG.info("Initialized WriteBufferManager.");
}
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInput.java
new file mode 100644
index 000000000..ff7d0c8ae
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInput.java
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.input;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceAudience.Public;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezRuntimeFrameworkConfigs;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.counters.TaskCounter;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.AbstractLogicalInput;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.api.ProgressFailedException;
+import org.apache.tez.runtime.library.api.KeyValuesReader;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.common.Constants;
+import org.apache.tez.runtime.library.common.MemoryUpdateCallbackHandler;
+import
org.apache.tez.runtime.library.common.shuffle.orderedgrouped.RMRssShuffle;
+import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.RssShuffle;
+import org.apache.tez.runtime.library.common.sort.impl.TezRawKeyValueIterator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.exception.RssException;
+
+import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
+
+/**
+ * {@link RMRssOrderedGroupedKVInput} in a {@link AbstractLogicalInput} which
shuffles intermediate
+ * sorted data, merges them and provides key/<values> to the consumer. This is
typically used to
+ * bring one partition of a set of partitioned distributed data to one
consumer. The shuffle
+ * operation brings all partitions to one place. These partitions are assumed
to be sorted and are
+ * merged sorted to merge them into a single input view.
+ *
+ * <p>The Copy and Merge will be triggered by the initialization - which is
handled by the Tez
+ * framework. Input is not consumable until the Copy and Merge are complete.
Methods are provided to
+ * check for this, as well as to wait for completion. Attempting to get a
reader on a non-complete
+ * input will block.
+ */
+@Public
+public class RMRssOrderedGroupedKVInput extends AbstractLogicalInput {
+
+ static final Logger LOG =
LoggerFactory.getLogger(RMRssOrderedGroupedKVInput.class);
+
+ protected TezRawKeyValueIterator rawIter = null;
+ protected Configuration conf;
+ protected RMRssShuffle shuffle;
+ protected MemoryUpdateCallbackHandler memoryUpdateCallbackHandler;
+ private int shuffleId;
+ private ApplicationAttemptId applicationAttemptId;
+ private final BlockingQueue<Event> pendingEvents = new
LinkedBlockingQueue<>();
+ private long firstEventReceivedTime = -1;
+
+ private final AtomicBoolean isStarted = new AtomicBoolean(false);
+
+ public RMRssOrderedGroupedKVInput(InputContext inputContext, int
numPhysicalInputs) {
+ super(inputContext, numPhysicalInputs);
+ }
+
+ @Override
+ public synchronized List<Event> initialize() throws IOException {
+ this.conf =
TezUtils.createConfFromUserPayload(getContext().getUserPayload());
+
+ if (this.getNumPhysicalInputs() == 0) {
+ getContext().requestInitialMemory(0L, null);
+ isStarted.set(true);
+ getContext().inputIsReady();
+ LOG.info(
+ "input fetch not required since there are 0 physical inputs for
input vertex: "
+ + getContext().getSourceVertexName());
+ return Collections.emptyList();
+ }
+
+ long initialMemoryRequest =
+ RssShuffle.getInitialMemoryRequirement(conf,
getContext().getTotalMemoryAvailableToTask());
+ this.memoryUpdateCallbackHandler = new MemoryUpdateCallbackHandler();
+ getContext().requestInitialMemory(initialMemoryRequest,
memoryUpdateCallbackHandler);
+
+ this.conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS,
getContext().getWorkDirs());
+
+ TezTaskAttemptID taskAttemptId =
+ TezTaskAttemptID.fromString(
+
RssTezUtils.uniqueIdentifierToAttemptId(getContext().getUniqueIdentifier()));
+ TezVertexID tezVertexID = taskAttemptId.getTaskID().getVertexID();
+ TezDAGID tezDAGID = tezVertexID.getDAGId();
+ int sourceVertexId = this.conf.getInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, -1);
+ int destinationVertexId =
this.conf.getInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, -1);
+ assert sourceVertexId != -1;
+ assert destinationVertexId != -1;
+ this.shuffleId =
+ RssTezUtils.computeShuffleId(tezDAGID.getId(), sourceVertexId,
destinationVertexId);
+ this.applicationAttemptId =
+ ApplicationAttemptId.newInstance(
+ getContext().getApplicationId(),
getContext().getDAGAttemptNumber());
+ return Collections.emptyList();
+ }
+
+ @Override
+ public synchronized void start() throws IOException {
+ if (!isStarted.get()) {
+ memoryUpdateCallbackHandler.validateUpdateReceived();
+ // Start the shuffle - copy and merge
+ shuffle = createRssShuffle();
+ shuffle.run();
+ List<Event> pending = new LinkedList<>();
+ pendingEvents.drainTo(pending);
+ if (pending.size() > 0) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(
+ "NoAutoStart delay in processing first event: "
+ + (System.currentTimeMillis() - firstEventReceivedTime));
+ }
+ shuffle.handleEvents(pending);
+ }
+ isStarted.set(true);
+ }
+ }
+
+ @VisibleForTesting
+ RMRssShuffle createRssShuffle() throws IOException {
+ return new RMRssShuffle(
+ getContext(), conf, getNumPhysicalInputs(), shuffleId,
applicationAttemptId);
+ }
+
+ @Override
+ public synchronized List<Event> close() throws IOException {
+ if (this.getNumPhysicalInputs() != 0 && rawIter != null) {
+ rawIter.close();
+ }
+ if (shuffle != null) {
+ shuffle.shutdown();
+ }
+
+ long dataSize =
+
getContext().getCounters().findCounter(TaskCounter.SHUFFLE_BYTES_DECOMPRESSED).getValue();
+ getContext().getStatisticsReporter().reportDataSize(dataSize);
+ long inputRecords =
+
getContext().getCounters().findCounter(TaskCounter.REDUCE_INPUT_RECORDS).getValue();
+ getContext().getStatisticsReporter().reportItemsProcessed(inputRecords);
+
+ return Collections.emptyList();
+ }
+
+ /**
+ * Get a KVReader for the Input. This method will block until the input is
ready - i.e. the copy
+ * and merge stages are complete. Users can use the isInputReady method to
check if the input is
+ * ready, which gives an indication of whether this method will block or not.
+ *
+ * <p>NOTE: All values for the current K-V pair must be read prior to
invoking moveToNext. Once
+ * moveToNext() is called, the valueIterator from the previous K-V pair will
throw an Exception
+ *
+ * @return a KVReader over the sorted input.
+ * @throws {@link IOInterruptedException} if IO was performing a blocking
operation and was
+ * interrupted
+ */
+ @Override
+ public KeyValuesReader getReader() throws Exception {
+ synchronized (this) {
+ if (getNumPhysicalInputs() == 0) {
+ return new KeyValuesReader() {
+ @Override
+ public boolean next() throws IOException {
+ getContext().notifyProgress();
+ hasCompletedProcessing();
+ completedProcessing = true;
+ return false;
+ }
+
+ @Override
+ public Object getCurrentKey() throws IOException {
+ throw new RssException("No data available in Input");
+ }
+
+ @Override
+ public Iterable<Object> getCurrentValues() throws IOException {
+ throw new RssException("No data available in Input");
+ }
+ };
+ }
+ }
+ shuffle.waitForEvents();
+ org.apache.uniffle.client.record.reader.KeyValuesReader keyValuesReader =
+ shuffle.getKeyValuesReader();
+ return new KeyValuesReader() {
+ @Override
+ public boolean next() throws IOException {
+ return keyValuesReader.next();
+ }
+
+ @Override
+ public Object getCurrentKey() throws IOException {
+ return keyValuesReader.getCurrentKey();
+ }
+
+ @Override
+ public Iterable<Object> getCurrentValues() throws IOException {
+ return keyValuesReader.getCurrentValues();
+ }
+ };
+ }
+
+ @Override
+ public float getProgress() throws ProgressFailedException,
InterruptedException {
+ // TODO: add progress
+ return super.getProgress();
+ }
+
+ @Override
+ public void handleEvents(List<Event> inputEvents) throws IOException {
+ RMRssShuffle shuffleLocalRef;
+ synchronized (this) {
+ if (getNumPhysicalInputs() == 0) {
+ throw new RssException("No input events expected as numInputs is 0");
+ }
+ if (!isStarted.get()) {
+ if (firstEventReceivedTime == -1) {
+ firstEventReceivedTime = System.currentTimeMillis();
+ }
+ pendingEvents.addAll(inputEvents);
+ return;
+ }
+ shuffleLocalRef = shuffle;
+ }
+ shuffleLocalRef.handleEvents(inputEvents);
+ }
+
+ private static final Set<String> CONF_KEYS = new HashSet<String>();
+
+ static {
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_READAHEAD);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_READAHEAD_BYTES);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_IO_FILE_BUFFER_SIZE);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_IO_SORT_FACTOR);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMBINE_MIN_SPILLS);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMBINER_CLASS);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_USE_ASYNC_HTTP);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_PARALLEL_COPIES);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES_LIMIT);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_MAX_TASK_OUTPUT_AT_ONCE);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_CONNECT_TIMEOUT);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_KEEP_ALIVE_ENABLED);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_KEEP_ALIVE_MAX_CONNECTIONS);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_READ_TIMEOUT);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_BUFFER_SIZE);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_VERIFY_DISK_CHECKSUM);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_BUFFER_PERCENT);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MEMORY_LIMIT_PERCENT);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MERGE_PERCENT);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MEMTOMEM_SEGMENTS);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_ENABLE_MEMTOMEM);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_SOURCE_ATTEMPT_ABORT_LIMIT);
+ CONF_KEYS.add(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_ACCEPTABLE_HOST_FETCH_FAILURE_FRACTION);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_FAILURES_PER_HOST);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MAX_STALL_TIME_FRACTION);
+ CONF_KEYS.add(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MAX_ALLOWED_FAILED_FETCH_ATTEMPT_FRACTION);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_REQUIRED_PROGRESS_FRACTION);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FAILED_CHECK_SINCE_LAST_COMPLETION);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCHER_USE_SHARED_POOL);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_INPUT_POST_MERGE_BUFFER_PERCENT);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_GROUP_COMPARATOR_CLASS);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_COMPARATOR_CLASS);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMPRESS);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMPRESS_CODEC);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_SECONDARY_COMPARATOR_CLASS);
+ CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_OPTIMIZE_LOCAL_FETCH);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_CONVERT_USER_PAYLOAD_TO_HISTORY_TEXT);
+ CONF_KEYS.add(TezConfiguration.TEZ_COUNTERS_MAX);
+ CONF_KEYS.add(TezConfiguration.TEZ_COUNTERS_GROUP_NAME_MAX_LENGTH);
+ CONF_KEYS.add(TezConfiguration.TEZ_COUNTERS_COUNTER_NAME_MAX_LENGTH);
+ CONF_KEYS.add(TezConfiguration.TEZ_COUNTERS_MAX_GROUPS);
+
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_CLEANUP_FILES_ON_INTERRUPT);
+ CONF_KEYS.add(Constants.TEZ_RUNTIME_TASK_MEMORY);
+ CONF_KEYS.add(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID);
+ }
+
+ @InterfaceAudience.Private
+ public static Set<String> getConfigurationKeySet() {
+ return Collections.unmodifiableSet(CONF_KEYS);
+ }
+}
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
index 997dcdfa6..d95033c53 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
@@ -30,9 +30,11 @@ import java.util.zip.Deflater;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
+import org.apache.commons.lang.ClassUtils;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.Credentials;
@@ -43,6 +45,7 @@ import
org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.GetShuffleServerRequest;
import org.apache.tez.common.GetShuffleServerResponse;
+import org.apache.tez.common.RssTezConfig;
import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
@@ -175,9 +178,29 @@ public class RssOrderedPartitionedKVOutput extends
AbstractLogicalOutput {
}
this.shuffleId =
RssTezUtils.computeShuffleId(tezDAGID.getId(), sourceVertexId,
destinationVertexId);
+ String keyClassName =
conf.get(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS, "");
+ String valueClassName =
conf.get(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, "");
+ // For hive on tez, we use separate serializer and comparator, namely
+ // TezBytesWritableSerialization and TezBytesComparator. But in remote
+ // merge mode, we use separate serializers, so we should also use
+ // separate comparators.
+ String comparatorClassName =
+
WritableComparator.get(ClassUtils.getClass(keyClassName)).getClass().getName();
+ boolean remoteMergeEnable =
+ conf.getBoolean(
+ RssTezConfig.RSS_REMOTE_MERGE_ENABLE,
RssTezConfig.RSS_REMOTE_MERGE_ENABLE_DEFAULT);
GetShuffleServerRequest request =
- new GetShuffleServerRequest(
- this.taskAttemptId, this.mapNum, this.numOutputs, this.shuffleId);
+ remoteMergeEnable
+ ? new GetShuffleServerRequest(
+ this.taskAttemptId,
+ this.mapNum,
+ this.numOutputs,
+ this.shuffleId,
+ keyClassName,
+ valueClassName,
+ comparatorClassName)
+ : new GetShuffleServerRequest(
+ this.taskAttemptId, this.mapNum, this.numOutputs,
this.shuffleId);
GetShuffleServerResponse response =
umbilical.getShuffleAssignments(request);
this.partitionToServers =
response
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleTest.java
new file mode 100644
index 000000000..b42b57bda
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleTest.java
@@ -0,0 +1,443 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.lang.reflect.Method;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.zip.Deflater;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.UmbilicalUtils;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.ExecutionContext;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.api.events.DataMovementEvent;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
+import org.apache.tez.runtime.library.common.sort.impl.TezSpillRecord;
+import org.junit.jupiter.api.Test;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.client.api.ShuffleServerClient;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.record.reader.KeyValuesReader;
+import org.apache.uniffle.client.record.reader.MockedShuffleServerClient;
+import org.apache.uniffle.client.record.reader.MockedShuffleWriteClient;
+import org.apache.uniffle.client.record.reader.RMRecordsReader;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.serializer.SerializerUtils;
+import org.apache.uniffle.common.util.BlockIdLayout;
+
+import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
+import static
org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS;
+import static
org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_KEY_COMPARATOR_CLASS;
+import static
org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS;
+import static
org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anySet;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+public class RMRssShuffleTest {
+
+ private static final int RECORDS_NUM = 1009;
+ private static final ApplicationAttemptId APPLICATION_ATTEMPT_ID =
+ ApplicationAttemptId.newInstance(ApplicationId.newInstance(0, 0), 0);
+ private static final int SHUFFLE_ID = 0;
+ private static final int PARTITION_ID = 0;
+
+ @Test
+ public void testReadShuffleData() throws Exception {
+ // 1 basic parameter
+ final Class keyClass = Text.class;
+ final Class valueClass = IntWritable.class;
+ final Comparator comparator = new Text.Comparator();
+ final Configuration conf = new Configuration();
+ conf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, 0);
+ conf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, 1);
+ final RssConf rssConf = new RssConf();
+ final List<ShuffleServerInfo> serverInfos =
+ Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+ final int taskAttemptId = 0;
+ BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+ final long[] blockIds = new long[] {blockIdLayout.getBlockId(0,
PARTITION_ID, taskAttemptId)};
+ final int duplicated = 5;
+
+ // 2 mock input context
+ InputContext inputContext = mock(InputContext.class);
+ when(inputContext.getSourceVertexName()).thenReturn("Map 0");
+ TezCounters tezCounters = new TezCounters();
+ when(inputContext.getCounters()).thenReturn(tezCounters);
+ TezTaskAttemptID tezTaskAttemptID =
+ TezTaskAttemptID.getInstance(
+ TezTaskID.getInstance(
+ TezVertexID.getInstance(
+
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+ 0),
+ 0);
+ when(inputContext.getUniqueIdentifier())
+ .thenReturn(String.format("%s_%05d", tezTaskAttemptID.toString(), 0));
+ when(inputContext.getDagIdentifier()).thenReturn(0);
+
when(inputContext.getApplicationId()).thenReturn(APPLICATION_ATTEMPT_ID.getApplicationId());
+ ExecutionContext executionContext = mock(ExecutionContext.class);
+ when(executionContext.getHostName()).thenReturn("hostname");
+ when(inputContext.getExecutionContext()).thenReturn(executionContext);
+ DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
+ dataOutputBuffer.writeInt(-1);
+ when(inputContext.getServiceProviderMetaData(anyString()))
+ .thenReturn(ByteBuffer.wrap(dataOutputBuffer.getData(), 0,
dataOutputBuffer.getLength()));
+ Token<JobTokenIdentifier> sessionToken =
+ new Token<JobTokenIdentifier>(
+ new JobTokenIdentifier(new Text("text")), new
JobTokenSecretManager());
+ ByteBuffer tokenBuffer = TezCommonUtils.serializeServiceData(sessionToken);
+
doReturn(tokenBuffer).when(inputContext).getServiceConsumerMetaData(anyString());
+ conf.setClass(TEZ_RUNTIME_KEY_CLASS, keyClass, Writable.class);
+ conf.setClass(TEZ_RUNTIME_VALUE_CLASS, valueClass, Writable.class);
+ conf.setClass(TEZ_RUNTIME_KEY_COMPARATOR_CLASS, comparator.getClass(),
Comparator.class);
+
+ // 3 mock recordsReader
+ RMRecordsReader recordsReader =
+ new RMRecordsReader(
+ APPLICATION_ATTEMPT_ID.toString(),
+ SHUFFLE_ID,
+ Sets.newHashSet(PARTITION_ID),
+ ImmutableMap.of(PARTITION_ID, serverInfos),
+ rssConf,
+ keyClass,
+ valueClass,
+ comparator,
+ true,
+ null,
+ false,
+ null);
+ RMRecordsReader recordsReaderSpy = spy(recordsReader);
+ ByteBuffer[][] buffers =
+ new ByteBuffer[][] {
+ {
+ ByteBuffer.wrap(
+ genSortedRecordBytes(rssConf, keyClass, valueClass, 0, 1,
RECORDS_NUM, duplicated))
+ }
+ };
+ ShuffleServerClient serverClient =
+ new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers,
blockIds);
+
doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+ // 4 mock shuffle
+ RMRssShuffle rssShuffle = new RMRssShuffle(inputContext, conf, 5, 0,
APPLICATION_ATTEMPT_ID);
+ RMRssShuffle rssShuffleSpy = spy(rssShuffle);
+
doReturn(recordsReaderSpy).when(rssShuffleSpy).createRMRecordsReader(anySet());
+
+ try (MockedStatic<UmbilicalUtils> umbilicalUtils =
Mockito.mockStatic(UmbilicalUtils.class);
+ MockedStatic<RssTezUtils> tezUtils =
Mockito.mockStatic(RssTezUtils.class)) {
+ umbilicalUtils
+ .when(() -> UmbilicalUtils.requestShuffleServer(any(), any(), any(),
anyInt()))
+ .thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos));
+ ShuffleWriteClient writeClient = new MockedShuffleWriteClient();
+ writeClient.reportShuffleResult(
+ ImmutableMap.of(
+ serverInfos.get(0), ImmutableMap.of(PARTITION_ID,
Sets.newHashSet(blockIds[0]))),
+ APPLICATION_ATTEMPT_ID.toString(),
+ SHUFFLE_ID,
+ taskAttemptId,
+ 0);
+ tezUtils.when(() ->
RssTezUtils.createShuffleClient(any())).thenReturn(writeClient);
+ tezUtils
+ .when(() -> RssTezUtils.fetchAllRssTaskIds(anySet(), anyInt(),
anyInt(), anyInt()))
+ .thenReturn(Roaring64NavigableMap.bitmapOf(taskAttemptId));
+
+ // 5 run shuffle
+ rssShuffleSpy.run();
+
+ // 6 send and handle dme
+ List<Event> events = new ArrayList<>();
+ for (int i = 0; i < 5; i++) {
+ TezTaskAttemptID taskAttemptID =
+ TezTaskAttemptID.getInstance(
+ TezTaskID.getInstance(
+ TezVertexID.getInstance(
+
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+ i),
+ 0);
+ events.add(
+ createDataMovementEvent(
+ PARTITION_ID, String.format("%s_%05d",
taskAttemptID.toString(), 1)));
+ }
+ rssShuffleSpy.handleEvents(events);
+ rssShuffleSpy.waitForEvents();
+ }
+
+ // 7 verify result
+ int index = 0;
+ KeyValuesReader keyValuesReader = rssShuffleSpy.getKeyValuesReader();
+ while (keyValuesReader.next()) {
+ assertEquals(SerializerUtils.genData(keyClass, index),
keyValuesReader.getCurrentKey());
+ Iterator iterator = keyValuesReader.getCurrentValues().iterator();
+ int dup = 0;
+ while (iterator.hasNext()) {
+ assertEquals(SerializerUtils.genData(valueClass, index),
iterator.next());
+ dup++;
+ }
+ assertEquals(duplicated, dup);
+ index++;
+ }
+ assertEquals(RECORDS_NUM, index);
+ }
+
+ @Test
+ public void testReadMultiPartitionShuffleData() throws Exception {
+ // 1 basic parameter
+ final Class keyClass = Text.class;
+ final Class valueClass = IntWritable.class;
+ final Comparator comparator = new Text.Comparator();
+ final Configuration conf = new Configuration();
+ conf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, 0);
+ conf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, 1);
+ final RssConf rssConf = new RssConf();
+ final List<ShuffleServerInfo> serverInfos =
+ Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+ final int taskAttemptId = 0;
+ BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+ final long[] blockIds =
+ new long[] {
+ blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId),
+ blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId + 1),
+ blockIdLayout.getBlockId(0, PARTITION_ID + 1, taskAttemptId + 2),
+ blockIdLayout.getBlockId(0, PARTITION_ID + 1, taskAttemptId + 3),
+ blockIdLayout.getBlockId(0, PARTITION_ID + 2, taskAttemptId + 4),
+ blockIdLayout.getBlockId(0, PARTITION_ID + 2, taskAttemptId + 5),
+ };
+ final int duplicated = 3;
+
+ // 2 mock input context
+ InputContext inputContext = mock(InputContext.class);
+ when(inputContext.getSourceVertexName()).thenReturn("Map 0");
+ TezCounters tezCounters = new TezCounters();
+ when(inputContext.getCounters()).thenReturn(tezCounters);
+ TezTaskAttemptID tezTaskAttemptID =
+ TezTaskAttemptID.getInstance(
+ TezTaskID.getInstance(
+ TezVertexID.getInstance(
+
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+ 0),
+ 0);
+ when(inputContext.getUniqueIdentifier())
+ .thenReturn(String.format("%s_%05d", tezTaskAttemptID.toString(), 0));
+ when(inputContext.getDagIdentifier()).thenReturn(0);
+
when(inputContext.getApplicationId()).thenReturn(APPLICATION_ATTEMPT_ID.getApplicationId());
+ ExecutionContext executionContext = mock(ExecutionContext.class);
+ when(executionContext.getHostName()).thenReturn("hostname");
+ when(inputContext.getExecutionContext()).thenReturn(executionContext);
+ DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
+ dataOutputBuffer.writeInt(-1);
+ when(inputContext.getServiceProviderMetaData(anyString()))
+ .thenReturn(ByteBuffer.wrap(dataOutputBuffer.getData(), 0,
dataOutputBuffer.getLength()));
+ Token<JobTokenIdentifier> sessionToken =
+ new Token<JobTokenIdentifier>(
+ new JobTokenIdentifier(new Text("text")), new
JobTokenSecretManager());
+ ByteBuffer tokenBuffer = TezCommonUtils.serializeServiceData(sessionToken);
+
doReturn(tokenBuffer).when(inputContext).getServiceConsumerMetaData(anyString());
+ conf.setClass(TEZ_RUNTIME_KEY_CLASS, keyClass, Writable.class);
+ conf.setClass(TEZ_RUNTIME_VALUE_CLASS, valueClass, Writable.class);
+ conf.setClass(TEZ_RUNTIME_KEY_COMPARATOR_CLASS, comparator.getClass(),
Comparator.class);
+
+ // 3 mock recordsReader
+ RMRecordsReader recordsReader =
+ new RMRecordsReader(
+ APPLICATION_ATTEMPT_ID.toString(),
+ SHUFFLE_ID,
+ Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2),
+ ImmutableMap.of(
+ PARTITION_ID,
+ serverInfos,
+ PARTITION_ID + 1,
+ serverInfos,
+ PARTITION_ID + 2,
+ serverInfos),
+ rssConf,
+ keyClass,
+ valueClass,
+ comparator,
+ true,
+ null,
+ false,
+ null);
+ RMRecordsReader recordsReaderSpy = spy(recordsReader);
+ ByteBuffer[][] buffers = new ByteBuffer[3][2];
+ for (int i = 0; i < 3; i++) {
+ buffers[i][0] =
+ ByteBuffer.wrap(
+ genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3,
RECORDS_NUM, duplicated));
+ buffers[i][1] =
+ ByteBuffer.wrap(
+ genSortedRecordBytes(
+ rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3,
RECORDS_NUM, duplicated));
+ }
+ ShuffleServerClient serverClient =
+ new MockedShuffleServerClient(
+ new int[] {PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2},
buffers, blockIds);
+
doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+ // 4 mock shuffle
+ RMRssShuffle rssShuffle = new RMRssShuffle(inputContext, conf, 6, 0,
APPLICATION_ATTEMPT_ID);
+ RMRssShuffle rssShuffleSpy = spy(rssShuffle);
+
doReturn(recordsReaderSpy).when(rssShuffleSpy).createRMRecordsReader(anySet());
+ try (MockedStatic<UmbilicalUtils> umbilicalUtils =
Mockito.mockStatic(UmbilicalUtils.class);
+ MockedStatic<RssTezUtils> tezUtils =
Mockito.mockStatic(RssTezUtils.class)) {
+ umbilicalUtils
+ .when(() -> UmbilicalUtils.requestShuffleServer(any(), any(), any(),
anyInt()))
+ .thenReturn(
+ ImmutableMap.of(
+ PARTITION_ID,
+ serverInfos,
+ PARTITION_ID + 1,
+ serverInfos,
+ PARTITION_ID + 2,
+ serverInfos));
+ ShuffleWriteClient writeClient = new MockedShuffleWriteClient();
+ writeClient.reportShuffleResult(
+ ImmutableMap.of(
+ serverInfos.get(0),
+ ImmutableMap.of(
+ PARTITION_ID,
+ Sets.newHashSet(blockIds[0], blockIds[1]),
+ PARTITION_ID + 1,
+ Sets.newHashSet(blockIds[2], blockIds[3]),
+ PARTITION_ID + 2,
+ Sets.newHashSet(blockIds[4], blockIds[5]))),
+ APPLICATION_ATTEMPT_ID.toString(),
+ SHUFFLE_ID,
+ taskAttemptId,
+ 0);
+ tezUtils.when(() ->
RssTezUtils.createShuffleClient(any())).thenReturn(writeClient);
+ tezUtils
+ .when(() -> RssTezUtils.fetchAllRssTaskIds(anySet(), anyInt(),
anyInt(), anyInt()))
+ .thenReturn(Roaring64NavigableMap.bitmapOf(taskAttemptId));
+
+ // 5 run shuffle
+ rssShuffleSpy.run();
+
+ // 6 send and handle dme
+ List<Event> events = new ArrayList<>();
+ for (int i = 0; i < 6; i++) {
+ TezTaskAttemptID taskAttemptID =
+ TezTaskAttemptID.getInstance(
+ TezTaskID.getInstance(
+ TezVertexID.getInstance(
+
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+ i),
+ 0);
+ events.add(
+ createDataMovementEvent(
+ PARTITION_ID, String.format("%s_%05d",
taskAttemptID.toString(), 1)));
+ }
+ rssShuffleSpy.handleEvents(events);
+ rssShuffleSpy.waitForEvents();
+ }
+
+ // 7 verify result
+ int index = 0;
+ KeyValuesReader keyValuesReader = rssShuffleSpy.getKeyValuesReader();
+ while (keyValuesReader.next()) {
+ assertEquals(SerializerUtils.genData(keyClass, index),
keyValuesReader.getCurrentKey());
+ Iterator iterator = keyValuesReader.getCurrentValues().iterator();
+ int dup = 0;
+ while (iterator.hasNext()) {
+ assertEquals(SerializerUtils.genData(valueClass, index),
iterator.next());
+ dup++;
+ }
+ assertEquals(dup, dup);
+ index++;
+ }
+ assertEquals(RECORDS_NUM * 6, index);
+ }
+
+ public static DataMovementEvent createDataMovementEvent(int partition,
String path)
+ throws Exception {
+ OutputContext context = mock(OutputContext.class);
+ ExecutionContext executionContext = mock(ExecutionContext.class);
+ when(executionContext.getHostName()).thenReturn("");
+ when(context.getExecutionContext()).thenReturn(executionContext);
+ DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
+ dataOutputBuffer.writeInt(-1);
+ when(context.getServiceProviderMetaData(anyString()))
+ .thenReturn(ByteBuffer.wrap(dataOutputBuffer.getData(), 0,
dataOutputBuffer.getLength()));
+ Method method =
+ ShuffleUtils.class.getDeclaredMethod(
+ "generateDMEPayload",
+ new Class[] {
+ boolean.class,
+ int.class,
+ TezSpillRecord.class,
+ OutputContext.class,
+ int.class,
+ boolean.class,
+ boolean.class,
+ String.class,
+ String.class,
+ Deflater.class
+ });
+ method.setAccessible(true);
+ ByteBuffer byteBuffer =
+ (ByteBuffer)
+ method.invoke(
+ null,
+ false,
+ 0,
+ null,
+ context,
+ -1,
+ true,
+ true,
+ path,
+ "mapreduce_shuffle",
+ TezCommonUtils.newBestCompressionDeflater());
+ return DataMovementEvent.create(partition, byteBuffer);
+ }
+}
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index dbe40fd06..724cec6c0 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -19,7 +19,9 @@ package org.apache.tez.runtime.library.common.sort.buffer;
import java.io.File;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -30,10 +32,12 @@ import java.util.stream.Collectors;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
+import io.netty.buffer.ByteBuf;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.RawComparator;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparator;
@@ -44,6 +48,7 @@ import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezRuntimeFrameworkConfigs;
import org.apache.tez.common.counters.TaskCounter;
import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.OutputContext;
import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
@@ -66,7 +71,13 @@ import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.records.RecordsReader;
import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
+import org.apache.uniffle.common.serializer.SerializerFactory;
+import org.apache.uniffle.common.serializer.SerializerInstance;
+import org.apache.uniffle.common.serializer.SerializerUtils;
+import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -75,6 +86,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
public class WriteBufferManagerTest {
+
+ private static final int RECORDS = 1009;
+
@Test
public void testWriteException(@TempDir File tmpDir) throws IOException,
InterruptedException {
TezTaskAttemptID tezTaskAttemptID =
@@ -155,7 +169,10 @@ public class WriteBufferManagerTest {
shuffleId,
true,
mapOutputByteCounter,
- mapOutputRecordCounter);
+ mapOutputRecordCounter,
+ false,
+ null,
+ null);
partitionToServers.put(1,
Lists.newArrayList(mock(ShuffleServerInfo.class)));
Random random = new Random();
for (int i = 0; i < 1000; i++) {
@@ -259,7 +276,10 @@ public class WriteBufferManagerTest {
shuffleId,
true,
mapOutputByteCounter,
- mapOutputRecordCounter);
+ mapOutputRecordCounter,
+ false,
+ null,
+ null);
Random random = new Random();
for (int i = 0; i < 1000; i++) {
@@ -286,8 +306,6 @@ public class WriteBufferManagerTest {
}
assertEquals(1175900, mapOutputByteCounter.getValue());
- assert (1 == bufferManager.getWaitSendBuffers().size());
- assert (4928 == bufferManager.getWaitSendBuffers().get(0).getDataLength());
bufferManager.waitSendFinished();
assertTrue(bufferManager.getWaitSendBuffers().isEmpty());
@@ -374,10 +392,13 @@ public class WriteBufferManagerTest {
shuffleId,
true,
mapOutputByteCounter,
- mapOutputRecordCounter);
+ mapOutputRecordCounter,
+ false,
+ null,
+ null);
Random random = new Random();
- for (int i = 0; i < 10000; i++) {
+ for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
@@ -388,8 +409,8 @@ public class WriteBufferManagerTest {
}
bufferManager.waitSendFinished();
- assertEquals(10000, mapOutputRecordCounter.getValue());
- assertEquals(10520000, mapOutputByteCounter.getValue());
+ assertEquals(1000, mapOutputRecordCounter.getValue());
+ assertEquals(1052000, mapOutputByteCounter.getValue());
assertTrue(bufferManager.getWaitSendBuffers().isEmpty());
assertEquals(
writeClient.mockedShuffleServer.getFinishBlockSize(),
@@ -478,7 +499,10 @@ public class WriteBufferManagerTest {
shuffleId,
true,
mapOutputByteCounter,
- mapOutputRecordCounter);
+ mapOutputRecordCounter,
+ false,
+ null,
+ null);
Random random = new Random();
RssException rssException =
@@ -506,6 +530,107 @@ public class WriteBufferManagerTest {
assertTrue(mapOutputByteCounter.getValue() < 10520000);
}
+ @Test
+ public void testWriteWithRemoteMerge() throws Exception {
+ MockShuffleWriteClient client = new MockShuffleWriteClient();
+ client.setMode(3);
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers =
JavaUtils.newConcurrentMap();
+ partitionToServers.put(0, new ArrayList());
+ partitionToServers.get(0).add(new ShuffleServerInfo("host", 39998));
+ Set<Long> successBlockIds = Sets.newConcurrentHashSet();
+ Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
+ RssConf rssConf = new RssConf();
+ TezCounters counter = new TezCounters();
+ TezCounter mapOutputByteCounter = counter.findCounter("group", "bytes");
+ TezCounter mapOutputRecordCounter = counter.findCounter("group",
"records");
+ TezTaskAttemptID tezTaskAttemptID =
+
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
+ final long maxMemSize = 102400L;
+ final String appId = "application_1681717153064_3770270";
+ final int taskAttemptId = 0;
+ long maxSegmentSize = 3 * 1024;
+ long maxBufferSize = 14 * 1024 * 1024;
+ int shuffleId =
+ RssTezUtils.computeShuffleId(
+ tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1,
2);
+
+ WriteBufferManager<Text, Text> manager =
+ new WriteBufferManager<Text, Text>(
+ tezTaskAttemptID,
+ maxMemSize,
+ appId,
+ taskAttemptId,
+ successBlockIds,
+ failedBlockIds,
+ client,
+ new Text.Comparator(),
+ maxSegmentSize,
+ null,
+ null,
+ maxBufferSize,
+ 0.8f,
+ 1,
+ 0.2f,
+ 50,
+ rssConf,
+ partitionToServers,
+ 1,
+ false,
+ 500L,
+ 5 * 1000,
+ 1,
+ shuffleId,
+ true,
+ mapOutputByteCounter,
+ mapOutputRecordCounter,
+ true,
+ Text.class,
+ Text.class);
+
+ List<Integer> indexes = new ArrayList<>();
+ for (int i = 0; i < RECORDS; i++) {
+ indexes.add(i);
+ }
+ Collections.shuffle(indexes);
+ for (Integer index : indexes) {
+ manager.addRecord(
+ 0,
+ (Text) SerializerUtils.genData(Text.class, index),
+ (Text) SerializerUtils.genData(Text.class, index + 1));
+ }
+ manager.waitSendFinished();
+ assertEquals(RECORDS, mapOutputRecordCounter.getValue());
+ SerializerFactory factory = new SerializerFactory(rssConf);
+ org.apache.uniffle.common.serializer.Serializer serializer =
factory.getSerializer(Text.class);
+ SerializerInstance instance = serializer.newInstance();
+ DataOutputBuffer keyBuffer = new DataOutputBuffer();
+ instance.serialize(SerializerUtils.genData(Text.class, 0), keyBuffer);
+ assertEquals(RECORDS * keyBuffer.getLength() * 2,
mapOutputByteCounter.getValue());
+ assertTrue(manager.getWaitSendBuffers().isEmpty());
+
+ // check blocks
+ List<ShuffleBlockInfo> blockInfos = client.getCachedBlockInfos();
+ assertEquals(1, blockInfos.size());
+ ByteBuf buf = blockInfos.get(0).getData();
+ byte[] bytes = new byte[blockInfos.get(0).getLength()];
+ buf.readBytes(bytes);
+ RecordsReader<Text, Text> reader =
+ new RecordsReader<>(
+ rssConf,
+ PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes)),
+ Text.class,
+ Text.class,
+ false);
+ int index = 0;
+ while (reader.next()) {
+ assertEquals(SerializerUtils.genData(Text.class, index),
reader.getCurrentKey());
+ assertEquals(SerializerUtils.genData(Text.class, index + 1),
reader.getCurrentValue());
+ index++;
+ }
+ reader.close();
+ assertEquals(RECORDS, index);
+ }
+
class MockShuffleServer {
private List<ShuffleBlockInfo> cachedBlockInfos = new ArrayList<>();
private List<ShuffleBlockInfo> flushBlockInfos = new ArrayList<>();
@@ -530,6 +655,10 @@ public class WriteBufferManagerTest {
public synchronized int getFinishBlockSize() {
return finishedBlockInfos.size();
}
+
+ public List<ShuffleBlockInfo> getCachedBlockInfos() {
+ return cachedBlockInfos;
+ }
}
class MockShuffleWriteClient implements ShuffleWriteClient {
@@ -697,5 +826,9 @@ public class WriteBufferManagerTest {
int shuffleId,
int partitionId,
Roaring64NavigableMap expectedTaskIds) {}
+
+ public List<ShuffleBlockInfo> getCachedBlockInfos() {
+ return mockedShuffleServer.getCachedBlockInfos();
+ }
}
}
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
index dc3ebf707..73d399062 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
@@ -20,11 +20,17 @@ package org.apache.tez.runtime.library.common.sort.buffer;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.io.serializer.Deserializer;
@@ -33,11 +39,23 @@ import org.apache.hadoop.io.serializer.Serializer;
import org.apache.hadoop.mapred.JobConf;
import org.junit.jupiter.api.Test;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.serializer.DeserializationStream;
+import org.apache.uniffle.common.serializer.PartialInputStream;
+import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
+import org.apache.uniffle.common.serializer.SerializerFactory;
+import org.apache.uniffle.common.serializer.SerializerInstance;
+
import static com.google.common.collect.Maps.newConcurrentMap;
+import static org.apache.uniffle.common.serializer.SerializerUtils.genData;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
public class WriteBufferTest {
+ private static final int RECORDS_NUM = 1009;
+
@Test
public void testReadWrite() throws IOException {
@@ -57,8 +75,10 @@ public class WriteBufferTest {
1,
WritableComparator.get(BytesWritable.class),
1024L,
+ false,
keySerializer,
- valSerializer);
+ valSerializer,
+ null);
long recordLength = buffer.addRecord(key, value);
assertEquals(20, buffer.getData().length);
@@ -88,8 +108,10 @@ public class WriteBufferTest {
1,
WritableComparator.get(BytesWritable.class),
528L,
+ false,
keySerializer,
- valSerializer);
+ valSerializer,
+ null);
long start = buffer.getDataLength();
assertEquals(0, start);
keyStr = "key3";
@@ -161,6 +183,75 @@ public class WriteBufferTest {
assertEquals(bigWritableValue, valueRead);
}
+ @Test
+ public void testReadWriteWithRemoteMergeAndNoSort() throws IOException {
+ RssConf rssConf = new RssConf();
+ SerializerFactory factory = new SerializerFactory(rssConf);
+ org.apache.uniffle.common.serializer.Serializer serializer =
factory.getSerializer(Text.class);
+ SerializerInstance instance = serializer.newInstance();
+ WriteBuffer buffer =
+ new WriteBuffer<BytesWritable, BytesWritable>(
+ false,
+ 1,
+ WritableComparator.get(BytesWritable.class),
+ 1024L,
+ true,
+ null,
+ null,
+ instance);
+ for (int i = 0; i < RECORDS_NUM; i++) {
+ buffer.addRecord(genData(Text.class, i), genData(IntWritable.class, i));
+ }
+ byte[] bytes = buffer.getData();
+ PartialInputStream inputStream =
PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes));
+ DeserializationStream dStream =
+ instance.deserializeStream(inputStream, Text.class, IntWritable.class,
false);
+ for (int i = 0; i < RECORDS_NUM; i++) {
+ assertTrue(dStream.nextRecord());
+ assertEquals(genData(Text.class, i), dStream.getCurrentKey());
+ assertEquals(genData(IntWritable.class, i), dStream.getCurrentValue());
+ }
+ assertFalse(dStream.nextRecord());
+ dStream.close();
+ }
+
+ @Test
+ public void testReadWriteWithRemoteMergeAndSort() throws IOException {
+ RssConf rssConf = new RssConf();
+ SerializerFactory factory = new SerializerFactory(rssConf);
+ org.apache.uniffle.common.serializer.Serializer serializer =
factory.getSerializer(Text.class);
+ SerializerInstance instance = serializer.newInstance();
+ WriteBuffer buffer =
+ new WriteBuffer<BytesWritable, BytesWritable>(
+ false,
+ 1,
+ WritableComparator.get(BytesWritable.class),
+ 1024L,
+ true,
+ null,
+ null,
+ instance);
+ List<Integer> indices = new ArrayList<>();
+ for (int i = 0; i < RECORDS_NUM; i++) {
+ indices.add(i);
+ }
+ Collections.shuffle(indices);
+ for (int i = 0; i < RECORDS_NUM; i++) {
+ buffer.addRecord(genData(Text.class, i), genData(IntWritable.class, i));
+ }
+ byte[] bytes = buffer.getData();
+ PartialInputStream inputStream =
PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes));
+ DeserializationStream dStream =
+ instance.deserializeStream(inputStream, Text.class, IntWritable.class,
false);
+ for (int i = 0; i < RECORDS_NUM; i++) {
+ assertTrue(dStream.nextRecord());
+ assertEquals(genData(Text.class, i), dStream.getCurrentKey());
+ assertEquals(genData(IntWritable.class, i), dStream.getCurrentValue());
+ }
+ assertFalse(dStream.nextRecord());
+ dStream.close();
+ }
+
int readInt(DataInputStream dStream) throws IOException {
return WritableUtils.readVInt(dStream);
}
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInputTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInputTest.java
new file mode 100644
index 000000000..24157c883
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInputTest.java
@@ -0,0 +1,481 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.input;
+
+import java.io.ByteArrayOutputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.UmbilicalUtils;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.ExecutionContext;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.api.MemoryUpdateCallback;
+import org.apache.tez.runtime.api.events.DataMovementEvent;
+import org.apache.tez.runtime.library.api.KeyValuesReader;
+import org.apache.tez.runtime.library.common.MemoryUpdateCallbackHandler;
+import
org.apache.tez.runtime.library.common.shuffle.orderedgrouped.RMRssShuffle;
+import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;
+import org.junit.jupiter.api.Test;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.client.api.ShuffleServerClient;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.record.reader.MockedShuffleServerClient;
+import org.apache.uniffle.client.record.reader.MockedShuffleWriteClient;
+import org.apache.uniffle.client.record.reader.RMRecordsReader;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.merger.Merger;
+import org.apache.uniffle.common.merger.Segment;
+import org.apache.uniffle.common.serializer.SerializerUtils;
+import org.apache.uniffle.common.util.BlockIdLayout;
+
+import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
+import static
org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.anySet;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+public class RMRssOrderedGroupedKVInputTest {
+
+ private static final int RECORDS_NUM = 1009;
+ private static final ApplicationAttemptId APPLICATION_ATTEMPT_ID =
+ ApplicationAttemptId.newInstance(ApplicationId.newInstance(0, 0), 0);
+ private static final int SHUFFLE_ID = 0;
+ private static final int PARTITION_ID = 0;
+
+ @Test
+ public void testRMRssOrderedGroupedKVInput() throws Exception {
+ // 1 basic parameter
+ final Class keyClass = Text.class;
+ final Class valueClass = IntWritable.class;
+ final Comparator comparator = new Text.Comparator();
+ final Configuration conf = new Configuration();
+ conf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, 0);
+ conf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, 1);
+ final RssConf rssConf = new RssConf();
+ final List<ShuffleServerInfo> serverInfos =
+ Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+ final int taskAttemptId = 0;
+ BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+ final long[] blockIds = new long[] {blockIdLayout.getBlockId(0,
PARTITION_ID, taskAttemptId)};
+
+ // 2 mock input context
+ InputContext inputContext = mock(InputContext.class);
+ when(inputContext.getSourceVertexName()).thenReturn("Map 0");
+ TezCounters tezCounters = new TezCounters();
+ when(inputContext.getCounters()).thenReturn(tezCounters);
+ TezTaskAttemptID tezTaskAttemptID =
+ TezTaskAttemptID.getInstance(
+ TezTaskID.getInstance(
+ TezVertexID.getInstance(
+
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+ 0),
+ 0);
+ when(inputContext.getUniqueIdentifier())
+ .thenReturn(String.format("%s_%05d", tezTaskAttemptID.toString(), 0));
+
when(inputContext.getUserPayload()).thenReturn(TezUtils.createUserPayloadFromConf(conf));
+ when(inputContext.getWorkDirs()).thenReturn(new String[] {"/dummy"});
+ doAnswer(
+ new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) throws Throwable
{
+ long requestedSize = (Long) invocation.getArguments()[0];
+ MemoryUpdateCallbackHandler callback =
+ (MemoryUpdateCallbackHandler) invocation.getArguments()[1];
+ callback.memoryAssigned(requestedSize);
+ return null;
+ }
+ })
+ .when(inputContext)
+ .requestInitialMemory(anyLong(), any(MemoryUpdateCallback.class));
+ when(inputContext.getDagIdentifier()).thenReturn(0);
+
when(inputContext.getApplicationId()).thenReturn(APPLICATION_ATTEMPT_ID.getApplicationId());
+ ExecutionContext executionContext = mock(ExecutionContext.class);
+ when(executionContext.getHostName()).thenReturn("hostname");
+ when(inputContext.getExecutionContext()).thenReturn(executionContext);
+ DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
+ dataOutputBuffer.writeInt(-1);
+ when(inputContext.getServiceProviderMetaData(anyString()))
+ .thenReturn(ByteBuffer.wrap(dataOutputBuffer.getData(), 0,
dataOutputBuffer.getLength()));
+ Token<JobTokenIdentifier> sessionToken =
+ new Token<JobTokenIdentifier>(
+ new JobTokenIdentifier(new Text("text")), new
JobTokenSecretManager());
+ ByteBuffer tokenBuffer = TezCommonUtils.serializeServiceData(sessionToken);
+
doReturn(tokenBuffer).when(inputContext).getServiceConsumerMetaData(anyString());
+
+ // 3 mock recordsReader
+ List<Segment> segments = new ArrayList<>();
+ segments.add(
+ SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 0L, 0,
2, RECORDS_NUM));
+ segments.add(
+ SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 1L, 0,
2, RECORDS_NUM));
+ segments.add(
+ SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 2L, 1,
2, RECORDS_NUM));
+ ByteArrayOutputStream output = new ByteArrayOutputStream();
+ Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator,
false);
+ output.close();
+ ByteBuffer[][] buffers = new ByteBuffer[][]
{{ByteBuffer.wrap(output.toByteArray())}};
+ ShuffleServerClient serverClient =
+ new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers,
blockIds);
+ RMRecordsReader recordsReader =
+ new RMRecordsReader(
+ APPLICATION_ATTEMPT_ID.toString(),
+ SHUFFLE_ID,
+ Sets.newHashSet(PARTITION_ID),
+ ImmutableMap.of(PARTITION_ID, serverInfos),
+ rssConf,
+ keyClass,
+ valueClass,
+ comparator,
+ true,
+ null,
+ false,
+ null);
+ RMRecordsReader recordsReaderSpy = spy(recordsReader);
+
doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+ // 4 mock shuffle
+ RMRssShuffle rssShuffle = new RMRssShuffle(inputContext, conf, 5, 0,
APPLICATION_ATTEMPT_ID);
+ RMRssShuffle rssShuffleSpy = spy(rssShuffle);
+
doReturn(recordsReaderSpy).when(rssShuffleSpy).createRMRecordsReader(anySet());
+ // rssShuffleSpy.setEventHandler(new RMShuffleEventHandler(rssShuffleSpy));
+
+ try (MockedStatic<UmbilicalUtils> umbilicalUtils =
Mockito.mockStatic(UmbilicalUtils.class);
+ MockedStatic<RssTezUtils> tezUtils =
Mockito.mockStatic(RssTezUtils.class)) {
+ umbilicalUtils
+ .when(() -> UmbilicalUtils.requestShuffleServer(any(), any(), any(),
anyInt()))
+ .thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos));
+ ShuffleWriteClient writeClient = new MockedShuffleWriteClient();
+
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>>
serverToPartitionToBlockIds =
+ ImmutableMap.of(
+ serverInfos.get(0), ImmutableMap.of(PARTITION_ID,
Sets.newHashSet(blockIds[0])));
+ writeClient.reportShuffleResult(
+ serverToPartitionToBlockIds,
+ APPLICATION_ATTEMPT_ID.toString(),
+ SHUFFLE_ID,
+ taskAttemptId,
+ 0);
+ tezUtils.when(() ->
RssTezUtils.createShuffleClient(any())).thenReturn(writeClient);
+ tezUtils
+ .when(() -> RssTezUtils.fetchAllRssTaskIds(anySet(), anyInt(),
anyInt(), anyInt()))
+ .thenReturn(Roaring64NavigableMap.bitmapOf(taskAttemptId));
+ tezUtils
+ .when(() -> RssTezUtils.uniqueIdentifierToAttemptId(anyString()))
+ .thenReturn(tezTaskAttemptID.toString());
+
+ // 5 init and start kv input
+ RMRssOrderedGroupedKVInput input = new
RMRssOrderedGroupedKVInput(inputContext, 5);
+ RMRssOrderedGroupedKVInput inputSpy = spy(input);
+ doReturn(rssShuffleSpy).when(inputSpy).createRssShuffle();
+ inputSpy.initialize();
+ List<Event> events = new ArrayList<>();
+ for (int i = 0; i < 2; i++) {
+ TezTaskAttemptID taskAttemptID =
+ TezTaskAttemptID.getInstance(
+ TezTaskID.getInstance(
+ TezVertexID.getInstance(
+
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+ i),
+ 0);
+ events.add(
+ createDataMovementEvent(
+ PARTITION_ID, String.format("%s_%05d",
taskAttemptID.toString(), 1)));
+ }
+ inputSpy.handleEvents(events);
+ inputSpy.start();
+ events.clear();
+ for (int i = 2; i < 5; i++) {
+ TezTaskAttemptID taskAttemptID =
+ TezTaskAttemptID.getInstance(
+ TezTaskID.getInstance(
+ TezVertexID.getInstance(
+
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+ i),
+ 0);
+ events.add(
+ createDataMovementEvent(
+ PARTITION_ID, String.format("%s_%05d",
taskAttemptID.toString(), 1)));
+ }
+ inputSpy.handleEvents(events);
+
+ // 6 verify result
+ KeyValuesReader reader = inputSpy.getReader();
+ int index = 0;
+ while (reader.next()) {
+ assertEquals(SerializerUtils.genData(keyClass, index),
reader.getCurrentKey());
+ Iterator iterator = reader.getCurrentValues().iterator();
+ int i = 0;
+ while (iterator.hasNext()) {
+ assertEquals(SerializerUtils.genData(valueClass, index),
iterator.next());
+ i++;
+ }
+ assertEquals(index % 2 == 0 ? 2 : 1, i);
+ index++;
+ }
+ assertEquals(RECORDS_NUM * 2, index);
+ }
+ }
+
+ @Test
+ public void testRMRssOrderedGroupedKVInputMulitPartition() throws Exception {
+ // 1 basic parameter
+ final Class keyClass = Text.class;
+ final Class valueClass = IntWritable.class;
+ final Comparator comparator = new Text.Comparator();
+ final Configuration conf = new Configuration();
+ conf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, 0);
+ conf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, 1);
+ final RssConf rssConf = new RssConf();
+ final List<ShuffleServerInfo> serverInfos =
+ Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+ final int taskAttemptId = 0;
+ BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+ final long[] blockIds =
+ new long[] {
+ blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId),
+ blockIdLayout.getBlockId(1, PARTITION_ID, taskAttemptId),
+ blockIdLayout.getBlockId(0, PARTITION_ID + 1, taskAttemptId),
+ blockIdLayout.getBlockId(1, PARTITION_ID + 1, taskAttemptId),
+ blockIdLayout.getBlockId(0, PARTITION_ID + 2, taskAttemptId),
+ blockIdLayout.getBlockId(1, PARTITION_ID + 2, taskAttemptId)
+ };
+
+ // 2 mock input context
+ InputContext inputContext = mock(InputContext.class);
+ when(inputContext.getSourceVertexName()).thenReturn("Map 0");
+ TezCounters tezCounters = new TezCounters();
+ when(inputContext.getCounters()).thenReturn(tezCounters);
+ TezTaskAttemptID tezTaskAttemptID =
+ TezTaskAttemptID.getInstance(
+ TezTaskID.getInstance(
+ TezVertexID.getInstance(
+
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+ 0),
+ 0);
+ when(inputContext.getUniqueIdentifier())
+ .thenReturn(String.format("%s_%05d", tezTaskAttemptID.toString(), 0));
+
when(inputContext.getUserPayload()).thenReturn(TezUtils.createUserPayloadFromConf(conf));
+ when(inputContext.getWorkDirs()).thenReturn(new String[] {"/dummy"});
+ doAnswer(
+ new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) throws Throwable
{
+ long requestedSize = (Long) invocation.getArguments()[0];
+ MemoryUpdateCallbackHandler callback =
+ (MemoryUpdateCallbackHandler) invocation.getArguments()[1];
+ callback.memoryAssigned(requestedSize);
+ return null;
+ }
+ })
+ .when(inputContext)
+ .requestInitialMemory(anyLong(), any(MemoryUpdateCallback.class));
+ when(inputContext.getDagIdentifier()).thenReturn(0);
+
when(inputContext.getApplicationId()).thenReturn(APPLICATION_ATTEMPT_ID.getApplicationId());
+ ExecutionContext executionContext = mock(ExecutionContext.class);
+ when(executionContext.getHostName()).thenReturn("hostname");
+ when(inputContext.getExecutionContext()).thenReturn(executionContext);
+ DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
+ dataOutputBuffer.writeInt(-1);
+ when(inputContext.getServiceProviderMetaData(anyString()))
+ .thenReturn(ByteBuffer.wrap(dataOutputBuffer.getData(), 0,
dataOutputBuffer.getLength()));
+ Token<JobTokenIdentifier> sessionToken =
+ new Token<JobTokenIdentifier>(
+ new JobTokenIdentifier(new Text("text")), new
JobTokenSecretManager());
+ ByteBuffer tokenBuffer = TezCommonUtils.serializeServiceData(sessionToken);
+
doReturn(tokenBuffer).when(inputContext).getServiceConsumerMetaData(anyString());
+
+ // 3 mock recordsReader
+ RMRecordsReader recordsReader =
+ new RMRecordsReader(
+ APPLICATION_ATTEMPT_ID.toString(),
+ SHUFFLE_ID,
+ Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2),
+ ImmutableMap.of(
+ PARTITION_ID,
+ serverInfos,
+ PARTITION_ID + 1,
+ serverInfos,
+ PARTITION_ID + 2,
+ serverInfos),
+ rssConf,
+ keyClass,
+ valueClass,
+ comparator,
+ true,
+ null,
+ false,
+ null);
+ RMRecordsReader recordsReaderSpy = spy(recordsReader);
+ ByteBuffer[][] buffers = new ByteBuffer[3][2];
+ for (int i = 0; i < 3; i++) {
+ buffers[i][0] =
+ ByteBuffer.wrap(
+ genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3,
RECORDS_NUM, 1));
+ buffers[i][1] =
+ ByteBuffer.wrap(
+ genSortedRecordBytes(
+ rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3,
RECORDS_NUM, 1));
+ }
+ ShuffleServerClient serverClient =
+ new MockedShuffleServerClient(
+ new int[] {PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2},
buffers, blockIds);
+
doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+ // 4 mock shuffle
+ RMRssShuffle rssShuffle = new RMRssShuffle(inputContext, conf, 5, 0,
APPLICATION_ATTEMPT_ID);
+ RMRssShuffle rssShuffleSpy = spy(rssShuffle);
+
doReturn(recordsReaderSpy).when(rssShuffleSpy).createRMRecordsReader(anySet());
+ // rssShuffleSpy.setEventHandler(new RMShuffleEventHandler(rssShuffleSpy));
+
+ try (MockedStatic<UmbilicalUtils> umbilicalUtils =
Mockito.mockStatic(UmbilicalUtils.class);
+ MockedStatic<RssTezUtils> tezUtils =
Mockito.mockStatic(RssTezUtils.class)) {
+ umbilicalUtils
+ .when(() -> UmbilicalUtils.requestShuffleServer(any(), any(), any(),
anyInt()))
+ .thenReturn(
+ ImmutableMap.of(
+ PARTITION_ID,
+ serverInfos,
+ PARTITION_ID + 1,
+ serverInfos,
+ PARTITION_ID + 2,
+ serverInfos));
+ ShuffleWriteClient writeClient = new MockedShuffleWriteClient();
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>>
serverToPartitionToBlockIds =
+ ImmutableMap.of(
+ serverInfos.get(0),
+ ImmutableMap.of(
+ PARTITION_ID,
+ Sets.newHashSet(blockIds[0], blockIds[1]),
+ PARTITION_ID + 1,
+ Sets.newHashSet(blockIds[2], blockIds[3]),
+ PARTITION_ID + 2,
+ Sets.newHashSet(blockIds[4], blockIds[5])));
+ writeClient.reportShuffleResult(
+ serverToPartitionToBlockIds,
+ APPLICATION_ATTEMPT_ID.toString(),
+ SHUFFLE_ID,
+ taskAttemptId,
+ 0);
+ tezUtils.when(() ->
RssTezUtils.createShuffleClient(any())).thenReturn(writeClient);
+ tezUtils
+ .when(() -> RssTezUtils.fetchAllRssTaskIds(anySet(), anyInt(),
anyInt(), anyInt()))
+ .thenReturn(Roaring64NavigableMap.bitmapOf(taskAttemptId));
+ tezUtils
+ .when(() -> RssTezUtils.uniqueIdentifierToAttemptId(anyString()))
+ .thenReturn(tezTaskAttemptID.toString());
+
+ // 5 init and start kv input
+ RMRssOrderedGroupedKVInput input = new
RMRssOrderedGroupedKVInput(inputContext, 5);
+ RMRssOrderedGroupedKVInput inputSpy = spy(input);
+ doReturn(rssShuffleSpy).when(inputSpy).createRssShuffle();
+ inputSpy.initialize();
+ List<Event> events = new ArrayList<>();
+ for (int i = 0; i < 2; i++) {
+ TezTaskAttemptID taskAttemptID =
+ TezTaskAttemptID.getInstance(
+ TezTaskID.getInstance(
+ TezVertexID.getInstance(
+
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+ i),
+ 0);
+ events.add(
+ createDataMovementEvent(i / 2, String.format("%s_%05d",
taskAttemptID.toString(), 1)));
+ }
+ inputSpy.handleEvents(events);
+ inputSpy.start();
+ events.clear();
+ for (int i = 2; i < 5; i++) {
+ TezTaskAttemptID taskAttemptID =
+ TezTaskAttemptID.getInstance(
+ TezTaskID.getInstance(
+ TezVertexID.getInstance(
+
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+ i),
+ 0);
+ events.add(
+ createDataMovementEvent(i / 2, String.format("%s_%05d",
taskAttemptID.toString(), 1)));
+ }
+ inputSpy.handleEvents(events);
+
+ // 6 verify result
+ KeyValuesReader reader = inputSpy.getReader();
+ int index = 0;
+ while (reader.next()) {
+ assertEquals(SerializerUtils.genData(keyClass, index),
reader.getCurrentKey());
+ Iterator iterator = reader.getCurrentValues().iterator();
+ int i = 0;
+ while (iterator.hasNext()) {
+ assertEquals(SerializerUtils.genData(valueClass, index),
iterator.next());
+ i++;
+ }
+ assertEquals(1, i);
+ index++;
+ }
+ assertEquals(RECORDS_NUM * 6, index);
+ }
+ }
+
+ private static DataMovementEvent createDataMovementEvent(int partition,
String path) {
+ ShuffleUserPayloads.DataMovementEventPayloadProto proto =
+ ShuffleUserPayloads.DataMovementEventPayloadProto.newBuilder()
+ .setPathComponent(path)
+ .build();
+ return DataMovementEvent.create(partition,
proto.toByteString().asReadOnlyByteBuffer());
+ }
+}
diff --git
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
b/integration-test/tez/src/test/java/org/apache/uniffle/test/RMTezOrderedWordCountTest.java
similarity index 63%
copy from
integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
copy to
integration-test/tez/src/test/java/org/apache/uniffle/test/RMTezOrderedWordCountTest.java
index 4bd2f7e8a..eb5ee51dc 100644
---
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
+++
b/integration-test/tez/src/test/java/org/apache/uniffle/test/RMTezOrderedWordCountTest.java
@@ -28,23 +28,59 @@ import com.google.common.collect.Lists;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.Tool;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.examples.OrderedWordCount;
+import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
-public class TezOrderedWordCountTest extends TezIntegrationTestBase {
+import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.server.ShuffleServerConf;
- private String inputPath = "ordered_word_count_input";
- private String outputPath = "ordered_word_count_output";
+import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_ENABLE;
+
+public class RMTezOrderedWordCountTest extends TezIntegrationTestBase {
+
+ private String inputPath = "rm_ordered_word_count_input";
+ private String outputPath = "rm_ordered_word_count_output";
private List<String> wordTable =
Lists.newArrayList(
"apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan",
"tomato");
+ @BeforeAll
+ public static void setupServers() throws Exception {
+ ShuffleServerConf serverConf = new ShuffleServerConf();
+ serverConf.set(SERVER_MERGE_ENABLE, true);
+ TezIntegrationTestBase.setupServers(serverConf);
+ }
+
@Test
public void orderedWordCountTest() throws Exception {
generateInputFile();
run();
}
+ public void run() throws Exception {
+ // 1 Run original Tez examples
+ TezConfiguration appConf = new
TezConfiguration(miniTezCluster.getConfig());
+ updateCommonConfiguration(appConf);
+ runTezApp(appConf, getTestTool(), getTestArgs("origin"));
+ final String originPath = getOutputDir("origin");
+
+ // Run RSS tests with different configurations
+ runRemoteMergeRssTest(ClientType.GRPC, "rss-grpc", originPath);
+ }
+
+ private void runRemoteMergeRssTest(ClientType clientType, String testName,
String originPath)
+ throws Exception {
+ TezConfiguration appConf = new
TezConfiguration(miniTezCluster.getConfig());
+ appConf.set(RssTezConfig.RSS_REMOTE_MERGE_ENABLE, "true");
+ updateRssConfiguration(appConf, clientType);
+ appendAndUploadRssJars(appConf);
+ runTezApp(appConf, getTestTool(), getTestArgs(testName));
+ verifyResults(originPath, getOutputDir(testName));
+ }
+
private void generateInputFile() throws Exception {
// For ordered word count, the key of last ordered sorter is the summation
of word, the value is
// the word. So it means this key may not be unique. Because Sorter can
only make sure key is
diff --git
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezCartesianProductTest.java
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezCartesianProductTest.java
index 7e8b2a031..6edb99cd4 100644
---
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezCartesianProductTest.java
+++
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezCartesianProductTest.java
@@ -22,6 +22,7 @@ import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.Tool;
import org.apache.tez.examples.CartesianProduct;
+import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
public class TezCartesianProductTest extends TezIntegrationTestBase {
@@ -31,6 +32,11 @@ public class TezCartesianProductTest extends
TezIntegrationTestBase {
private String inputPath3 = "cartesian_product_input3";
private String outputPath = "cartesian_product_output";
+ @BeforeAll
+ public static void setupServers() throws Exception {
+ TezIntegrationTestBase.setupServers(null);
+ }
+
@Test
public void cartesianProductTest() throws Exception {
generateInputFile();
diff --git
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
index cdaffffb1..d077e3608 100644
---
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
+++
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
@@ -75,6 +75,9 @@ public class TezIntegrationTestBase extends
IntegrationTestBase {
miniTezCluster.init(conf);
miniTezCluster.start();
}
+ }
+
+ protected static void setupServers(ShuffleServerConf serverConf) throws
Exception {
LOG.info("Starting coordinators and shuffle servers");
CoordinatorConf coordinatorConf = getCoordinatorConf();
Map<String, String> dynamicConf = new HashMap<>();
@@ -83,8 +86,14 @@ public class TezIntegrationTestBase extends
IntegrationTestBase {
addDynamicConf(coordinatorConf, dynamicConf);
createCoordinatorServer(coordinatorConf);
ShuffleServerConf grpcShuffleServerConf =
getShuffleServerConf(ServerType.GRPC);
+ if (serverConf != null) {
+ grpcShuffleServerConf.addAll(serverConf);
+ }
createShuffleServer(grpcShuffleServerConf);
ShuffleServerConf nettyShuffleServerConf =
getShuffleServerConf(ServerType.GRPC_NETTY);
+ if (serverConf != null) {
+ nettyShuffleServerConf.addAll(serverConf);
+ }
createShuffleServer(nettyShuffleServerConf);
startServers();
}
diff --git
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezJoinIntegrationTestBase.java
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezJoinIntegrationTestBase.java
index 641e40d34..6d87237e5 100644
---
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezJoinIntegrationTestBase.java
+++
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezJoinIntegrationTestBase.java
@@ -22,6 +22,7 @@ import org.apache.hadoop.util.ToolRunner;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.examples.JoinDataGen;
import org.apache.tez.examples.JoinValidate;
+import org.junit.jupiter.api.BeforeAll;
import org.apache.uniffle.common.ClientType;
@@ -36,6 +37,11 @@ public class TezJoinIntegrationTestBase extends
TezIntegrationTestBase {
protected static final String JOIN_EXPECTED_PATH = "join_expected";
protected static final String NUM_TASKS = "2";
+ @BeforeAll
+ public static void setupServers() throws Exception {
+ TezIntegrationTestBase.setupServers(null);
+ }
+
protected void generateInputFile() throws Exception {
fs.delete(new Path(STREAM_INPUT_PATH), true);
fs.delete(new Path(HASH_INPUT_PATH), true);
diff --git
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
index 4bd2f7e8a..820995b4b 100644
---
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
+++
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
@@ -29,6 +29,7 @@ import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.Tool;
import org.apache.tez.examples.OrderedWordCount;
+import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
public class TezOrderedWordCountTest extends TezIntegrationTestBase {
@@ -39,6 +40,11 @@ public class TezOrderedWordCountTest extends
TezIntegrationTestBase {
Lists.newArrayList(
"apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan",
"tomato");
+ @BeforeAll
+ public static void setupServers() throws Exception {
+ TezIntegrationTestBase.setupServers(null);
+ }
+
@Test
public void orderedWordCountTest() throws Exception {
generateInputFile();
diff --git
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezSimpleSessionExampleTest.java
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezSimpleSessionExampleTest.java
index 788465d61..54da4e25a 100644
---
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezSimpleSessionExampleTest.java
+++
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezSimpleSessionExampleTest.java
@@ -31,6 +31,7 @@ import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.Tool;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.examples.SimpleSessionExample;
+import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.apache.uniffle.common.ClientType;
@@ -44,6 +45,11 @@ public class TezSimpleSessionExampleTest extends
TezIntegrationTestBase {
Lists.newArrayList(
"apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan",
"tomato");
+ @BeforeAll
+ public static void setupServers() throws Exception {
+ TezIntegrationTestBase.setupServers(null);
+ }
+
@Test
public void simpleSessionExampleTest() throws Exception {
generateInputFile();
diff --git
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountTest.java
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountTest.java
index e2e9b564b..a99bf9bc0 100644
---
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountTest.java
+++
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountTest.java
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.Tool;
import org.apache.tez.examples.WordCount;
+import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
public class TezWordCountTest extends TezIntegrationTestBase {
@@ -35,6 +36,11 @@ public class TezWordCountTest extends TezIntegrationTestBase
{
Lists.newArrayList(
"apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan",
"tomato");
+ @BeforeAll
+ public static void setupServers() throws Exception {
+ TezIntegrationTestBase.setupServers(null);
+ }
+
@Test
public void wordCountTest() throws Exception {
generateInputFile();