This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new e9342ccd [#855] feat(tez): Support Tez Output
OrderedPartitionedKVOutput (#930)
e9342ccd is described below
commit e9342ccddb364916dbd72f5cc5e4c645465a5dc4
Author: bin41215 <[email protected]>
AuthorDate: Sat Jun 10 10:52:30 2023 +0800
[#855] feat(tez): Support Tez Output OrderedPartitionedKVOutput (#930)
### What changes were proposed in this pull request?
support tez write OrderedPartitionedKVOutput
### Why are the changes needed?
Fix:#855
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
UT
Co-authored-by: bin3.zhang <[email protected]>
---
.../library/common/sort/buffer/WriteBuffer.java | 332 +++++++++++++++++
.../common/sort/buffer/WriteBufferManager.java | 392 +++++++++++++++++++++
.../library/common/sort/impl/RssSorter.java | 206 +++++++++++
.../common/sort/impl/RssTezPerPartitionRecord.java | 83 +++++
.../output/RssOrderedPartitionedKVOutput.java | 292 +++++++++++++++
.../common/sort/buffer/WriteBufferManagerTest.java | 388 ++++++++++++++++++++
.../common/sort/buffer/WriteBufferTest.java | 165 +++++++++
.../library/common/sort/impl/RssSorterTest.java | 134 +++++++
.../sort/impl/RssTezPerPartitionRecordTest.java | 55 +++
.../runtime/library/output/OutputTestHelpers.java | 69 ++++
.../output/RssOrderedPartitionedKVOutputTest.java | 106 ++++++
11 files changed, 2222 insertions(+)
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
new file mode 100644
index 00000000..61bee2d6
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBuffer.java
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.buffer;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.Comparator;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+
+public class WriteBuffer<K,V> extends OutputStream {
+
+ private static final Logger LOG = LoggerFactory.getLogger(WriteBuffer.class);
+
+ private int partitionId;
+ private Serializer<K> keySerializer;
+ private Serializer<V> valSerializer;
+ private long maxSegmentSize;
+ private final RawComparator<K> comparator;
+ private int dataLength = 0;
+ private int currentOffset = 0;
+ private int currentIndex = 0;
+ private long sortTime = 0;
+ private long copyTime = 0;
+ private boolean isNeedSorted = false;
+ private final List<WrappedBuffer> buffers = Lists.newArrayList();
+ private final List<Record<K>> records = Lists.newArrayList();
+
+ public WriteBuffer(
+ boolean isNeedSorted,
+ int partitionId,
+ RawComparator<K> comparator,
+ long maxSegmentSize,
+ Serializer<K> keySerializer,
+ Serializer<V> valueSerializer) {
+ this.partitionId = partitionId;
+ this.comparator = comparator;
+ this.maxSegmentSize = maxSegmentSize;
+ this.keySerializer = keySerializer;
+ this.valSerializer = valueSerializer;
+ this.isNeedSorted = isNeedSorted;
+ }
+
+ /**
+ * add records
+ */
+ public int addRecord(K key, V value) throws IOException {
+ keySerializer.open(this);
+ valSerializer.open(this);
+ int lastOffSet = currentOffset;
+ int lastIndex = currentIndex;
+ int lastDataLength = dataLength;
+ int keyIndex = lastIndex;
+ keySerializer.serialize(key);
+ int keyLength = dataLength - lastDataLength;
+ int keyOffset = lastOffSet;
+ if (compact(lastIndex, lastOffSet, keyLength)) {
+ keyOffset = lastOffSet;
+ keyIndex = lastIndex;
+ }
+ lastDataLength = dataLength;
+ valSerializer.serialize(value);
+ int valueLength = dataLength - lastDataLength;
+ records.add(new Record<K>(keyIndex, keyOffset, keyLength, valueLength));
+ return keyLength + valueLength;
+ }
+
+ public void clear() {
+ buffers.clear();
+ records.clear();
+ }
+
+ /**
+ * get data
+ */
+ public synchronized byte[] getData() {
+ int extraSize = 0;
+ for (Record<K> record : records) {
+ extraSize += WritableUtils.getVIntSize(record.getKeyLength());
+ extraSize += WritableUtils.getVIntSize(record.getValueLength());
+ }
+ extraSize += WritableUtils.getVIntSize(-1);
+ extraSize += WritableUtils.getVIntSize(-1);
+ byte[] data = new byte[dataLength + extraSize];
+ int offset = 0;
+ long startSort = System.currentTimeMillis();
+ if (this.isNeedSorted) {
+ records.sort(new Comparator<Record<K>>() {
+ @Override
+ public int compare(Record<K> o1, Record<K> o2) {
+ return comparator.compare(
+ buffers.get(o1.getKeyIndex()).getBuffer(),
+ o1.getKeyOffSet(),
+ o1.getKeyLength(),
+ buffers.get(o2.getKeyIndex()).getBuffer(),
+ o2.getKeyOffSet(),
+ o2.getKeyLength());
+ }
+ });
+ }
+ long startCopy = System.currentTimeMillis();
+ sortTime += startCopy - startSort;
+
+ for (Record<K> record : records) {
+ offset = writeDataInt(data, offset, record.getKeyLength());
+ offset = writeDataInt(data, offset, record.getValueLength());
+ int recordLength = record.getKeyLength() + record.getValueLength();
+ int copyOffset = record.getKeyOffSet();
+ int copyIndex = record.getKeyIndex();
+ while (recordLength > 0) {
+ byte[] srcBytes = buffers.get(copyIndex).getBuffer();
+ int length = copyOffset + recordLength;
+ int copyLength = recordLength;
+ if (length > srcBytes.length) {
+ copyLength = srcBytes.length - copyOffset;
+ }
+ System.arraycopy(srcBytes, copyOffset, data, offset, copyLength);
+ copyOffset = 0;
+ copyIndex++;
+ recordLength -= copyLength;
+ offset += copyLength;
+ }
+ }
+ offset = writeDataInt(data, offset, -1);
+ writeDataInt(data, offset, -1);
+ copyTime += System.currentTimeMillis() - startCopy;
+ return data;
+ }
+
+ private boolean compact(int lastIndex, int lastOffset, int dataLength) {
+ if (lastIndex != currentIndex) {
+ LOG.debug("compact lastIndex {}, currentIndex {}, lastOffset {}
currentOffset {} dataLength {}",
+ lastIndex, currentIndex, lastOffset, currentOffset, dataLength);
+ WrappedBuffer buffer = new WrappedBuffer(lastOffset + dataLength);
+ // copy data
+ int offset = 0;
+ for (int i = lastIndex; i < currentIndex; i++) {
+ byte[] sourceBuffer = buffers.get(i).getBuffer();
+ System.arraycopy(sourceBuffer, 0, buffer.getBuffer(), offset,
sourceBuffer.length);
+ offset += sourceBuffer.length;
+ }
+ System.arraycopy(buffers.get(currentIndex).getBuffer(), 0,
buffer.getBuffer(), offset, currentOffset);
+ // remove data
+ for (int i = currentIndex; i >= lastIndex; i--) {
+ buffers.remove(i);
+ }
+ buffers.add(buffer);
+ currentOffset = 0;
+ WrappedBuffer anotherBuffer = new WrappedBuffer((int)maxSegmentSize);
+ buffers.add(anotherBuffer);
+ currentIndex = buffers.size() - 1;
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ if (buffers.isEmpty()) {
+ buffers.add(new WrappedBuffer((int)maxSegmentSize));
+ }
+ if (1 + currentOffset > maxSegmentSize) {
+ currentIndex++;
+ currentOffset = 0;
+ buffers.add(new WrappedBuffer((int)maxSegmentSize));
+ }
+ WrappedBuffer buffer = buffers.get(currentIndex);
+ buffer.getBuffer()[currentOffset] = (byte) b;
+ currentOffset++;
+ dataLength++;
+ }
+
+ @Override
+ public void write(byte[] b, int off, int len) throws IOException {
+ if (b == null) {
+ throw new NullPointerException();
+ } else if ((off < 0) || (off > b.length) || (len < 0)
+ || ((off + len) > b.length) || ((off + len) < 0)) {
+ throw new IndexOutOfBoundsException();
+ } else if (len == 0) {
+ return;
+ }
+ if (buffers.isEmpty()) {
+ buffers.add(new WrappedBuffer((int) maxSegmentSize));
+ }
+ int bufferNum = (int)((currentOffset + len) / maxSegmentSize);
+
+ for (int i = 0; i < bufferNum; i++) {
+ buffers.add(new WrappedBuffer((int) maxSegmentSize));
+ }
+
+ int index = currentIndex;
+ int offset = currentOffset;
+ int srcPos = 0;
+
+ while (len > 0) {
+ int copyLength = 0;
+ if (offset + len >= maxSegmentSize) {
+ copyLength = (int) (maxSegmentSize - offset);
+ currentOffset = 0;
+ } else {
+ copyLength = len;
+ currentOffset += len;
+ }
+ System.arraycopy(b, srcPos, buffers.get(index).getBuffer(), offset,
copyLength);
+ offset = 0;
+ srcPos += copyLength;
+ index++;
+ len -= copyLength;
+ dataLength += copyLength;
+ }
+ currentIndex += bufferNum;
+ }
+
+ private int writeDataInt(byte[] data, int offset, long dataInt) {
+ if (dataInt >= -112L && dataInt <= 127L) {
+ data[offset] = (byte)((int)dataInt);
+ offset++;
+ } else {
+ int len = -112;
+ if (dataInt < 0L) {
+ dataInt = ~dataInt;
+ len = -120;
+ }
+ for (long tmp = dataInt; tmp != 0L; --len) {
+ tmp >>= 8;
+ }
+ data[offset] = (byte)len;
+ offset++;
+ len = len < -120 ? -(len + 120) : -(len + 112);
+ for (int idx = len; idx != 0; --idx) {
+ int shiftBits = (idx - 1) * 8;
+ long mask = 255L << shiftBits;
+ data[offset] = ((byte)((int)((dataInt & mask) >> shiftBits)));
+ offset++;
+ }
+ }
+ return offset;
+ }
+
+ public int getDataLength() {
+ return dataLength;
+ }
+
+ public int getPartitionId() {
+ return partitionId;
+ }
+
+ public long getCopyTime() {
+ return copyTime;
+ }
+
+ public long getSortTime() {
+ return sortTime;
+ }
+
+ private static final class Record<K> {
+ private final int keyIndex;
+ private final int keyOffSet;
+ private final int keyLength;
+ private final int valueLength;
+
+ Record(int keyIndex,
+ int keyOffset,
+ int keyLength,
+ int valueLength) {
+ this.keyIndex = keyIndex;
+ this.keyOffSet = keyOffset;
+ this.keyLength = keyLength;
+ this.valueLength = valueLength;
+ }
+
+ public int getKeyIndex() {
+ return keyIndex;
+ }
+
+ public int getKeyOffSet() {
+ return keyOffSet;
+ }
+
+ public int getKeyLength() {
+ return keyLength;
+ }
+
+ public int getValueLength() {
+ return valueLength;
+ }
+ }
+
+ private static final class WrappedBuffer {
+ private byte[] buffer;
+ private int size;
+
+ WrappedBuffer(int size) {
+ this.buffer = new byte[size];
+ this.size = size;
+ }
+
+ public byte[] getBuffer() {
+ return buffer;
+ }
+
+ public int getSize() {
+ return size;
+ }
+ }
+
+}
+
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
new file mode 100644
index 00000000..96fa6f29
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
@@ -0,0 +1,392 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.buffer;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.ReentrantLock;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.response.SendShuffleDataResult;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ChecksumUtils;
+import org.apache.uniffle.common.util.ThreadUtils;
+
+
+public class WriteBufferManager<K,V> {
+ private static final Logger LOG =
LoggerFactory.getLogger(WriteBufferManager.class);
+ private long copyTime = 0;
+ private long sortTime = 0;
+ private long compressTime = 0;
+ private final Map<Integer, Integer> partitionToSeqNo = Maps.newHashMap();
+ private long uncompressedDataLen = 0;
+ private final long maxMemSize;
+ private final ExecutorService sendExecutorService;
+ private final ShuffleWriteClient shuffleWriteClient;
+ private final String appId;
+ private final Set<Long> successBlockIds;
+ private final Set<Long> failedBlockIds;
+ private final ReentrantLock memoryLock = new ReentrantLock();
+ private final AtomicLong memoryUsedSize = new AtomicLong(0);
+ private final AtomicLong inSendListBytes = new AtomicLong(0);
+ private final Condition full = memoryLock.newCondition();
+ private final RawComparator<K> comparator;
+ private final long maxSegmentSize;
+ private final Serializer<K> keySerializer;
+ private final Serializer<V> valSerializer;
+ private final List<WriteBuffer<K, V>> waitSendBuffers =
Lists.newLinkedList();
+ private final Map<Integer, WriteBuffer<K, V>> buffers =
Maps.newConcurrentMap();
+ private final long maxBufferSize;
+ private final double memoryThreshold;
+ private final double sendThreshold;
+ private final int batch;
+ private final Codec codec;
+ private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+ private final Set<Long> allBlockIds = Sets.newConcurrentHashSet();
+ private final Map<Integer, List<Long>> partitionToBlocks =
Maps.newConcurrentMap();
+ private final int numMaps;
+ private final boolean isMemoryShuffleEnabled;
+ private final long sendCheckInterval;
+ private final long sendCheckTimeout;
+ private final int bitmapSplitNum;
+ private final long taskAttemptId;
+ private TezTaskAttemptID tezTaskAttemptID;
+ private final RssConf rssConf;
+ private final int shuffleId;
+ private final boolean isNeedSorted;
+
+ /**
+ * WriteBufferManager
+ */
+ public WriteBufferManager(
+ TezTaskAttemptID tezTaskAttemptID,
+ long maxMemSize,
+ String appId,
+ long taskAttemptId,
+ Set<Long> successBlockIds,
+ Set<Long> failedBlockIds,
+ ShuffleWriteClient shuffleWriteClient,
+ RawComparator<K> comparator,
+ long maxSegmentSize,
+ Serializer<K> keySerializer,
+ Serializer<V> valSerializer,
+ long maxBufferSize,
+ double memoryThreshold,
+ double sendThreshold,
+ int batch,
+ RssConf rssConf,
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+ int numMaps,
+ boolean isMemoryShuffleEnabled,
+ long sendCheckInterval,
+ long sendCheckTimeout,
+ int bitmapSplitNum,
+ int shuffleId,
+ boolean isNeedSorted) {
+ this.tezTaskAttemptID = tezTaskAttemptID;
+ this.maxMemSize = maxMemSize;
+ this.appId = appId;
+ this.taskAttemptId = taskAttemptId;
+ this.successBlockIds = successBlockIds;
+ this.failedBlockIds = failedBlockIds;
+ this.shuffleWriteClient = shuffleWriteClient;
+ this.comparator = comparator;
+ this.maxSegmentSize = maxSegmentSize;
+ this.keySerializer = keySerializer;
+ this.valSerializer = valSerializer;
+ this.maxBufferSize = maxBufferSize;
+ this.memoryThreshold = memoryThreshold;
+ this.sendThreshold = sendThreshold;
+ this.batch = batch;
+ this.codec = Codec.newInstance(rssConf);
+ this.partitionToServers = partitionToServers;
+ this.numMaps = numMaps;
+ this.isMemoryShuffleEnabled = isMemoryShuffleEnabled;
+ this.sendCheckInterval = sendCheckInterval;
+ this.sendCheckTimeout = sendCheckTimeout;
+ this.bitmapSplitNum = bitmapSplitNum;
+ this.rssConf = rssConf;
+ this.shuffleId = shuffleId;
+ this.isNeedSorted = isNeedSorted;
+ this.sendExecutorService = Executors.newFixedThreadPool(
+ 1,
+ ThreadUtils.getThreadFactory("send-thread"));
+ }
+
+ /**
+ * add record
+ */
+ public void addRecord(int partitionId, K key, V value) throws
InterruptedException, IOException {
+ memoryLock.lock();
+ try {
+ while (memoryUsedSize.get() > maxMemSize) {
+ LOG.warn("memoryUsedSize {} is more than {}, inSendListBytes {}",
+ memoryUsedSize, maxMemSize, inSendListBytes);
+ full.await();
+ }
+ } finally {
+ memoryLock.unlock();
+ }
+
+ if (!buffers.containsKey(partitionId)) {
+ WriteBuffer<K, V> sortWriterBuffer = new WriteBuffer(
+ isNeedSorted, partitionId, comparator, maxSegmentSize,
keySerializer, valSerializer);
+ buffers.putIfAbsent(partitionId, sortWriterBuffer);
+ waitSendBuffers.add(sortWriterBuffer);
+ }
+ WriteBuffer<K, V> buffer = buffers.get(partitionId);
+ int length = buffer.addRecord(key, value);
+ if (length > maxMemSize) {
+ throw new RssException("record is too big");
+ }
+
+ memoryUsedSize.addAndGet(length);
+ if (buffer.getDataLength() > maxBufferSize) {
+ if (waitSendBuffers.remove(buffer)) {
+ sendBufferToServers(buffer);
+ } else {
+ LOG.error("waitSendBuffers don't contain buffer {}", buffer);
+ }
+ }
+ if (memoryUsedSize.get() > maxMemSize * memoryThreshold
+ && inSendListBytes.get() <= maxMemSize * sendThreshold) {
+ sendBuffersToServers();
+ }
+ }
+
+ private void sendBufferToServers(WriteBuffer<K, V> buffer) {
+ List<ShuffleBlockInfo> shuffleBlocks = Lists.newArrayList();
+ prepareBufferForSend(shuffleBlocks, buffer);
+ sendShuffleBlocks(shuffleBlocks);
+ }
+
+ void sendBuffersToServers() {
+ waitSendBuffers.sort(new Comparator<WriteBuffer<K, V>>() {
+ @Override
+ public int compare(WriteBuffer<K, V> o1, WriteBuffer<K, V> o2) {
+ return o2.getDataLength() - o1.getDataLength();
+ }
+ });
+
+ int sendSize = batch;
+ if (batch > waitSendBuffers.size()) {
+ sendSize = waitSendBuffers.size();
+ }
+
+ Iterator<WriteBuffer<K, V>> iterator = waitSendBuffers.iterator();
+ int index = 0;
+ List<ShuffleBlockInfo> shuffleBlocks = Lists.newArrayList();
+ while (iterator.hasNext() && index < sendSize) {
+ WriteBuffer<K, V> buffer = iterator.next();
+ prepareBufferForSend(shuffleBlocks, buffer);
+ iterator.remove();
+ index++;
+ }
+ sendShuffleBlocks(shuffleBlocks);
+ }
+
+ private void prepareBufferForSend(List<ShuffleBlockInfo> shuffleBlocks,
WriteBuffer buffer) {
+ buffers.remove(buffer.getPartitionId());
+ ShuffleBlockInfo block = createShuffleBlock(buffer);
+ buffer.clear();
+ shuffleBlocks.add(block);
+ allBlockIds.add(block.getBlockId());
+ if (!partitionToBlocks.containsKey(block.getPartitionId())) {
+ partitionToBlocks.putIfAbsent(block.getPartitionId(),
Lists.newArrayList());
+ }
+ partitionToBlocks.get(block.getPartitionId()).add(block.getBlockId());
+ }
+
+ private void sendShuffleBlocks(List<ShuffleBlockInfo> shuffleBlocks) {
+ sendExecutorService.submit(new Runnable() {
+ @Override
+ public void run() {
+ long size = 0;
+ try {
+ for (ShuffleBlockInfo block : shuffleBlocks) {
+ size += block.getFreeMemory();
+ }
+ SendShuffleDataResult result =
shuffleWriteClient.sendShuffleData(appId, shuffleBlocks, () -> false);
+ successBlockIds.addAll(result.getSuccessBlockIds());
+ failedBlockIds.addAll(result.getFailedBlockIds());
+ } catch (Throwable t) {
+ LOG.warn("send shuffle data exception ", t);
+ } finally {
+ try {
+ memoryLock.lock();
+ LOG.debug("memoryUsedSize {} decrease {}", memoryUsedSize, size);
+ memoryUsedSize.addAndGet(-size);
+ inSendListBytes.addAndGet(-size);
+ full.signalAll();
+ } finally {
+ memoryLock.unlock();
+ }
+ }
+ }
+ });
+ }
+
+ /**
+ * wait send finished
+ */
+ public void waitSendFinished() {
+ while (!waitSendBuffers.isEmpty()) {
+ sendBuffersToServers();
+ }
+ long start = System.currentTimeMillis();
+ long commitDuration = 0;
+ if (!isMemoryShuffleEnabled) {
+ long s = System.currentTimeMillis();
+ sendCommit();
+ commitDuration = System.currentTimeMillis() - s;
+ }
+ while (true) {
+ if (failedBlockIds.size() > 0) {
+ String errorMsg = "Send failed: failed because " +
failedBlockIds.size()
+ + " blocks can't be sent to shuffle server.";
+ LOG.error(errorMsg);
+ throw new RssException(errorMsg);
+ }
+ allBlockIds.removeAll(successBlockIds);
+ if (allBlockIds.isEmpty()) {
+ break;
+ }
+ LOG.info("Wait " + allBlockIds.size() + " blocks sent to shuffle
server");
+ Uninterruptibles.sleepUninterruptibly(sendCheckInterval,
TimeUnit.MILLISECONDS);
+ if (System.currentTimeMillis() - start > sendCheckTimeout) {
+ String errorMsg = "Timeout: failed because " + allBlockIds.size()
+ + " blocks can't be sent to shuffle server in " + sendCheckTimeout
+ " ms.";
+ LOG.error(errorMsg);
+ throw new RssException(errorMsg);
+ }
+ }
+ start = System.currentTimeMillis();
+ TezVertexID tezVertexID = tezTaskAttemptID.getTaskID().getVertexID();
+ TezDAGID tezDAGID = tezVertexID.getDAGId();
+ LOG.info("tezVertexID is {}, tezDAGID is {}, shuffleId is {}",
tezVertexID, tezDAGID, shuffleId);
+ shuffleWriteClient.reportShuffleResult(partitionToServers, appId,
shuffleId,
+ taskAttemptId, partitionToBlocks, bitmapSplitNum);
+ LOG.info("Report shuffle result for task[{}] with bitmapNum[{}] cost {}
ms",
+ taskAttemptId, bitmapSplitNum, (System.currentTimeMillis() -
start));
+ LOG.info("Task uncompressed data length {} compress time cost {} ms,
commit time cost {} ms,"
+ + " copy time cost {} ms, sort time cost {} ms",
+ uncompressedDataLen, compressTime, commitDuration, copyTime,
sortTime);
+ }
+
+ ShuffleBlockInfo createShuffleBlock(WriteBuffer wb) {
+ byte[] data = wb.getData();
+ copyTime += wb.getCopyTime();
+ sortTime += wb.getSortTime();
+ int partitionId = wb.getPartitionId();
+ final int uncompressLength = data.length;
+ long start = System.currentTimeMillis();
+
+ final byte[] compressed = codec.compress(data);
+ final long crc32 = ChecksumUtils.getCrc32(compressed);
+ compressTime += System.currentTimeMillis() - start;
+ final long blockId = RssTezUtils.getBlockId((long)partitionId,
taskAttemptId, getNextSeqNo(partitionId));
+ LOG.info("blockId is {}", blockId);
+ uncompressedDataLen += data.length;
+ // add memory to indicate bytes which will be sent to shuffle server
+ inSendListBytes.addAndGet(wb.getDataLength());
+
+ TezVertexID tezVertexID = tezTaskAttemptID.getTaskID().getVertexID();
+ TezDAGID tezDAGID = tezVertexID.getDAGId();
+ LOG.info("tezVertexID is {}, tezDAGID is {}, shuffleId is {}",
tezVertexID, tezDAGID, shuffleId);
+ return new ShuffleBlockInfo(shuffleId, partitionId, blockId,
compressed.length, crc32,
+ compressed, partitionToServers.get(partitionId),
+ uncompressLength, wb.getDataLength(), taskAttemptId);
+ }
+
+ protected void sendCommit() {
+ ExecutorService executor = Executors.newSingleThreadExecutor();
+ Set<ShuffleServerInfo> serverInfos = Sets.newHashSet();
+ for (List<ShuffleServerInfo> serverInfoList : partitionToServers.values())
{
+ for (ShuffleServerInfo serverInfo : serverInfoList) {
+ serverInfos.add(serverInfo);
+ }
+ }
+ LOG.info("sendCommit shuffle id is {}", shuffleId);
+ Future<Boolean> future = executor.submit(
+ () -> shuffleWriteClient.sendCommit(serverInfos, appId, shuffleId,
numMaps));
+ long start = System.currentTimeMillis();
+ int currentWait = 200;
+ int maxWait = 5000;
+ while (!future.isDone()) {
+ LOG.info("Wait commit to shuffle server for task[" + taskAttemptId + "]
cost "
+ + (System.currentTimeMillis() - start) + " ms");
+ Uninterruptibles.sleepUninterruptibly(currentWait,
TimeUnit.MILLISECONDS);
+ currentWait = Math.min(currentWait * 2, maxWait);
+ }
+ try {
+ if (!future.get()) {
+ throw new RssException("Failed to commit task to shuffle server");
+ }
+ } catch (InterruptedException ie) {
+ LOG.warn("Ignore the InterruptedException which should be caused by
internal killed");
+ } catch (Exception e) {
+ throw new RssException("Exception happened when get commit status", e);
+ } finally {
+ executor.shutdown();
+ }
+ }
+
+ List<WriteBuffer<K,V>> getWaitSendBuffers() {
+ return waitSendBuffers;
+ }
+
+ private int getNextSeqNo(int partitionId) {
+ partitionToSeqNo.putIfAbsent(partitionId, 0);
+ int seqNo = partitionToSeqNo.get(partitionId);
+ partitionToSeqNo.put(partitionId, seqNo + 1);
+ return seqNo;
+ }
+
+ public void freeAllResources() {
+ sendExecutorService.shutdownNow();
+ }
+
+}
+
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
new file mode 100644
index 00000000..45b9f5d6
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.impl;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.Sets;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.util.ConverterUtils;
+import org.apache.tez.common.IdUtils;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.library.common.sort.buffer.WriteBufferManager;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ByteUnit;
+import org.apache.uniffle.storage.util.StorageType;
+
+
+/**{@link RssSorter} is an {@link ExternalSorter}
+ */
+public class RssSorter extends ExternalSorter {
+
+ private static final Logger LOG = LoggerFactory.getLogger(RssSorter.class);
+ private WriteBufferManager bufferManager;
+ private Set<Long> successBlockIds = Sets.newConcurrentHashSet();
+ private Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
+ private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+
+ private int[] numRecordsPerPartition;
+
+ /**
+ * Initialization
+ */
+ public RssSorter(TezTaskAttemptID tezTaskAttemptID,OutputContext
outputContext,
+ Configuration conf, int numMaps, int numOutputs,
+ long initialMemoryAvailable, int shuffleId,
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers)
throws IOException {
+ super(outputContext, conf, numOutputs, initialMemoryAvailable);
+ this.partitionToServers = partitionToServers;
+
+ this.numRecordsPerPartition = new int[numOutputs];
+
+ long sortmb = conf.getLong(RssTezConfig.RSS_RUNTIME_IO_SORT_MB,
RssTezConfig.RSS_DEFAULT_RUNTIME_IO_SORT_MB);
+ LOG.info("conf.sortmb is {}", sortmb);
+ sortmb = this.availableMemoryMb;
+ LOG.info("sortmb, availableMemoryMb is {}, {}", sortmb, availableMemoryMb);
+ if ((sortmb & 0x7FF) != sortmb) {
+ throw new IOException(
+ "Invalid \"" + RssTezConfig.RSS_RUNTIME_IO_SORT_MB + "\": " +
sortmb);
+ }
+ double sortThreshold =
conf.getDouble(RssTezConfig.RSS_CLIENT_SORT_MEMORY_USE_THRESHOLD,
+ RssTezConfig.RSS_CLIENT_DEFAULT_SORT_MEMORY_USE_THRESHOLD);
+ long taskAttemptId =
RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptID,
IdUtils.getAppAttemptId());
+
+ long maxSegmentSize = conf.getLong(RssTezConfig.RSS_CLIENT_MAX_BUFFER_SIZE,
+ RssTezConfig.RSS_CLIENT_DEFAULT_MAX_BUFFER_SIZE);
+ long maxBufferSize = conf.getLong(RssTezConfig.RSS_WRITER_BUFFER_SIZE,
RssTezConfig.RSS_DEFAULT_WRITER_BUFFER_SIZE);
+ double memoryThreshold =
conf.getDouble(RssTezConfig.RSS_CLIENT_MEMORY_THRESHOLD,
+ RssTezConfig.RSS_CLIENT_DEFAULT_MEMORY_THRESHOLD);
+ double sendThreshold =
conf.getDouble(RssTezConfig.RSS_CLIENT_SEND_THRESHOLD,
+ RssTezConfig.RSS_CLIENT_DEFAULT_SEND_THRESHOLD);
+ int batch = conf.getInt(RssTezConfig.RSS_CLIENT_BATCH_TRIGGER_NUM,
+ RssTezConfig.RSS_CLIENT_DEFAULT_BATCH_TRIGGER_NUM);
+ String storageType = conf.get(RssTezConfig.RSS_STORAGE_TYPE,
RssTezConfig.RSS_DEFAULT_STORAGE_TYPE);
+ if (StringUtils.isEmpty(storageType)) {
+ throw new RssException("storage type mustn't be empty");
+ }
+ long sendCheckInterval =
conf.getLong(RssTezConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS,
+ RssTezConfig.RSS_CLIENT_DEFAULT_SEND_CHECK_INTERVAL_MS);
+ long sendCheckTimeout =
conf.getLong(RssTezConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS,
+ RssTezConfig.RSS_CLIENT_DEFAULT_SEND_CHECK_TIMEOUT_MS);
+ int bitmapSplitNum = conf.getInt(RssTezConfig.RSS_CLIENT_BITMAP_NUM,
+ RssTezConfig.RSS_CLIENT_DEFAULT_BITMAP_NUM);
+
+ if (conf.get(RssTezConfig.HIVE_TEZ_LOG_LEVEL,
RssTezConfig.DEFAULT_HIVE_TEZ_LOG_LEVEL)
+ .equalsIgnoreCase(RssTezConfig.DEBUG_HIVE_TEZ_LOG_LEVEL)) {
+ LOG.info("sortmb is {}", sortmb);
+ LOG.info("sortThreshold is {}", sortThreshold);
+ LOG.info("taskAttemptId is {}", taskAttemptId);
+ LOG.info("maxSegmentSize is {}", maxSegmentSize);
+ LOG.info("maxBufferSize is {}", maxBufferSize);
+ LOG.info("memoryThreshold is {}", memoryThreshold);
+ LOG.info("sendThreshold is {}", sendThreshold);
+ LOG.info("batch is {}", batch);
+ LOG.info("storageType is {}", storageType);
+ LOG.info("sendCheckInterval is {}", sendCheckInterval);
+ LOG.info("sendCheckTimeout is {}", sendCheckTimeout);
+ LOG.info("bitmapSplitNum is {}", bitmapSplitNum);
+ }
+
+
+ String containerIdStr =
+ System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name());
+ ContainerId containerId = ConverterUtils.toContainerId(containerIdStr);
+ ApplicationAttemptId applicationAttemptId =
+ containerId.getApplicationAttemptId();
+ LOG.info("containerIdStr is {}", containerIdStr);
+ LOG.info("containerId is {}", containerId);
+ LOG.info("applicationAttemptId is {}", applicationAttemptId.toString());
+
+
+ bufferManager = new WriteBufferManager(
+ tezTaskAttemptID,
+ (long)(ByteUnit.MiB.toBytes(sortmb) * sortThreshold),
+ applicationAttemptId.toString(),
+ taskAttemptId,
+ successBlockIds,
+ failedBlockIds,
+ RssTezUtils.createShuffleClient(conf),
+ comparator,
+ maxSegmentSize,
+ keySerializer,
+ valSerializer,
+ maxBufferSize,
+ memoryThreshold,
+ sendThreshold,
+ batch,
+ new RssConf(),
+ partitionToServers,
+ numMaps,
+ isMemoryShuffleEnabled(storageType),
+ sendCheckInterval,
+ sendCheckTimeout,
+ bitmapSplitNum,
+ shuffleId,
+ true);
+ LOG.info("Initialized WriteBufferManager.");
+ }
+
+ @Override
+ public void flush() throws IOException {
+ bufferManager.waitSendFinished();
+ }
+
+ @Override
+ public final void close() throws IOException {
+ super.close();
+ bufferManager.freeAllResources();
+ }
+
+ @Override
+ public void write(Object key, Object value) throws IOException {
+ try {
+ collect(key, value, partitioner.getPartition(key, value, partitions));
+ } catch (InterruptedException e) {
+ throw new RssException(e);
+ }
+ }
+
+ synchronized void collect(Object key, Object value, final int partition)
throws IOException, InterruptedException {
+ if (key.getClass() != keyClass) {
+ throw new IOException("Type mismatch in key from map: expected "
+ + keyClass.getName() + ", received "
+ + key.getClass().getName());
+ }
+ if (value.getClass() != valClass) {
+ throw new IOException("Type mismatch in value from map: expected "
+ + valClass.getName() + ", received "
+ + value.getClass().getName());
+ }
+ if (partition < 0 || partition >= partitions) {
+ throw new IOException("Illegal partition for " + key + " ("
+ + partition + ")");
+ }
+
+ bufferManager.addRecord(partition, key, value);
+ numRecordsPerPartition[partition]++;
+ }
+
+ public int[] getNumRecordsPerPartition() {
+ return numRecordsPerPartition;
+ }
+
+ private boolean isMemoryShuffleEnabled(String storageType) {
+ return StorageType.withMemory(StorageType.valueOf(storageType));
+ }
+}
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecord.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecord.java
new file mode 100644
index 00000000..97bb2ebf
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecord.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.impl;
+
+import java.io.IOException;
+import java.util.zip.Checksum;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+
+
+public class RssTezPerPartitionRecord extends TezSpillRecord {
+ private int numPartitions;
+ private int[] numRecordsPerPartition;
+
+ public RssTezPerPartitionRecord(int numPartitions) {
+ super(numPartitions);
+ this.numPartitions = numPartitions;
+ }
+
+ public RssTezPerPartitionRecord(int numPartitions, int[]
numRecordsPerPartition) {
+ super(numPartitions);
+ this.numPartitions = numPartitions;
+ this.numRecordsPerPartition = numRecordsPerPartition;
+ }
+
+
+ public RssTezPerPartitionRecord(Path indexFileName, Configuration job)
throws IOException {
+ super(indexFileName, job);
+ }
+
+ public RssTezPerPartitionRecord(Path indexFileName, Configuration job,
String expectedIndexOwner) throws IOException {
+ super(indexFileName, job, expectedIndexOwner);
+ }
+
+ public RssTezPerPartitionRecord(Path indexFileName, Configuration job,
Checksum crc, String expectedIndexOwner)
+ throws IOException {
+ super(indexFileName, job, crc, expectedIndexOwner);
+ }
+
+ @Override
+ public int size() {
+ return numPartitions;
+ }
+
+ @Override
+ public RssTezIndexRecord getIndex(int i) {
+ int records = numRecordsPerPartition[i];
+ RssTezIndexRecord rssTezIndexRecord = new RssTezIndexRecord();
+ rssTezIndexRecord.setData(!(records == 0));
+ return rssTezIndexRecord;
+ }
+
+
+ static class RssTezIndexRecord extends TezIndexRecord {
+ private boolean hasData;
+
+ private void setData(boolean hasData) {
+ this. hasData = hasData;
+ }
+
+ @Override
+ public boolean hasData() {
+ return hasData;
+ }
+ }
+
+}
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
new file mode 100644
index 00000000..928ec3c4
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.output;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.security.PrivilegedExceptionAction;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.zip.Deflater;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceAudience.Public;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.ipc.RPC;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.GetShuffleServerRequest;
+import org.apache.tez.common.GetShuffleServerResponse;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.AbstractLogicalOutput;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.api.Writer;
+import org.apache.tez.runtime.library.api.KeyValuesWriter;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.common.MemoryUpdateCallbackHandler;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
+import org.apache.tez.runtime.library.common.sort.impl.ExternalSorter;
+import org.apache.tez.runtime.library.common.sort.impl.RssSorter;
+import
org.apache.tez.runtime.library.common.sort.impl.RssTezPerPartitionRecord;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+
+
+/**
+ * {@link RssOrderedPartitionedKVOutput} is an {@link AbstractLogicalOutput}
which
+ * support remote shuffle.
+ */
+@Public
+public class RssOrderedPartitionedKVOutput extends AbstractLogicalOutput {
+ private static final Logger LOG =
LoggerFactory.getLogger(RssOrderedPartitionedKVOutput.class);
+ protected ExternalSorter sorter;
+ protected Configuration conf;
+ protected MemoryUpdateCallbackHandler memoryUpdateCallbackHandler;
+ private long startTime;
+ private long endTime;
+ private final AtomicBoolean isStarted = new AtomicBoolean(false);
+ private final Deflater deflater;
+ private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+ private int mapNum;
+ private int numOutputs;
+ private TezTaskAttemptID taskAttemptId;
+ private ApplicationId applicationId;
+ private boolean sendEmptyPartitionDetails;
+ private OutputContext outputContext;
+ private String host;
+ private int port;
+ private String taskVertexName;
+ private String destinationVertexName;
+ private int shuffleId;
+
+
+ public RssOrderedPartitionedKVOutput(OutputContext outputContext, int
numPhysicalOutputs) {
+ super(outputContext, numPhysicalOutputs);
+ this.outputContext = outputContext;
+ this.deflater = TezCommonUtils.newBestCompressionDeflater();
+ this.numOutputs = getNumPhysicalOutputs();
+ this.mapNum = outputContext.getVertexParallelism();
+ this.applicationId = outputContext.getApplicationId();
+ this.taskAttemptId = TezTaskAttemptID.fromString(
+
RssTezUtils.uniqueIdentifierToAttemptId(outputContext.getUniqueIdentifier()));
+ this.taskVertexName = outputContext.getTaskVertexName();
+ this.destinationVertexName = outputContext.getDestinationVertexName();
+ LOG.info("taskAttemptId is {}", taskAttemptId.toString());
+ LOG.info("taskVertexName is {}", taskVertexName);
+ LOG.info("destinationVertexName is {}", destinationVertexName);
+ LOG.info("Initialized RssOrderedPartitionedKVOutput.");
+ }
+
+ private void getRssConf() {
+ try {
+ JobConf conf = new JobConf(RssTezConfig.RSS_CONF_FILE);
+ this.host = conf.get(RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS, "null
host");
+ this.port = conf.getInt(RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT, -1);
+
+ LOG.info("Got RssConf am info : host is {}, port is {}", host, port);
+ } catch (Exception e) {
+ LOG.warn("debugRssConf error: ", e);
+ }
+ }
+
+ @Override
+ public List<Event> initialize() throws Exception {
+ this.startTime = System.nanoTime();
+ this.conf =
TezUtils.createConfFromUserPayload(getContext().getUserPayload());
+ this.memoryUpdateCallbackHandler = new MemoryUpdateCallbackHandler();
+
+ long memRequestSize = RssTezUtils.getInitialMemoryRequirement(conf,
getContext().getTotalMemoryAvailableToTask());
+ LOG.info("memRequestSize is {}", memRequestSize);
+ getContext().requestInitialMemory(memRequestSize,
memoryUpdateCallbackHandler);
+ LOG.info("Got initialMemory.");
+
+ getRssConf();
+
+ this.sendEmptyPartitionDetails = conf.getBoolean(
+
TezRuntimeConfiguration.TEZ_RUNTIME_EMPTY_PARTITION_INFO_VIA_EVENTS_ENABLED,
+
TezRuntimeConfiguration.TEZ_RUNTIME_EMPTY_PARTITION_INFO_VIA_EVENTS_ENABLED_DEFAULT);
+
+ final InetSocketAddress address = NetUtils.createSocketAddrForHost(host,
port);
+
+ UserGroupInformation taskOwner =
UserGroupInformation.createRemoteUser(this.applicationId.toString());
+
+ TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
+ .doAs(new
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
+ @Override
+ public TezRemoteShuffleUmbilicalProtocol run() throws Exception {
+ return RPC.getProxy(TezRemoteShuffleUmbilicalProtocol.class,
+ TezRemoteShuffleUmbilicalProtocol.versionID,
+ address, conf);
+ }
+ });
+ TezVertexID tezVertexID = taskAttemptId.getTaskID().getVertexID();
+ TezDAGID tezDAGID = tezVertexID.getDAGId();
+ this.shuffleId = RssTezUtils.computeShuffleId(tezDAGID.getId(),
this.taskVertexName, this.destinationVertexName);
+ GetShuffleServerRequest request = new
GetShuffleServerRequest(this.taskAttemptId, this.mapNum,
+ this.numOutputs, this.shuffleId);
+ GetShuffleServerResponse response =
umbilical.getShuffleAssignments(request);
+ this.partitionToServers = response.getShuffleAssignmentsInfoWritable()
+ .getShuffleAssignmentsInfo()
+ .getPartitionToServers();
+
+ LOG.info("Got response from am.");
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void handleEvents(List<Event> list) {
+
+ }
+
+ @Override
+ public List<Event> close() throws Exception {
+ List<Event> returnEvents = Lists.newLinkedList();
+ if (sorter != null) {
+ sorter.flush();
+ sorter.close();
+ this.endTime = System.nanoTime();
+ returnEvents.addAll(generateEvents());
+ sorter = null;
+ } else {
+ LOG.warn(getContext().getDestinationVertexName()
+ + ": Attempting to close output {} of type {} before it was started.
Generating empty events",
+ getContext().getDestinationVertexName(),
this.getClass().getSimpleName());
+ returnEvents = generateEmptyEvents();
+ }
+ LOG.info("RssOrderedPartitionedKVOutput close.");
+ return returnEvents;
+ }
+
+ @Override
+ public void start() throws Exception {
+ if (!isStarted.get()) {
+ memoryUpdateCallbackHandler.validateUpdateReceived();
+ sorter = new RssSorter(taskAttemptId, getContext(), conf, mapNum,
numOutputs,
+ memoryUpdateCallbackHandler.getMemoryAssigned(), shuffleId,
+ partitionToServers);
+ LOG.info("Initialized RssSorter.");
+ isStarted.set(true);
+ }
+ }
+
+ @Override
+ public Writer getWriter() throws IOException {
+ Preconditions.checkState(isStarted.get(), "Cannot get writer before
starting the Output");
+
+ return new KeyValuesWriter() {
+ @Override
+ public void write(Object key, Iterable<Object> values) throws
IOException {
+ sorter.write(key, values);
+ }
+
+ @Override
+ public void write(Object key, Object value) throws IOException {
+ sorter.write(key, value);
+ }
+ };
+ }
+
+ private List<Event> generateEvents() throws IOException {
+ List<Event> eventList = Lists.newLinkedList();
+ boolean isLastEvent = true;
+
+ String auxiliaryService =
conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
+ TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT);
+
+ int[] numRecordsPerPartition = ((RssSorter)
sorter).getNumRecordsPerPartition();
+
+ RssTezPerPartitionRecord rssTezSpillRecord = new
RssTezPerPartitionRecord(numOutputs, numRecordsPerPartition);
+
+ LOG.info("RssTezPerPartitionRecord is initialized");
+
+ ShuffleUtils.generateEventOnSpill(eventList, true, isLastEvent,
+ getContext(), 0, rssTezSpillRecord,
+ getNumPhysicalOutputs(), sendEmptyPartitionDetails,
getContext().getUniqueIdentifier(),
+ sorter.getPartitionStats(), sorter.reportDetailedPartitionStats(),
auxiliaryService, deflater);
+ LOG.info("Generate events.");
+
+ return eventList;
+ }
+
+ private List<Event> generateEmptyEvents() throws IOException {
+ List<Event> eventList = Lists.newArrayList();
+ ShuffleUtils.generateEventsForNonStartedOutput(eventList,
+ getNumPhysicalOutputs(),
+ getContext(),
+ true,
+ true,
+ deflater);
+ LOG.info("Generate empty events.");
+ return eventList;
+ }
+
+ private static final Set<String> confKeys = new HashSet<String>();
+
+ static {
+ confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_READAHEAD);
+ confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_READAHEAD_BYTES);
+ confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_IO_FILE_BUFFER_SIZE);
+
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_INDEX_CACHE_MEMORY_LIMIT_BYTES);
+
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_UNORDERED_OUTPUT_BUFFER_SIZE_MB);
+
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_UNORDERED_OUTPUT_MAX_PER_BUFFER_SIZE_BYTES);
+ confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS);
+ confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS);
+ confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS);
+ confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMPRESS);
+ confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_COMPRESS_CODEC);
+
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_EMPTY_PARTITION_INFO_VIA_EVENTS_ENABLED);
+
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_CONVERT_USER_PAYLOAD_TO_HISTORY_TEXT);
+
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_PIPELINED_SHUFFLE_ENABLED);
+
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_ENABLE_FINAL_MERGE_IN_OUTPUT);
+ confKeys.add(TezConfiguration.TEZ_COUNTERS_MAX);
+ confKeys.add(TezConfiguration.TEZ_COUNTERS_GROUP_NAME_MAX_LENGTH);
+ confKeys.add(TezConfiguration.TEZ_COUNTERS_COUNTER_NAME_MAX_LENGTH);
+ confKeys.add(TezConfiguration.TEZ_COUNTERS_MAX_GROUPS);
+
confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_CLEANUP_FILES_ON_INTERRUPT);
+ confKeys.add(TezRuntimeConfiguration.TEZ_RUNTIME_REPORT_PARTITION_STATS);
+ confKeys.add(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID);
+ confKeys.add(
+
TezRuntimeConfiguration.TEZ_RUNTIME_UNORDERED_PARTITIONED_KVWRITER_BUFFER_MERGE_PERCENT);
+ }
+
+ @InterfaceAudience.Private
+ public static Set<String> getConfigurationKeySet() {
+ return Collections.unmodifiableSet(confKeys);
+ }
+
+}
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
new file mode 100644
index 00000000..ce3a8a48
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -0,0 +1,388 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.buffer;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+
+import com.google.common.collect.Sets;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.WritableComparator;
+import org.apache.hadoop.io.serializer.SerializationFactory;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.response.SendShuffleDataResult;
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleDataDistributionType;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.storage.util.StorageType;
+
+
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class WriteBufferManagerTest {
+ @Test
+ public void testWriteException() throws IOException, InterruptedException {
+ TezTaskAttemptID tezTaskAttemptID =
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
+ long maxMemSize = 10240;
+ String appId = "application_1681717153064_3770270";
+ long taskAttemptId = 0;
+ Set<Long> successBlockIds = Sets.newConcurrentHashSet();
+ Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
+ MockShuffleWriteClient writeClient = new MockShuffleWriteClient();
+ RawComparator comparator = WritableComparator.get(BytesWritable.class);
+ long maxSegmentSize = 3 * 1024;
+ SerializationFactory serializationFactory = new SerializationFactory(new
JobConf());
+ Serializer<BytesWritable> keySerializer =
serializationFactory.getSerializer(BytesWritable.class);
+ Serializer<BytesWritable> valSerializer =
serializationFactory.getSerializer(BytesWritable.class);
+ long maxBufferSize = 14 * 1024 * 1024;
+ double memoryThreshold = 0.8f;
+ double sendThreshold = 0.2f;
+ int batch = 50;
+ int numMaps = 1;
+ String storageType = "MEMORY";
+ RssConf rssConf = new RssConf();
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+ long sendCheckInterval = 500L;
+ long sendCheckTimeout = 5;
+ int bitmapSplitNum = 1;
+ int shuffleId = getShuffleId(tezTaskAttemptID, "Map 1", "Reducer 2");
+
+ WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
+ new WriteBufferManager(tezTaskAttemptID, maxMemSize, appId,
+ taskAttemptId, successBlockIds, failedBlockIds, writeClient,
+ comparator, maxSegmentSize, keySerializer,
+ valSerializer, maxBufferSize, memoryThreshold,
+ sendThreshold, batch, rssConf, partitionToServers,
+ numMaps, isMemoryShuffleEnabled(storageType),
+ sendCheckInterval, sendCheckTimeout, bitmapSplitNum, shuffleId, true);
+
+ Random random = new Random();
+ for (int i = 0; i < 1000; i++) {
+ byte[] key = new byte[20];
+ byte[] value = new byte[1024];
+ random.nextBytes(key);
+ random.nextBytes(value);
+ bufferManager.addRecord(1, new BytesWritable(key), new
BytesWritable(value));
+ }
+
+ boolean isException = false;
+ try {
+ bufferManager.waitSendFinished();
+ } catch (RssException re) {
+ isException = true;
+ }
+ assertTrue(isException);
+ }
+
+ @Test
+ public void testWriteNormal() throws IOException, InterruptedException {
+ TezTaskAttemptID tezTaskAttemptID =
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
+ long maxMemSize = 10240;
+ String appId = "appattempt_1681717153064_3770270_000001";
+ long taskAttemptId = 0;
+ Set<Long> successBlockIds = Sets.newConcurrentHashSet();
+ Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
+ MockShuffleWriteClient writeClient = new MockShuffleWriteClient();
+ writeClient.setMode(2);
+ RawComparator comparator = WritableComparator.get(BytesWritable.class);
+ long maxSegmentSize = 3 * 1024;
+ SerializationFactory serializationFactory = new SerializationFactory(new
JobConf());
+ Serializer<BytesWritable> keySerializer =
serializationFactory.getSerializer(BytesWritable.class);
+ Serializer<BytesWritable> valSerializer =
serializationFactory.getSerializer(BytesWritable.class);
+ long maxBufferSize = 14 * 1024 * 1024;
+ double memoryThreshold = 0.8f;
+ double sendThreshold = 0.2f;
+ int batch = 50;
+ int numMaps = 1;
+ String storageType = "MEMORY";
+ RssConf rssConf = new RssConf();
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+ long sendCheckInterval = 500L;
+ long sendCheckTimeout = 60 * 1000 * 10L;
+ int bitmapSplitNum = 1;
+ int shuffleId = getShuffleId(tezTaskAttemptID, "Map 1", "Reducer 2");
+
+ WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
+ new WriteBufferManager(tezTaskAttemptID, maxMemSize, appId,
+ taskAttemptId, successBlockIds, failedBlockIds, writeClient,
+ comparator, maxSegmentSize, keySerializer,
+ valSerializer, maxBufferSize, memoryThreshold,
+ sendThreshold, batch, rssConf, partitionToServers,
+ numMaps, isMemoryShuffleEnabled(storageType),
+ sendCheckInterval, sendCheckTimeout, bitmapSplitNum, shuffleId, true);
+
+ Random random = new Random();
+ for (int i = 0; i < 1000; i++) {
+ byte[] key = new byte[20];
+ byte[] value = new byte[1024];
+ random.nextBytes(key);
+ random.nextBytes(value);
+ int partitionId = random.nextInt(50);
+ bufferManager.addRecord(partitionId, new BytesWritable(key), new
BytesWritable(value));
+ }
+ bufferManager.waitSendFinished();
+ assertTrue(bufferManager.getWaitSendBuffers().isEmpty());
+
+ for (int i = 0; i < 50; i++) {
+ byte[] key = new byte[20];
+ byte[] value = new byte[i * 100];
+ random.nextBytes(key);
+ random.nextBytes(value);
+ bufferManager.addRecord(i, new BytesWritable(key), new
BytesWritable(value));
+ }
+ assert (1 == bufferManager.getWaitSendBuffers().size());
+ assert (4928 == bufferManager.getWaitSendBuffers().get(0).getDataLength());
+
+ bufferManager.waitSendFinished();
+ assertTrue(bufferManager.getWaitSendBuffers().isEmpty());
+ }
+
+ @Test
+ public void testCommitBlocksWhenMemoryShuffleDisabled() throws IOException,
InterruptedException {
+ TezTaskAttemptID tezTaskAttemptID =
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
+ long maxMemSize = 10240;
+ String appId = "application_1681717153064_3770270";
+ long taskAttemptId = 0;
+ Set<Long> successBlockIds = Sets.newConcurrentHashSet();
+ Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
+ MockShuffleWriteClient writeClient = new MockShuffleWriteClient();
+ writeClient.setMode(3);
+ RawComparator comparator = WritableComparator.get(BytesWritable.class);
+ long maxSegmentSize = 3 * 1024;
+ SerializationFactory serializationFactory = new SerializationFactory(new
JobConf());
+ Serializer<BytesWritable> keySerializer =
serializationFactory.getSerializer(BytesWritable.class);
+ Serializer<BytesWritable> valSerializer =
serializationFactory.getSerializer(BytesWritable.class);
+ long maxBufferSize = 14 * 1024 * 1024;
+ double memoryThreshold = 0.8f;
+ double sendThreshold = 0.2f;
+ int batch = 50;
+ int numMaps = 1;
+ String storageType = "MEMORY";
+ RssConf rssConf = new RssConf();
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+ long sendCheckInterval = 500L;
+ long sendCheckTimeout = 60 * 1000 * 10L;
+ int bitmapSplitNum = 1;
+ int shuffleId = getShuffleId(tezTaskAttemptID, "Map 1", "Reducer 2");
+
+ WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
+ new WriteBufferManager(tezTaskAttemptID, maxMemSize, appId,
+ taskAttemptId, successBlockIds, failedBlockIds, writeClient,
+ comparator, maxSegmentSize, keySerializer,
+ valSerializer, maxBufferSize, memoryThreshold,
+ sendThreshold, batch, rssConf, partitionToServers,
+ numMaps, isMemoryShuffleEnabled(storageType),
+ sendCheckInterval, sendCheckTimeout, bitmapSplitNum, shuffleId, true);
+
+ Random random = new Random();
+ for (int i = 0; i < 10000; i++) {
+ byte[] key = new byte[20];
+ byte[] value = new byte[1024];
+ random.nextBytes(key);
+ random.nextBytes(value);
+ int partitionId = random.nextInt(50);
+ bufferManager.addRecord(partitionId, new BytesWritable(key), new
BytesWritable(value));
+ }
+
+ assertTrue(bufferManager.getWaitSendBuffers().isEmpty());
+ assertEquals(writeClient.mockedShuffleServer.getFinishBlockSize(),
+ writeClient.mockedShuffleServer.getFlushBlockSize());
+ }
+
+ private int getShuffleId(TezTaskAttemptID tezTaskAttemptID, String
upVertexName, String downVertexName) {
+ TezVertexID tezVertexID = tezTaskAttemptID.getTaskID().getVertexID();
+ int shuffleId =
RssTezUtils.computeShuffleId(tezVertexID.getDAGId().getId(), upVertexName,
downVertexName);
+ return shuffleId;
+ }
+
+ private boolean isMemoryShuffleEnabled(String storageType) {
+ return StorageType.withMemory(StorageType.valueOf(storageType));
+ }
+
+ class MockShuffleServer {
+ private List<ShuffleBlockInfo> cachedBlockInfos = new ArrayList<>();
+ private List<ShuffleBlockInfo> flushBlockInfos = new ArrayList<>();
+ private List<Long> finishedBlockInfos = new ArrayList<>();
+
+ public synchronized void finishShuffle() {
+ flushBlockInfos.addAll(cachedBlockInfos);
+ }
+
+ public synchronized void addCachedBlockInfos(List<ShuffleBlockInfo>
shuffleBlockInfoList) {
+ cachedBlockInfos.addAll(shuffleBlockInfoList);
+ }
+
+ public synchronized void addFinishedBlockInfos(List<Long>
shuffleBlockInfoList) {
+ finishedBlockInfos.addAll(shuffleBlockInfoList);
+ }
+
+ public synchronized int getFlushBlockSize() {
+ return flushBlockInfos.size();
+ }
+
+ public synchronized int getFinishBlockSize() {
+ return finishedBlockInfos.size();
+ }
+ }
+
+ class MockShuffleWriteClient implements ShuffleWriteClient {
+
+ int mode = 0;
+ MockShuffleServer mockedShuffleServer = new MockShuffleServer();
+ int committedMaps = 0;
+
+ public void setMode(int mode) {
+ this.mode = mode;
+ }
+
+ @Override
+ public SendShuffleDataResult sendShuffleData(String appId,
List<ShuffleBlockInfo> shuffleBlockInfoList,
+ Supplier<Boolean>
needCancelRequest) {
+ if (mode == 0) {
+ throw new RssException("send data failed.");
+ } else if (mode == 1) {
+ return new SendShuffleDataResult(Sets.newHashSet(2L),
Sets.newHashSet(1L));
+ } else {
+ if (mode == 3) {
+ try {
+ Thread.sleep(10);
+ mockedShuffleServer.addCachedBlockInfos(shuffleBlockInfoList);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RssException(e.toString());
+ }
+ }
+ Set<Long> successBlockIds = Sets.newHashSet();
+ for (ShuffleBlockInfo blockInfo : shuffleBlockInfoList) {
+ successBlockIds.add(blockInfo.getBlockId());
+ }
+ return new SendShuffleDataResult(successBlockIds, Sets.newHashSet());
+ }
+ }
+
+ @Override
+ public void sendAppHeartbeat(String appId, long timeoutMs) {
+
+ }
+
+ @Override
+ public void registerApplicationInfo(String appId, long timeoutMs, String
user) {
+
+ }
+
+ @Override
+ public void registerShuffle(ShuffleServerInfo shuffleServerInfo, String
appId, int shuffleId,
+ List<PartitionRange> partitionRanges,
RemoteStorageInfo remoteStorage,
+ ShuffleDataDistributionType
dataDistributionType,
+ int maxConcurrencyPerPartitionToWrite) {
+
+ }
+
+
+ @Override
+ public boolean sendCommit(Set<ShuffleServerInfo> shuffleServerInfoSet,
String appId, int shuffleId, int numMaps) {
+ if (mode == 3) {
+ committedMaps++;
+ if (committedMaps >= numMaps) {
+ mockedShuffleServer.finishShuffle();
+ }
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public void registerCoordinators(String coordinators) {
+
+ }
+
+ @Override
+ public Map<String, String> fetchClientConf(int timeoutMs) {
+ return null;
+ }
+
+ @Override
+ public RemoteStorageInfo fetchRemoteStorage(String appId) {
+ return null;
+ }
+
+ @Override
+ public void reportShuffleResult(Map<Integer, List<ShuffleServerInfo>>
partitionToServers,
+ String appId, int shuffleId, long
taskAttemptId,
+ Map<Integer, List<Long>>
partitionToBlockIds, int bitmapNum) {
+ if (mode == 3) {
+ mockedShuffleServer.addFinishedBlockInfos(
+ partitionToBlockIds.values().stream().flatMap(it ->
it.stream()).collect(Collectors.toList())
+ );
+ }
+ }
+
+ @Override
+ public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int
shuffleId,
+ int partitionNum, int partitionNumPerRange,
+ Set<String> requiredTags, int assignmentShuffleServerNumber, int
estimateTaskConcurrency) {
+ return null;
+ }
+
+ @Override
+ public Roaring64NavigableMap getShuffleResult(String clientType,
Set<ShuffleServerInfo> shuffleServerInfoSet,
+ String appId, int shuffleId, int partitionId) {
+ return null;
+ }
+
+ @Override
+ public Roaring64NavigableMap getShuffleResultForMultiPart(String
clientType, Map<ShuffleServerInfo,
+ Set<Integer>> serverToPartitions, String appId, int shuffleId) {
+ return null;
+ }
+
+ @Override
+ public void close() {
+
+ }
+
+ @Override
+ public void unregisterShuffle(String appId, int shuffleId) {
+
+ }
+ }
+
+}
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
new file mode 100644
index 00000000..96cd7cd3
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.buffer;
+
+import java.io.ByteArrayInputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.util.Map;
+import java.util.Random;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.WritableComparator;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.hadoop.io.serializer.Deserializer;
+import org.apache.hadoop.io.serializer.SerializationFactory;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.hadoop.mapred.JobConf;
+import org.junit.jupiter.api.Test;
+
+
+
+import static com.google.common.collect.Maps.newConcurrentMap;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class WriteBufferTest {
+
+ @Test
+ public void testReadWrite() throws IOException {
+
+ String keyStr = "key";
+ String valueStr = "value";
+ BytesWritable key = new BytesWritable(keyStr.getBytes());
+ BytesWritable value = new BytesWritable(valueStr.getBytes());
+ JobConf jobConf = new JobConf(new Configuration());
+ SerializationFactory serializationFactory = new
SerializationFactory(jobConf);
+ Serializer<BytesWritable> keySerializer =
serializationFactory.getSerializer(BytesWritable.class);
+ Serializer<BytesWritable> valSerializer =
serializationFactory.getSerializer(BytesWritable.class);
+ WriteBuffer<BytesWritable, BytesWritable> buffer =
+ new WriteBuffer<BytesWritable, BytesWritable>(
+ true,
+ 1,
+ WritableComparator.get(BytesWritable.class),
+ 1024L,
+ keySerializer,
+ valSerializer);
+
+ long recordLength = buffer.addRecord(key, value);
+ assertEquals(20, buffer.getData().length);
+ assertEquals(16, recordLength);
+ assertEquals(1, buffer.getPartitionId());
+ byte[] result = buffer.getData();
+ Deserializer<BytesWritable> keyDeserializer =
serializationFactory.getDeserializer(BytesWritable.class);
+ Deserializer<BytesWritable> valDeserializer =
serializationFactory.getDeserializer(BytesWritable.class);
+ ByteArrayInputStream byteArrayInputStream = new
ByteArrayInputStream(result);
+ keyDeserializer.open(byteArrayInputStream);
+ valDeserializer.open(byteArrayInputStream);
+
+ DataInputStream dStream = new DataInputStream(byteArrayInputStream);
+ int keyLen = readInt(dStream);
+ int valueLen = readInt(dStream);
+ assertEquals(recordLength, keyLen + valueLen);
+ BytesWritable keyRead = keyDeserializer.deserialize(null);
+ assertEquals(key, keyRead);
+ BytesWritable valueRead = keyDeserializer.deserialize(null);
+ assertEquals(value, valueRead);
+
+ buffer = new WriteBuffer<BytesWritable, BytesWritable>(
+ true,
+ 1,
+ WritableComparator.get(BytesWritable.class),
+ 528L,
+ keySerializer,
+ valSerializer);
+ long start = buffer.getDataLength();
+ assertEquals(0, start);
+ keyStr = "key3";
+ key = new BytesWritable(keyStr.getBytes());
+ keySerializer.serialize(key);
+ byte[] valueBytes = new byte[200];
+ Map<String, BytesWritable> valueMap = newConcurrentMap();
+ Random random = new Random();
+ random.nextBytes(valueBytes);
+ value = new BytesWritable(valueBytes);
+ valueMap.putIfAbsent(keyStr, value);
+ valSerializer.serialize(value);
+ recordLength = buffer.addRecord(key, value);
+ Map<String, Long> recordLenMap = newConcurrentMap();
+ recordLenMap.putIfAbsent(keyStr, recordLength);
+
+ keyStr = "key1";
+ key = new BytesWritable(keyStr.getBytes());
+ valueBytes = new byte[2032];
+ random.nextBytes(valueBytes);
+ value = new BytesWritable(valueBytes);
+ valueMap.putIfAbsent(keyStr, value);
+ recordLength = buffer.addRecord(key, value);
+ recordLenMap.putIfAbsent(keyStr, recordLength);
+
+ byte[] bigKey = new byte[555];
+ random.nextBytes(bigKey);
+ bigKey[0] = 'k';
+ bigKey[1] = 'e';
+ bigKey[2] = 'y';
+ bigKey[3] = '4';
+ final BytesWritable bigWritableKey = new BytesWritable(bigKey);
+ valueBytes = new byte[253];
+ random.nextBytes(valueBytes);
+ final BytesWritable bigWritableValue = new BytesWritable(valueBytes);
+ final long bigRecordLength = buffer.addRecord(bigWritableKey,
bigWritableValue);
+ keyStr = "key2";
+ key = new BytesWritable(keyStr.getBytes());
+ valueBytes = new byte[3100];
+ value = new BytesWritable(valueBytes);
+ valueMap.putIfAbsent(keyStr, value);
+ recordLength = buffer.addRecord(key, value);
+ recordLenMap.putIfAbsent(keyStr, recordLength);
+
+ result = buffer.getData();
+ byteArrayInputStream = new ByteArrayInputStream(result);
+ keyDeserializer.open(byteArrayInputStream);
+ valDeserializer.open(byteArrayInputStream);
+ for (int i = 1; i <= 3; i++) {
+ dStream = new DataInputStream(byteArrayInputStream);
+ long keyLenTmp = readInt(dStream);
+ long valueLenTmp = readInt(dStream);
+ String tmpStr = "key" + i;
+ assertEquals(recordLenMap.get(tmpStr).longValue(), keyLenTmp +
valueLenTmp);
+ keyRead = keyDeserializer.deserialize(null);
+ valueRead = valDeserializer.deserialize(null);
+ BytesWritable bytesWritable = new BytesWritable(tmpStr.getBytes());
+ assertEquals(bytesWritable, keyRead);
+ assertEquals(valueMap.get(tmpStr), valueRead);
+ }
+
+ dStream = new DataInputStream(byteArrayInputStream);
+ long keyLenTmp = readInt(dStream);
+ long valueLenTmp = readInt(dStream);
+ assertEquals(bigRecordLength, keyLenTmp + valueLenTmp);
+ keyRead = keyDeserializer.deserialize(null);
+ valueRead = valDeserializer.deserialize(null);
+ assertEquals(bigWritableKey, keyRead);
+ assertEquals(bigWritableValue, valueRead);
+ }
+
+ int readInt(DataInputStream dStream) throws IOException {
+ return WritableUtils.readVInt(dStream);
+ }
+
+}
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
new file mode 100644
index 00000000..7d7fa3fa
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.impl;
+
+import java.io.IOException;
+import java.lang.reflect.Field;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.tez.common.TezRuntimeFrameworkConfigs;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.output.OutputTestHelpers;
+import org.apache.tez.runtime.library.partitioner.HashPartitioner;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class RssSorterTest {
+ private static Map<Integer, List<ShuffleServerInfo>> partitionToServers =
new HashMap<>();
+ private Configuration conf;
+ private FileSystem localFs;
+ private Path workingDir;
+
+ /**
+ * set up
+ */
+ @BeforeEach
+ public void setup() throws Exception {
+ conf = new Configuration();
+ localFs = FileSystem.getLocal(conf);
+ workingDir = new Path(System.getProperty("test.build.data",
+ System.getProperty("java.io.tmpdir", "/tmp")),
+ RssSorterTest.class.getName()).makeQualified(
+ localFs.getUri(), localFs.getWorkingDirectory());
+ conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS,
Text.class.getName());
+ conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS,
Text.class.getName());
+ conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS,
+ HashPartitioner.class.getName());
+ conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS,
workingDir.toString());
+
+ Map<String, String> envMap = System.getenv();
+ Map<String, String> env = new HashMap<>();
+ env.putAll(envMap);
+ env.put(ApplicationConstants.Environment.CONTAINER_ID.name(),
"container_e160_1681717153064_3770270_01_000001");
+
+ setEnv(env);
+ }
+
+ @Test
+ public void testCollectAndRecordsPerPartition() throws IOException,
InterruptedException {
+ TezTaskAttemptID tezTaskAttemptID =
+
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
+
+ OutputContext outputContext = OutputTestHelpers.createOutputContext(conf,
workingDir);
+
+ long initialMemoryAvailable = 10240000;
+ int shuffleId = 1001;
+
+ RssSorter rssSorter = new RssSorter(tezTaskAttemptID, outputContext, conf,
5, 5, initialMemoryAvailable,
+ shuffleId, partitionToServers);
+
+ rssSorter.collect(new Text("0"), new Text("0"), 0);
+ rssSorter.collect(new Text("0"), new Text("1"), 0);
+ rssSorter.collect(new Text("1"), new Text("1"), 1);
+ rssSorter.collect(new Text("2"), new Text("2"), 2);
+ rssSorter.collect(new Text("3"), new Text("3"), 3);
+ rssSorter.collect(new Text("4"), new Text("4"), 4);
+
+ assertTrue(2 == rssSorter.getNumRecordsPerPartition()[0]);
+ assertTrue(1 == rssSorter.getNumRecordsPerPartition()[1]);
+ assertTrue(1 == rssSorter.getNumRecordsPerPartition()[2]);
+ assertTrue(1 == rssSorter.getNumRecordsPerPartition()[3]);
+ assertTrue(1 == rssSorter.getNumRecordsPerPartition()[4]);
+
+ assertTrue(5 == rssSorter.getNumRecordsPerPartition().length);
+ }
+
+
+
+ protected static void setEnv(Map<String, String> newEnv) throws Exception {
+ try {
+ Class<?> processEnvironmentClass =
Class.forName("java.lang.ProcessEnvironment");
+ Field theEnvironmentField =
processEnvironmentClass.getDeclaredField("theEnvironment");
+ theEnvironmentField.setAccessible(true);
+ Map<String, String> env = (Map<String, String>)
theEnvironmentField.get(null);
+ env.putAll(newEnv);
+ Field theCaseInsensitiveEnvironmentField =
+
processEnvironmentClass.getDeclaredField("theCaseInsensitiveEnvironment");
+ theCaseInsensitiveEnvironmentField.setAccessible(true);
+ Map<String, String> cienv = (Map<String, String>)
theCaseInsensitiveEnvironmentField.get(null);
+ cienv.putAll(newEnv);
+ } catch (NoSuchFieldException e) {
+ Class[] classes = Collections.class.getDeclaredClasses();
+ Map<String, String> env = System.getenv();
+ for (Class cl : classes) {
+ if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) {
+ Field field = cl.getDeclaredField("m");
+ field.setAccessible(true);
+ Object obj = field.get(env);
+ Map<String, String> map = (Map<String, String>) obj;
+ map.clear();
+ map.putAll(newEnv);
+ }
+ }
+ }
+ }
+}
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecordTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecordTest.java
new file mode 100644
index 00000000..8a8de5ca
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssTezPerPartitionRecordTest.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.sort.impl;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class RssTezPerPartitionRecordTest {
+
+ @Test
+ public void testNumPartitions() {
+ int[] numRecordsPerPartition = {0, 10, 10, 20, 30};
+ int numOutputs = 5;
+
+ RssTezPerPartitionRecord rssTezPerPartitionRecord
+ = new RssTezPerPartitionRecord(numOutputs, numRecordsPerPartition);
+
+ assertTrue(numOutputs == rssTezPerPartitionRecord.size());
+ }
+
+ @Test
+ public void testRssTezIndexHasData() {
+ int[] numRecordsPerPartition = {0, 10, 10, 20, 30};
+ int numOutputs = 5;
+
+ RssTezPerPartitionRecord rssTezPerPartitionRecord
+ = new RssTezPerPartitionRecord(numOutputs, numRecordsPerPartition);
+
+ for (int i = 0; i < numRecordsPerPartition.length; i++) {
+ if (0 == i) {
+ assertFalse(rssTezPerPartitionRecord.getIndex(i).hasData());
+ }
+ if (0 != i) {
+ assertTrue(rssTezPerPartitionRecord.getIndex(i).hasData());
+ }
+ }
+ }
+}
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/output/OutputTestHelpers.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/output/OutputTestHelpers.java
new file mode 100644
index 00000000..e6f8d7e7
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/output/OutputTestHelpers.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.output;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.runtime.api.MemoryUpdateCallback;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.api.OutputStatisticsReporter;
+import org.apache.tez.runtime.api.impl.ExecutionContextImpl;
+import org.apache.tez.runtime.library.common.MemoryUpdateCallbackHandler;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+
+public class OutputTestHelpers {
+ /**
+ * help to create output context
+ */
+ public static OutputContext createOutputContext(Configuration conf, Path
workingDir) throws IOException {
+ OutputContext ctx = mock(OutputContext.class);
+
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) throws Throwable {
+ long requestedSize = (Long) invocation.getArguments()[0];
+ MemoryUpdateCallbackHandler callback = (MemoryUpdateCallbackHandler)
invocation
+ .getArguments()[1];
+ callback.memoryAssigned(requestedSize);
+ return null;
+ }
+ }).when(ctx).requestInitialMemory(anyLong(),
any(MemoryUpdateCallback.class));
+
doReturn(TezUtils.createUserPayloadFromConf(conf)).when(ctx).getUserPayload();
+ doReturn("Map 1").when(ctx).getTaskVertexName();
+ doReturn("Reducer 2").when(ctx).getDestinationVertexName();
+
doReturn("attempt_1681717153064_3601637_1_13_000096_0").when(ctx).getUniqueIdentifier();
+ doReturn(new String[] { workingDir.toString() }).when(ctx).getWorkDirs();
+ doReturn(200 * 1024 * 1024L).when(ctx).getTotalMemoryAvailableToTask();
+ doReturn(new TezCounters()).when(ctx).getCounters();
+ OutputStatisticsReporter statsReporter =
mock(OutputStatisticsReporter.class);
+ doReturn(statsReporter).when(ctx).getStatisticsReporter();
+ doReturn(new
ExecutionContextImpl("localhost")).when(ctx).getExecutionContext();
+ return ctx;
+ }
+}
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutputTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutputTest.java
new file mode 100644
index 00000000..e4b68f81
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutputTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.output;
+
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.BitSet;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import com.google.protobuf.ByteString;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.TezRuntimeFrameworkConfigs;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.OutputContext;
+import org.apache.tez.runtime.api.events.CompositeDataMovementEvent;
+import org.apache.tez.runtime.api.events.VertexManagerEvent;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.partitioner.HashPartitioner;
+import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class RssOrderedPartitionedKVOutputTest {
+ private static Map<Integer, List<ShuffleServerInfo>> partitionToServers =
new HashMap<>();
+ private Configuration conf;
+ private FileSystem localFs;
+ private Path workingDir;
+
+ @BeforeEach
+ public void setup() throws IOException {
+ conf = new Configuration();
+ localFs = FileSystem.getLocal(conf);
+ workingDir = new Path(System.getProperty("test.build.data",
+ System.getProperty("java.io.tmpdir", "/tmp")),
+ RssOrderedPartitionedKVOutputTest.class.getName()).makeQualified(
+ localFs.getUri(), localFs.getWorkingDirectory());
+ conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS,
Text.class.getName());
+ conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS,
Text.class.getName());
+ conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS,
+ HashPartitioner.class.getName());
+ conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS,
workingDir.toString());
+ }
+
+ @AfterEach
+ public void cleanup() throws IOException {
+ localFs.delete(workingDir, true);
+ }
+
+ @Test
+ @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
+ public void testNonStartedOutput() throws Exception {
+ OutputContext outputContext = OutputTestHelpers.createOutputContext(conf,
workingDir);
+ int numPartitions = 10;
+ RssOrderedPartitionedKVOutput output = new
RssOrderedPartitionedKVOutput(outputContext, numPartitions);
+ List<Event> events = output.close();
+ assertEquals(2, events.size());
+ Event event1 = events.get(0);
+ assertTrue(event1 instanceof VertexManagerEvent);
+ Event event2 = events.get(1);
+ assertTrue(event2 instanceof CompositeDataMovementEvent);
+ CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) event2;
+ ByteBuffer bb = cdme.getUserPayload();
+ ShuffleUserPayloads.DataMovementEventPayloadProto shufflePayload =
+
ShuffleUserPayloads.DataMovementEventPayloadProto.parseFrom(ByteString.copyFrom(bb));
+ assertTrue(shufflePayload.hasEmptyPartitions());
+ byte[] emptyPartitions =
TezCommonUtils.decompressByteStringToByteArray(shufflePayload
+ .getEmptyPartitions());
+ BitSet emptyPartionsBitSet =
TezUtilsInternal.fromByteArray(emptyPartitions);
+ assertEquals(numPartitions, emptyPartionsBitSet.cardinality());
+ for (int i = 0; i < numPartitions; i++) {
+ assertTrue(emptyPartionsBitSet.get(i));
+ }
+ }
+
+}