This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new e9342ccd [#855] feat(tez): Support Tez Output 
OrderedPartitionedKVOutput (#930)
e9342ccd is described below

commit e9342ccddb364916dbd72f5cc5e4c645465a5dc4
Author: bin41215 <[email protected]>
AuthorDate: Sat Jun 10 10:52:30 2023 +0800

    [#855] feat(tez): Support Tez Output OrderedPartitionedKVOutput (#930)
    
    ### What changes were proposed in this pull request?
    
    support tez write OrderedPartitionedKVOutput
    
    ### Why are the changes needed?
    
    Fix:#855
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    UT
    
    Co-authored-by: bin3.zhang <[email protected]>
---
 .../library/common/sort/buffer/WriteBuffer.java    | 332 +++++++++++++++++
 .../common/sort/buffer/WriteBufferManager.java     | 392 +++++++++++++++++++++
 .../library/common/sort/impl/RssSorter.java        | 206 +++++++++++
 .../common/sort/impl/RssTezPerPartitionRecord.java |  83 +++++
 .../output/RssOrderedPartitionedKVOutput.java      | 292 +++++++++++++++
 .../common/sort/buffer/WriteBufferManagerTest.java | 388 ++++++++++++++++++++
 .../common/sort/buffer/WriteBufferTest.java        | 165 +++++++++
 .../library/common/sort/impl/RssSorterTest.java    | 134 +++++++
 .../sort/impl/RssTezPerPartitionRecordTest.java    |  55 +++
 .../runtime/library/output/OutputTestHelpers.java  |  69 ++++
 .../output/RssOrderedPartitionedKVOutputTest.java  | 106 ++++++
 11 files changed, 2222 insertions(+)

diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
new file mode 100644
index 00000000..61bee2d6
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.buffer;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.Comparator;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+
+public class WriteBuffer<K,V> extends OutputStream {
+
+  private static final Logger LOG = LoggerFactory.getLogger(WriteBuffer.class);
+
+  private int partitionId;
+  private Serializer<K> keySerializer;
+  private Serializer<V> valSerializer;
+  private long maxSegmentSize;
+  private final RawComparator<K> comparator;
+  private int dataLength = 0;
+  private int currentOffset = 0;
+  private int currentIndex = 0;
+  private long sortTime = 0;
+  private long copyTime = 0;
+  private boolean isNeedSorted = false;
+  private final List<WrappedBuffer> buffers = Lists.newArrayList();
+  private final List<Record<K>> records = Lists.newArrayList();
+
+  public WriteBuffer(
+      boolean isNeedSorted,
+      int partitionId,
+      RawComparator<K> comparator,
+      long maxSegmentSize,
+      Serializer<K> keySerializer,
+      Serializer<V> valueSerializer) {
+    this.partitionId = partitionId;
+    this.comparator = comparator;
+    this.maxSegmentSize = maxSegmentSize;
+    this.keySerializer = keySerializer;
+    this.valSerializer = valueSerializer;
+    this.isNeedSorted = isNeedSorted;
+  }
+
+  /**
+   * add records
+   */
+  public int addRecord(K key, V value) throws IOException {
+    keySerializer.open(this);
+    valSerializer.open(this);
+    int lastOffSet = currentOffset;
+    int lastIndex = currentIndex;
+    int lastDataLength = dataLength;
+    int keyIndex = lastIndex;
+    keySerializer.serialize(key);
+    int keyLength = dataLength - lastDataLength;
+    int keyOffset = lastOffSet;
+    if (compact(lastIndex, lastOffSet, keyLength)) {
+      keyOffset = lastOffSet;
+      keyIndex = lastIndex;
+    }
+    lastDataLength = dataLength;
+    valSerializer.serialize(value);
+    int valueLength = dataLength - lastDataLength;
+    records.add(new Record<K>(keyIndex, keyOffset, keyLength, valueLength));
+    return keyLength + valueLength;
+  }
+
+  public void clear() {
+    buffers.clear();
+    records.clear();
+  }
+
+  /**
+   * get data
+   */
+  public synchronized byte[] getData() {
+    int extraSize = 0;
+    for (Record<K> record : records) {
+      extraSize += WritableUtils.getVIntSize(record.getKeyLength());
+      extraSize += WritableUtils.getVIntSize(record.getValueLength());
+    }
+    extraSize += WritableUtils.getVIntSize(-1);
+    extraSize += WritableUtils.getVIntSize(-1);
+    byte[] data = new byte[dataLength + extraSize];
+    int offset = 0;
+    long startSort = System.currentTimeMillis();
+    if (this.isNeedSorted) {
+      records.sort(new Comparator<Record<K>>() {
+        @Override
+        public int compare(Record<K> o1, Record<K> o2) {
+          return comparator.compare(
+                  buffers.get(o1.getKeyIndex()).getBuffer(),
+                  o1.getKeyOffSet(),
+                  o1.getKeyLength(),
+                  buffers.get(o2.getKeyIndex()).getBuffer(),
+                  o2.getKeyOffSet(),
+                  o2.getKeyLength());
+        }
+      });
+    }
+    long startCopy =  System.currentTimeMillis();
+    sortTime += startCopy - startSort;
+
+    for (Record<K> record : records) {
+      offset = writeDataInt(data, offset, record.getKeyLength());
+      offset = writeDataInt(data, offset, record.getValueLength());
+      int recordLength = record.getKeyLength() + record.getValueLength();
+      int copyOffset = record.getKeyOffSet();
+      int copyIndex = record.getKeyIndex();
+      while (recordLength > 0) {
+        byte[] srcBytes = buffers.get(copyIndex).getBuffer();
+        int length = copyOffset + recordLength;
+        int copyLength = recordLength;
+        if (length > srcBytes.length) {
+          copyLength = srcBytes.length - copyOffset;
+        }
+        System.arraycopy(srcBytes, copyOffset, data, offset, copyLength);
+        copyOffset = 0;
+        copyIndex++;
+        recordLength -= copyLength;
+        offset += copyLength;
+      }
+    }
+    offset = writeDataInt(data, offset, -1);
+    writeDataInt(data, offset, -1);
+    copyTime += System.currentTimeMillis() - startCopy;
+    return data;
+  }
+
+  private boolean compact(int lastIndex, int lastOffset, int dataLength) {
+    if (lastIndex != currentIndex) {
+      LOG.debug("compact lastIndex {}, currentIndex {}, lastOffset {} 
currentOffset {} dataLength {}",
+          lastIndex, currentIndex, lastOffset, currentOffset, dataLength);
+      WrappedBuffer buffer = new WrappedBuffer(lastOffset + dataLength);
+      // copy data
+      int offset = 0;
+      for (int i = lastIndex; i < currentIndex; i++) {
+        byte[] sourceBuffer = buffers.get(i).getBuffer();
+        System.arraycopy(sourceBuffer, 0, buffer.getBuffer(), offset, 
sourceBuffer.length);
+        offset += sourceBuffer.length;
+      }
+      System.arraycopy(buffers.get(currentIndex).getBuffer(), 0, 
buffer.getBuffer(), offset, currentOffset);
+      // remove data
+      for (int i = currentIndex; i >= lastIndex; i--) {
+        buffers.remove(i);
+      }
+      buffers.add(buffer);
+      currentOffset = 0;
+      WrappedBuffer anotherBuffer = new WrappedBuffer((int)maxSegmentSize);
+      buffers.add(anotherBuffer);
+      currentIndex = buffers.size() - 1;
+      return true;
+    }
+    return false;
+  }
+
+  @Override
+  public void write(int b) throws IOException {
+    if (buffers.isEmpty()) {
+      buffers.add(new WrappedBuffer((int)maxSegmentSize));
+    }
+    if (1 + currentOffset > maxSegmentSize) {
+      currentIndex++;
+      currentOffset = 0;
+      buffers.add(new WrappedBuffer((int)maxSegmentSize));
+    }
+    WrappedBuffer buffer = buffers.get(currentIndex);
+    buffer.getBuffer()[currentOffset] = (byte) b;
+    currentOffset++;
+    dataLength++;
+  }
+
+  @Override
+  public void write(byte[] b, int off, int len) throws IOException {
+    if (b == null) {
+      throw new NullPointerException();
+    } else if ((off < 0) || (off > b.length) || (len < 0)
+        || ((off + len) > b.length) || ((off + len) < 0)) {
+      throw new IndexOutOfBoundsException();
+    } else if (len == 0) {
+      return;
+    }
+    if (buffers.isEmpty()) {
+      buffers.add(new WrappedBuffer((int) maxSegmentSize));
+    }
+    int bufferNum = (int)((currentOffset + len) / maxSegmentSize);
+
+    for (int i = 0; i < bufferNum; i++) {
+      buffers.add(new WrappedBuffer((int) maxSegmentSize));
+    }
+
+    int index = currentIndex;
+    int offset = currentOffset;
+    int srcPos = 0;
+
+    while (len > 0) {
+      int copyLength = 0;
+      if (offset + len >= maxSegmentSize) {
+        copyLength = (int) (maxSegmentSize - offset);
+        currentOffset = 0;
+      } else {
+        copyLength = len;
+        currentOffset += len;
+      }
+      System.arraycopy(b, srcPos, buffers.get(index).getBuffer(), offset, 
copyLength);
+      offset = 0;
+      srcPos += copyLength;
+      index++;
+      len -= copyLength;
+      dataLength += copyLength;
+    }
+    currentIndex += bufferNum;
+  }
+
+  private int writeDataInt(byte[] data, int offset, long dataInt) {
+    if (dataInt >= -112L && dataInt <= 127L) {
+      data[offset] = (byte)((int)dataInt);
+      offset++;
+    } else {
+      int len = -112;
+      if (dataInt < 0L) {
+        dataInt = ~dataInt;
+        len = -120;
+      }
+      for (long tmp = dataInt; tmp != 0L; --len) {
+        tmp >>= 8;
+      }
+      data[offset] = (byte)len;
+      offset++;
+      len = len < -120 ? -(len + 120) : -(len + 112);
+      for (int idx = len; idx != 0; --idx) {
+        int shiftBits = (idx - 1) * 8;
+        long mask = 255L << shiftBits;
+        data[offset] = ((byte)((int)((dataInt & mask) >> shiftBits)));
+        offset++;
+      }
+    }
+    return offset;
+  }
+
+  public int getDataLength() {
+    return dataLength;
+  }
+
+  public int getPartitionId() {
+    return partitionId;
+  }
+
+  public long getCopyTime() {
+    return copyTime;
+  }
+
+  public long getSortTime() {
+    return sortTime;
+  }
+
+  private static final class Record<K> {
+    private final int keyIndex;
+    private final int keyOffSet;
+    private final int keyLength;
+    private final int valueLength;
+
+    Record(int keyIndex,
+           int keyOffset,
+           int keyLength,
+           int valueLength) {
+      this.keyIndex = keyIndex;
+      this.keyOffSet = keyOffset;
+      this.keyLength = keyLength;
+      this.valueLength = valueLength;
+    }
+
+    public int getKeyIndex() {
+      return keyIndex;
+    }
+
+    public int getKeyOffSet() {
+      return keyOffSet;
+    }
+
+    public int getKeyLength() {
+      return keyLength;
+    }
+
+    public int getValueLength() {
+      return valueLength;
+    }
+  }
+
+  private static final class WrappedBuffer {
+    private byte[] buffer;
+    private int size;
+
+    WrappedBuffer(int size) {
+      this.buffer = new byte[size];
+      this.size = size;
+    }
+
+    public byte[] getBuffer() {
+      return buffer;
+    }
+
+    public int getSize() {
+      return size;
+    }
+  }
+
+}
+
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
new file mode 100644
index 00000000..96fa6f29
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
@@ -0,0 +1,392 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.buffer;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.ReentrantLock;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.response.SendShuffleDataResult;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ChecksumUtils;
+import org.apache.uniffle.common.util.ThreadUtils;
+
+
+public class WriteBufferManager<K,V> {
+  private static final Logger LOG = 
LoggerFactory.getLogger(WriteBufferManager.class);
+  private long copyTime = 0;
+  private long sortTime = 0;
+  private long compressTime = 0;
+  private final Map<Integer, Integer> partitionToSeqNo = Maps.newHashMap();
+  private long uncompressedDataLen = 0;
+  private final long maxMemSize;
+  private final ExecutorService sendExecutorService;
+  private final ShuffleWriteClient shuffleWriteClient;
+  private final String appId;
+  private final Set<Long> successBlockIds;
+  private final Set<Long> failedBlockIds;
+  private final ReentrantLock memoryLock = new ReentrantLock();
+  private final AtomicLong memoryUsedSize = new AtomicLong(0);
+  private final AtomicLong inSendListBytes = new AtomicLong(0);
+  private final Condition full = memoryLock.newCondition();
+  private final RawComparator<K> comparator;
+  private final long maxSegmentSize;
+  private final Serializer<K> keySerializer;
+  private final Serializer<V> valSerializer;
+  private final List<WriteBuffer<K, V>> waitSendBuffers = 
Lists.newLinkedList();
+  private final Map<Integer, WriteBuffer<K, V>> buffers = 
Maps.newConcurrentMap();
+  private final long maxBufferSize;
+  private final double memoryThreshold;
+  private final double sendThreshold;
+  private final int batch;
+  private final Codec codec;
+  private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+  private final Set<Long> allBlockIds = Sets.newConcurrentHashSet();
+  private final Map<Integer, List<Long>> partitionToBlocks = 
Maps.newConcurrentMap();
+  private final int numMaps;
+  private final boolean isMemoryShuffleEnabled;
+  private final long sendCheckInterval;
+  private final long sendCheckTimeout;
+  private final int bitmapSplitNum;
+  private final long taskAttemptId;
+  private TezTaskAttemptID tezTaskAttemptID;
+  private final RssConf rssConf;
+  private final int shuffleId;
+  private final boolean isNeedSorted;
+
+  /**
+   * WriteBufferManager
+   */
+  public WriteBufferManager(
+      TezTaskAttemptID tezTaskAttemptID,
+      long maxMemSize,
+      String appId,
+      long taskAttemptId,
+      Set<Long> successBlockIds,
+      Set<Long> failedBlockIds,
+      ShuffleWriteClient shuffleWriteClient,
+      RawComparator<K> comparator,
+      long maxSegmentSize,
+      Serializer<K> keySerializer,
+      Serializer<V> valSerializer,
+      long maxBufferSize,
+      double memoryThreshold,
+      double sendThreshold,
+      int batch,
+      RssConf rssConf,
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+      int numMaps,
+      boolean isMemoryShuffleEnabled,
+      long sendCheckInterval,
+      long sendCheckTimeout,
+      int bitmapSplitNum,
+      int shuffleId,
+      boolean isNeedSorted) {
+    this.tezTaskAttemptID = tezTaskAttemptID;
+    this.maxMemSize = maxMemSize;
+    this.appId = appId;
+    this.taskAttemptId = taskAttemptId;
+    this.successBlockIds = successBlockIds;
+    this.failedBlockIds = failedBlockIds;
+    this.shuffleWriteClient = shuffleWriteClient;
+    this.comparator = comparator;
+    this.maxSegmentSize = maxSegmentSize;
+    this.keySerializer = keySerializer;
+    this.valSerializer = valSerializer;
+    this.maxBufferSize = maxBufferSize;
+    this.memoryThreshold = memoryThreshold;
+    this.sendThreshold = sendThreshold;
+    this.batch = batch;
+    this.codec = Codec.newInstance(rssConf);
+    this.partitionToServers = partitionToServers;
+    this.numMaps = numMaps;
+    this.isMemoryShuffleEnabled = isMemoryShuffleEnabled;
+    this.sendCheckInterval = sendCheckInterval;
+    this.sendCheckTimeout = sendCheckTimeout;
+    this.bitmapSplitNum = bitmapSplitNum;
+    this.rssConf = rssConf;
+    this.shuffleId = shuffleId;
+    this.isNeedSorted = isNeedSorted;
+    this.sendExecutorService = Executors.newFixedThreadPool(
+            1,
+            ThreadUtils.getThreadFactory("send-thread"));
+  }
+
+  /**
+   * add record
+   */
+  public void addRecord(int partitionId, K key, V value) throws 
InterruptedException, IOException {
+    memoryLock.lock();
+    try {
+      while (memoryUsedSize.get() > maxMemSize) {
+        LOG.warn("memoryUsedSize {} is more than {}, inSendListBytes {}",
+            memoryUsedSize, maxMemSize, inSendListBytes);
+        full.await();
+      }
+    } finally {
+      memoryLock.unlock();
+    }
+
+    if (!buffers.containsKey(partitionId)) {
+      WriteBuffer<K, V> sortWriterBuffer = new WriteBuffer(
+          isNeedSorted, partitionId, comparator, maxSegmentSize, 
keySerializer, valSerializer);
+      buffers.putIfAbsent(partitionId, sortWriterBuffer);
+      waitSendBuffers.add(sortWriterBuffer);
+    }
+    WriteBuffer<K, V> buffer = buffers.get(partitionId);
+    int length = buffer.addRecord(key, value);
+    if (length > maxMemSize) {
+      throw new RssException("record is too big");
+    }
+
+    memoryUsedSize.addAndGet(length);
+    if (buffer.getDataLength() > maxBufferSize) {
+      if (waitSendBuffers.remove(buffer)) {
+        sendBufferToServers(buffer);
+      } else {
+        LOG.error("waitSendBuffers don't contain buffer {}", buffer);
+      }
+    }
+    if (memoryUsedSize.get() > maxMemSize * memoryThreshold
+        && inSendListBytes.get() <= maxMemSize * sendThreshold) {
+      sendBuffersToServers();
+    }
+  }
+
+  private void sendBufferToServers(WriteBuffer<K, V> buffer) {
+    List<ShuffleBlockInfo> shuffleBlocks = Lists.newArrayList();
+    prepareBufferForSend(shuffleBlocks, buffer);
+    sendShuffleBlocks(shuffleBlocks);
+  }
+
+  void sendBuffersToServers() {
+    waitSendBuffers.sort(new Comparator<WriteBuffer<K, V>>() {
+      @Override
+      public int compare(WriteBuffer<K, V> o1, WriteBuffer<K, V> o2) {
+        return o2.getDataLength() - o1.getDataLength();
+      }
+    });
+
+    int sendSize = batch;
+    if (batch > waitSendBuffers.size()) {
+      sendSize = waitSendBuffers.size();
+    }
+
+    Iterator<WriteBuffer<K, V>> iterator = waitSendBuffers.iterator();
+    int index = 0;
+    List<ShuffleBlockInfo> shuffleBlocks = Lists.newArrayList();
+    while (iterator.hasNext() && index < sendSize) {
+      WriteBuffer<K, V> buffer = iterator.next();
+      prepareBufferForSend(shuffleBlocks, buffer);
+      iterator.remove();
+      index++;
+    }
+    sendShuffleBlocks(shuffleBlocks);
+  }
+
+  private void prepareBufferForSend(List<ShuffleBlockInfo> shuffleBlocks, 
WriteBuffer buffer) {
+    buffers.remove(buffer.getPartitionId());
+    ShuffleBlockInfo block = createShuffleBlock(buffer);
+    buffer.clear();
+    shuffleBlocks.add(block);
+    allBlockIds.add(block.getBlockId());
+    if (!partitionToBlocks.containsKey(block.getPartitionId())) {
+      partitionToBlocks.putIfAbsent(block.getPartitionId(), 
Lists.newArrayList());
+    }
+    partitionToBlocks.get(block.getPartitionId()).add(block.getBlockId());
+  }
+
+  private void sendShuffleBlocks(List<ShuffleBlockInfo> shuffleBlocks) {
+    sendExecutorService.submit(new Runnable() {
+      @Override
+      public void run() {
+        long size = 0;
+        try {
+          for (ShuffleBlockInfo block : shuffleBlocks) {
+            size += block.getFreeMemory();
+          }
+          SendShuffleDataResult result = 
shuffleWriteClient.sendShuffleData(appId, shuffleBlocks, () -> false);
+          successBlockIds.addAll(result.getSuccessBlockIds());
+          failedBlockIds.addAll(result.getFailedBlockIds());
+        } catch (Throwable t) {
+          LOG.warn("send shuffle data exception ", t);
+        } finally {
+          try {
+            memoryLock.lock();
+            LOG.debug("memoryUsedSize {} decrease {}", memoryUsedSize, size);
+            memoryUsedSize.addAndGet(-size);
+            inSendListBytes.addAndGet(-size);
+            full.signalAll();
+          } finally {
+            memoryLock.unlock();
+          }
+        }
+      }
+    });
+  }
+
+  /**
+   * wait send finished
+   */
+  public void waitSendFinished() {
+    while (!waitSendBuffers.isEmpty()) {
+      sendBuffersToServers();
+    }
+    long start = System.currentTimeMillis();
+    long commitDuration = 0;
+    if (!isMemoryShuffleEnabled) {
+      long s = System.currentTimeMillis();
+      sendCommit();
+      commitDuration = System.currentTimeMillis() - s;
+    }
+    while (true) {
+      if (failedBlockIds.size() > 0) {
+        String errorMsg = "Send failed: failed because " + 
failedBlockIds.size()
+            + " blocks can't be sent to shuffle server.";
+        LOG.error(errorMsg);
+        throw new RssException(errorMsg);
+      }
+      allBlockIds.removeAll(successBlockIds);
+      if (allBlockIds.isEmpty()) {
+        break;
+      }
+      LOG.info("Wait " + allBlockIds.size() + " blocks sent to shuffle 
server");
+      Uninterruptibles.sleepUninterruptibly(sendCheckInterval, 
TimeUnit.MILLISECONDS);
+      if (System.currentTimeMillis() - start > sendCheckTimeout) {
+        String errorMsg = "Timeout: failed because " + allBlockIds.size()
+            + " blocks can't be sent to shuffle server in " + sendCheckTimeout 
+ " ms.";
+        LOG.error(errorMsg);
+        throw new RssException(errorMsg);
+      }
+    }
+    start = System.currentTimeMillis();
+    TezVertexID tezVertexID = tezTaskAttemptID.getTaskID().getVertexID();
+    TezDAGID tezDAGID = tezVertexID.getDAGId();
+    LOG.info("tezVertexID is {}, tezDAGID is {}, shuffleId is {}", 
tezVertexID, tezDAGID, shuffleId);
+    shuffleWriteClient.reportShuffleResult(partitionToServers, appId, 
shuffleId,
+            taskAttemptId, partitionToBlocks, bitmapSplitNum);
+    LOG.info("Report shuffle result for task[{}] with bitmapNum[{}] cost {} 
ms",
+            taskAttemptId, bitmapSplitNum, (System.currentTimeMillis() - 
start));
+    LOG.info("Task uncompressed data length {} compress time cost {} ms, 
commit time cost {} ms,"
+                  + " copy time cost {} ms, sort time cost {} ms",
+            uncompressedDataLen, compressTime, commitDuration, copyTime, 
sortTime);
+  }
+
+  ShuffleBlockInfo createShuffleBlock(WriteBuffer wb) {
+    byte[] data = wb.getData();
+    copyTime += wb.getCopyTime();
+    sortTime += wb.getSortTime();
+    int partitionId = wb.getPartitionId();
+    final int uncompressLength = data.length;
+    long start = System.currentTimeMillis();
+
+    final byte[] compressed = codec.compress(data);
+    final long crc32 = ChecksumUtils.getCrc32(compressed);
+    compressTime += System.currentTimeMillis() - start;
+    final long blockId = RssTezUtils.getBlockId((long)partitionId, 
taskAttemptId, getNextSeqNo(partitionId));
+    LOG.info("blockId is {}", blockId);
+    uncompressedDataLen += data.length;
+    // add memory to indicate bytes which will be sent to shuffle server
+    inSendListBytes.addAndGet(wb.getDataLength());
+
+    TezVertexID tezVertexID = tezTaskAttemptID.getTaskID().getVertexID();
+    TezDAGID tezDAGID = tezVertexID.getDAGId();
+    LOG.info("tezVertexID is {}, tezDAGID is {}, shuffleId is {}", 
tezVertexID, tezDAGID, shuffleId);
+    return new ShuffleBlockInfo(shuffleId, partitionId, blockId, 
compressed.length, crc32,
+          compressed, partitionToServers.get(partitionId),
+          uncompressLength, wb.getDataLength(), taskAttemptId);
+  }
+
+  protected void sendCommit() {
+    ExecutorService executor = Executors.newSingleThreadExecutor();
+    Set<ShuffleServerInfo> serverInfos = Sets.newHashSet();
+    for (List<ShuffleServerInfo> serverInfoList : partitionToServers.values()) 
{
+      for (ShuffleServerInfo serverInfo : serverInfoList) {
+        serverInfos.add(serverInfo);
+      }
+    }
+    LOG.info("sendCommit  shuffle id is {}", shuffleId);
+    Future<Boolean> future = executor.submit(
+          () -> shuffleWriteClient.sendCommit(serverInfos, appId, shuffleId, 
numMaps));
+    long start = System.currentTimeMillis();
+    int currentWait = 200;
+    int maxWait = 5000;
+    while (!future.isDone()) {
+      LOG.info("Wait commit to shuffle server for task[" + taskAttemptId + "] 
cost "
+            + (System.currentTimeMillis() - start) + " ms");
+      Uninterruptibles.sleepUninterruptibly(currentWait, 
TimeUnit.MILLISECONDS);
+      currentWait = Math.min(currentWait * 2, maxWait);
+    }
+    try {
+      if (!future.get()) {
+        throw new RssException("Failed to commit task to shuffle server");
+      }
+    } catch (InterruptedException ie) {
+      LOG.warn("Ignore the InterruptedException which should be caused by 
internal killed");
+    } catch (Exception e) {
+      throw new RssException("Exception happened when get commit status", e);
+    } finally {
+      executor.shutdown();
+    }
+  }
+
+  List<WriteBuffer<K,V>> getWaitSendBuffers() {
+    return waitSendBuffers;
+  }
+
+  private int getNextSeqNo(int partitionId) {
+    partitionToSeqNo.putIfAbsent(partitionId, 0);
+    int seqNo = partitionToSeqNo.get(partitionId);
+    partitionToSeqNo.put(partitionId, seqNo + 1);
+    return seqNo;
+  }
+
+  public void freeAllResources() {
+    sendExecutorService.shutdownNow();
+  }
+
+}
+
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
new file mode 100644
index 00000000..45b9f5d6
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.impl;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.Sets;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.util.ConverterUtils;
+import org.apache.tez.common.IdUtils;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.library.common.sort.buffer.WriteBufferManager;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ByteUnit;
+import org.apache.uniffle.storage.util.StorageType;
+
+
+/**{@link RssSorter} is an {@link ExternalSorter}
+ */
+public class RssSorter extends ExternalSorter {
+
+  private static final Logger LOG = LoggerFactory.getLogger(RssSorter.class);
+  private WriteBufferManager bufferManager;
+  private Set<Long> successBlockIds = Sets.newConcurrentHashSet();
+  private Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
+  private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+
+  private int[] numRecordsPerPartition;
+
+  /**
+   * Initialization
+   */
+  public RssSorter(TezTaskAttemptID tezTaskAttemptID,OutputContext 
outputContext,
+                   Configuration conf, int numMaps, int numOutputs,
+                   long initialMemoryAvailable, int shuffleId,
+                   Map<Integer, List<ShuffleServerInfo>> partitionToServers) 
throws IOException {
+    super(outputContext, conf, numOutputs, initialMemoryAvailable);
+    this.partitionToServers = partitionToServers;
+
+    this.numRecordsPerPartition = new int[numOutputs];
+
+    long sortmb = conf.getLong(RssTezConfig.RSS_RUNTIME_IO_SORT_MB, 
RssTezConfig.RSS_DEFAULT_RUNTIME_IO_SORT_MB);
+    LOG.info("conf.sortmb is {}", sortmb);
+    sortmb = this.availableMemoryMb;
+    LOG.info("sortmb, availableMemoryMb is {}, {}", sortmb, availableMemoryMb);
+    if ((sortmb & 0x7FF) != sortmb) {
+      throw new IOException(
+          "Invalid \"" + RssTezConfig.RSS_RUNTIME_IO_SORT_MB + "\": " + 
sortmb);
+    }
+    double sortThreshold = 
conf.getDouble(RssTezConfig.RSS_CLIENT_SORT_MEMORY_USE_THRESHOLD,
+          RssTezConfig.RSS_CLIENT_DEFAULT_SORT_MEMORY_USE_THRESHOLD);
+    long taskAttemptId = 
RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptID, 
IdUtils.getAppAttemptId());
+
+    long maxSegmentSize = conf.getLong(RssTezConfig.RSS_CLIENT_MAX_BUFFER_SIZE,
+          RssTezConfig.RSS_CLIENT_DEFAULT_MAX_BUFFER_SIZE);
+    long maxBufferSize = conf.getLong(RssTezConfig.RSS_WRITER_BUFFER_SIZE, 
RssTezConfig.RSS_DEFAULT_WRITER_BUFFER_SIZE);
+    double memoryThreshold = 
conf.getDouble(RssTezConfig.RSS_CLIENT_MEMORY_THRESHOLD,
+          RssTezConfig.RSS_CLIENT_DEFAULT_MEMORY_THRESHOLD);
+    double sendThreshold = 
conf.getDouble(RssTezConfig.RSS_CLIENT_SEND_THRESHOLD,
+          RssTezConfig.RSS_CLIENT_DEFAULT_SEND_THRESHOLD);
+    int batch = conf.getInt(RssTezConfig.RSS_CLIENT_BATCH_TRIGGER_NUM,
+          RssTezConfig.RSS_CLIENT_DEFAULT_BATCH_TRIGGER_NUM);
+    String storageType = conf.get(RssTezConfig.RSS_STORAGE_TYPE, 
RssTezConfig.RSS_DEFAULT_STORAGE_TYPE);
+    if (StringUtils.isEmpty(storageType)) {
+      throw new RssException("storage type mustn't be empty");
+    }
+    long sendCheckInterval = 
conf.getLong(RssTezConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS,
+          RssTezConfig.RSS_CLIENT_DEFAULT_SEND_CHECK_INTERVAL_MS);
+    long sendCheckTimeout = 
conf.getLong(RssTezConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS,
+          RssTezConfig.RSS_CLIENT_DEFAULT_SEND_CHECK_TIMEOUT_MS);
+    int bitmapSplitNum = conf.getInt(RssTezConfig.RSS_CLIENT_BITMAP_NUM,
+          RssTezConfig.RSS_CLIENT_DEFAULT_BITMAP_NUM);
+
+    if (conf.get(RssTezConfig.HIVE_TEZ_LOG_LEVEL, 
RssTezConfig.DEFAULT_HIVE_TEZ_LOG_LEVEL)
+          .equalsIgnoreCase(RssTezConfig.DEBUG_HIVE_TEZ_LOG_LEVEL)) {
+      LOG.info("sortmb is {}", sortmb);
+      LOG.info("sortThreshold is {}", sortThreshold);
+      LOG.info("taskAttemptId is {}", taskAttemptId);
+      LOG.info("maxSegmentSize is {}", maxSegmentSize);
+      LOG.info("maxBufferSize is {}", maxBufferSize);
+      LOG.info("memoryThreshold is {}", memoryThreshold);
+      LOG.info("sendThreshold is {}", sendThreshold);
+      LOG.info("batch is {}", batch);
+      LOG.info("storageType is {}", storageType);
+      LOG.info("sendCheckInterval is {}", sendCheckInterval);
+      LOG.info("sendCheckTimeout is {}", sendCheckTimeout);
+      LOG.info("bitmapSplitNum is {}", bitmapSplitNum);
+    }
+
+
+    String containerIdStr =
+        System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name());
+    ContainerId containerId = ConverterUtils.toContainerId(containerIdStr);
+    ApplicationAttemptId applicationAttemptId =
+        containerId.getApplicationAttemptId();
+    LOG.info("containerIdStr is {}", containerIdStr);
+    LOG.info("containerId is {}", containerId);
+    LOG.info("applicationAttemptId is {}", applicationAttemptId.toString());
+
+
+    bufferManager = new WriteBufferManager(
+        tezTaskAttemptID,
+        (long)(ByteUnit.MiB.toBytes(sortmb) * sortThreshold),
+        applicationAttemptId.toString(),
+        taskAttemptId,
+        successBlockIds,
+        failedBlockIds,
+        RssTezUtils.createShuffleClient(conf),
+        comparator,
+        maxSegmentSize,
+        keySerializer,
+        valSerializer,
+        maxBufferSize,
+        memoryThreshold,
+        sendThreshold,
+        batch,
+        new RssConf(),
+        partitionToServers,
+        numMaps,
+        isMemoryShuffleEnabled(storageType),
+        sendCheckInterval,
+        sendCheckTimeout,
+        bitmapSplitNum,
+        shuffleId,
+        true);
+    LOG.info("Initialized WriteBufferManager.");
+  }
+
+  @Override
+  public void flush() throws IOException {
+    bufferManager.waitSendFinished();
+  }
+
+  @Override
+  public final void close() throws IOException {
+    super.close();
+    bufferManager.freeAllResources();
+  }
+
+  @Override
+  public void write(Object key, Object value) throws IOException {
+    try {
+      collect(key, value, partitioner.getPartition(key, value, partitions));
+    } catch (InterruptedException e) {
+      throw new RssException(e);
+    }
+  }
+
+  synchronized void collect(Object key, Object value, final int partition) 
throws IOException, InterruptedException {
+    if (key.getClass() != keyClass) {
+      throw new IOException("Type mismatch in key from map: expected "
+        + keyClass.getName() + ", received "
+        + key.getClass().getName());
+    }
+    if (value.getClass() != valClass) {
+      throw new IOException("Type mismatch in value from map: expected "
+        + valClass.getName() + ", received "
+        + value.getClass().getName());
+    }
+    if (partition < 0 || partition >= partitions) {
+      throw new IOException("Illegal partition for " + key + " ("
+        + partition + ")");
+    }
+
+    bufferManager.addRecord(partition, key, value);
+    numRecordsPerPartition[partition]++;
+  }
+
+  public int[] getNumRecordsPerPartition() {
+    return numRecordsPerPartition;
+  }
+
+  private boolean isMemoryShuffleEnabled(String storageType) {
+    return StorageType.withMemory(StorageType.valueOf(storageType));
+  }
+}
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecord.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecord.java
new file mode 100644
index 00000000..97bb2ebf
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecord.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.impl;
+
+import java.io.IOException;
+import java.util.zip.Checksum;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+
+
+public class RssTezPerPartitionRecord extends TezSpillRecord {
+  private int numPartitions;
+  private int[] numRecordsPerPartition;
+
+  public RssTezPerPartitionRecord(int numPartitions) {
+    super(numPartitions);
+    this.numPartitions = numPartitions;
+  }
+
+  public RssTezPerPartitionRecord(int numPartitions, int[] 
numRecordsPerPartition) {
+    super(numPartitions);
+    this.numPartitions = numPartitions;
+    this.numRecordsPerPartition = numRecordsPerPartition;
+  }
+
+
+  public RssTezPerPartitionRecord(Path indexFileName, Configuration job) 
throws IOException {
+    super(indexFileName, job);
+  }
+
+  public RssTezPerPartitionRecord(Path indexFileName, Configuration job, 
String expectedIndexOwner) throws IOException {
+    super(indexFileName, job, expectedIndexOwner);
+  }
+
+  public RssTezPerPartitionRecord(Path indexFileName, Configuration job, 
Checksum crc, String expectedIndexOwner)
+      throws IOException {
+    super(indexFileName, job, crc, expectedIndexOwner);
+  }
+
+  @Override
+  public int size() {
+    return numPartitions;
+  }
+
+  @Override
+  public RssTezIndexRecord getIndex(int i) {
+    int records = numRecordsPerPartition[i];
+    RssTezIndexRecord rssTezIndexRecord = new RssTezIndexRecord();
+    rssTezIndexRecord.setData(!(records == 0));
+    return rssTezIndexRecord;
+  }
+
+
+  static class RssTezIndexRecord extends TezIndexRecord {
+    private boolean hasData;
+
+    private void setData(boolean hasData) {
+      this. hasData = hasData;
+    }
+
+    @Override
+    public boolean hasData() {
+      return hasData;
+    }
+  }
+
+}
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
new file mode 100644
index 00000000..928ec3c4
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.output;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.security.PrivilegedExceptionAction;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.zip.Deflater;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceAudience.Public;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.ipc.RPC;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.GetShuffleServerRequest;
+import org.apache.tez.common.GetShuffleServerResponse;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.AbstractLogicalOutput;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.api.Writer;
+import org.apache.tez.runtime.library.api.KeyValuesWriter;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.common.MemoryUpdateCallbackHandler;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
+import org.apache.tez.runtime.library.common.sort.impl.ExternalSorter;
+import org.apache.tez.runtime.library.common.sort.impl.RssSorter;
+import 
org.apache.tez.runtime.library.common.sort.impl.RssTezPerPartitionRecord;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+
+
+/**
+ * {@link RssOrderedPartitionedKVOutput} is an {@link AbstractLogicalOutput} 
which
+ * support remote shuffle.
+ */
+@Public
+public class RssOrderedPartitionedKVOutput extends AbstractLogicalOutput {
+  private static final Logger LOG = 
LoggerFactory.getLogger(RssOrderedPartitionedKVOutput.class);
+  protected ExternalSorter sorter;
+  protected Configuration conf;
+  protected MemoryUpdateCallbackHandler memoryUpdateCallbackHandler;
+  private long startTime;
+  private long endTime;
+  private final AtomicBoolean isStarted = new AtomicBoolean(false);
+  private final Deflater deflater;
+  private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+  private int mapNum;
+  private int numOutputs;
+  private TezTaskAttemptID taskAttemptId;
+  private ApplicationId applicationId;
+  private boolean sendEmptyPartitionDetails;
+  private OutputContext outputContext;
+  private String host;
+  private int port;
+  private String taskVertexName;
+  private String destinationVertexName;
+  private int shuffleId;
+
+
+  public RssOrderedPartitionedKVOutput(OutputContext outputContext, int 
numPhysicalOutputs) {
+    super(outputContext, numPhysicalOutputs);
+    this.outputContext = outputContext;
+    this.deflater = TezCommonUtils.newBestCompressionDeflater();
+    this.numOutputs = getNumPhysicalOutputs();
+    this.mapNum = outputContext.getVertexParallelism();
+    this.applicationId = outputContext.getApplicationId();
+    this.taskAttemptId = TezTaskAttemptID.fromString(
+      
RssTezUtils.uniqueIdentifierToAttemptId(outputContext.getUniqueIdentifier()));
+    this.taskVertexName = outputContext.getTaskVertexName();
+    this.destinationVertexName = outputContext.getDestinationVertexName();
+    LOG.info("taskAttemptId is {}", taskAttemptId.toString());
+    LOG.info("taskVertexName is {}", taskVertexName);
+    LOG.info("destinationVertexName is {}", destinationVertexName);
+    LOG.info("Initialized RssOrderedPartitionedKVOutput.");
+  }
+
+  private void getRssConf() {
+    try {
+      JobConf conf = new JobConf(RssTezConfig.RSS_CONF_FILE);
+      this.host = conf.get(RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS, "null 
host");
+      this.port = conf.getInt(RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT, -1);
+
+      LOG.info("Got RssConf am info : host is {}, port is {}", host, port);
+    } catch (Exception e) {
+      LOG.warn("debugRssConf error: ", e);
+    }
+  }
+
+  @Override
+  public List<Event> initialize() throws Exception {
+    this.startTime = System.nanoTime();
+    this.conf = 
TezUtils.createConfFromUserPayload(getContext().getUserPayload());
+    this.memoryUpdateCallbackHandler = new MemoryUpdateCallbackHandler();
+
+    long memRequestSize = RssTezUtils.getInitialMemoryRequirement(conf, 
getContext().getTotalMemoryAvailableToTask());
+    LOG.info("memRequestSize is {}", memRequestSize);
+    getContext().requestInitialMemory(memRequestSize, 
memoryUpdateCallbackHandler);
+    LOG.info("Got initialMemory.");
+
+    getRssConf();
+
+    this.sendEmptyPartitionDetails = conf.getBoolean(
+        
TezRuntimeConfiguration.TEZ_RUNTIME_EMPTY_PARTITION_INFO_VIA_EVENTS_ENABLED,
+        
TezRuntimeConfiguration.TEZ_RUNTIME_EMPTY_PARTITION_INFO_VIA_EVENTS_ENABLED_DEFAULT);
+
+    final InetSocketAddress address = NetUtils.createSocketAddrForHost(host, 
port);
+
+    UserGroupInformation taskOwner = 
UserGroupInformation.createRemoteUser(this.applicationId.toString());
+
+    TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
+        .doAs(new 
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
+          @Override
+          public TezRemoteShuffleUmbilicalProtocol run() throws Exception {
+            return RPC.getProxy(TezRemoteShuffleUmbilicalProtocol.class,
+              TezRemoteShuffleUmbilicalProtocol.versionID,
+              address, conf);
+          }
+        });
+    TezVertexID tezVertexID = taskAttemptId.getTaskID().getVertexID();
+    TezDAGID tezDAGID = tezVertexID.getDAGId();
+    this.shuffleId = RssTezUtils.computeShuffleId(tezDAGID.getId(), 
this.taskVertexName, this.destinationVertexName);
+    GetShuffleServerRequest request = new 
GetShuffleServerRequest(this.taskAttemptId, this.mapNum,
+        this.numOutputs, this.shuffleId);
+    GetShuffleServerResponse response = 
umbilical.getShuffleAssignments(request);
+    this.partitionToServers = response.getShuffleAssignmentsInfoWritable()
+        .getShuffleAssignmentsInfo()
+        .getPartitionToServers();
+
+    LOG.info("Got response from am.");
+    return Collections.emptyList();
+  }
+
+  @Override
+  public void handleEvents(List<Event> list) {
+
+  }
+
+  @Override
+  public List<Event> close() throws Exception {
+    List<Event> returnEvents = Lists.newLinkedList();
+    if (sorter != null) {
+      sorter.flush();
+      sorter.close();
+      this.endTime = System.nanoTime();
+      returnEvents.addAll(generateEvents());
+      sorter = null;
+    } else {
+      LOG.warn(getContext().getDestinationVertexName()
+          + ": Attempting to close output {} of type {} before it was started. 
Generating empty events",
+          getContext().getDestinationVertexName(), 
this.getClass().getSimpleName());
+      returnEvents = generateEmptyEvents();
+    }
+    LOG.info("RssOrderedPartitionedKVOutput close.");
+    return returnEvents;
+  }
+
+  @Override
+  public void start() throws Exception {
+    if (!isStarted.get()) {
+      memoryUpdateCallbackHandler.validateUpdateReceived();
+      sorter = new RssSorter(taskAttemptId, getContext(), conf, mapNum, 
numOutputs,
+          memoryUpdateCallbackHandler.getMemoryAssigned(), shuffleId,
+          partitionToServers);
+      LOG.info("Initialized RssSorter.");
+      isStarted.set(true);
+    }
+  }
+
+  @Override
+  public Writer getWriter() throws IOException {
+    Preconditions.checkState(isStarted.get(), "Cannot get writer before 
starting the Output");
+
+    return new KeyValuesWriter() {
+      @Override
+      public void write(Object key, Iterable<Object> values) throws 
IOException {
+        sorter.write(key, values);
+      }
+
+      @Override
+      public void write(Object key, Object value) throws IOException {
+        sorter.write(key, value);
+      }
+    };
+  }
+
+  private List<Event> generateEvents() throws IOException {
+    List<Event> eventList = Lists.newLinkedList();
+    boolean isLastEvent = true;
+
+    String auxiliaryService = 
conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
+        TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT);
+
+    int[] numRecordsPerPartition = ((RssSorter) 
sorter).getNumRecordsPerPartition();
+
+    RssTezPerPartitionRecord rssTezSpillRecord = new 
RssTezPerPartitionRecord(numOutputs, numRecordsPerPartition);
+
+    LOG.info("RssTezPerPartitionRecord is initialized");
+
+    ShuffleUtils.generateEventOnSpill(eventList, true, isLastEvent,
+        getContext(), 0, rssTezSpillRecord,
+        getNumPhysicalOutputs(), sendEmptyPartitionDetails, 
getContext().getUniqueIdentifier(),
+        sorter.getPartitionStats(), sorter.reportDetailedPartitionStats(), 
auxiliaryService, deflater);
+    LOG.info("Generate events.");
+
+    return eventList;
+  }
+
+  private List<Event> generateEmptyEvents() throws IOException {
+    List<Event> eventList = Lists.newArrayList();
+    ShuffleUtils.generateEventsForNonStartedOutput(eventList,
+        getNumPhysicalOutputs(),
+        getContext(),
+        true,
+        true,
+        deflater);
+    LOG.info("Generate empty events.");
+    return eventList;
+  }
+
+  private static final Set<String> confKeys = new HashSet<String>();
+
+  static {
+    confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_READAHEAD);
+    confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_READAHEAD_BYTES);
+    confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_IO_FILE_BUFFER_SIZE);
+    
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_INDEX_CACHE_MEMORY_LIMIT_BYTES);
+    
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_UNORDERED_OUTPUT_BUFFER_SIZE_MB);
+    
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_UNORDERED_OUTPUT_MAX_PER_BUFFER_SIZE_BYTES);
+    confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS);
+    confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS);
+    confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS);
+    confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMPRESS);
+    confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMPRESS_CODEC);
+    
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_EMPTY_PARTITION_INFO_VIA_EVENTS_ENABLED);
+    
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_CONVERT_USER_PAYLOAD_TO_HISTORY_TEXT);
+    
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_PIPELINED_SHUFFLE_ENABLED);
+    
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_ENABLE_FINAL_MERGE_IN_OUTPUT);
+    confKeys.add(TezConfiguration.TEZ_COUNTERS_MAX);
+    confKeys.add(TezConfiguration.TEZ_COUNTERS_GROUP_NAME_MAX_LENGTH);
+    confKeys.add(TezConfiguration.TEZ_COUNTERS_COUNTER_NAME_MAX_LENGTH);
+    confKeys.add(TezConfiguration.TEZ_COUNTERS_MAX_GROUPS);
+    
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_CLEANUP_FILES_ON_INTERRUPT);
+    confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_REPORT_PARTITION_STATS);
+    confKeys.add(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID);
+    confKeys.add(
+        
TezRuntimeConfiguration.TEZ_RUNTIME_UNORDERED_PARTITIONED_KVWRITER_BUFFER_MERGE_PERCENT);
+  }
+
+  @InterfaceAudience.Private
+  public static Set<String> getConfigurationKeySet() {
+    return Collections.unmodifiableSet(confKeys);
+  }
+
+}
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
new file mode 100644
index 00000000..ce3a8a48
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -0,0 +1,388 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.buffer;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+
+import com.google.common.collect.Sets;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.WritableComparator;
+import org.apache.hadoop.io.serializer.SerializationFactory;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.response.SendShuffleDataResult;
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleDataDistributionType;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.storage.util.StorageType;
+
+
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class WriteBufferManagerTest {
+  @Test
+  public void testWriteException() throws IOException, InterruptedException {
+    TezTaskAttemptID tezTaskAttemptID = 
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
+    long maxMemSize = 10240;
+    String appId = "application_1681717153064_3770270";
+    long taskAttemptId = 0;
+    Set<Long> successBlockIds = Sets.newConcurrentHashSet();
+    Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
+    MockShuffleWriteClient writeClient = new MockShuffleWriteClient();
+    RawComparator comparator = WritableComparator.get(BytesWritable.class);
+    long maxSegmentSize = 3 * 1024;
+    SerializationFactory serializationFactory = new SerializationFactory(new 
JobConf());
+    Serializer<BytesWritable> keySerializer =  
serializationFactory.getSerializer(BytesWritable.class);
+    Serializer<BytesWritable> valSerializer = 
serializationFactory.getSerializer(BytesWritable.class);
+    long maxBufferSize = 14 * 1024 * 1024;
+    double memoryThreshold = 0.8f;
+    double sendThreshold = 0.2f;
+    int batch = 50;
+    int numMaps = 1;
+    String storageType = "MEMORY";
+    RssConf rssConf = new RssConf();
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    long sendCheckInterval = 500L;
+    long sendCheckTimeout = 5;
+    int bitmapSplitNum = 1;
+    int shuffleId = getShuffleId(tezTaskAttemptID, "Map 1", "Reducer 2");
+
+    WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
+        new WriteBufferManager(tezTaskAttemptID, maxMemSize, appId,
+        taskAttemptId, successBlockIds, failedBlockIds, writeClient,
+        comparator, maxSegmentSize, keySerializer,
+        valSerializer, maxBufferSize, memoryThreshold,
+        sendThreshold, batch, rssConf, partitionToServers,
+        numMaps, isMemoryShuffleEnabled(storageType),
+        sendCheckInterval, sendCheckTimeout, bitmapSplitNum, shuffleId, true);
+
+    Random random = new Random();
+    for (int i = 0; i < 1000; i++) {
+      byte[] key = new byte[20];
+      byte[] value = new byte[1024];
+      random.nextBytes(key);
+      random.nextBytes(value);
+      bufferManager.addRecord(1, new BytesWritable(key), new 
BytesWritable(value));
+    }
+
+    boolean isException = false;
+    try {
+      bufferManager.waitSendFinished();
+    } catch (RssException re) {
+      isException = true;
+    }
+    assertTrue(isException);
+  }
+
+  @Test
+  public void testWriteNormal() throws IOException, InterruptedException {
+    TezTaskAttemptID tezTaskAttemptID = 
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
+    long maxMemSize = 10240;
+    String appId = "appattempt_1681717153064_3770270_000001";
+    long taskAttemptId = 0;
+    Set<Long> successBlockIds = Sets.newConcurrentHashSet();
+    Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
+    MockShuffleWriteClient writeClient = new MockShuffleWriteClient();
+    writeClient.setMode(2);
+    RawComparator comparator = WritableComparator.get(BytesWritable.class);
+    long maxSegmentSize = 3 * 1024;
+    SerializationFactory serializationFactory = new SerializationFactory(new 
JobConf());
+    Serializer<BytesWritable> keySerializer =  
serializationFactory.getSerializer(BytesWritable.class);
+    Serializer<BytesWritable> valSerializer = 
serializationFactory.getSerializer(BytesWritable.class);
+    long maxBufferSize = 14 * 1024 * 1024;
+    double memoryThreshold = 0.8f;
+    double sendThreshold = 0.2f;
+    int batch = 50;
+    int numMaps = 1;
+    String storageType = "MEMORY";
+    RssConf rssConf = new RssConf();
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    long sendCheckInterval = 500L;
+    long sendCheckTimeout = 60 * 1000 * 10L;
+    int bitmapSplitNum = 1;
+    int shuffleId = getShuffleId(tezTaskAttemptID, "Map 1", "Reducer 2");
+
+    WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
+        new WriteBufferManager(tezTaskAttemptID, maxMemSize, appId,
+        taskAttemptId, successBlockIds, failedBlockIds, writeClient,
+        comparator, maxSegmentSize, keySerializer,
+        valSerializer, maxBufferSize, memoryThreshold,
+        sendThreshold, batch, rssConf, partitionToServers,
+        numMaps, isMemoryShuffleEnabled(storageType),
+        sendCheckInterval, sendCheckTimeout, bitmapSplitNum, shuffleId, true);
+
+    Random random = new Random();
+    for (int i = 0; i < 1000; i++) {
+      byte[] key = new byte[20];
+      byte[] value = new byte[1024];
+      random.nextBytes(key);
+      random.nextBytes(value);
+      int partitionId = random.nextInt(50);
+      bufferManager.addRecord(partitionId, new BytesWritable(key), new 
BytesWritable(value));
+    }
+    bufferManager.waitSendFinished();
+    assertTrue(bufferManager.getWaitSendBuffers().isEmpty());
+
+    for (int i = 0; i < 50; i++) {
+      byte[] key = new byte[20];
+      byte[] value = new byte[i * 100];
+      random.nextBytes(key);
+      random.nextBytes(value);
+      bufferManager.addRecord(i, new BytesWritable(key), new 
BytesWritable(value));
+    }
+    assert (1 == bufferManager.getWaitSendBuffers().size());
+    assert (4928 == bufferManager.getWaitSendBuffers().get(0).getDataLength());
+
+    bufferManager.waitSendFinished();
+    assertTrue(bufferManager.getWaitSendBuffers().isEmpty());
+  }
+
+  @Test
+  public void testCommitBlocksWhenMemoryShuffleDisabled() throws IOException, 
InterruptedException {
+    TezTaskAttemptID tezTaskAttemptID = 
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
+    long maxMemSize = 10240;
+    String appId = "application_1681717153064_3770270";
+    long taskAttemptId = 0;
+    Set<Long> successBlockIds = Sets.newConcurrentHashSet();
+    Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
+    MockShuffleWriteClient writeClient = new MockShuffleWriteClient();
+    writeClient.setMode(3);
+    RawComparator comparator = WritableComparator.get(BytesWritable.class);
+    long maxSegmentSize = 3 * 1024;
+    SerializationFactory serializationFactory = new SerializationFactory(new 
JobConf());
+    Serializer<BytesWritable> keySerializer =  
serializationFactory.getSerializer(BytesWritable.class);
+    Serializer<BytesWritable> valSerializer = 
serializationFactory.getSerializer(BytesWritable.class);
+    long maxBufferSize = 14 * 1024 * 1024;
+    double memoryThreshold = 0.8f;
+    double sendThreshold = 0.2f;
+    int batch = 50;
+    int numMaps = 1;
+    String storageType = "MEMORY";
+    RssConf rssConf = new RssConf();
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    long sendCheckInterval = 500L;
+    long sendCheckTimeout = 60 * 1000 * 10L;
+    int bitmapSplitNum = 1;
+    int shuffleId = getShuffleId(tezTaskAttemptID, "Map 1", "Reducer 2");
+
+    WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
+        new WriteBufferManager(tezTaskAttemptID, maxMemSize, appId,
+        taskAttemptId, successBlockIds, failedBlockIds, writeClient,
+        comparator, maxSegmentSize, keySerializer,
+        valSerializer, maxBufferSize, memoryThreshold,
+        sendThreshold, batch, rssConf, partitionToServers,
+        numMaps, isMemoryShuffleEnabled(storageType),
+        sendCheckInterval, sendCheckTimeout, bitmapSplitNum, shuffleId, true);
+
+    Random random = new Random();
+    for (int i = 0; i < 10000; i++) {
+      byte[] key = new byte[20];
+      byte[] value = new byte[1024];
+      random.nextBytes(key);
+      random.nextBytes(value);
+      int partitionId = random.nextInt(50);
+      bufferManager.addRecord(partitionId, new BytesWritable(key), new 
BytesWritable(value));
+    }
+
+    assertTrue(bufferManager.getWaitSendBuffers().isEmpty());
+    assertEquals(writeClient.mockedShuffleServer.getFinishBlockSize(),
+        writeClient.mockedShuffleServer.getFlushBlockSize());
+  }
+
+  private int getShuffleId(TezTaskAttemptID tezTaskAttemptID, String 
upVertexName, String downVertexName) {
+    TezVertexID tezVertexID = tezTaskAttemptID.getTaskID().getVertexID();
+    int shuffleId = 
RssTezUtils.computeShuffleId(tezVertexID.getDAGId().getId(), upVertexName, 
downVertexName);
+    return shuffleId;
+  }
+
+  private boolean isMemoryShuffleEnabled(String storageType) {
+    return StorageType.withMemory(StorageType.valueOf(storageType));
+  }
+
+  class MockShuffleServer {
+    private List<ShuffleBlockInfo> cachedBlockInfos = new ArrayList<>();
+    private List<ShuffleBlockInfo> flushBlockInfos = new ArrayList<>();
+    private List<Long> finishedBlockInfos = new ArrayList<>();
+
+    public synchronized void finishShuffle() {
+      flushBlockInfos.addAll(cachedBlockInfos);
+    }
+
+    public synchronized void addCachedBlockInfos(List<ShuffleBlockInfo> 
shuffleBlockInfoList) {
+      cachedBlockInfos.addAll(shuffleBlockInfoList);
+    }
+
+    public synchronized void addFinishedBlockInfos(List<Long> 
shuffleBlockInfoList) {
+      finishedBlockInfos.addAll(shuffleBlockInfoList);
+    }
+
+    public synchronized int getFlushBlockSize() {
+      return flushBlockInfos.size();
+    }
+
+    public synchronized int getFinishBlockSize() {
+      return finishedBlockInfos.size();
+    }
+  }
+
+  class MockShuffleWriteClient implements ShuffleWriteClient {
+
+    int mode = 0;
+    MockShuffleServer mockedShuffleServer = new MockShuffleServer();
+    int committedMaps = 0;
+
+    public void setMode(int mode) {
+      this.mode = mode;
+    }
+
+    @Override
+    public SendShuffleDataResult sendShuffleData(String appId, 
List<ShuffleBlockInfo> shuffleBlockInfoList,
+                                                 Supplier<Boolean> 
needCancelRequest) {
+      if (mode == 0) {
+        throw new RssException("send data failed.");
+      } else if (mode == 1) {
+        return new SendShuffleDataResult(Sets.newHashSet(2L), 
Sets.newHashSet(1L));
+      } else {
+        if (mode == 3) {
+          try {
+            Thread.sleep(10);
+            mockedShuffleServer.addCachedBlockInfos(shuffleBlockInfoList);
+          } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            throw new RssException(e.toString());
+          }
+        }
+        Set<Long> successBlockIds = Sets.newHashSet();
+        for (ShuffleBlockInfo blockInfo : shuffleBlockInfoList) {
+          successBlockIds.add(blockInfo.getBlockId());
+        }
+        return new SendShuffleDataResult(successBlockIds, Sets.newHashSet());
+      }
+    }
+
+    @Override
+    public void sendAppHeartbeat(String appId, long timeoutMs) {
+
+    }
+
+    @Override
+    public void registerApplicationInfo(String appId, long timeoutMs, String 
user) {
+
+    }
+
+    @Override
+    public void registerShuffle(ShuffleServerInfo shuffleServerInfo, String 
appId, int shuffleId,
+                                List<PartitionRange> partitionRanges, 
RemoteStorageInfo remoteStorage,
+                                ShuffleDataDistributionType 
dataDistributionType,
+                                int maxConcurrencyPerPartitionToWrite) {
+
+    }
+
+
+    @Override
+    public boolean sendCommit(Set<ShuffleServerInfo> shuffleServerInfoSet, 
String appId, int shuffleId, int numMaps) {
+      if (mode == 3) {
+        committedMaps++;
+        if (committedMaps >= numMaps) {
+          mockedShuffleServer.finishShuffle();
+        }
+        return true;
+      }
+      return false;
+    }
+
+    @Override
+    public void registerCoordinators(String coordinators) {
+
+    }
+
+    @Override
+    public Map<String, String> fetchClientConf(int timeoutMs) {
+      return null;
+    }
+
+    @Override
+    public RemoteStorageInfo fetchRemoteStorage(String appId) {
+      return null;
+    }
+
+    @Override
+    public void reportShuffleResult(Map<Integer, List<ShuffleServerInfo>> 
partitionToServers,
+                                    String appId, int shuffleId, long 
taskAttemptId,
+                                    Map<Integer, List<Long>> 
partitionToBlockIds, int bitmapNum) {
+      if (mode == 3) {
+        mockedShuffleServer.addFinishedBlockInfos(
+            partitionToBlockIds.values().stream().flatMap(it -> 
it.stream()).collect(Collectors.toList())
+        );
+      }
+    }
+
+    @Override
+    public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int 
shuffleId,
+        int partitionNum, int partitionNumPerRange,
+        Set<String> requiredTags, int assignmentShuffleServerNumber, int 
estimateTaskConcurrency) {
+      return null;
+    }
+
+    @Override
+    public Roaring64NavigableMap getShuffleResult(String clientType, 
Set<ShuffleServerInfo> shuffleServerInfoSet,
+        String appId, int shuffleId, int partitionId) {
+      return null;
+    }
+
+    @Override
+    public Roaring64NavigableMap getShuffleResultForMultiPart(String 
clientType, Map<ShuffleServerInfo,
+        Set<Integer>> serverToPartitions, String appId, int shuffleId) {
+      return null;
+    }
+
+    @Override
+    public void close() {
+
+    }
+
+    @Override
+    public void unregisterShuffle(String appId, int shuffleId) {
+
+    }
+  }
+
+}
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
new file mode 100644
index 00000000..96cd7cd3
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.buffer;
+
+import java.io.ByteArrayInputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.util.Map;
+import java.util.Random;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.WritableComparator;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.hadoop.io.serializer.Deserializer;
+import org.apache.hadoop.io.serializer.SerializationFactory;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.hadoop.mapred.JobConf;
+import org.junit.jupiter.api.Test;
+
+
+
+import static com.google.common.collect.Maps.newConcurrentMap;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class WriteBufferTest {
+
+  @Test
+  public void testReadWrite() throws IOException {
+
+    String keyStr = "key";
+    String valueStr = "value";
+    BytesWritable key = new BytesWritable(keyStr.getBytes());
+    BytesWritable value = new BytesWritable(valueStr.getBytes());
+    JobConf jobConf = new JobConf(new Configuration());
+    SerializationFactory serializationFactory = new 
SerializationFactory(jobConf);
+    Serializer<BytesWritable> keySerializer = 
serializationFactory.getSerializer(BytesWritable.class);
+    Serializer<BytesWritable> valSerializer = 
serializationFactory.getSerializer(BytesWritable.class);
+    WriteBuffer<BytesWritable, BytesWritable> buffer =
+        new WriteBuffer<BytesWritable, BytesWritable>(
+          true,
+          1,
+          WritableComparator.get(BytesWritable.class),
+          1024L,
+          keySerializer,
+          valSerializer);
+
+    long recordLength = buffer.addRecord(key, value);
+    assertEquals(20, buffer.getData().length);
+    assertEquals(16, recordLength);
+    assertEquals(1, buffer.getPartitionId());
+    byte[] result = buffer.getData();
+    Deserializer<BytesWritable> keyDeserializer = 
serializationFactory.getDeserializer(BytesWritable.class);
+    Deserializer<BytesWritable> valDeserializer = 
serializationFactory.getDeserializer(BytesWritable.class);
+    ByteArrayInputStream byteArrayInputStream = new 
ByteArrayInputStream(result);
+    keyDeserializer.open(byteArrayInputStream);
+    valDeserializer.open(byteArrayInputStream);
+
+    DataInputStream dStream = new DataInputStream(byteArrayInputStream);
+    int keyLen = readInt(dStream);
+    int valueLen = readInt(dStream);
+    assertEquals(recordLength, keyLen + valueLen);
+    BytesWritable keyRead = keyDeserializer.deserialize(null);
+    assertEquals(key, keyRead);
+    BytesWritable valueRead = keyDeserializer.deserialize(null);
+    assertEquals(value, valueRead);
+
+    buffer = new WriteBuffer<BytesWritable, BytesWritable>(
+        true,
+        1,
+        WritableComparator.get(BytesWritable.class),
+        528L,
+        keySerializer,
+        valSerializer);
+    long start = buffer.getDataLength();
+    assertEquals(0, start);
+    keyStr = "key3";
+    key = new BytesWritable(keyStr.getBytes());
+    keySerializer.serialize(key);
+    byte[] valueBytes = new byte[200];
+    Map<String, BytesWritable> valueMap = newConcurrentMap();
+    Random random = new Random();
+    random.nextBytes(valueBytes);
+    value = new BytesWritable(valueBytes);
+    valueMap.putIfAbsent(keyStr, value);
+    valSerializer.serialize(value);
+    recordLength = buffer.addRecord(key, value);
+    Map<String, Long> recordLenMap = newConcurrentMap();
+    recordLenMap.putIfAbsent(keyStr, recordLength);
+
+    keyStr = "key1";
+    key = new BytesWritable(keyStr.getBytes());
+    valueBytes = new byte[2032];
+    random.nextBytes(valueBytes);
+    value = new BytesWritable(valueBytes);
+    valueMap.putIfAbsent(keyStr, value);
+    recordLength = buffer.addRecord(key, value);
+    recordLenMap.putIfAbsent(keyStr, recordLength);
+
+    byte[] bigKey = new byte[555];
+    random.nextBytes(bigKey);
+    bigKey[0] = 'k';
+    bigKey[1] = 'e';
+    bigKey[2] = 'y';
+    bigKey[3] = '4';
+    final BytesWritable bigWritableKey = new BytesWritable(bigKey);
+    valueBytes = new byte[253];
+    random.nextBytes(valueBytes);
+    final BytesWritable bigWritableValue = new BytesWritable(valueBytes);
+    final long bigRecordLength = buffer.addRecord(bigWritableKey, 
bigWritableValue);
+    keyStr = "key2";
+    key = new BytesWritable(keyStr.getBytes());
+    valueBytes = new byte[3100];
+    value = new BytesWritable(valueBytes);
+    valueMap.putIfAbsent(keyStr, value);
+    recordLength = buffer.addRecord(key, value);
+    recordLenMap.putIfAbsent(keyStr, recordLength);
+
+    result = buffer.getData();
+    byteArrayInputStream = new ByteArrayInputStream(result);
+    keyDeserializer.open(byteArrayInputStream);
+    valDeserializer.open(byteArrayInputStream);
+    for (int i = 1; i <= 3; i++) {
+      dStream = new DataInputStream(byteArrayInputStream);
+      long keyLenTmp = readInt(dStream);
+      long valueLenTmp = readInt(dStream);
+      String tmpStr = "key" + i;
+      assertEquals(recordLenMap.get(tmpStr).longValue(), keyLenTmp + 
valueLenTmp);
+      keyRead = keyDeserializer.deserialize(null);
+      valueRead = valDeserializer.deserialize(null);
+      BytesWritable bytesWritable = new BytesWritable(tmpStr.getBytes());
+      assertEquals(bytesWritable, keyRead);
+      assertEquals(valueMap.get(tmpStr), valueRead);
+    }
+
+    dStream = new DataInputStream(byteArrayInputStream);
+    long keyLenTmp = readInt(dStream);
+    long valueLenTmp = readInt(dStream);
+    assertEquals(bigRecordLength, keyLenTmp + valueLenTmp);
+    keyRead = keyDeserializer.deserialize(null);
+    valueRead = valDeserializer.deserialize(null);
+    assertEquals(bigWritableKey, keyRead);
+    assertEquals(bigWritableValue, valueRead);
+  }
+
+  int readInt(DataInputStream dStream) throws IOException {
+    return WritableUtils.readVInt(dStream);
+  }
+
+}
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
new file mode 100644
index 00000000..7d7fa3fa
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.impl;
+
+import java.io.IOException;
+import java.lang.reflect.Field;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.tez.common.TezRuntimeFrameworkConfigs;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.output.OutputTestHelpers;
+import org.apache.tez.runtime.library.partitioner.HashPartitioner;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class RssSorterTest {
+  private static Map<Integer, List<ShuffleServerInfo>> partitionToServers = 
new HashMap<>();
+  private Configuration conf;
+  private FileSystem localFs;
+  private Path workingDir;
+
+  /**
+   * set up
+   */
+  @BeforeEach
+  public void setup() throws Exception {
+    conf = new Configuration();
+    localFs = FileSystem.getLocal(conf);
+    workingDir = new Path(System.getProperty("test.build.data",
+        System.getProperty("java.io.tmpdir", "/tmp")),
+        RssSorterTest.class.getName()).makeQualified(
+        localFs.getUri(), localFs.getWorkingDirectory());
+    conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS, 
Text.class.getName());
+    conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, 
Text.class.getName());
+    conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS,
+        HashPartitioner.class.getName());
+    conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, 
workingDir.toString());
+
+    Map<String, String> envMap = System.getenv();
+    Map<String, String> env = new HashMap<>();
+    env.putAll(envMap);
+    env.put(ApplicationConstants.Environment.CONTAINER_ID.name(), 
"container_e160_1681717153064_3770270_01_000001");
+
+    setEnv(env);
+  }
+
+  @Test
+  public void testCollectAndRecordsPerPartition() throws IOException, 
InterruptedException {
+    TezTaskAttemptID tezTaskAttemptID =
+        
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
+
+    OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, 
workingDir);
+
+    long initialMemoryAvailable = 10240000;
+    int shuffleId = 1001;
+
+    RssSorter rssSorter = new RssSorter(tezTaskAttemptID, outputContext, conf, 
5, 5, initialMemoryAvailable,
+        shuffleId, partitionToServers);
+
+    rssSorter.collect(new Text("0"), new Text("0"), 0);
+    rssSorter.collect(new Text("0"), new Text("1"), 0);
+    rssSorter.collect(new Text("1"), new Text("1"), 1);
+    rssSorter.collect(new Text("2"), new Text("2"), 2);
+    rssSorter.collect(new Text("3"), new Text("3"), 3);
+    rssSorter.collect(new Text("4"), new Text("4"), 4);
+
+    assertTrue(2 == rssSorter.getNumRecordsPerPartition()[0]);
+    assertTrue(1 == rssSorter.getNumRecordsPerPartition()[1]);
+    assertTrue(1 == rssSorter.getNumRecordsPerPartition()[2]);
+    assertTrue(1 == rssSorter.getNumRecordsPerPartition()[3]);
+    assertTrue(1 == rssSorter.getNumRecordsPerPartition()[4]);
+
+    assertTrue(5 == rssSorter.getNumRecordsPerPartition().length);
+  }
+
+
+
+  protected static void setEnv(Map<String, String> newEnv) throws Exception {
+    try {
+      Class<?> processEnvironmentClass = 
Class.forName("java.lang.ProcessEnvironment");
+      Field theEnvironmentField = 
processEnvironmentClass.getDeclaredField("theEnvironment");
+      theEnvironmentField.setAccessible(true);
+      Map<String, String> env = (Map<String, String>) 
theEnvironmentField.get(null);
+      env.putAll(newEnv);
+      Field theCaseInsensitiveEnvironmentField =
+          
processEnvironmentClass.getDeclaredField("theCaseInsensitiveEnvironment");
+      theCaseInsensitiveEnvironmentField.setAccessible(true);
+      Map<String, String> cienv = (Map<String, String>) 
theCaseInsensitiveEnvironmentField.get(null);
+      cienv.putAll(newEnv);
+    } catch (NoSuchFieldException e) {
+      Class[] classes = Collections.class.getDeclaredClasses();
+      Map<String, String> env = System.getenv();
+      for (Class cl : classes) {
+        if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) {
+          Field field = cl.getDeclaredField("m");
+          field.setAccessible(true);
+          Object obj = field.get(env);
+          Map<String, String> map = (Map<String, String>) obj;
+          map.clear();
+          map.putAll(newEnv);
+        }
+      }
+    }
+  }
+}
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecordTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecordTest.java
new file mode 100644
index 00000000..8a8de5ca
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecordTest.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.impl;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class RssTezPerPartitionRecordTest {
+
+  @Test
+  public void testNumPartitions() {
+    int[] numRecordsPerPartition = {0, 10, 10, 20, 30};
+    int numOutputs = 5;
+
+    RssTezPerPartitionRecord rssTezPerPartitionRecord
+        = new RssTezPerPartitionRecord(numOutputs, numRecordsPerPartition);
+
+    assertTrue(numOutputs == rssTezPerPartitionRecord.size());
+  }
+
+  @Test
+  public void testRssTezIndexHasData() {
+    int[] numRecordsPerPartition = {0, 10, 10, 20, 30};
+    int numOutputs = 5;
+
+    RssTezPerPartitionRecord rssTezPerPartitionRecord
+        = new RssTezPerPartitionRecord(numOutputs, numRecordsPerPartition);
+
+    for (int i = 0; i < numRecordsPerPartition.length; i++) {
+      if (0 == i) {
+        assertFalse(rssTezPerPartitionRecord.getIndex(i).hasData());
+      }
+      if (0 != i) {
+        assertTrue(rssTezPerPartitionRecord.getIndex(i).hasData());
+      }
+    }
+  }
+}
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/output/OutputTestHelpers.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/output/OutputTestHelpers.java
new file mode 100644
index 00000000..e6f8d7e7
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/output/OutputTestHelpers.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.output;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.runtime.api.MemoryUpdateCallback;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.api.OutputStatisticsReporter;
+import org.apache.tez.runtime.api.impl.ExecutionContextImpl;
+import org.apache.tez.runtime.library.common.MemoryUpdateCallbackHandler;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+
+public class OutputTestHelpers {
+  /**
+   * help to create output context
+   */
+  public static OutputContext createOutputContext(Configuration conf, Path 
workingDir) throws IOException {
+    OutputContext ctx = mock(OutputContext.class);
+
+    doAnswer(new Answer<Void>() {
+      @Override
+      public Void answer(InvocationOnMock invocation) throws Throwable {
+        long requestedSize = (Long) invocation.getArguments()[0];
+        MemoryUpdateCallbackHandler callback = (MemoryUpdateCallbackHandler) 
invocation
+            .getArguments()[1];
+        callback.memoryAssigned(requestedSize);
+        return null;
+      }
+    }).when(ctx).requestInitialMemory(anyLong(), 
any(MemoryUpdateCallback.class));
+    
doReturn(TezUtils.createUserPayloadFromConf(conf)).when(ctx).getUserPayload();
+    doReturn("Map 1").when(ctx).getTaskVertexName();
+    doReturn("Reducer 2").when(ctx).getDestinationVertexName();
+    
doReturn("attempt_1681717153064_3601637_1_13_000096_0").when(ctx).getUniqueIdentifier();
+    doReturn(new String[] { workingDir.toString() }).when(ctx).getWorkDirs();
+    doReturn(200 * 1024 * 1024L).when(ctx).getTotalMemoryAvailableToTask();
+    doReturn(new TezCounters()).when(ctx).getCounters();
+    OutputStatisticsReporter statsReporter = 
mock(OutputStatisticsReporter.class);
+    doReturn(statsReporter).when(ctx).getStatisticsReporter();
+    doReturn(new 
ExecutionContextImpl("localhost")).when(ctx).getExecutionContext();
+    return ctx;
+  }
+}
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutputTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutputTest.java
new file mode 100644
index 00000000..e4b68f81
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutputTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.output;
+
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.BitSet;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import com.google.protobuf.ByteString;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.TezRuntimeFrameworkConfigs;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.api.events.CompositeDataMovementEvent;
+import org.apache.tez.runtime.api.events.VertexManagerEvent;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.partitioner.HashPartitioner;
+import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class RssOrderedPartitionedKVOutputTest {
+  private static Map<Integer, List<ShuffleServerInfo>> partitionToServers = 
new HashMap<>();
+  private Configuration conf;
+  private FileSystem localFs;
+  private Path workingDir;
+
+  @BeforeEach
+  public void setup() throws IOException {
+    conf = new Configuration();
+    localFs = FileSystem.getLocal(conf);
+    workingDir = new Path(System.getProperty("test.build.data",
+        System.getProperty("java.io.tmpdir", "/tmp")),
+        RssOrderedPartitionedKVOutputTest.class.getName()).makeQualified(
+        localFs.getUri(), localFs.getWorkingDirectory());
+    conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS, 
Text.class.getName());
+    conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, 
Text.class.getName());
+    conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS,
+        HashPartitioner.class.getName());
+    conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, 
workingDir.toString());
+  }
+
+  @AfterEach
+  public void cleanup() throws IOException {
+    localFs.delete(workingDir, true);
+  }
+
+  @Test
+  @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
+  public void testNonStartedOutput() throws Exception {
+    OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, 
workingDir);
+    int numPartitions = 10;
+    RssOrderedPartitionedKVOutput output = new 
RssOrderedPartitionedKVOutput(outputContext, numPartitions);
+    List<Event> events = output.close();
+    assertEquals(2, events.size());
+    Event event1 = events.get(0);
+    assertTrue(event1 instanceof VertexManagerEvent);
+    Event event2 = events.get(1);
+    assertTrue(event2 instanceof CompositeDataMovementEvent);
+    CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) event2;
+    ByteBuffer bb = cdme.getUserPayload();
+    ShuffleUserPayloads.DataMovementEventPayloadProto shufflePayload =
+        
ShuffleUserPayloads.DataMovementEventPayloadProto.parseFrom(ByteString.copyFrom(bb));
+    assertTrue(shufflePayload.hasEmptyPartitions());
+    byte[] emptyPartitions = 
TezCommonUtils.decompressByteStringToByteArray(shufflePayload
+        .getEmptyPartitions());
+    BitSet emptyPartionsBitSet = 
TezUtilsInternal.fromByteArray(emptyPartitions);
+    assertEquals(numPartitions, emptyPartionsBitSet.cardinality());
+    for (int i = 0; i < numPartitions; i++) {
+      assertTrue(emptyPartionsBitSet.get(i));
+    }
+  }
+
+}


Reply via email to