This is an automated email from the ASF dual-hosted git repository.

zhengchenyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new be23f6137 [#1750] feat(remote merge): Support Tez. (#2160)
be23f6137 is described below

commit be23f6137fb5f478372328bf492923f115203cf6
Author: zhengchenyu <[email protected]>
AuthorDate: Thu Oct 10 10:39:36 2024 +0800

    [#1750] feat(remote merge): Support Tez. (#2160)
    
    ### What changes were proposed in this pull request?
    
    Tez support remote merge.
    
    ### Why are the changes needed?
    
    Fix: #1750
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, I will refine documentation in other PR.
    
    ### How was this patch tested?
    
    unit test, integration test, real job in cluster.
---
 client-tez/pom.xml                                 |  34 +-
 .../apache/tez/common/GetShuffleServerRequest.java |  36 ++
 .../java/org/apache/tez/common/RssTezConfig.java   |  10 +
 .../java/org/apache/tez/common/RssTezUtils.java    |  17 +-
 .../org/apache/tez/dag/app/RssDAGAppMaster.java    |   9 +-
 .../tez/dag/app/TezRemoteShuffleManager.java       |  25 +-
 .../shuffle/orderedgrouped/RMRssShuffle.java       | 269 ++++++++++++
 .../orderedgrouped/RMRssShuffleScheduler.java      |  84 ++++
 .../library/common/sort/buffer/WriteBuffer.java    |  34 +-
 .../common/sort/buffer/WriteBufferManager.java     |  40 +-
 .../library/common/sort/impl/RssSorter.java        |   9 +-
 .../library/common/sort/impl/RssUnSorter.java      |   5 +-
 .../library/input/RMRssOrderedGroupedKVInput.java  | 320 ++++++++++++++
 .../output/RssOrderedPartitionedKVOutput.java      |  27 +-
 .../shuffle/orderedgrouped/RMRssShuffleTest.java   | 443 +++++++++++++++++++
 .../common/sort/buffer/WriteBufferManagerTest.java | 151 ++++++-
 .../common/sort/buffer/WriteBufferTest.java        |  95 +++-
 .../input/RMRssOrderedGroupedKVInputTest.java      | 481 +++++++++++++++++++++
 ...untTest.java => RMTezOrderedWordCountTest.java} |  42 +-
 .../uniffle/test/TezCartesianProductTest.java      |   6 +
 .../uniffle/test/TezIntegrationTestBase.java       |   9 +
 .../uniffle/test/TezJoinIntegrationTestBase.java   |   6 +
 .../uniffle/test/TezOrderedWordCountTest.java      |   6 +
 .../uniffle/test/TezSimpleSessionExampleTest.java  |   6 +
 .../org/apache/uniffle/test/TezWordCountTest.java  |   6 +
 25 files changed, 2121 insertions(+), 49 deletions(-)

diff --git a/client-tez/pom.xml b/client-tez/pom.xml
index 52430441f..3f7ba7a02 100644
--- a/client-tez/pom.xml
+++ b/client-tez/pom.xml
@@ -128,17 +128,29 @@
           </exclusion>
         </exclusions>
       </dependency>
-        <dependency>
-          <groupId>org.apache.hadoop</groupId>
-          <artifactId>hadoop-minicluster</artifactId>
-          <scope>test</scope>
-          <exclusions>
-            <exclusion>
-              <groupId>org.slf4j</groupId>
-              <artifactId>slf4j-log4j12</artifactId>
-            </exclusion>
-          </exclusions>
-        </dependency>
+      <dependency>
+        <groupId>org.apache.hadoop</groupId>
+        <artifactId>hadoop-minicluster</artifactId>
+        <scope>test</scope>
+        <exclusions>
+          <exclusion>
+            <groupId>org.slf4j</groupId>
+            <artifactId>slf4j-log4j12</artifactId>
+          </exclusion>
+        </exclusions>
+      </dependency>
+      <dependency>
+        <groupId>org.apache.uniffle</groupId>
+        <artifactId>rss-common</artifactId>
+        <scope>test</scope>
+        <type>test-jar</type>
+      </dependency>
+      <dependency>
+        <groupId>org.apache.uniffle</groupId>
+        <artifactId>rss-client</artifactId>
+        <scope>test</scope>
+        <type>test-jar</type>
+      </dependency>
     </dependencies>
 
     <build>
diff --git 
a/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java 
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
index e87b97406..dd9a6decc 100644
--- 
a/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
+++ 
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
@@ -22,6 +22,7 @@ import java.io.DataOutput;
 import java.io.IOException;
 
 import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableUtils;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 
 public class GetShuffleServerRequest implements Writable {
@@ -29,15 +30,32 @@ public class GetShuffleServerRequest implements Writable {
   private int startIndex;
   private int partitionNum;
   private int shuffleId;
+  private String keyClassName;
+  private String valueClassName;
+  private String comparatorClassName;
 
   public GetShuffleServerRequest() {}
 
   public GetShuffleServerRequest(
       TezTaskAttemptID currentTaskAttemptID, int startIndex, int partitionNum, 
int shuffleId) {
+    this(currentTaskAttemptID, startIndex, partitionNum, shuffleId, "", "", 
"");
+  }
+
+  public GetShuffleServerRequest(
+      TezTaskAttemptID currentTaskAttemptID,
+      int startIndex,
+      int partitionNum,
+      int shuffleId,
+      String keyClassName,
+      String valueClassName,
+      String comparatorClassName) {
     this.currentTaskAttemptID = currentTaskAttemptID;
     this.startIndex = startIndex;
     this.partitionNum = partitionNum;
     this.shuffleId = shuffleId;
+    this.keyClassName = keyClassName;
+    this.valueClassName = valueClassName;
+    this.comparatorClassName = comparatorClassName;
   }
 
   @Override
@@ -51,6 +69,9 @@ public class GetShuffleServerRequest implements Writable {
     } else {
       output.writeBoolean(false);
     }
+    WritableUtils.writeString(output, keyClassName);
+    WritableUtils.writeString(output, valueClassName);
+    WritableUtils.writeString(output, comparatorClassName);
   }
 
   @Override
@@ -63,6 +84,9 @@ public class GetShuffleServerRequest implements Writable {
       currentTaskAttemptID = new TezTaskAttemptID();
       currentTaskAttemptID.readFields(dataInput);
     }
+    keyClassName = WritableUtils.readString(dataInput);
+    valueClassName = WritableUtils.readString(dataInput);
+    comparatorClassName = WritableUtils.readString(dataInput);
   }
 
   @Override
@@ -94,4 +118,16 @@ public class GetShuffleServerRequest implements Writable {
   public int getShuffleId() {
     return shuffleId;
   }
+
+  public String getKeyClassName() {
+    return keyClassName;
+  }
+
+  public String getValueClassName() {
+    return valueClassName;
+  }
+
+  public String getComparatorClassName() {
+    return comparatorClassName;
+  }
 }
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java 
b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
index a1186534a..e73b04fb5 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
@@ -225,6 +225,16 @@ public class RssTezConfig {
   public static final String RSS_SHUFFLE_MODE = TEZ_RSS_CONFIG_PREFIX + 
"shuffle.mode";
   public static final String DEFAULT_RSS_SHUFFLE_MODE = "remote";
 
+  public static final String RSS_REMOTE_MERGE_ENABLE =
+      TEZ_RSS_CONFIG_PREFIX + RssClientConfig.RSS_REMOTE_MERGE_ENABLE;
+  public static final boolean RSS_REMOTE_MERGE_ENABLE_DEFAULT = false;
+  public static final String RSS_MERGED_BLOCK_SZIE =
+      TEZ_RSS_CONFIG_PREFIX + RssClientConfig.RSS_MERGED_BLOCK_SZIE;
+  public static final int RSS_MERGED_BLOCK_SZIE_DEFAULT =
+      RssClientConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT;
+  public static final String RSS_REMOTE_MERGE_CLASS_LOADER =
+      TEZ_RSS_CONFIG_PREFIX + RssClientConfig.RSS_REMOTE_MERGE_CLASS_LOADER;
+
   public static RssConf toRssConf(Configuration jobConf) {
     RssConf rssConf = new RssConf();
     for (Map.Entry<String, String> entry : jobConf) {
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java 
b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
index 3df582bde..7b6d9ad2c 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
@@ -33,6 +33,7 @@ import 
org.apache.tez.runtime.library.input.ConcatenatedMergedKeyValuesInput;
 import org.apache.tez.runtime.library.input.OrderedGroupedInputLegacy;
 import org.apache.tez.runtime.library.input.OrderedGroupedKVInput;
 import org.apache.tez.runtime.library.input.OrderedGroupedMergedKVInput;
+import org.apache.tez.runtime.library.input.RMRssOrderedGroupedKVInput;
 import org.apache.tez.runtime.library.input.RssConcatenatedMergedKeyValueInput;
 import 
org.apache.tez.runtime.library.input.RssConcatenatedMergedKeyValuesInput;
 import org.apache.tez.runtime.library.input.RssOrderedGroupedInputLegacy;
@@ -122,7 +123,8 @@ public class RssTezUtils {
                     .replicaRead(replicaRead)
                     .replicaSkipEnabled(replicaSkipEnabled)
                     .dataTransferPoolSize(dataTransferPoolSize)
-                    .dataCommitPoolSize(dataCommitPoolSize));
+                    .dataCommitPoolSize(dataCommitPoolSize)
+                    .rssConf(RssTezConfig.toRssConf(conf)));
     return client;
   }
 
@@ -423,13 +425,14 @@ public class RssTezUtils {
     }
   }
 
-  public static String replaceRssInputClassName(String className) {
+  public static String replaceRssInputClassName(String className, boolean 
isRemoteMergeEnable) {
     if (className.equals(OrderedGroupedKVInput.class.getName())) {
-      LOG.info(
-          "Input class name will transient from {} to {}",
-          className,
-          RssOrderedGroupedKVInput.class.getName());
-      return RssOrderedGroupedKVInput.class.getName();
+      String orderedInputClasName =
+          isRemoteMergeEnable
+              ? RMRssOrderedGroupedKVInput.class.getName()
+              : RssOrderedGroupedKVInput.class.getName();
+      LOG.info("Input class name will transient from {} to {}", className, 
orderedInputClasName);
+      return orderedInputClasName;
     } else if (className.equals(OrderedGroupedMergedKVInput.class.getName())) {
       LOG.info(
           "Input class name will transient from {} to {}",
diff --git 
a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java 
b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
index ac1f9c28f..0fb41146c 100644
--- a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
+++ b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
@@ -608,7 +608,14 @@ public class RssDAGAppMaster extends DAGAppMaster {
               
outputDescriptor.getClass().getSuperclass().getDeclaredField("className");
           inputClassNameField.setAccessible(true);
           String inputClassName = (String) 
outputClassNameField.get(inputDescriptor);
-          String rssInputClassName = 
RssTezUtils.replaceRssInputClassName(inputClassName);
+          String rssInputClassName =
+              RssTezUtils.replaceRssInputClassName(
+                  inputClassName,
+                  appMaster
+                      .getConfig()
+                      .getBoolean(
+                          RssTezConfig.RSS_REMOTE_MERGE_ENABLE,
+                          RssTezConfig.RSS_REMOTE_MERGE_ENABLE_DEFAULT));
           outputClassNameField.set(inputDescriptor, rssInputClassName);
         }
       } catch (IOException | IllegalAccessException | NoSuchFieldException e) {
diff --git 
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java 
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
index bf376814e..85a138c13 100644
--- 
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
+++ 
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
@@ -190,7 +190,13 @@ public class TezRemoteShuffleManager implements 
ServicePluginLifecycle {
           if (shuffleIdToShuffleAssignsInfo.containsKey(shuffleId)) {
             shuffleAssignmentsInfo = 
shuffleIdToShuffleAssignsInfo.get(shuffleId);
           } else {
-            shuffleAssignmentsInfo = 
getShuffleWorks(request.getPartitionNum(), shuffleId);
+            shuffleAssignmentsInfo =
+                getShuffleWorks(
+                    request.getPartitionNum(),
+                    shuffleId,
+                    request.getKeyClassName(),
+                    request.getValueClassName(),
+                    request.getComparatorClassName());
           }
 
           if (shuffleAssignmentsInfo == null) {
@@ -221,7 +227,12 @@ public class TezRemoteShuffleManager implements 
ServicePluginLifecycle {
     }
   }
 
-  private ShuffleAssignmentsInfo getShuffleWorks(int partitionNum, int 
shuffleId) {
+  private ShuffleAssignmentsInfo getShuffleWorks(
+      int partitionNum,
+      int shuffleId,
+      String keyClassName,
+      String valueClassName,
+      String comparatorClassName) {
     ShuffleAssignmentsInfo shuffleAssignmentsInfo;
     int requiredAssignmentShuffleServersNum =
         RssTezUtils.getRequiredShuffleServerNumber(conf, 200, partitionNum);
@@ -292,7 +303,15 @@ public class TezRemoteShuffleManager implements 
ServicePluginLifecycle {
                                           remoteStorage,
                                           ShuffleDataDistributionType.NORMAL,
                                           RssTezConfig.toRssConf(conf)
-                                              
.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE)));
+                                              
.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE),
+                                          0,
+                                          keyClassName,
+                                          valueClassName,
+                                          comparatorClassName,
+                                          conf.getInt(
+                                              
RssTezConfig.RSS_MERGED_BLOCK_SZIE,
+                                              
RssTezConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT),
+                                          
conf.get(RssTezConfig.RSS_REMOTE_MERGE_CLASS_LOADER)));
                           LOG.info(
                               "Finish register shuffle with "
                                   + (System.currentTimeMillis() - start)
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffle.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffle.java
new file mode 100644
index 000000000..542d251e7
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffle.java
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.hadoop.classification.InterfaceAudience.Private;
+import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.WritableComparator;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.tez.common.InputContextUtils;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.common.UmbilicalUtils;
+import org.apache.tez.common.counters.TaskCounter;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.library.common.ConfigUtils;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.record.reader.KeyValuesReader;
+import org.apache.uniffle.client.record.reader.RMRecordsReader;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+
+@Private
[email protected]
+public class RMRssShuffle implements ExceptionReporter {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(RMRssShuffle.class);
+
+  private final Configuration conf;
+  private final RssConf rssConf;
+  private final InputContext inputContext;
+  private final int numInputs;
+  private final int shuffleId;
+  private final ApplicationAttemptId applicationAttemptId;
+  private final String appId;
+  private ShuffleInputEventHandlerOrderedGrouped eventHandler;
+  private final TezTaskAttemptID tezTaskAttemptID;
+  private final String srcNameTrimmed;
+  private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+
+  private AtomicBoolean isShutDown = new AtomicBoolean(false);
+
+  final TezCounter skippedInputCounter;
+  final TezCounter inputRecordCounter;
+
+  final Map<Integer, Set<InputAttemptIdentifier>> 
partitionIdToSuccessMapTaskAttempts =
+      new HashMap<>();
+  final Map<Integer, Set<TezTaskID>> partitionIdToSuccessTezTasks = new 
HashMap<>();
+
+  final Set<Integer> partitionIds = new HashSet<>();
+  private RMRecordsReader reader = null;
+  private RMRssShuffleScheduler scheduler;
+
+  public RMRssShuffle(
+      InputContext inputContext,
+      Configuration conf,
+      int numInputs,
+      int shuffleId,
+      ApplicationAttemptId applicationAttemptId)
+      throws IOException {
+    this.inputContext = inputContext;
+    this.conf = conf;
+    this.rssConf = RssTezConfig.toRssConf(conf);
+    this.numInputs = numInputs;
+    this.shuffleId = shuffleId;
+    this.applicationAttemptId = applicationAttemptId;
+    this.appId = this.applicationAttemptId.toString();
+    this.srcNameTrimmed = 
TezUtilsInternal.cleanVertexName(inputContext.getSourceVertexName());
+    LOG.info(srcNameTrimmed + ": Shuffle assigned with " + numInputs + " 
inputs.");
+    this.skippedInputCounter =
+        inputContext.getCounters().findCounter(TaskCounter.NUM_SKIPPED_INPUTS);
+    this.inputRecordCounter =
+        
inputContext.getCounters().findCounter(TaskCounter.INPUT_RECORDS_PROCESSED);
+
+    this.scheduler =
+        new RMRssShuffleScheduler(
+            this.inputContext,
+            this.conf,
+            numInputs,
+            this,
+            null,
+            null,
+            System.currentTimeMillis(),
+            null,
+            false,
+            0,
+            srcNameTrimmed,
+            this);
+    this.eventHandler =
+        new ShuffleInputEventHandlerOrderedGrouped(
+            inputContext, scheduler, ShuffleUtils.isTezShuffleHandler(conf));
+    this.tezTaskAttemptID = 
InputContextUtils.getTezTaskAttemptID(this.inputContext);
+    // When remote merge is enable, we use the reading-while-processing 
method, so we set input
+    // ready directly.
+    inputContext.inputIsReady();
+  }
+
+  public void handleEvents(List<Event> events) throws IOException {
+    if (!isShutDown.get()) {
+      eventHandler.handleEvents(events);
+    } else {
+      LOG.info(
+          srcNameTrimmed
+              + ": Ignoring events since already shutdown. EventCount: "
+              + events.size());
+    }
+  }
+
+  public void run() throws IOException {
+    this.partitionToServers =
+        UmbilicalUtils.requestShuffleServer(
+            inputContext.getApplicationId(), conf, tezTaskAttemptID, 
shuffleId);
+  }
+
+  public void shutdown() {
+    if (!isShutDown.getAndSet(true)) {
+      if (reader != null) {
+        reader.close();
+      }
+      LOG.info("Shutting down Shuffle for source: " + srcNameTrimmed);
+    }
+  }
+
+  public void waitForEvents() throws InterruptedException {
+    while (!allInputTaskAttemptDone()) {
+      Thread.sleep(100);
+    }
+    // report unique blocks
+    reportUniqueBlockIds();
+    if (partitionIds.size() > 0) {
+      reader = createRMRecordsReader(partitionIds);
+      reader.start();
+    }
+  }
+
+  private boolean allInputTaskAttemptDone() {
+    return (this.partitionIdToSuccessTezTasks.values().stream().mapToInt(s -> 
s.size()).sum()
+            + skippedInputCounter.getValue())
+        == numInputs;
+  }
+
+  public void reportUniqueBlockIds() {
+    ShuffleWriteClient writeClient = RssTezUtils.createShuffleClient(conf);
+    for (int partitionId : partitionIds) {
+      Roaring64NavigableMap blockIdBitmap =
+          writeClient.getShuffleResult(
+              null,
+              new HashSet<>(partitionToServers.get(partitionId)),
+              appId,
+              shuffleId,
+              partitionId);
+      Roaring64NavigableMap taskIdBitmap =
+          RssTezUtils.fetchAllRssTaskIds(
+              partitionIdToSuccessMapTaskAttempts.get(partitionId),
+              numInputs,
+              applicationAttemptId.getAttemptId(),
+              RssTezUtils.getMaxAttemptNo(conf));
+      Roaring64NavigableMap uniqueBlockIdBitMap = 
Roaring64NavigableMap.bitmapOf();
+      blockIdBitmap.forEach(
+          blockId -> {
+            long taId = RssTezUtils.getTaskAttemptId(blockId);
+            if (taskIdBitmap.contains(taId)) {
+              uniqueBlockIdBitMap.add(blockId);
+            }
+          });
+      writeClient.startSortMerge(
+          new HashSet<>(partitionToServers.get(partitionId)),
+          appId,
+          shuffleId,
+          partitionId,
+          uniqueBlockIdBitMap);
+    }
+  }
+
+  public KeyValuesReader getKeyValuesReader() {
+    if (reader == null) {
+      return new KeyValuesReader() {
+        @Override
+        public boolean next() {
+          return false;
+        }
+
+        @Override
+        public Object getCurrentKey() throws IOException {
+          throw new IOException("No data available");
+        }
+
+        @Override
+        public Iterable getCurrentValues() throws IOException {
+          throw new IOException("No data available");
+        }
+      };
+    }
+    return this.reader.keyValuesReader();
+  }
+
+  @VisibleForTesting
+  public RMRecordsReader createRMRecordsReader(Set partitionIds) {
+    Class keyClass = ConfigUtils.getIntermediateInputKeyClass(conf);
+    Class valueClass = ConfigUtils.getIntermediateInputValueClass(conf);
+    // For hive on tez, we use separate serializer and comparator, namely
+    // TezBytesWritableSerialization and TezBytesComparator. But in remote
+    // merge mode, we use separate serializers, so we should also use
+    // separate comparators.
+    RawComparator rawComparator = WritableComparator.get(keyClass);
+    return new RMRecordsReader(
+        appId,
+        shuffleId,
+        partitionIds,
+        partitionToServers,
+        rssConf,
+        keyClass,
+        valueClass,
+        rawComparator,
+        true,
+        null,
+        false,
+        (inc) -> {
+          inputRecordCounter.increment(inc);
+        });
+  }
+
+  @Override
+  public void reportException(Throwable t) {
+    throw new RssException("should never happen!");
+  }
+
+  @Override
+  public void killSelf(Exception exception, String message) {
+    throw new RssException("should never happen!");
+  }
+}
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleScheduler.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleScheduler.java
new file mode 100644
index 000000000..14732ec0a
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleScheduler.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.io.IOException;
+import java.util.HashSet;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.compress.CompressionCodec;
+import org.apache.tez.common.IdUtils;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.library.common.CompositeInputAttemptIdentifier;
+
+// Although RMRssShuffleScheduler is extended from ShuffleScheduler, it is 
only used
+// to handle DME events. Because the protobuf-java version of tez is different 
from
+// uniffle, it is not convenient to parse DME events. It means there is no 
need to
+// start RMRssShuffleScheduler.
+public class RMRssShuffleScheduler extends ShuffleScheduler {
+
+  private RMRssShuffle rssShuffle;
+
+  public RMRssShuffleScheduler(
+      InputContext inputContext,
+      Configuration conf,
+      int numberOfInputs,
+      ExceptionReporter exceptionReporter,
+      MergeManager mergeManager,
+      FetchedInputAllocatorOrderedGrouped allocator,
+      long startTime,
+      CompressionCodec codec,
+      boolean ifileReadAhead,
+      int ifileReadAheadLength,
+      String srcNameTrimmed,
+      RMRssShuffle rssShuffle)
+      throws IOException {
+    super(
+        inputContext,
+        conf,
+        numberOfInputs,
+        exceptionReporter,
+        mergeManager,
+        allocator,
+        startTime,
+        codec,
+        ifileReadAhead,
+        ifileReadAheadLength,
+        srcNameTrimmed);
+    this.rssShuffle = rssShuffle;
+  }
+
+  @Override
+  public synchronized void addKnownMapOutput(
+      String inputHostName, int port, int partitionId, 
CompositeInputAttemptIdentifier srcAttempt) {
+    super.addKnownMapOutput(inputHostName, port, partitionId, srcAttempt);
+    rssShuffle.partitionIds.add(partitionId);
+    if 
(!rssShuffle.partitionIdToSuccessMapTaskAttempts.containsKey(partitionId)) {
+      rssShuffle.partitionIdToSuccessMapTaskAttempts.put(partitionId, new 
HashSet<>());
+    }
+    
rssShuffle.partitionIdToSuccessMapTaskAttempts.get(partitionId).add(srcAttempt);
+
+    String pathComponent = srcAttempt.getPathComponent();
+    TezTaskAttemptID tezTaskAttemptId = 
IdUtils.convertTezTaskAttemptID(pathComponent);
+    if (!rssShuffle.partitionIdToSuccessTezTasks.containsKey(partitionId)) {
+      rssShuffle.partitionIdToSuccessTezTasks.put(partitionId, new 
HashSet<>());
+    }
+    
rssShuffle.partitionIdToSuccessTezTasks.get(partitionId).add(tezTaskAttemptId.getTaskID());
+  }
+}
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
index 43d784f1d..c1af1d006 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
@@ -17,6 +17,7 @@
 
 package org.apache.tez.runtime.library.common.sort.buffer;
 
+import java.io.DataOutputStream;
 import java.io.IOException;
 import java.io.OutputStream;
 import java.util.Comparator;
@@ -29,6 +30,8 @@ import org.apache.hadoop.io.serializer.Serializer;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.common.serializer.SerializerInstance;
+
 public class WriteBuffer<K, V> extends OutputStream {
 
   private static final Logger LOG = LoggerFactory.getLogger(WriteBuffer.class);
@@ -47,30 +50,47 @@ public class WriteBuffer<K, V> extends OutputStream {
   private final List<WrappedBuffer> buffers = Lists.newArrayList();
   private final List<Record<K>> records = Lists.newArrayList();
 
+  private final boolean useUniffleSerializer;
+  private SerializerInstance serializerInstance;
+  private DataOutputStream dataOutputStream;
+
   public WriteBuffer(
       boolean isNeedSorted,
       int partitionId,
       RawComparator<K> comparator,
       long maxSegmentSize,
+      boolean useUniffleSerializer,
       Serializer<K> keySerializer,
-      Serializer<V> valueSerializer) {
+      Serializer<V> valueSerializer,
+      SerializerInstance serializerInstance) {
     this.partitionId = partitionId;
     this.comparator = comparator;
+    this.useUniffleSerializer = useUniffleSerializer;
     this.maxSegmentSize = maxSegmentSize;
     this.keySerializer = keySerializer;
     this.valSerializer = valueSerializer;
     this.isNeedSorted = isNeedSorted;
+    this.serializerInstance = serializerInstance;
+    if (useUniffleSerializer) {
+      this.dataOutputStream = new DataOutputStream(this);
+    }
   }
 
   /** add records */
   public int addRecord(K key, V value) throws IOException {
-    keySerializer.open(this);
-    valSerializer.open(this);
+    if (!useUniffleSerializer) {
+      keySerializer.open(this);
+      valSerializer.open(this);
+    }
     int lastOffSet = currentOffset;
     int lastIndex = currentIndex;
     int lastDataLength = dataLength;
     int keyIndex = lastIndex;
-    keySerializer.serialize(key);
+    if (useUniffleSerializer) {
+      serializerInstance.serialize(key, this.dataOutputStream);
+    } else {
+      keySerializer.serialize(key);
+    }
     int keyLength = dataLength - lastDataLength;
     int keyOffset = lastOffSet;
     if (compact(lastIndex, lastOffSet, keyLength)) {
@@ -78,7 +98,11 @@ public class WriteBuffer<K, V> extends OutputStream {
       keyIndex = lastIndex;
     }
     lastDataLength = dataLength;
-    valSerializer.serialize(value);
+    if (useUniffleSerializer) {
+      serializerInstance.serialize(value, this.dataOutputStream);
+    } else {
+      valSerializer.serialize(value);
+    }
     int valueLength = dataLength - lastDataLength;
     records.add(new Record<K>(keyIndex, keyOffset, keyLength, valueLength));
     return keyLength + valueLength;
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
index 93735efa4..78c235447 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
@@ -38,6 +38,8 @@ import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import com.google.common.util.concurrent.Uninterruptibles;
 import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.io.WritableComparator;
 import org.apache.hadoop.io.serializer.Serializer;
 import org.apache.tez.common.RssTezUtils;
 import org.apache.tez.common.counters.TezCounter;
@@ -54,6 +56,8 @@ import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.serializer.SerializerFactory;
+import org.apache.uniffle.common.serializer.SerializerInstance;
 import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.ChecksumUtils;
 import org.apache.uniffle.common.util.ThreadUtils;
@@ -105,6 +109,9 @@ public class WriteBufferManager<K, V> {
   private final TezCounter mapOutputByteCounter;
   private final TezCounter mapOutputRecordCounter;
 
+  private final boolean useUniffleSerializer;
+  private SerializerInstance serializerInstance;
+
   /** WriteBufferManager */
   public WriteBufferManager(
       TezTaskAttemptID tezTaskAttemptID,
@@ -133,7 +140,10 @@ public class WriteBufferManager<K, V> {
       int shuffleId,
       boolean isNeedSorted,
       TezCounter mapOutputByteCounter,
-      TezCounter mapOutputRecordCounter) {
+      TezCounter mapOutputRecordCounter,
+      boolean useUniffleSerializer,
+      Class<K> keyClass,
+      Class<V> valClass) {
     this.tezTaskAttemptID = tezTaskAttemptID;
     this.maxMemSize = maxMemSize;
     this.appId = appId;
@@ -142,7 +152,14 @@ public class WriteBufferManager<K, V> {
     this.successBlockIds = successBlockIds;
     this.failedBlockIds = failedBlockIds;
     this.shuffleWriteClient = shuffleWriteClient;
-    this.comparator = comparator;
+    // For hive on tez, we use separate serializer and comparator, namely
+    // TezBytesWritableSerialization and TezBytesComparator. But in remote
+    // merge mode, we use separate serializers, so we should also use
+    // separate comparators.
+    this.comparator =
+        useUniffleSerializer
+            ? WritableComparator.get((Class<? extends WritableComparable>) 
keyClass)
+            : comparator;
     this.maxSegmentSize = maxSegmentSize;
     this.keySerializer = keySerializer;
     this.valSerializer = valSerializer;
@@ -162,6 +179,13 @@ public class WriteBufferManager<K, V> {
     this.isNeedSorted = isNeedSorted;
     this.mapOutputByteCounter = mapOutputByteCounter;
     this.mapOutputRecordCounter = mapOutputRecordCounter;
+    this.useUniffleSerializer = useUniffleSerializer;
+    if (useUniffleSerializer) {
+      SerializerFactory factory = new SerializerFactory(rssConf);
+      org.apache.uniffle.common.serializer.Serializer serializer = 
factory.getSerializer(keyClass);
+      this.serializerInstance = serializer.newInstance();
+      assert 
factory.getSerializer(valClass).getClass().equals(serializer.getClass());
+    }
     this.sendExecutorService =
         Executors.newFixedThreadPool(sendThreadNum, 
ThreadUtils.getThreadFactory("send-thread"));
   }
@@ -188,7 +212,14 @@ public class WriteBufferManager<K, V> {
     if (!buffers.containsKey(partitionId)) {
       WriteBuffer<K, V> sortWriterBuffer =
           new WriteBuffer(
-              isNeedSorted, partitionId, comparator, maxSegmentSize, 
keySerializer, valSerializer);
+              isNeedSorted,
+              partitionId,
+              comparator,
+              maxSegmentSize,
+              useUniffleSerializer,
+              keySerializer,
+              valSerializer,
+              serializerInstance);
       buffers.putIfAbsent(partitionId, sortWriterBuffer);
       waitSendBuffers.add(sortWriterBuffer);
     }
@@ -371,7 +402,8 @@ public class WriteBufferManager<K, V> {
     final int uncompressLength = data.length;
     long start = System.currentTimeMillis();
 
-    final byte[] compressed = codec.map(c -> c.compress(data)).orElse(data);
+    final byte[] compressed =
+        useUniffleSerializer ? data : codec.map(c -> 
c.compress(data)).orElse(data);
     final long crc32 = ChecksumUtils.getCrc32(compressed);
     compressTime += System.currentTimeMillis() - start;
     final long blockId =
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
index fe4f11e13..84a71680e 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
@@ -30,6 +30,7 @@ import org.apache.tez.common.RssTezConfig;
 import org.apache.tez.common.RssTezUtils;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.library.common.ConfigUtils;
 import org.apache.tez.runtime.library.common.sort.buffer.WriteBufferManager;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -139,6 +140,9 @@ public class RssSorter extends ExternalSorter {
 
     LOG.info("applicationAttemptId is {}", applicationAttemptId.toString());
 
+    boolean isRemoteMergeEnable =
+        conf.getBoolean(
+            RssTezConfig.RSS_REMOTE_MERGE_ENABLE, 
RssTezConfig.RSS_REMOTE_MERGE_ENABLE_DEFAULT);
     bufferManager =
         new WriteBufferManager(
             tezTaskAttemptID,
@@ -167,7 +171,10 @@ public class RssSorter extends ExternalSorter {
             shuffleId,
             true,
             mapOutputByteCounter,
-            mapOutputRecordCounter);
+            mapOutputRecordCounter,
+            isRemoteMergeEnable,
+            ConfigUtils.getIntermediateOutputKeyClass(this.conf),
+            ConfigUtils.getIntermediateOutputValueClass(this.conf));
     LOG.info("Initialized WriteBufferManager.");
   }
 
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
index 94e87a40e..b553700a6 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
@@ -165,7 +165,10 @@ public class RssUnSorter extends ExternalSorter {
             shuffleId,
             false,
             mapOutputByteCounter,
-            mapOutputRecordCounter);
+            mapOutputRecordCounter,
+            false,
+            null,
+            null);
     LOG.info("Initialized WriteBufferManager.");
   }
 
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInput.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInput.java
new file mode 100644
index 000000000..ff7d0c8ae
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInput.java
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.input;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceAudience.Public;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezRuntimeFrameworkConfigs;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.counters.TaskCounter;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.AbstractLogicalInput;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.api.ProgressFailedException;
+import org.apache.tez.runtime.library.api.KeyValuesReader;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.common.Constants;
+import org.apache.tez.runtime.library.common.MemoryUpdateCallbackHandler;
+import 
org.apache.tez.runtime.library.common.shuffle.orderedgrouped.RMRssShuffle;
+import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.RssShuffle;
+import org.apache.tez.runtime.library.common.sort.impl.TezRawKeyValueIterator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.exception.RssException;
+
+import static 
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
+
+/**
+ * {@link RMRssOrderedGroupedKVInput} in a {@link AbstractLogicalInput} which 
shuffles intermediate
+ * sorted data, merges them and provides key/<values> to the consumer. This is 
typically used to
+ * bring one partition of a set of partitioned distributed data to one 
consumer. The shuffle
+ * operation brings all partitions to one place. These partitions are assumed 
to be sorted and are
+ * merged sorted to merge them into a single input view.
+ *
+ * <p>The Copy and Merge will be triggered by the initialization - which is 
handled by the Tez
+ * framework. Input is not consumable until the Copy and Merge are complete. 
Methods are provided to
+ * check for this, as well as to wait for completion. Attempting to get a 
reader on a non-complete
+ * input will block.
+ */
+@Public
+public class RMRssOrderedGroupedKVInput extends AbstractLogicalInput {
+
+  static final Logger LOG = 
LoggerFactory.getLogger(RMRssOrderedGroupedKVInput.class);
+
+  protected TezRawKeyValueIterator rawIter = null;
+  protected Configuration conf;
+  protected RMRssShuffle shuffle;
+  protected MemoryUpdateCallbackHandler memoryUpdateCallbackHandler;
+  private int shuffleId;
+  private ApplicationAttemptId applicationAttemptId;
+  private final BlockingQueue<Event> pendingEvents = new 
LinkedBlockingQueue<>();
+  private long firstEventReceivedTime = -1;
+
+  private final AtomicBoolean isStarted = new AtomicBoolean(false);
+
+  public RMRssOrderedGroupedKVInput(InputContext inputContext, int 
numPhysicalInputs) {
+    super(inputContext, numPhysicalInputs);
+  }
+
+  @Override
+  public synchronized List<Event> initialize() throws IOException {
+    this.conf = 
TezUtils.createConfFromUserPayload(getContext().getUserPayload());
+
+    if (this.getNumPhysicalInputs() == 0) {
+      getContext().requestInitialMemory(0L, null);
+      isStarted.set(true);
+      getContext().inputIsReady();
+      LOG.info(
+          "input fetch not required since there are 0 physical inputs for 
input vertex: "
+              + getContext().getSourceVertexName());
+      return Collections.emptyList();
+    }
+
+    long initialMemoryRequest =
+        RssShuffle.getInitialMemoryRequirement(conf, 
getContext().getTotalMemoryAvailableToTask());
+    this.memoryUpdateCallbackHandler = new MemoryUpdateCallbackHandler();
+    getContext().requestInitialMemory(initialMemoryRequest, 
memoryUpdateCallbackHandler);
+
+    this.conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, 
getContext().getWorkDirs());
+
+    TezTaskAttemptID taskAttemptId =
+        TezTaskAttemptID.fromString(
+            
RssTezUtils.uniqueIdentifierToAttemptId(getContext().getUniqueIdentifier()));
+    TezVertexID tezVertexID = taskAttemptId.getTaskID().getVertexID();
+    TezDAGID tezDAGID = tezVertexID.getDAGId();
+    int sourceVertexId = this.conf.getInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, -1);
+    int destinationVertexId = 
this.conf.getInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, -1);
+    assert sourceVertexId != -1;
+    assert destinationVertexId != -1;
+    this.shuffleId =
+        RssTezUtils.computeShuffleId(tezDAGID.getId(), sourceVertexId, 
destinationVertexId);
+    this.applicationAttemptId =
+        ApplicationAttemptId.newInstance(
+            getContext().getApplicationId(), 
getContext().getDAGAttemptNumber());
+    return Collections.emptyList();
+  }
+
+  @Override
+  public synchronized void start() throws IOException {
+    if (!isStarted.get()) {
+      memoryUpdateCallbackHandler.validateUpdateReceived();
+      // Start the shuffle - copy and merge
+      shuffle = createRssShuffle();
+      shuffle.run();
+      List<Event> pending = new LinkedList<>();
+      pendingEvents.drainTo(pending);
+      if (pending.size() > 0) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug(
+              "NoAutoStart delay in processing first event: "
+                  + (System.currentTimeMillis() - firstEventReceivedTime));
+        }
+        shuffle.handleEvents(pending);
+      }
+      isStarted.set(true);
+    }
+  }
+
+  @VisibleForTesting
+  RMRssShuffle createRssShuffle() throws IOException {
+    return new RMRssShuffle(
+        getContext(), conf, getNumPhysicalInputs(), shuffleId, 
applicationAttemptId);
+  }
+
+  @Override
+  public synchronized List<Event> close() throws IOException {
+    if (this.getNumPhysicalInputs() != 0 && rawIter != null) {
+      rawIter.close();
+    }
+    if (shuffle != null) {
+      shuffle.shutdown();
+    }
+
+    long dataSize =
+        
getContext().getCounters().findCounter(TaskCounter.SHUFFLE_BYTES_DECOMPRESSED).getValue();
+    getContext().getStatisticsReporter().reportDataSize(dataSize);
+    long inputRecords =
+        
getContext().getCounters().findCounter(TaskCounter.REDUCE_INPUT_RECORDS).getValue();
+    getContext().getStatisticsReporter().reportItemsProcessed(inputRecords);
+
+    return Collections.emptyList();
+  }
+
+  /**
+   * Get a KVReader for the Input. This method will block until the input is 
ready - i.e. the copy
+   * and merge stages are complete. Users can use the isInputReady method to 
check if the input is
+   * ready, which gives an indication of whether this method will block or not.
+   *
+   * <p>NOTE: All values for the current K-V pair must be read prior to 
invoking moveToNext. Once
+   * moveToNext() is called, the valueIterator from the previous K-V pair will 
throw an Exception
+   *
+   * @return a KVReader over the sorted input.
+   * @throws {@link IOInterruptedException} if IO was performing a blocking 
operation and was
+   *     interrupted
+   */
+  @Override
+  public KeyValuesReader getReader() throws Exception {
+    synchronized (this) {
+      if (getNumPhysicalInputs() == 0) {
+        return new KeyValuesReader() {
+          @Override
+          public boolean next() throws IOException {
+            getContext().notifyProgress();
+            hasCompletedProcessing();
+            completedProcessing = true;
+            return false;
+          }
+
+          @Override
+          public Object getCurrentKey() throws IOException {
+            throw new RssException("No data available in Input");
+          }
+
+          @Override
+          public Iterable<Object> getCurrentValues() throws IOException {
+            throw new RssException("No data available in Input");
+          }
+        };
+      }
+    }
+    shuffle.waitForEvents();
+    org.apache.uniffle.client.record.reader.KeyValuesReader keyValuesReader =
+        shuffle.getKeyValuesReader();
+    return new KeyValuesReader() {
+      @Override
+      public boolean next() throws IOException {
+        return keyValuesReader.next();
+      }
+
+      @Override
+      public Object getCurrentKey() throws IOException {
+        return keyValuesReader.getCurrentKey();
+      }
+
+      @Override
+      public Iterable<Object> getCurrentValues() throws IOException {
+        return keyValuesReader.getCurrentValues();
+      }
+    };
+  }
+
+  @Override
+  public float getProgress() throws ProgressFailedException, 
InterruptedException {
+    // TODO: add progress
+    return super.getProgress();
+  }
+
+  @Override
+  public void handleEvents(List<Event> inputEvents) throws IOException {
+    RMRssShuffle shuffleLocalRef;
+    synchronized (this) {
+      if (getNumPhysicalInputs() == 0) {
+        throw new RssException("No input events expected as numInputs is 0");
+      }
+      if (!isStarted.get()) {
+        if (firstEventReceivedTime == -1) {
+          firstEventReceivedTime = System.currentTimeMillis();
+        }
+        pendingEvents.addAll(inputEvents);
+        return;
+      }
+      shuffleLocalRef = shuffle;
+    }
+    shuffleLocalRef.handleEvents(inputEvents);
+  }
+
+  private static final Set<String> CONF_KEYS = new HashSet<String>();
+
+  static {
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_READAHEAD);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_READAHEAD_BYTES);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_IO_FILE_BUFFER_SIZE);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_IO_SORT_FACTOR);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMBINE_MIN_SPILLS);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMBINER_CLASS);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_USE_ASYNC_HTTP);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_PARALLEL_COPIES);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES_LIMIT);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_MAX_TASK_OUTPUT_AT_ONCE);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_CONNECT_TIMEOUT);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_KEEP_ALIVE_ENABLED);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_KEEP_ALIVE_MAX_CONNECTIONS);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_READ_TIMEOUT);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_BUFFER_SIZE);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_VERIFY_DISK_CHECKSUM);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_BUFFER_PERCENT);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MEMORY_LIMIT_PERCENT);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MERGE_PERCENT);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MEMTOMEM_SEGMENTS);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_ENABLE_MEMTOMEM);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_SOURCE_ATTEMPT_ABORT_LIMIT);
+    CONF_KEYS.add(
+        
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_ACCEPTABLE_HOST_FETCH_FAILURE_FRACTION);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_FAILURES_PER_HOST);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MAX_STALL_TIME_FRACTION);
+    CONF_KEYS.add(
+        
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MAX_ALLOWED_FAILED_FETCH_ATTEMPT_FRACTION);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_REQUIRED_PROGRESS_FRACTION);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FAILED_CHECK_SINCE_LAST_COMPLETION);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCHER_USE_SHARED_POOL);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_INPUT_POST_MERGE_BUFFER_PERCENT);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_GROUP_COMPARATOR_CLASS);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_COMPARATOR_CLASS);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMPRESS);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMPRESS_CODEC);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_SECONDARY_COMPARATOR_CLASS);
+    CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_OPTIMIZE_LOCAL_FETCH);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_CONVERT_USER_PAYLOAD_TO_HISTORY_TEXT);
+    CONF_KEYS.add(TezConfiguration.TEZ_COUNTERS_MAX);
+    CONF_KEYS.add(TezConfiguration.TEZ_COUNTERS_GROUP_NAME_MAX_LENGTH);
+    CONF_KEYS.add(TezConfiguration.TEZ_COUNTERS_COUNTER_NAME_MAX_LENGTH);
+    CONF_KEYS.add(TezConfiguration.TEZ_COUNTERS_MAX_GROUPS);
+    
CONF_KEYS.add(TezRuntimeConfiguration.TEZ_RUNTIME_CLEANUP_FILES_ON_INTERRUPT);
+    CONF_KEYS.add(Constants.TEZ_RUNTIME_TASK_MEMORY);
+    CONF_KEYS.add(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID);
+  }
+
+  @InterfaceAudience.Private
+  public static Set<String> getConfigurationKeySet() {
+    return Collections.unmodifiableSet(CONF_KEYS);
+  }
+}
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
index 997dcdfa6..d95033c53 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
@@ -30,9 +30,11 @@ import java.util.zip.Deflater;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
+import org.apache.commons.lang.ClassUtils;
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.classification.InterfaceAudience.Public;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.WritableComparator;
 import org.apache.hadoop.ipc.RPC;
 import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.security.Credentials;
