This is an automated email from the ASF dual-hosted git repository.
zhengchenyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 45ffa5fda [#1300] feat(mr): Support combine operation in map stage for
mr engine. (#1301)
45ffa5fda is described below
commit 45ffa5fda189af21bac5b6a1b43ebe8cb0566c05
Author: QI Jiale <[email protected]>
AuthorDate: Tue Nov 14 10:35:15 2023 +0800
[#1300] feat(mr): Support combine operation in map stage for mr engine.
(#1301)
### What changes were proposed in this pull request?
Support combine operation in map stage.
### Why are the changes needed?
Fix: #1300
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Tested by UT, especially `WordCountTest`
---
.../hadoop/mapred/RssCombineOutputCollector.java | 32 ++++++
.../hadoop/mapred/RssMapOutputCollector.java | 11 +-
.../org/apache/hadoop/mapred/SortWriteBuffer.java | 86 ++++++++++++---
.../hadoop/mapred/SortWriteBufferManager.java | 34 +++++-
.../hadoop/mapred/SortWriteBufferManagerTest.java | 121 ++++++++++++++++++++-
.../apache/hadoop/mapred/SortWriteBufferTest.java | 2 +
.../hadoop/mapreduce/task/reduce/FetcherTest.java | 3 +-
7 files changed, 263 insertions(+), 26 deletions(-)
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssCombineOutputCollector.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssCombineOutputCollector.java
new file mode 100644
index 000000000..ba11daebb
--- /dev/null
+++
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssCombineOutputCollector.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred;
+
+import java.io.IOException;
+
+public class RssCombineOutputCollector<K, V> implements OutputCollector<K, V> {
+ private SortWriteBuffer<K, V> writer;
+
+ public synchronized void setWriter(SortWriteBuffer<K, V> writer) {
+ this.writer = writer;
+ }
+
+ public synchronized void collect(K key, V value) throws IOException {
+ writer.addRecord(key, value);
+ }
+}
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
index 1066960c1..3acf0b417 100644
---
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
+++
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
@@ -55,6 +55,7 @@ public class RssMapOutputCollector<K extends Object, V
extends Object>
private int partitions;
private SortWriteBufferManager bufferManager;
private ShuffleWriteClient shuffleClient;
+ private Task.CombinerRunner<K, V> combinerRunner;
@Override
public void init(Context context) throws IOException, ClassNotFoundException
{
@@ -78,6 +79,13 @@ public class RssMapOutputCollector<K extends Object, V
extends Object>
throw new IOException("Invalid sort memory use threshold : " +
sortThreshold);
}
+ // combiner
+ final Counters.Counter combineInputCounter =
+ reporter.getCounter(TaskCounter.COMBINE_INPUT_RECORDS);
+ combinerRunner =
+ Task.CombinerRunner.create(
+ mrJobConf, mapTask.getTaskID(), combineInputCounter, reporter,
null);
+
int batch =
RssMRUtils.getInt(
rssJobConf,
@@ -165,7 +173,8 @@ public class RssMapOutputCollector<K extends Object, V
extends Object>
sendThreadNum,
sendThreshold,
maxBufferSize,
- RssMRConfig.toRssConf(rssJobConf));
+ RssMRConfig.toRssConf(rssJobConf),
+ combinerRunner);
}
private Map<Integer, List<ShuffleServerInfo>>
createAssignmentMap(Configuration jobConf) {
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java
index d6fc3bbeb..37b06871f 100644
--- a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java
+++ b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java
@@ -19,13 +19,15 @@ package org.apache.hadoop.mapred;
import java.io.IOException;
import java.io.OutputStream;
-import java.util.Comparator;
+import java.util.Iterator;
import java.util.List;
import com.google.common.collect.Lists;
+import org.apache.hadoop.io.DataInputBuffer;
import org.apache.hadoop.io.RawComparator;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.hadoop.util.Progress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -84,6 +86,21 @@ public class SortWriteBuffer<K, V> extends OutputStream {
records.clear();
}
+ public synchronized void sort() {
+ long startSort = System.currentTimeMillis();
+ records.sort(
+ (o1, o2) ->
+ comparator.compare(
+ buffers.get(o1.getKeyIndex()).getBuffer(),
+ o1.getKeyOffSet(),
+ o1.getKeyLength(),
+ buffers.get(o2.getKeyIndex()).getBuffer(),
+ o2.getKeyOffSet(),
+ o2.getKeyLength()));
+ long finishSort = System.currentTimeMillis();
+ sortTime += finishSort - startSort;
+ }
+
public synchronized byte[] getData() {
int extraSize = 0;
for (Record<K> record : records) {
@@ -95,22 +112,8 @@ public class SortWriteBuffer<K, V> extends OutputStream {
extraSize += WritableUtils.getVIntSize(-1);
byte[] data = new byte[dataLength + extraSize];
int offset = 0;
- long startSort = System.currentTimeMillis();
- records.sort(
- new Comparator<Record<K>>() {
- @Override
- public int compare(Record<K> o1, Record<K> o2) {
- return comparator.compare(
- buffers.get(o1.getKeyIndex()).getBuffer(),
- o1.getKeyOffSet(),
- o1.getKeyLength(),
- buffers.get(o2.getKeyIndex()).getBuffer(),
- o2.getKeyOffSet(),
- o2.getKeyLength());
- }
- });
- long startCopy = System.currentTimeMillis();
- sortTime += startCopy - startSort;
+
+ final long startCopy = System.currentTimeMillis();
for (Record<K> record : records) {
offset = writeDataInt(data, offset, record.getKeyLength());
@@ -325,4 +328,53 @@ public class SortWriteBuffer<K, V> extends OutputStream {
return size;
}
}
+
+ public static class SortBufferIterator<K, V> implements RawKeyValueIterator {
+ private final SortWriteBuffer<K, V> sortWriteBuffer;
+ private final Iterator<Record<K>> iterator;
+ private final DataInputBuffer keyBuffer = new DataInputBuffer();
+ private final DataInputBuffer valueBuffer = new DataInputBuffer();
+ private SortWriteBuffer.Record<K> currentRecord;
+
+ public SortBufferIterator(SortWriteBuffer<K, V> sortWriteBuffer) {
+ this.sortWriteBuffer = sortWriteBuffer;
+ this.iterator = sortWriteBuffer.records.iterator();
+ }
+
+ @Override
+ public DataInputBuffer getKey() {
+ SortWriteBuffer.WrappedBuffer keyWrappedBuffer =
+ sortWriteBuffer.buffers.get(currentRecord.getKeyIndex());
+ byte[] rawData = keyWrappedBuffer.getBuffer();
+ keyBuffer.reset(rawData, currentRecord.getKeyOffSet(),
currentRecord.getKeyLength());
+ return keyBuffer;
+ }
+
+ @Override
+ public DataInputBuffer getValue() {
+ SortWriteBuffer.WrappedBuffer valueWrappedBuffer =
+ sortWriteBuffer.buffers.get(currentRecord.getKeyIndex());
+ byte[] rawData = valueWrappedBuffer.getBuffer();
+ int valueOffset = currentRecord.getKeyOffSet() +
currentRecord.getKeyLength();
+ valueBuffer.reset(rawData, valueOffset, currentRecord.getValueLength());
+ return valueBuffer;
+ }
+
+ @Override
+ public boolean next() {
+ if (iterator.hasNext()) {
+ currentRecord = iterator.next();
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public void close() throws IOException {}
+
+ @Override
+ public Progress getProgress() {
+ return new Progress();
+ }
+ }
}
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
index d410fa3b3..0da7cd930 100644
---
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
+++
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
@@ -95,6 +95,7 @@ public class SortWriteBufferManager<K, V> {
private final ExecutorService sendExecutorService;
private final RssConf rssConf;
private final Codec codec;
+ private final Task.CombinerRunner<K, V> combinerRunner;
public SortWriteBufferManager(
long maxMemSize,
@@ -120,7 +121,8 @@ public class SortWriteBufferManager<K, V> {
int sendThreadNum,
double sendThreshold,
long maxBufferSize,
- RssConf rssConf) {
+ RssConf rssConf,
+ Task.CombinerRunner<K, V> combinerRunner) {
this.maxMemSize = maxMemSize;
this.taskAttemptId = taskAttemptId;
this.batch = batch;
@@ -146,6 +148,7 @@ public class SortWriteBufferManager<K, V> {
this.sendExecutorService =
ThreadUtils.getDaemonFixedThreadPool(sendThreadNum, "send-thread");
this.rssConf = rssConf;
this.codec = Codec.newInstance(rssConf);
+ this.combinerRunner = combinerRunner;
}
// todo: Single Buffer should also have its size limit
@@ -231,7 +234,19 @@ public class SortWriteBufferManager<K, V> {
private void prepareBufferForSend(List<ShuffleBlockInfo> shuffleBlocks,
SortWriteBuffer buffer) {
buffers.remove(buffer.getPartitionId());
- ShuffleBlockInfo block = createShuffleBlock(buffer);
+ buffer.sort();
+ ShuffleBlockInfo block;
+ if (combinerRunner != null) {
+ try {
+ buffer = combineBuffer(buffer);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Successfully finished combining.");
+ }
+ } catch (Exception e) {
+ LOG.error("Error occurred while combining in Map:", e);
+ }
+ }
+ block = createShuffleBlock(buffer);
buffer.clear();
shuffleBlocks.add(block);
allBlockIds.add(block.getBlockId());
@@ -239,6 +254,21 @@ public class SortWriteBufferManager<K, V> {
partitionToBlocks.get(block.getPartitionId()).add(block.getBlockId());
}
+ public SortWriteBuffer<K, V> combineBuffer(SortWriteBuffer<K, V> buffer)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ RawKeyValueIterator kvIterator = new
SortWriteBuffer.SortBufferIterator<>(buffer);
+
+ RssCombineOutputCollector<K, V> combineCollector = new
RssCombineOutputCollector<>();
+
+ SortWriteBuffer<K, V> newBuffer =
+ new SortWriteBuffer<>(
+ buffer.getPartitionId(), comparator, maxSegmentSize,
keySerializer, valSerializer);
+
+ combineCollector.setWriter(newBuffer);
+ combinerRunner.combine(kvIterator, combineCollector);
+ return newBuffer;
+ }
+
private void sendShuffleBlocks(List<ShuffleBlockInfo> shuffleBlocks) {
sendExecutorService.submit(
new Runnable() {
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 0cd9e1a06..fbb280391 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -17,7 +17,9 @@
package org.apache.hadoop.mapred;
+import java.io.IOException;
import java.util.ArrayList;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
@@ -25,11 +27,15 @@ import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
+import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.io.serializer.SerializationFactory;
+import org.apache.hadoop.io.serializer.Serializer;
import org.junit.jupiter.api.Test;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -49,6 +55,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
public class SortWriteBufferManagerTest {
@@ -89,7 +96,8 @@ public class SortWriteBufferManagerTest {
5,
0.2f,
1024000L,
- new RssConf());
+ new RssConf(),
+ null);
// case 1
Random random = new Random();
@@ -151,7 +159,8 @@ public class SortWriteBufferManagerTest {
5,
0.2f,
1024000L,
- new RssConf());
+ new RssConf(),
+ null);
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
@@ -202,7 +211,8 @@ public class SortWriteBufferManagerTest {
5,
0.2f,
100L,
- new RssConf());
+ new RssConf(),
+ null);
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
@@ -252,7 +262,8 @@ public class SortWriteBufferManagerTest {
5,
0.2f,
1024000L,
- new RssConf());
+ new RssConf(),
+ null);
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
@@ -317,7 +328,8 @@ public class SortWriteBufferManagerTest {
5,
0.2f,
1024000L,
- new RssConf());
+ new RssConf(),
+ null);
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
@@ -335,6 +347,86 @@ public class SortWriteBufferManagerTest {
client.mockedShuffleServer.getFlushBlockSize());
}
+ @Test
+ public void testCombineBuffer() throws Exception {
+ JobConf jobConf = new JobConf(new Configuration());
+ jobConf.setOutputKeyClass(Text.class);
+ jobConf.setOutputValueClass(IntWritable.class);
+ jobConf.setCombinerClass(Reduce.class);
+ SerializationFactory serializationFactory = new
SerializationFactory(jobConf);
+ Serializer<Text> keySerializer =
serializationFactory.getSerializer(Text.class);
+ Serializer<IntWritable> valueSerializer =
serializationFactory.getSerializer(IntWritable.class);
+ WritableComparator comparator = WritableComparator.get(Text.class);
+
+ Task.TaskReporter reporter = mock(Task.TaskReporter.class);
+
+ final Counters.Counter combineInputCounter = new Counters.Counter();
+
+ Task.CombinerRunner<Text, IntWritable> combinerRunner =
+ Task.CombinerRunner.create(
+ jobConf, new TaskAttemptID(), combineInputCounter, reporter, null);
+
+ SortWriteBuffer<Text, IntWritable> buffer =
+ new SortWriteBuffer<Text, IntWritable>(
+ 1, comparator, 10000, keySerializer, valueSerializer);
+
+ List<String> wordTable =
+ Lists.newArrayList(
+ "apple", "banana", "fruit", "cherry", "Chinese", "America",
"Japan", "tomato");
+ Random random = new Random();
+ for (int i = 0; i < 8; i++) {
+ buffer.addRecord(new Text(wordTable.get(i)), new IntWritable(1));
+ }
+ for (int i = 0; i < 100; i++) {
+ int index = random.nextInt(wordTable.size());
+ buffer.addRecord(new Text(wordTable.get(index)), new IntWritable(1));
+ }
+
+ SortWriteBufferManager<Text, IntWritable> manager =
+ new SortWriteBufferManager<Text, IntWritable>(
+ 10240,
+ 1L,
+ 10,
+ keySerializer,
+ valueSerializer,
+ comparator,
+ 0.9,
+ "test",
+ null,
+ 500,
+ 5 * 1000,
+ null,
+ null,
+ null,
+ null,
+ null,
+ 1,
+ 100,
+ 1,
+ true,
+ 5,
+ 0.2f,
+ 1024000L,
+ new RssConf(),
+ combinerRunner);
+
+ buffer.sort();
+ SortWriteBuffer<Text, IntWritable> newBuffer =
manager.combineBuffer(buffer);
+
+ RawKeyValueIterator kvIterator1 = new
SortWriteBuffer.SortBufferIterator<>(buffer);
+ RawKeyValueIterator kvIterator2 = new
SortWriteBuffer.SortBufferIterator<>(newBuffer);
+ int count1 = 0;
+ while (kvIterator1.next()) {
+ count1++;
+ }
+ int count2 = 0;
+ while (kvIterator2.next()) {
+ count2++;
+ }
+ assertEquals(108, count1);
+ assertEquals(8, count2);
+ }
+
class MockShuffleServer {
// All methods of MockShuffle are thread safe, because send-thread may do
something in
@@ -500,4 +592,23 @@ public class SortWriteBufferManagerTest {
@Override
public void unregisterShuffle(String appId) {}
}
+
+ static class Reduce extends MapReduceBase
+ implements Reducer<Text, IntWritable, Text, IntWritable> {
+
+ Reduce() {}
+
+ public void reduce(
+ Text key,
+ Iterator<IntWritable> values,
+ OutputCollector<Text, IntWritable> output,
+ Reporter reporter)
+ throws IOException {
+ int sum = 0;
+ while (values.hasNext()) {
+ sum += values.next().get();
+ }
+ output.collect(key, new IntWritable(sum));
+ }
+ }
}
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
index 92af295d4..ebfe75dcd 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
@@ -56,6 +56,7 @@ public class SortWriteBufferTest {
1, WritableComparator.get(BytesWritable.class), 1024L,
keySerializer, valSerializer);
long recordLength = buffer.addRecord(key, value);
+ buffer.sort();
assertEquals(20, buffer.getData().length);
assertEquals(16, recordLength);
assertEquals(1, buffer.getPartitionId());
@@ -124,6 +125,7 @@ public class SortWriteBufferTest {
recordLength = buffer.addRecord(key, value);
recordLenMap.putIfAbsent(keyStr, recordLength);
+ buffer.sort();
result = buffer.getData();
byteArrayInputStream = new ByteArrayInputStream(result);
keyDeserializer.open(byteArrayInputStream);
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index f02975404..6a7d36d68 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -367,7 +367,8 @@ public class FetcherTest {
5,
0.2f,
1024000L,
- new RssConf());
+ new RssConf(),
+ null);
for (String key : keysToValues.keySet()) {
String value = keysToValues.get(key);