@@ -43,6 +45,7 @@ import 
org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.tez.common.GetShuffleServerRequest;
 import org.apache.tez.common.GetShuffleServerResponse;
+import org.apache.tez.common.RssTezConfig;
 import org.apache.tez.common.RssTezUtils;
 import org.apache.tez.common.TezCommonUtils;
 import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
@@ -175,9 +178,29 @@ public class RssOrderedPartitionedKVOutput extends 
AbstractLogicalOutput {
     }
     this.shuffleId =
         RssTezUtils.computeShuffleId(tezDAGID.getId(), sourceVertexId, 
destinationVertexId);
+    String keyClassName = 
conf.get(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS, "");
+    String valueClassName = 
conf.get(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, "");
+    // For hive on tez, we use separate serializer and comparator, namely
+    // TezBytesWritableSerialization and TezBytesComparator. But in remote
+    // merge mode, we use separate serializers, so we should also use
+    // separate comparators.
+    String comparatorClassName =
+        
WritableComparator.get(ClassUtils.getClass(keyClassName)).getClass().getName();
+    boolean remoteMergeEnable =
+        conf.getBoolean(
+            RssTezConfig.RSS_REMOTE_MERGE_ENABLE, 
RssTezConfig.RSS_REMOTE_MERGE_ENABLE_DEFAULT);
     GetShuffleServerRequest request =
-        new GetShuffleServerRequest(
-            this.taskAttemptId, this.mapNum, this.numOutputs, this.shuffleId);
+        remoteMergeEnable
+            ? new GetShuffleServerRequest(
+                this.taskAttemptId,
+                this.mapNum,
+                this.numOutputs,
+                this.shuffleId,
+                keyClassName,
+                valueClassName,
+                comparatorClassName)
+            : new GetShuffleServerRequest(
+                this.taskAttemptId, this.mapNum, this.numOutputs, 
this.shuffleId);
     GetShuffleServerResponse response = 
umbilical.getShuffleAssignments(request);
     this.partitionToServers =
         response
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleTest.java
new file mode 100644
index 000000000..b42b57bda
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleTest.java
@@ -0,0 +1,443 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.lang.reflect.Method;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.zip.Deflater;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.UmbilicalUtils;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.ExecutionContext;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.api.events.DataMovementEvent;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
+import org.apache.tez.runtime.library.common.sort.impl.TezSpillRecord;
+import org.junit.jupiter.api.Test;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.client.api.ShuffleServerClient;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.record.reader.KeyValuesReader;
+import org.apache.uniffle.client.record.reader.MockedShuffleServerClient;
+import org.apache.uniffle.client.record.reader.MockedShuffleWriteClient;
+import org.apache.uniffle.client.record.reader.RMRecordsReader;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.serializer.SerializerUtils;
+import org.apache.uniffle.common.util.BlockIdLayout;
+
+import static 
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
+import static 
org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS;
+import static 
org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_KEY_COMPARATOR_CLASS;
+import static 
org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS;
+import static 
org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anySet;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+public class RMRssShuffleTest {
+
+  private static final int RECORDS_NUM = 1009;
+  private static final ApplicationAttemptId APPLICATION_ATTEMPT_ID =
+      ApplicationAttemptId.newInstance(ApplicationId.newInstance(0, 0), 0);
+  private static final int SHUFFLE_ID = 0;
+  private static final int PARTITION_ID = 0;
+
+  @Test
+  public void testReadShuffleData() throws Exception {
+    // 1 basic parameter
+    final Class keyClass = Text.class;
+    final Class valueClass = IntWritable.class;
+    final Comparator comparator = new Text.Comparator();
+    final Configuration conf = new Configuration();
+    conf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, 0);
+    conf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, 1);
+    final RssConf rssConf = new RssConf();
+    final List<ShuffleServerInfo> serverInfos =
+        Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+    final int taskAttemptId = 0;
+    BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+    final long[] blockIds = new long[] {blockIdLayout.getBlockId(0, 
PARTITION_ID, taskAttemptId)};
+    final int duplicated = 5;
+
+    // 2 mock input context
+    InputContext inputContext = mock(InputContext.class);
+    when(inputContext.getSourceVertexName()).thenReturn("Map 0");
+    TezCounters tezCounters = new TezCounters();
+    when(inputContext.getCounters()).thenReturn(tezCounters);
+    TezTaskAttemptID tezTaskAttemptID =
+        TezTaskAttemptID.getInstance(
+            TezTaskID.getInstance(
+                TezVertexID.getInstance(
+                    
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+                0),
+            0);
+    when(inputContext.getUniqueIdentifier())
+        .thenReturn(String.format("%s_%05d", tezTaskAttemptID.toString(), 0));
+    when(inputContext.getDagIdentifier()).thenReturn(0);
+    
when(inputContext.getApplicationId()).thenReturn(APPLICATION_ATTEMPT_ID.getApplicationId());
+    ExecutionContext executionContext = mock(ExecutionContext.class);
+    when(executionContext.getHostName()).thenReturn("hostname");
+    when(inputContext.getExecutionContext()).thenReturn(executionContext);
+    DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
+    dataOutputBuffer.writeInt(-1);
+    when(inputContext.getServiceProviderMetaData(anyString()))
+        .thenReturn(ByteBuffer.wrap(dataOutputBuffer.getData(), 0, 
dataOutputBuffer.getLength()));
+    Token<JobTokenIdentifier> sessionToken =
+        new Token<JobTokenIdentifier>(
+            new JobTokenIdentifier(new Text("text")), new 
JobTokenSecretManager());
+    ByteBuffer tokenBuffer = TezCommonUtils.serializeServiceData(sessionToken);
+    
doReturn(tokenBuffer).when(inputContext).getServiceConsumerMetaData(anyString());
+    conf.setClass(TEZ_RUNTIME_KEY_CLASS, keyClass, Writable.class);
+    conf.setClass(TEZ_RUNTIME_VALUE_CLASS, valueClass, Writable.class);
+    conf.setClass(TEZ_RUNTIME_KEY_COMPARATOR_CLASS, comparator.getClass(), 
Comparator.class);
+
+    // 3 mock recordsReader
+    RMRecordsReader recordsReader =
+        new RMRecordsReader(
+            APPLICATION_ATTEMPT_ID.toString(),
+            SHUFFLE_ID,
+            Sets.newHashSet(PARTITION_ID),
+            ImmutableMap.of(PARTITION_ID, serverInfos),
+            rssConf,
+            keyClass,
+            valueClass,
+            comparator,
+            true,
+            null,
+            false,
+            null);
+    RMRecordsReader recordsReaderSpy = spy(recordsReader);
+    ByteBuffer[][] buffers =
+        new ByteBuffer[][] {
+          {
+            ByteBuffer.wrap(
+                genSortedRecordBytes(rssConf, keyClass, valueClass, 0, 1, 
RECORDS_NUM, duplicated))
+          }
+        };
+    ShuffleServerClient serverClient =
+        new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, 
blockIds);
+    
doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+    // 4 mock shuffle
+    RMRssShuffle rssShuffle = new RMRssShuffle(inputContext, conf, 5, 0, 
APPLICATION_ATTEMPT_ID);
+    RMRssShuffle rssShuffleSpy = spy(rssShuffle);
+    
doReturn(recordsReaderSpy).when(rssShuffleSpy).createRMRecordsReader(anySet());
+
+    try (MockedStatic<UmbilicalUtils> umbilicalUtils = 
Mockito.mockStatic(UmbilicalUtils.class);
+        MockedStatic<RssTezUtils> tezUtils = 
Mockito.mockStatic(RssTezUtils.class)) {
+      umbilicalUtils
+          .when(() -> UmbilicalUtils.requestShuffleServer(any(), any(), any(), 
anyInt()))
+          .thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos));
+      ShuffleWriteClient writeClient = new MockedShuffleWriteClient();
+      writeClient.reportShuffleResult(
+          ImmutableMap.of(
+              serverInfos.get(0), ImmutableMap.of(PARTITION_ID, 
Sets.newHashSet(blockIds[0]))),
+          APPLICATION_ATTEMPT_ID.toString(),
+          SHUFFLE_ID,
+          taskAttemptId,
+          0);
+      tezUtils.when(() -> 
RssTezUtils.createShuffleClient(any())).thenReturn(writeClient);
+      tezUtils
+          .when(() -> RssTezUtils.fetchAllRssTaskIds(anySet(), anyInt(), 
anyInt(), anyInt()))
+          .thenReturn(Roaring64NavigableMap.bitmapOf(taskAttemptId));
+
+      // 5 run shuffle
+      rssShuffleSpy.run();
+
+      // 6 send and handle dme
+      List<Event> events = new ArrayList<>();
+      for (int i = 0; i < 5; i++) {
+        TezTaskAttemptID taskAttemptID =
+            TezTaskAttemptID.getInstance(
+                TezTaskID.getInstance(
+                    TezVertexID.getInstance(
+                        
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+                    i),
+                0);
+        events.add(
+            createDataMovementEvent(
+                PARTITION_ID, String.format("%s_%05d", 
taskAttemptID.toString(), 1)));
+      }
+      rssShuffleSpy.handleEvents(events);
+      rssShuffleSpy.waitForEvents();
+    }
+
+    // 7 verify result
+    int index = 0;
+    KeyValuesReader keyValuesReader = rssShuffleSpy.getKeyValuesReader();
+    while (keyValuesReader.next()) {
+      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValuesReader.getCurrentKey());
+      Iterator iterator = keyValuesReader.getCurrentValues().iterator();
+      int dup = 0;
+      while (iterator.hasNext()) {
+        assertEquals(SerializerUtils.genData(valueClass, index), 
iterator.next());
+        dup++;
+      }
+      assertEquals(duplicated, dup);
+      index++;
+    }
+    assertEquals(RECORDS_NUM, index);
+  }
+
+  @Test
+  public void testReadMultiPartitionShuffleData() throws Exception {
+    // 1 basic parameter
+    final Class keyClass = Text.class;
+    final Class valueClass = IntWritable.class;
+    final Comparator comparator = new Text.Comparator();
+    final Configuration conf = new Configuration();
+    conf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, 0);
+    conf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, 1);
+    final RssConf rssConf = new RssConf();
+    final List<ShuffleServerInfo> serverInfos =
+        Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+    final int taskAttemptId = 0;
+    BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+    final long[] blockIds =
+        new long[] {
+          blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId),
+          blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId + 1),
+          blockIdLayout.getBlockId(0, PARTITION_ID + 1, taskAttemptId + 2),
+          blockIdLayout.getBlockId(0, PARTITION_ID + 1, taskAttemptId + 3),
+          blockIdLayout.getBlockId(0, PARTITION_ID + 2, taskAttemptId + 4),
+          blockIdLayout.getBlockId(0, PARTITION_ID + 2, taskAttemptId + 5),
+        };
+    final int duplicated = 3;
+
+    // 2 mock input context
+    InputContext inputContext = mock(InputContext.class);
+    when(inputContext.getSourceVertexName()).thenReturn("Map 0");
+    TezCounters tezCounters = new TezCounters();
+    when(inputContext.getCounters()).thenReturn(tezCounters);
+    TezTaskAttemptID tezTaskAttemptID =
+        TezTaskAttemptID.getInstance(
+            TezTaskID.getInstance(
+                TezVertexID.getInstance(
+                    
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+                0),
+            0);
+    when(inputContext.getUniqueIdentifier())
+        .thenReturn(String.format("%s_%05d", tezTaskAttemptID.toString(), 0));
+    when(inputContext.getDagIdentifier()).thenReturn(0);
+    
when(inputContext.getApplicationId()).thenReturn(APPLICATION_ATTEMPT_ID.getApplicationId());
+    ExecutionContext executionContext = mock(ExecutionContext.class);
+    when(executionContext.getHostName()).thenReturn("hostname");
+    when(inputContext.getExecutionContext()).thenReturn(executionContext);
+    DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
+    dataOutputBuffer.writeInt(-1);
+    when(inputContext.getServiceProviderMetaData(anyString()))
+        .thenReturn(ByteBuffer.wrap(dataOutputBuffer.getData(), 0, 
dataOutputBuffer.getLength()));
+    Token<JobTokenIdentifier> sessionToken =
+        new Token<JobTokenIdentifier>(
+            new JobTokenIdentifier(new Text("text")), new 
JobTokenSecretManager());
+    ByteBuffer tokenBuffer = TezCommonUtils.serializeServiceData(sessionToken);
+    
doReturn(tokenBuffer).when(inputContext).getServiceConsumerMetaData(anyString());
+    conf.setClass(TEZ_RUNTIME_KEY_CLASS, keyClass, Writable.class);
+    conf.setClass(TEZ_RUNTIME_VALUE_CLASS, valueClass, Writable.class);
+    conf.setClass(TEZ_RUNTIME_KEY_COMPARATOR_CLASS, comparator.getClass(), 
Comparator.class);
+
+    // 3 mock recordsReader
+    RMRecordsReader recordsReader =
+        new RMRecordsReader(
+            APPLICATION_ATTEMPT_ID.toString(),
+            SHUFFLE_ID,
+            Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2),
+            ImmutableMap.of(
+                PARTITION_ID,
+                serverInfos,
+                PARTITION_ID + 1,
+                serverInfos,
+                PARTITION_ID + 2,
+                serverInfos),
+            rssConf,
+            keyClass,
+            valueClass,
+            comparator,
+            true,
+            null,
+            false,
+            null);
+    RMRecordsReader recordsReaderSpy = spy(recordsReader);
+    ByteBuffer[][] buffers = new ByteBuffer[3][2];
+    for (int i = 0; i < 3; i++) {
+      buffers[i][0] =
+          ByteBuffer.wrap(
+              genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, 
RECORDS_NUM, duplicated));
+      buffers[i][1] =
+          ByteBuffer.wrap(
+              genSortedRecordBytes(
+                  rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, 
RECORDS_NUM, duplicated));
+    }
+    ShuffleServerClient serverClient =
+        new MockedShuffleServerClient(
+            new int[] {PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2}, 
buffers, blockIds);
+    
doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+    // 4 mock shuffle
+    RMRssShuffle rssShuffle = new RMRssShuffle(inputContext, conf, 6, 0, 
APPLICATION_ATTEMPT_ID);
+    RMRssShuffle rssShuffleSpy = spy(rssShuffle);
+    
doReturn(recordsReaderSpy).when(rssShuffleSpy).createRMRecordsReader(anySet());
+    try (MockedStatic<UmbilicalUtils> umbilicalUtils = 
Mockito.mockStatic(UmbilicalUtils.class);
+        MockedStatic<RssTezUtils> tezUtils = 
Mockito.mockStatic(RssTezUtils.class)) {
+      umbilicalUtils
+          .when(() -> UmbilicalUtils.requestShuffleServer(any(), any(), any(), 
anyInt()))
+          .thenReturn(
+              ImmutableMap.of(
+                  PARTITION_ID,
+                  serverInfos,
+                  PARTITION_ID + 1,
+                  serverInfos,
+                  PARTITION_ID + 2,
+                  serverInfos));
+      ShuffleWriteClient writeClient = new MockedShuffleWriteClient();
+      writeClient.reportShuffleResult(
+          ImmutableMap.of(
+              serverInfos.get(0),
+              ImmutableMap.of(
+                  PARTITION_ID,
+                  Sets.newHashSet(blockIds[0], blockIds[1]),
+                  PARTITION_ID + 1,
+                  Sets.newHashSet(blockIds[2], blockIds[3]),
+                  PARTITION_ID + 2,
+                  Sets.newHashSet(blockIds[4], blockIds[5]))),
+          APPLICATION_ATTEMPT_ID.toString(),
+          SHUFFLE_ID,
+          taskAttemptId,
+          0);
+      tezUtils.when(() -> 
RssTezUtils.createShuffleClient(any())).thenReturn(writeClient);
+      tezUtils
+          .when(() -> RssTezUtils.fetchAllRssTaskIds(anySet(), anyInt(), 
anyInt(), anyInt()))
+          .thenReturn(Roaring64NavigableMap.bitmapOf(taskAttemptId));
+
+      // 5 run shuffle
+      rssShuffleSpy.run();
+
+      // 6 send and handle dme
+      List<Event> events = new ArrayList<>();
+      for (int i = 0; i < 6; i++) {
+        TezTaskAttemptID taskAttemptID =
+            TezTaskAttemptID.getInstance(
+                TezTaskID.getInstance(
+                    TezVertexID.getInstance(
+                        
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+                    i),
+                0);
+        events.add(
+            createDataMovementEvent(
+                PARTITION_ID, String.format("%s_%05d", 
taskAttemptID.toString(), 1)));
+      }
+      rssShuffleSpy.handleEvents(events);
+      rssShuffleSpy.waitForEvents();
+    }
+
+    // 7 verify result
+    int index = 0;
+    KeyValuesReader keyValuesReader = rssShuffleSpy.getKeyValuesReader();
+    while (keyValuesReader.next()) {
+      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValuesReader.getCurrentKey());
+      Iterator iterator = keyValuesReader.getCurrentValues().iterator();
+      int dup = 0;
+      while (iterator.hasNext()) {
+        assertEquals(SerializerUtils.genData(valueClass, index), 
iterator.next());
+        dup++;
+      }
+      assertEquals(dup, dup);
+      index++;
+    }
+    assertEquals(RECORDS_NUM * 6, index);
+  }
+
+  public static DataMovementEvent createDataMovementEvent(int partition, 
String path)
+      throws Exception {
+    OutputContext context = mock(OutputContext.class);
+    ExecutionContext executionContext = mock(ExecutionContext.class);
+    when(executionContext.getHostName()).thenReturn("");
+    when(context.getExecutionContext()).thenReturn(executionContext);
+    DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
+    dataOutputBuffer.writeInt(-1);
+    when(context.getServiceProviderMetaData(anyString()))
+        .thenReturn(ByteBuffer.wrap(dataOutputBuffer.getData(), 0, 
dataOutputBuffer.getLength()));
+    Method method =
+        ShuffleUtils.class.getDeclaredMethod(
+            "generateDMEPayload",
+            new Class[] {
+              boolean.class,
+              int.class,
+              TezSpillRecord.class,
+              OutputContext.class,
+              int.class,
+              boolean.class,
+              boolean.class,
+              String.class,
+              String.class,
+              Deflater.class
+            });
+    method.setAccessible(true);
+    ByteBuffer byteBuffer =
+        (ByteBuffer)
+            method.invoke(
+                null,
+                false,
+                0,
+                null,
+                context,
+                -1,
+                true,
+                true,
+                path,
+                "mapreduce_shuffle",
+                TezCommonUtils.newBestCompressionDeflater());
+    return DataMovementEvent.create(partition, byteBuffer);
+  }
+}
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index dbe40fd06..724cec6c0 100644
--- 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -19,7 +19,9 @@ package org.apache.tez.runtime.library.common.sort.buffer;
 
 import java.io.File;
 import java.io.IOException;
+import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -30,10 +32,12 @@ import java.util.stream.Collectors;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
+import io.netty.buffer.ByteBuf;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.DataOutputBuffer;
 import org.apache.hadoop.io.RawComparator;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.WritableComparator;
@@ -44,6 +48,7 @@ import org.apache.tez.common.RssTezUtils;
 import org.apache.tez.common.TezRuntimeFrameworkConfigs;
 import org.apache.tez.common.counters.TaskCounter;
 import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.common.counters.TezCounters;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.runtime.api.OutputContext;
 import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
@@ -66,7 +71,13 @@ import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.records.RecordsReader;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
+import org.apache.uniffle.common.serializer.SerializerFactory;
+import org.apache.uniffle.common.serializer.SerializerInstance;
+import org.apache.uniffle.common.serializer.SerializerUtils;
+import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -75,6 +86,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.mockito.Mockito.mock;
 
 public class WriteBufferManagerTest {
+
+  private static final int RECORDS = 1009;
+
   @Test
   public void testWriteException(@TempDir File tmpDir) throws IOException, 
InterruptedException {
     TezTaskAttemptID tezTaskAttemptID =
@@ -155,7 +169,10 @@ public class WriteBufferManagerTest {
             shuffleId,
             true,
             mapOutputByteCounter,
-            mapOutputRecordCounter);
+            mapOutputRecordCounter,
+            false,
+            null,
+            null);
     partitionToServers.put(1, 
Lists.newArrayList(mock(ShuffleServerInfo.class)));
     Random random = new Random();
     for (int i = 0; i < 1000; i++) {
@@ -259,7 +276,10 @@ public class WriteBufferManagerTest {
             shuffleId,
             true,
             mapOutputByteCounter,
-            mapOutputRecordCounter);
+            mapOutputRecordCounter,
+            false,
+            null,
+            null);
 
     Random random = new Random();
     for (int i = 0; i < 1000; i++) {
@@ -286,8 +306,6 @@ public class WriteBufferManagerTest {
     }
 
     assertEquals(1175900, mapOutputByteCounter.getValue());
-    assert (1 == bufferManager.getWaitSendBuffers().size());
-    assert (4928 == bufferManager.getWaitSendBuffers().get(0).getDataLength());
 
     bufferManager.waitSendFinished();
     assertTrue(bufferManager.getWaitSendBuffers().isEmpty());
@@ -374,10 +392,13 @@ public class WriteBufferManagerTest {
             shuffleId,
             true,
             mapOutputByteCounter,
-            mapOutputRecordCounter);
+            mapOutputRecordCounter,
+            false,
+            null,
+            null);
 
     Random random = new Random();
-    for (int i = 0; i < 10000; i++) {
+    for (int i = 0; i < 1000; i++) {
       byte[] key = new byte[20];
       byte[] value = new byte[1024];
       random.nextBytes(key);
@@ -388,8 +409,8 @@ public class WriteBufferManagerTest {
     }
     bufferManager.waitSendFinished();
 
-    assertEquals(10000, mapOutputRecordCounter.getValue());
-    assertEquals(10520000, mapOutputByteCounter.getValue());
+    assertEquals(1000, mapOutputRecordCounter.getValue());
+    assertEquals(1052000, mapOutputByteCounter.getValue());
     assertTrue(bufferManager.getWaitSendBuffers().isEmpty());
     assertEquals(
         writeClient.mockedShuffleServer.getFinishBlockSize(),
@@ -478,7 +499,10 @@ public class WriteBufferManagerTest {
             shuffleId,
             true,
             mapOutputByteCounter,
-            mapOutputRecordCounter);
+            mapOutputRecordCounter,
+            false,
+            null,
+            null);
 
     Random random = new Random();
     RssException rssException =
@@ -506,6 +530,107 @@ public class WriteBufferManagerTest {
     assertTrue(mapOutputByteCounter.getValue() < 10520000);
   }
 
+  @Test
+  public void testWriteWithRemoteMerge() throws Exception {
+    MockShuffleWriteClient client = new MockShuffleWriteClient();
+    client.setMode(3);
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = 
JavaUtils.newConcurrentMap();
+    partitionToServers.put(0, new ArrayList());
+    partitionToServers.get(0).add(new ShuffleServerInfo("host", 39998));
+    Set<Long> successBlockIds = Sets.newConcurrentHashSet();
+    Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
+    RssConf rssConf = new RssConf();
+    TezCounters counter = new TezCounters();
+    TezCounter mapOutputByteCounter = counter.findCounter("group", "bytes");
+    TezCounter mapOutputRecordCounter = counter.findCounter("group", 
"records");
+    TezTaskAttemptID tezTaskAttemptID =
+        
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
+    final long maxMemSize = 102400L;
+    final String appId = "application_1681717153064_3770270";
+    final int taskAttemptId = 0;
+    long maxSegmentSize = 3 * 1024;
+    long maxBufferSize = 14 * 1024 * 1024;
+    int shuffleId =
+        RssTezUtils.computeShuffleId(
+            tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 
2);
+
+    WriteBufferManager<Text, Text> manager =
+        new WriteBufferManager<Text, Text>(
+            tezTaskAttemptID,
+            maxMemSize,
+            appId,
+            taskAttemptId,
+            successBlockIds,
+            failedBlockIds,
+            client,
+            new Text.Comparator(),
+            maxSegmentSize,
+            null,
+            null,
+            maxBufferSize,
+            0.8f,
+            1,
+            0.2f,
+            50,
+            rssConf,
+            partitionToServers,
+            1,
+            false,
+            500L,
+            5 * 1000,
+            1,
+            shuffleId,
+            true,
+            mapOutputByteCounter,
+            mapOutputRecordCounter,
+            true,
+            Text.class,
+            Text.class);
+
+    List<Integer> indexes = new ArrayList<>();
+    for (int i = 0; i < RECORDS; i++) {
+      indexes.add(i);
+    }
+    Collections.shuffle(indexes);
+    for (Integer index : indexes) {
+      manager.addRecord(
+          0,
+          (Text) SerializerUtils.genData(Text.class, index),
+          (Text) SerializerUtils.genData(Text.class, index + 1));
+    }
+    manager.waitSendFinished();
+    assertEquals(RECORDS, mapOutputRecordCounter.getValue());
+    SerializerFactory factory = new SerializerFactory(rssConf);
+    org.apache.uniffle.common.serializer.Serializer serializer = 
factory.getSerializer(Text.class);
+    SerializerInstance instance = serializer.newInstance();
+    DataOutputBuffer keyBuffer = new DataOutputBuffer();
+    instance.serialize(SerializerUtils.genData(Text.class, 0), keyBuffer);
+    assertEquals(RECORDS * keyBuffer.getLength() * 2, 
mapOutputByteCounter.getValue());
+    assertTrue(manager.getWaitSendBuffers().isEmpty());
+
+    // check blocks
+    List<ShuffleBlockInfo> blockInfos = client.getCachedBlockInfos();
+    assertEquals(1, blockInfos.size());
+    ByteBuf buf = blockInfos.get(0).getData();
+    byte[] bytes = new byte[blockInfos.get(0).getLength()];
+    buf.readBytes(bytes);
+    RecordsReader<Text, Text> reader =
+        new RecordsReader<>(
+            rssConf,
+            PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes)),
+            Text.class,
+            Text.class,
+            false);
+    int index = 0;
+    while (reader.next()) {
+      assertEquals(SerializerUtils.genData(Text.class, index), 
reader.getCurrentKey());
+      assertEquals(SerializerUtils.genData(Text.class, index + 1), 
reader.getCurrentValue());
+      index++;
+    }
+    reader.close();
+    assertEquals(RECORDS, index);
+  }
+
   class MockShuffleServer {
     private List<ShuffleBlockInfo> cachedBlockInfos = new ArrayList<>();
     private List<ShuffleBlockInfo> flushBlockInfos = new ArrayList<>();
@@ -530,6 +655,10 @@ public class WriteBufferManagerTest {
     public synchronized int getFinishBlockSize() {
       return finishedBlockInfos.size();
     }
+
+    public List<ShuffleBlockInfo> getCachedBlockInfos() {
+      return cachedBlockInfos;
+    }
   }
 
   class MockShuffleWriteClient implements ShuffleWriteClient {
@@ -697,5 +826,9 @@ public class WriteBufferManagerTest {
         int shuffleId,
         int partitionId,
         Roaring64NavigableMap expectedTaskIds) {}
+
+    public List<ShuffleBlockInfo> getCachedBlockInfos() {
+      return mockedShuffleServer.getCachedBlockInfos();
+    }
   }
 }
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
index dc3ebf707..73d399062 100644
--- 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
@@ -20,11 +20,17 @@ package org.apache.tez.runtime.library.common.sort.buffer;
 import java.io.ByteArrayInputStream;
 import java.io.DataInputStream;
 import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 import java.util.Random;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.WritableComparator;
 import org.apache.hadoop.io.WritableUtils;
 import org.apache.hadoop.io.serializer.Deserializer;
@@ -33,11 +39,23 @@ import org.apache.hadoop.io.serializer.Serializer;
 import org.apache.hadoop.mapred.JobConf;
 import org.junit.jupiter.api.Test;
 
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.serializer.DeserializationStream;
+import org.apache.uniffle.common.serializer.PartialInputStream;
+import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
+import org.apache.uniffle.common.serializer.SerializerFactory;
+import org.apache.uniffle.common.serializer.SerializerInstance;
+
 import static com.google.common.collect.Maps.newConcurrentMap;
+import static org.apache.uniffle.common.serializer.SerializerUtils.genData;
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class WriteBufferTest {
 
+  private static final int RECORDS_NUM = 1009;
+
   @Test
   public void testReadWrite() throws IOException {
 
@@ -57,8 +75,10 @@ public class WriteBufferTest {
             1,
             WritableComparator.get(BytesWritable.class),
             1024L,
+            false,
             keySerializer,
-            valSerializer);
+            valSerializer,
+            null);
 
     long recordLength = buffer.addRecord(key, value);
     assertEquals(20, buffer.getData().length);
@@ -88,8 +108,10 @@ public class WriteBufferTest {
             1,
             WritableComparator.get(BytesWritable.class),
             528L,
+            false,
             keySerializer,
-            valSerializer);
+            valSerializer,
+            null);
     long start = buffer.getDataLength();
     assertEquals(0, start);
     keyStr = "key3";
@@ -161,6 +183,75 @@ public class WriteBufferTest {
     assertEquals(bigWritableValue, valueRead);
   }
 
+  @Test
+  public void testReadWriteWithRemoteMergeAndNoSort() throws IOException {
+    RssConf rssConf = new RssConf();
+    SerializerFactory factory = new SerializerFactory(rssConf);
+    org.apache.uniffle.common.serializer.Serializer serializer = 
factory.getSerializer(Text.class);
+    SerializerInstance instance = serializer.newInstance();
+    WriteBuffer buffer =
+        new WriteBuffer<BytesWritable, BytesWritable>(
+            false,
+            1,
+            WritableComparator.get(BytesWritable.class),
+            1024L,
+            true,
+            null,
+            null,
+            instance);
+    for (int i = 0; i < RECORDS_NUM; i++) {
+      buffer.addRecord(genData(Text.class, i), genData(IntWritable.class, i));
+    }
+    byte[] bytes = buffer.getData();
+    PartialInputStream inputStream = 
PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes));
+    DeserializationStream dStream =
+        instance.deserializeStream(inputStream, Text.class, IntWritable.class, 
false);
+    for (int i = 0; i < RECORDS_NUM; i++) {
+      assertTrue(dStream.nextRecord());
+      assertEquals(genData(Text.class, i), dStream.getCurrentKey());
+      assertEquals(genData(IntWritable.class, i), dStream.getCurrentValue());
+    }
+    assertFalse(dStream.nextRecord());
+    dStream.close();
+  }
+
+  @Test
+  public void testReadWriteWithRemoteMergeAndSort() throws IOException {
+    RssConf rssConf = new RssConf();
+    SerializerFactory factory = new SerializerFactory(rssConf);
+    org.apache.uniffle.common.serializer.Serializer serializer = 
factory.getSerializer(Text.class);
+    SerializerInstance instance = serializer.newInstance();
+    WriteBuffer buffer =
+        new WriteBuffer<BytesWritable, BytesWritable>(
+            false,
+            1,
+            WritableComparator.get(BytesWritable.class),
+            1024L,
+            true,
+            null,
+            null,
+            instance);
+    List<Integer> indices = new ArrayList<>();
+    for (int i = 0; i < RECORDS_NUM; i++) {
+      indices.add(i);
+    }
+    Collections.shuffle(indices);
+    for (int i = 0; i < RECORDS_NUM; i++) {
+      buffer.addRecord(genData(Text.class, i), genData(IntWritable.class, i));
+    }
+    byte[] bytes = buffer.getData();
+    PartialInputStream inputStream = 
PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes));
+    DeserializationStream dStream =
+        instance.deserializeStream(inputStream, Text.class, IntWritable.class, 
false);
+    for (int i = 0; i < RECORDS_NUM; i++) {
+      assertTrue(dStream.nextRecord());
+      assertEquals(genData(Text.class, i), dStream.getCurrentKey());
+      assertEquals(genData(IntWritable.class, i), dStream.getCurrentValue());
+    }
+    assertFalse(dStream.nextRecord());
+    dStream.close();
+  }
+
   int readInt(DataInputStream dStream) throws IOException {
     return WritableUtils.readVInt(dStream);
   }
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInputTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInputTest.java
new file mode 100644
index 000000000..24157c883
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInputTest.java
@@ -0,0 +1,481 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.input;
+
+import java.io.ByteArrayOutputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.UmbilicalUtils;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.ExecutionContext;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.api.MemoryUpdateCallback;
+import org.apache.tez.runtime.api.events.DataMovementEvent;
+import org.apache.tez.runtime.library.api.KeyValuesReader;
+import org.apache.tez.runtime.library.common.MemoryUpdateCallbackHandler;
+import 
org.apache.tez.runtime.library.common.shuffle.orderedgrouped.RMRssShuffle;
+import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;
+import org.junit.jupiter.api.Test;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.client.api.ShuffleServerClient;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.record.reader.MockedShuffleServerClient;
+import org.apache.uniffle.client.record.reader.MockedShuffleWriteClient;
+import org.apache.uniffle.client.record.reader.RMRecordsReader;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.merger.Merger;
+import org.apache.uniffle.common.merger.Segment;
+import org.apache.uniffle.common.serializer.SerializerUtils;
+import org.apache.uniffle.common.util.BlockIdLayout;
+
+import static 
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
+import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
+import static 
org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.anySet;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+public class RMRssOrderedGroupedKVInputTest {
+
+  private static final int RECORDS_NUM = 1009;
+  private static final ApplicationAttemptId APPLICATION_ATTEMPT_ID =
+      ApplicationAttemptId.newInstance(ApplicationId.newInstance(0, 0), 0);
+  private static final int SHUFFLE_ID = 0;
+  private static final int PARTITION_ID = 0;
+
+  @Test
+  public void testRMRssOrderedGroupedKVInput() throws Exception {
+    // 1 basic parameter
+    final Class keyClass = Text.class;
+    final Class valueClass = IntWritable.class;
+    final Comparator comparator = new Text.Comparator();
+    final Configuration conf = new Configuration();
+    conf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, 0);
+    conf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, 1);
+    final RssConf rssConf = new RssConf();
+    final List<ShuffleServerInfo> serverInfos =
+        Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+    final int taskAttemptId = 0;
+    BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+    final long[] blockIds = new long[] {blockIdLayout.getBlockId(0, 
PARTITION_ID, taskAttemptId)};
+
+    // 2 mock input context
+    InputContext inputContext = mock(InputContext.class);
+    when(inputContext.getSourceVertexName()).thenReturn("Map 0");
+    TezCounters tezCounters = new TezCounters();
+    when(inputContext.getCounters()).thenReturn(tezCounters);
+    TezTaskAttemptID tezTaskAttemptID =
+        TezTaskAttemptID.getInstance(
+            TezTaskID.getInstance(
+                TezVertexID.getInstance(
+                    
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+                0),
+            0);
+    when(inputContext.getUniqueIdentifier())
+        .thenReturn(String.format("%s_%05d", tezTaskAttemptID.toString(), 0));
+    
when(inputContext.getUserPayload()).thenReturn(TezUtils.createUserPayloadFromConf(conf));
+    when(inputContext.getWorkDirs()).thenReturn(new String[] {"/dummy"});
+    doAnswer(
+            new Answer<Void>() {
+              @Override
+              public Void answer(InvocationOnMock invocation) throws Throwable 
{
+                long requestedSize = (Long) invocation.getArguments()[0];
+                MemoryUpdateCallbackHandler callback =
+                    (MemoryUpdateCallbackHandler) invocation.getArguments()[1];
+                callback.memoryAssigned(requestedSize);
+                return null;
+              }
+            })
+        .when(inputContext)
+        .requestInitialMemory(anyLong(), any(MemoryUpdateCallback.class));
+    when(inputContext.getDagIdentifier()).thenReturn(0);
+    
when(inputContext.getApplicationId()).thenReturn(APPLICATION_ATTEMPT_ID.getApplicationId());
+    ExecutionContext executionContext = mock(ExecutionContext.class);
+    when(executionContext.getHostName()).thenReturn("hostname");
+    when(inputContext.getExecutionContext()).thenReturn(executionContext);
+    DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
+    dataOutputBuffer.writeInt(-1);
+    when(inputContext.getServiceProviderMetaData(anyString()))
+        .thenReturn(ByteBuffer.wrap(dataOutputBuffer.getData(), 0, 
dataOutputBuffer.getLength()));
+    Token<JobTokenIdentifier> sessionToken =
+        new Token<JobTokenIdentifier>(
+            new JobTokenIdentifier(new Text("text")), new 
JobTokenSecretManager());
+    ByteBuffer tokenBuffer = TezCommonUtils.serializeServiceData(sessionToken);
+    
doReturn(tokenBuffer).when(inputContext).getServiceConsumerMetaData(anyString());
+
+    // 3 mock recordsReader
+    List<Segment> segments = new ArrayList<>();
+    segments.add(
+        SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 0L, 0, 
2, RECORDS_NUM));
+    segments.add(
+        SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 1L, 0, 
2, RECORDS_NUM));
+    segments.add(
+        SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 2L, 1, 
2, RECORDS_NUM));
+    ByteArrayOutputStream output = new ByteArrayOutputStream();
+    Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator, 
false);
+    output.close();
+    ByteBuffer[][] buffers = new ByteBuffer[][] 
{{ByteBuffer.wrap(output.toByteArray())}};
+    ShuffleServerClient serverClient =
+        new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, 
blockIds);
+    RMRecordsReader recordsReader =
+        new RMRecordsReader(
+            APPLICATION_ATTEMPT_ID.toString(),
+            SHUFFLE_ID,
+            Sets.newHashSet(PARTITION_ID),
+            ImmutableMap.of(PARTITION_ID, serverInfos),
+            rssConf,
+            keyClass,
+            valueClass,
+            comparator,
+            true,
+            null,
+            false,
+            null);
+    RMRecordsReader recordsReaderSpy = spy(recordsReader);
+    
doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+    // 4 mock shuffle
+    RMRssShuffle rssShuffle = new RMRssShuffle(inputContext, conf, 5, 0, 
APPLICATION_ATTEMPT_ID);
+    RMRssShuffle rssShuffleSpy = spy(rssShuffle);
+    
doReturn(recordsReaderSpy).when(rssShuffleSpy).createRMRecordsReader(anySet());
+    // rssShuffleSpy.setEventHandler(new RMShuffleEventHandler(rssShuffleSpy));
+
+    try (MockedStatic<UmbilicalUtils> umbilicalUtils = 
Mockito.mockStatic(UmbilicalUtils.class);
+        MockedStatic<RssTezUtils> tezUtils = 
Mockito.mockStatic(RssTezUtils.class)) {
+      umbilicalUtils
+          .when(() -> UmbilicalUtils.requestShuffleServer(any(), any(), any(), 
anyInt()))
+          .thenReturn(ImmutableMap.of(PARTITION_ID, serverInfos));
+      ShuffleWriteClient writeClient = new MockedShuffleWriteClient();
+
+      Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
serverToPartitionToBlockIds =
+          ImmutableMap.of(
+              serverInfos.get(0), ImmutableMap.of(PARTITION_ID, 
Sets.newHashSet(blockIds[0])));
+      writeClient.reportShuffleResult(
+          serverToPartitionToBlockIds,
+          APPLICATION_ATTEMPT_ID.toString(),
+          SHUFFLE_ID,
+          taskAttemptId,
+          0);
+      tezUtils.when(() -> 
RssTezUtils.createShuffleClient(any())).thenReturn(writeClient);
+      tezUtils
+          .when(() -> RssTezUtils.fetchAllRssTaskIds(anySet(), anyInt(), 
anyInt(), anyInt()))
+          .thenReturn(Roaring64NavigableMap.bitmapOf(taskAttemptId));
+      tezUtils
+          .when(() -> RssTezUtils.uniqueIdentifierToAttemptId(anyString()))
+          .thenReturn(tezTaskAttemptID.toString());
+
+      // 5 init and start kv input
+      RMRssOrderedGroupedKVInput input = new 
RMRssOrderedGroupedKVInput(inputContext, 5);
+      RMRssOrderedGroupedKVInput inputSpy = spy(input);
+      doReturn(rssShuffleSpy).when(inputSpy).createRssShuffle();
+      inputSpy.initialize();
+      List<Event> events = new ArrayList<>();
+      for (int i = 0; i < 2; i++) {
+        TezTaskAttemptID taskAttemptID =
+            TezTaskAttemptID.getInstance(
+                TezTaskID.getInstance(
+                    TezVertexID.getInstance(
+                        
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+                    i),
+                0);
+        events.add(
+            createDataMovementEvent(
+                PARTITION_ID, String.format("%s_%05d", 
taskAttemptID.toString(), 1)));
+      }
+      inputSpy.handleEvents(events);
+      inputSpy.start();
+      events.clear();
+      for (int i = 2; i < 5; i++) {
+        TezTaskAttemptID taskAttemptID =
+            TezTaskAttemptID.getInstance(
+                TezTaskID.getInstance(
+                    TezVertexID.getInstance(
+                        
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+                    i),
+                0);
+        events.add(
+            createDataMovementEvent(
+                PARTITION_ID, String.format("%s_%05d", 
taskAttemptID.toString(), 1)));
+      }
+      inputSpy.handleEvents(events);
+
+      // 6 verify result
+      KeyValuesReader reader = inputSpy.getReader();
+      int index = 0;
+      while (reader.next()) {
+        assertEquals(SerializerUtils.genData(keyClass, index), 
reader.getCurrentKey());
+        Iterator iterator = reader.getCurrentValues().iterator();
+        int i = 0;
+        while (iterator.hasNext()) {
+          assertEquals(SerializerUtils.genData(valueClass, index), 
iterator.next());
+          i++;
+        }
+        assertEquals(index % 2 == 0 ? 2 : 1, i);
+        index++;
+      }
+      assertEquals(RECORDS_NUM * 2, index);
+    }
+  }
+
+  @Test
+  public void testRMRssOrderedGroupedKVInputMulitPartition() throws Exception {
+    // 1 basic parameter
+    final Class keyClass = Text.class;
+    final Class valueClass = IntWritable.class;
+    final Comparator comparator = new Text.Comparator();
+    final Configuration conf = new Configuration();
+    conf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, 0);
+    conf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, 1);
+    final RssConf rssConf = new RssConf();
+    final List<ShuffleServerInfo> serverInfos =
+        Lists.newArrayList(new ShuffleServerInfo("dummy", -1));
+    final int taskAttemptId = 0;
+    BlockIdLayout blockIdLayout = BlockIdLayout.from(rssConf);
+    final long[] blockIds =
+        new long[] {
+          blockIdLayout.getBlockId(0, PARTITION_ID, taskAttemptId),
+          blockIdLayout.getBlockId(1, PARTITION_ID, taskAttemptId),
+          blockIdLayout.getBlockId(0, PARTITION_ID + 1, taskAttemptId),
+          blockIdLayout.getBlockId(1, PARTITION_ID + 1, taskAttemptId),
+          blockIdLayout.getBlockId(0, PARTITION_ID + 2, taskAttemptId),
+          blockIdLayout.getBlockId(1, PARTITION_ID + 2, taskAttemptId)
+        };
+
+    // 2 mock input context
+    InputContext inputContext = mock(InputContext.class);
+    when(inputContext.getSourceVertexName()).thenReturn("Map 0");
+    TezCounters tezCounters = new TezCounters();
+    when(inputContext.getCounters()).thenReturn(tezCounters);
+    TezTaskAttemptID tezTaskAttemptID =
+        TezTaskAttemptID.getInstance(
+            TezTaskID.getInstance(
+                TezVertexID.getInstance(
+                    
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+                0),
+            0);
+    when(inputContext.getUniqueIdentifier())
+        .thenReturn(String.format("%s_%05d", tezTaskAttemptID.toString(), 0));
+    
when(inputContext.getUserPayload()).thenReturn(TezUtils.createUserPayloadFromConf(conf));
+    when(inputContext.getWorkDirs()).thenReturn(new String[] {"/dummy"});
+    doAnswer(
+            new Answer<Void>() {
+              @Override
+              public Void answer(InvocationOnMock invocation) throws Throwable 
{
+                long requestedSize = (Long) invocation.getArguments()[0];
+                MemoryUpdateCallbackHandler callback =
+                    (MemoryUpdateCallbackHandler) invocation.getArguments()[1];
+                callback.memoryAssigned(requestedSize);
+                return null;
+              }
+            })
+        .when(inputContext)
+        .requestInitialMemory(anyLong(), any(MemoryUpdateCallback.class));
+    when(inputContext.getDagIdentifier()).thenReturn(0);
+    
when(inputContext.getApplicationId()).thenReturn(APPLICATION_ATTEMPT_ID.getApplicationId());
+    ExecutionContext executionContext = mock(ExecutionContext.class);
+    when(executionContext.getHostName()).thenReturn("hostname");
+    when(inputContext.getExecutionContext()).thenReturn(executionContext);
+    DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
+    dataOutputBuffer.writeInt(-1);
+    when(inputContext.getServiceProviderMetaData(anyString()))
+        .thenReturn(ByteBuffer.wrap(dataOutputBuffer.getData(), 0, 
dataOutputBuffer.getLength()));
+    Token<JobTokenIdentifier> sessionToken =
+        new Token<JobTokenIdentifier>(
+            new JobTokenIdentifier(new Text("text")), new 
JobTokenSecretManager());
+    ByteBuffer tokenBuffer = TezCommonUtils.serializeServiceData(sessionToken);
+    
doReturn(tokenBuffer).when(inputContext).getServiceConsumerMetaData(anyString());
+
+    // 3 mock recordsReader
+    RMRecordsReader recordsReader =
+        new RMRecordsReader(
+            APPLICATION_ATTEMPT_ID.toString(),
+            SHUFFLE_ID,
+            Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2),
+            ImmutableMap.of(
+                PARTITION_ID,
+                serverInfos,
+                PARTITION_ID + 1,
+                serverInfos,
+                PARTITION_ID + 2,
+                serverInfos),
+            rssConf,
+            keyClass,
+            valueClass,
+            comparator,
+            true,
+            null,
+            false,
+            null);
+    RMRecordsReader recordsReaderSpy = spy(recordsReader);
+    ByteBuffer[][] buffers = new ByteBuffer[3][2];
+    for (int i = 0; i < 3; i++) {
+      buffers[i][0] =
+          ByteBuffer.wrap(
+              genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, 
RECORDS_NUM, 1));
+      buffers[i][1] =
+          ByteBuffer.wrap(
+              genSortedRecordBytes(
+                  rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, 
RECORDS_NUM, 1));
+    }
+    ShuffleServerClient serverClient =
+        new MockedShuffleServerClient(
+            new int[] {PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2}, 
buffers, blockIds);
+    
doReturn(serverClient).when(recordsReaderSpy).createShuffleServerClient(any());
+
+    // 4 mock shuffle
+    RMRssShuffle rssShuffle = new RMRssShuffle(inputContext, conf, 5, 0, 
APPLICATION_ATTEMPT_ID);
+    RMRssShuffle rssShuffleSpy = spy(rssShuffle);
+    
doReturn(recordsReaderSpy).when(rssShuffleSpy).createRMRecordsReader(anySet());
+    // rssShuffleSpy.setEventHandler(new RMShuffleEventHandler(rssShuffleSpy));
+
+    try (MockedStatic<UmbilicalUtils> umbilicalUtils = 
Mockito.mockStatic(UmbilicalUtils.class);
+        MockedStatic<RssTezUtils> tezUtils = 
Mockito.mockStatic(RssTezUtils.class)) {
+      umbilicalUtils
+          .when(() -> UmbilicalUtils.requestShuffleServer(any(), any(), any(), 
anyInt()))
+          .thenReturn(
+              ImmutableMap.of(
+                  PARTITION_ID,
+                  serverInfos,
+                  PARTITION_ID + 1,
+                  serverInfos,
+                  PARTITION_ID + 2,
+                  serverInfos));
+      ShuffleWriteClient writeClient = new MockedShuffleWriteClient();
+      Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
serverToPartitionToBlockIds =
+          ImmutableMap.of(
+              serverInfos.get(0),
+              ImmutableMap.of(
+                  PARTITION_ID,
+                  Sets.newHashSet(blockIds[0], blockIds[1]),
+                  PARTITION_ID + 1,
+                  Sets.newHashSet(blockIds[2], blockIds[3]),
+                  PARTITION_ID + 2,
+                  Sets.newHashSet(blockIds[4], blockIds[5])));
+      writeClient.reportShuffleResult(
+          serverToPartitionToBlockIds,
+          APPLICATION_ATTEMPT_ID.toString(),
+          SHUFFLE_ID,
+          taskAttemptId,
+          0);
+      tezUtils.when(() -> 
RssTezUtils.createShuffleClient(any())).thenReturn(writeClient);
+      tezUtils
+          .when(() -> RssTezUtils.fetchAllRssTaskIds(anySet(), anyInt(), 
anyInt(), anyInt()))
+          .thenReturn(Roaring64NavigableMap.bitmapOf(taskAttemptId));
+      tezUtils
+          .when(() -> RssTezUtils.uniqueIdentifierToAttemptId(anyString()))
+          .thenReturn(tezTaskAttemptID.toString());
+
+      // 5 init and start kv input
+      RMRssOrderedGroupedKVInput input = new 
RMRssOrderedGroupedKVInput(inputContext, 5);
+      RMRssOrderedGroupedKVInput inputSpy = spy(input);
+      doReturn(rssShuffleSpy).when(inputSpy).createRssShuffle();
+      inputSpy.initialize();
+      List<Event> events = new ArrayList<>();
+      for (int i = 0; i < 2; i++) {
+        TezTaskAttemptID taskAttemptID =
+            TezTaskAttemptID.getInstance(
+                TezTaskID.getInstance(
+                    TezVertexID.getInstance(
+                        
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+                    i),
+                0);
+        events.add(
+            createDataMovementEvent(i / 2, String.format("%s_%05d", 
taskAttemptID.toString(), 1)));
+      }
+      inputSpy.handleEvents(events);
+      inputSpy.start();
+      events.clear();
+      for (int i = 2; i < 5; i++) {
+        TezTaskAttemptID taskAttemptID =
+            TezTaskAttemptID.getInstance(
+                TezTaskID.getInstance(
+                    TezVertexID.getInstance(
+                        
TezDAGID.getInstance(APPLICATION_ATTEMPT_ID.getApplicationId(), 0), 0),
+                    i),
+                0);
+        events.add(
+            createDataMovementEvent(i / 2, String.format("%s_%05d", 
taskAttemptID.toString(), 1)));
+      }
+      inputSpy.handleEvents(events);
+
+      // 6 verify result
+      KeyValuesReader reader = inputSpy.getReader();
+      int index = 0;
+      while (reader.next()) {
+        assertEquals(SerializerUtils.genData(keyClass, index), 
reader.getCurrentKey());
+        Iterator iterator = reader.getCurrentValues().iterator();
+        int i = 0;
+        while (iterator.hasNext()) {
+          assertEquals(SerializerUtils.genData(valueClass, index), 
iterator.next());
+          i++;
+        }
+        assertEquals(1, i);
+        index++;
+      }
+      assertEquals(RECORDS_NUM * 6, index);
+    }
+  }
+
+  private static DataMovementEvent createDataMovementEvent(int partition, 
String path) {
+    ShuffleUserPayloads.DataMovementEventPayloadProto proto =
+        ShuffleUserPayloads.DataMovementEventPayloadProto.newBuilder()
+            .setPathComponent(path)
+            .build();
+    return DataMovementEvent.create(partition, 
proto.toByteString().asReadOnlyByteBuffer());
+  }
+}
diff --git 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/RMTezOrderedWordCountTest.java
similarity index 63%
copy from 
integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
copy to 
integration-test/tez/src/test/java/org/apache/uniffle/test/RMTezOrderedWordCountTest.java
index 4bd2f7e8a..eb5ee51dc 100644
--- 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
+++ 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/RMTezOrderedWordCountTest.java
@@ -28,23 +28,59 @@ import com.google.common.collect.Lists;
 import org.apache.hadoop.fs.FSDataOutputStream;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.util.Tool;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.examples.OrderedWordCount;
+import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 
-public class TezOrderedWordCountTest extends TezIntegrationTestBase {
+import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.server.ShuffleServerConf;
 
-  private String inputPath = "ordered_word_count_input";
-  private String outputPath = "ordered_word_count_output";
+import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_ENABLE;
+
+public class RMTezOrderedWordCountTest extends TezIntegrationTestBase {
+
+  private String inputPath = "rm_ordered_word_count_input";
+  private String outputPath = "rm_ordered_word_count_output";
   private List<String> wordTable =
       Lists.newArrayList(
           "apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan", 
"tomato");
 
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    ShuffleServerConf serverConf = new ShuffleServerConf();
+    serverConf.set(SERVER_MERGE_ENABLE, true);
+    TezIntegrationTestBase.setupServers(serverConf);
+  }
+
   @Test
   public void orderedWordCountTest() throws Exception {
     generateInputFile();
     run();
   }
 
+  public void run() throws Exception {
+    // 1 Run original Tez examples
+    TezConfiguration appConf = new 
TezConfiguration(miniTezCluster.getConfig());
+    updateCommonConfiguration(appConf);
+    runTezApp(appConf, getTestTool(), getTestArgs("origin"));
+    final String originPath = getOutputDir("origin");
+
+    // Run RSS tests with different configurations
+    runRemoteMergeRssTest(ClientType.GRPC, "rss-grpc", originPath);
+  }
+
+  private void runRemoteMergeRssTest(ClientType clientType, String testName, 
String originPath)
+      throws Exception {
+    TezConfiguration appConf = new 
TezConfiguration(miniTezCluster.getConfig());
+    appConf.set(RssTezConfig.RSS_REMOTE_MERGE_ENABLE, "true");
+    updateRssConfiguration(appConf, clientType);
+    appendAndUploadRssJars(appConf);
+    runTezApp(appConf, getTestTool(), getTestArgs(testName));
+    verifyResults(originPath, getOutputDir(testName));
+  }
+
   private void generateInputFile() throws Exception {
     // For ordered word count, the key of last ordered sorter is the summation 
of word, the value is
     // the word. So it means this key may not be unique. Because Sorter can 
only make sure key is
diff --git 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezCartesianProductTest.java
 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezCartesianProductTest.java
index 7e8b2a031..6edb99cd4 100644
--- 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezCartesianProductTest.java
+++ 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezCartesianProductTest.java
@@ -22,6 +22,7 @@ import org.apache.hadoop.fs.FSDataOutputStream;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.util.Tool;
 import org.apache.tez.examples.CartesianProduct;
+import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 
 public class TezCartesianProductTest extends TezIntegrationTestBase {
@@ -31,6 +32,11 @@ public class TezCartesianProductTest extends 
TezIntegrationTestBase {
   private String inputPath3 = "cartesian_product_input3";
   private String outputPath = "cartesian_product_output";
 
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    TezIntegrationTestBase.setupServers(null);
+  }
+
   @Test
   public void cartesianProductTest() throws Exception {
     generateInputFile();
diff --git 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
index cdaffffb1..d077e3608 100644
--- 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
+++ 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
@@ -75,6 +75,9 @@ public class TezIntegrationTestBase extends 
IntegrationTestBase {
       miniTezCluster.init(conf);
       miniTezCluster.start();
     }
+  }
+
+  protected static void setupServers(ShuffleServerConf serverConf) throws 
Exception {
     LOG.info("Starting coordinators and shuffle servers");
     CoordinatorConf coordinatorConf = getCoordinatorConf();
     Map<String, String> dynamicConf = new HashMap<>();
@@ -83,8 +86,14 @@ public class TezIntegrationTestBase extends 
IntegrationTestBase {
     addDynamicConf(coordinatorConf, dynamicConf);
     createCoordinatorServer(coordinatorConf);
     ShuffleServerConf grpcShuffleServerConf = 
getShuffleServerConf(ServerType.GRPC);
+    if (serverConf != null) {
+      grpcShuffleServerConf.addAll(serverConf);
+    }
     createShuffleServer(grpcShuffleServerConf);
     ShuffleServerConf nettyShuffleServerConf = 
getShuffleServerConf(ServerType.GRPC_NETTY);
+    if (serverConf != null) {
+      nettyShuffleServerConf.addAll(serverConf);
+    }
     createShuffleServer(nettyShuffleServerConf);
     startServers();
   }
diff --git 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezJoinIntegrationTestBase.java
 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezJoinIntegrationTestBase.java
index 641e40d34..6d87237e5 100644
--- 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezJoinIntegrationTestBase.java
+++ 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezJoinIntegrationTestBase.java
@@ -22,6 +22,7 @@ import org.apache.hadoop.util.ToolRunner;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.examples.JoinDataGen;
 import org.apache.tez.examples.JoinValidate;
+import org.junit.jupiter.api.BeforeAll;
 
 import org.apache.uniffle.common.ClientType;
 
@@ -36,6 +37,11 @@ public class TezJoinIntegrationTestBase extends 
TezIntegrationTestBase {
   protected static final String JOIN_EXPECTED_PATH = "join_expected";
   protected static final String NUM_TASKS = "2";
 
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    TezIntegrationTestBase.setupServers(null);
+  }
+
   protected void generateInputFile() throws Exception {
     fs.delete(new Path(STREAM_INPUT_PATH), true);
     fs.delete(new Path(HASH_INPUT_PATH), true);
diff --git 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
index 4bd2f7e8a..820995b4b 100644
--- 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
+++ 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezOrderedWordCountTest.java
@@ -29,6 +29,7 @@ import org.apache.hadoop.fs.FSDataOutputStream;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.util.Tool;
 import org.apache.tez.examples.OrderedWordCount;
+import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 
 public class TezOrderedWordCountTest extends TezIntegrationTestBase {
@@ -39,6 +40,11 @@ public class TezOrderedWordCountTest extends 
TezIntegrationTestBase {
       Lists.newArrayList(
           "apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan", 
"tomato");
 
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    TezIntegrationTestBase.setupServers(null);
+  }
+
   @Test
   public void orderedWordCountTest() throws Exception {
     generateInputFile();
diff --git 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezSimpleSessionExampleTest.java
 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezSimpleSessionExampleTest.java
index 788465d61..54da4e25a 100644
--- 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezSimpleSessionExampleTest.java
+++ 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezSimpleSessionExampleTest.java
@@ -31,6 +31,7 @@ import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.util.Tool;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.examples.SimpleSessionExample;
+import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 
 import org.apache.uniffle.common.ClientType;
@@ -44,6 +45,11 @@ public class TezSimpleSessionExampleTest extends 
TezIntegrationTestBase {
       Lists.newArrayList(
           "apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan", 
"tomato");
 
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    TezIntegrationTestBase.setupServers(null);
+  }
+
   @Test
   public void simpleSessionExampleTest() throws Exception {
     generateInputFile();
diff --git 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountTest.java
 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountTest.java
index e2e9b564b..a99bf9bc0 100644
--- 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountTest.java
+++ 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountTest.java
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.FSDataOutputStream;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.util.Tool;
 import org.apache.tez.examples.WordCount;
+import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 
 public class TezWordCountTest extends TezIntegrationTestBase {
@@ -35,6 +36,11 @@ public class TezWordCountTest extends TezIntegrationTestBase 
{
       Lists.newArrayList(
           "apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan", 
"tomato");
 
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    TezIntegrationTestBase.setupServers(null);
+  }
+
   @Test
   public void wordCountTest() throws Exception {
     generateInputFile();

Reply via email to