This is an automated email from the ASF dual-hosted git repository.

zhengchenyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 45ffa5fda [#1300] feat(mr): Support combine operation in map stage for 
mr engine.  (#1301)
45ffa5fda is described below

commit 45ffa5fda189af21bac5b6a1b43ebe8cb0566c05
Author: QI Jiale <[email protected]>
AuthorDate: Tue Nov 14 10:35:15 2023 +0800

    [#1300] feat(mr): Support combine operation in map stage for mr engine.  
(#1301)
    
    ### What changes were proposed in this pull request?
    
    Support combine operation in map stage.
    
    ### Why are the changes needed?
    
    Fix: #1300
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Tested by UT, especially `WordCountTest`
---
 .../hadoop/mapred/RssCombineOutputCollector.java   |  32 ++++++
 .../hadoop/mapred/RssMapOutputCollector.java       |  11 +-
 .../org/apache/hadoop/mapred/SortWriteBuffer.java  |  86 ++++++++++++---
 .../hadoop/mapred/SortWriteBufferManager.java      |  34 +++++-
 .../hadoop/mapred/SortWriteBufferManagerTest.java  | 121 ++++++++++++++++++++-
 .../apache/hadoop/mapred/SortWriteBufferTest.java  |   2 +
 .../hadoop/mapreduce/task/reduce/FetcherTest.java  |   3 +-
 7 files changed, 263 insertions(+), 26 deletions(-)

diff --git 
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssCombineOutputCollector.java
 
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssCombineOutputCollector.java
new file mode 100644
index 000000000..ba11daebb
--- /dev/null
+++ 
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssCombineOutputCollector.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred;
+
+import java.io.IOException;
+
+public class RssCombineOutputCollector<K, V> implements OutputCollector<K, V> {
+  private SortWriteBuffer<K, V> writer;
+
+  public synchronized void setWriter(SortWriteBuffer<K, V> writer) {
+    this.writer = writer;
+  }
+
+  public synchronized void collect(K key, V value) throws IOException {
+    writer.addRecord(key, value);
+  }
+}
diff --git 
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
 
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
index 1066960c1..3acf0b417 100644
--- 
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
+++ 
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
@@ -55,6 +55,7 @@ public class RssMapOutputCollector<K extends Object, V 
extends Object>
   private int partitions;
   private SortWriteBufferManager bufferManager;
   private ShuffleWriteClient shuffleClient;
+  private Task.CombinerRunner<K, V> combinerRunner;
 
   @Override
   public void init(Context context) throws IOException, ClassNotFoundException 
{
@@ -78,6 +79,13 @@ public class RssMapOutputCollector<K extends Object, V 
extends Object>
       throw new IOException("Invalid  sort memory use threshold : " + 
sortThreshold);
     }
 
+    // combiner
+    final Counters.Counter combineInputCounter =
+        reporter.getCounter(TaskCounter.COMBINE_INPUT_RECORDS);
+    combinerRunner =
+        Task.CombinerRunner.create(
+            mrJobConf, mapTask.getTaskID(), combineInputCounter, reporter, 
null);
+
     int batch =
         RssMRUtils.getInt(
             rssJobConf,
@@ -165,7 +173,8 @@ public class RssMapOutputCollector<K extends Object, V 
extends Object>
             sendThreadNum,
             sendThreshold,
             maxBufferSize,
-            RssMRConfig.toRssConf(rssJobConf));
+            RssMRConfig.toRssConf(rssJobConf),
+            combinerRunner);
   }
 
   private Map<Integer, List<ShuffleServerInfo>> 
createAssignmentMap(Configuration jobConf) {
diff --git 
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java 
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java
index d6fc3bbeb..37b06871f 100644
--- a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java
+++ b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java
@@ -19,13 +19,15 @@ package org.apache.hadoop.mapred;
 
 import java.io.IOException;
 import java.io.OutputStream;
-import java.util.Comparator;
+import java.util.Iterator;
 import java.util.List;
 
 import com.google.common.collect.Lists;
+import org.apache.hadoop.io.DataInputBuffer;
 import org.apache.hadoop.io.RawComparator;
 import org.apache.hadoop.io.WritableUtils;
 import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.hadoop.util.Progress;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -84,6 +86,21 @@ public class SortWriteBuffer<K, V> extends OutputStream {
     records.clear();
   }
 
+  public synchronized void sort() {
+    long startSort = System.currentTimeMillis();
+    records.sort(
+        (o1, o2) ->
+            comparator.compare(
+                buffers.get(o1.getKeyIndex()).getBuffer(),
+                o1.getKeyOffSet(),
+                o1.getKeyLength(),
+                buffers.get(o2.getKeyIndex()).getBuffer(),
+                o2.getKeyOffSet(),
+                o2.getKeyLength()));
+    long finishSort = System.currentTimeMillis();
+    sortTime += finishSort - startSort;
+  }
+
   public synchronized byte[] getData() {
     int extraSize = 0;
     for (Record<K> record : records) {
@@ -95,22 +112,8 @@ public class SortWriteBuffer<K, V> extends OutputStream {
     extraSize += WritableUtils.getVIntSize(-1);
     byte[] data = new byte[dataLength + extraSize];
     int offset = 0;
-    long startSort = System.currentTimeMillis();
-    records.sort(
-        new Comparator<Record<K>>() {
-          @Override
-          public int compare(Record<K> o1, Record<K> o2) {
-            return comparator.compare(
-                buffers.get(o1.getKeyIndex()).getBuffer(),
-                o1.getKeyOffSet(),
-                o1.getKeyLength(),
-                buffers.get(o2.getKeyIndex()).getBuffer(),
-                o2.getKeyOffSet(),
-                o2.getKeyLength());
-          }
-        });
-    long startCopy = System.currentTimeMillis();
-    sortTime += startCopy - startSort;
+
+    final long startCopy = System.currentTimeMillis();
 
     for (Record<K> record : records) {
       offset = writeDataInt(data, offset, record.getKeyLength());
@@ -325,4 +328,53 @@ public class SortWriteBuffer<K, V> extends OutputStream {
       return size;
     }
   }
+
+  public static class SortBufferIterator<K, V> implements RawKeyValueIterator {
+    private final SortWriteBuffer<K, V> sortWriteBuffer;
+    private final Iterator<Record<K>> iterator;
+    private final DataInputBuffer keyBuffer = new DataInputBuffer();
+    private final DataInputBuffer valueBuffer = new DataInputBuffer();
+    private SortWriteBuffer.Record<K> currentRecord;
+
+    public SortBufferIterator(SortWriteBuffer<K, V> sortWriteBuffer) {
+      this.sortWriteBuffer = sortWriteBuffer;
+      this.iterator = sortWriteBuffer.records.iterator();
+    }
+
+    @Override
+    public DataInputBuffer getKey() {
+      SortWriteBuffer.WrappedBuffer keyWrappedBuffer =
+          sortWriteBuffer.buffers.get(currentRecord.getKeyIndex());
+      byte[] rawData = keyWrappedBuffer.getBuffer();
+      keyBuffer.reset(rawData, currentRecord.getKeyOffSet(), 
currentRecord.getKeyLength());
+      return keyBuffer;
+    }
+
+    @Override
+    public DataInputBuffer getValue() {
+      SortWriteBuffer.WrappedBuffer valueWrappedBuffer =
+          sortWriteBuffer.buffers.get(currentRecord.getKeyIndex());
+      byte[] rawData = valueWrappedBuffer.getBuffer();
+      int valueOffset = currentRecord.getKeyOffSet() + 
currentRecord.getKeyLength();
+      valueBuffer.reset(rawData, valueOffset, currentRecord.getValueLength());
+      return valueBuffer;
+    }
+
+    @Override
+    public boolean next() {
+      if (iterator.hasNext()) {
+        currentRecord = iterator.next();
+        return true;
+      }
+      return false;
+    }
+
+    @Override
+    public void close() throws IOException {}
+
+    @Override
+    public Progress getProgress() {
+      return new Progress();
+    }
+  }
 }
diff --git 
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
 
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
index d410fa3b3..0da7cd930 100644
--- 
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
+++ 
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
@@ -95,6 +95,7 @@ public class SortWriteBufferManager<K, V> {
   private final ExecutorService sendExecutorService;
   private final RssConf rssConf;
   private final Codec codec;
+  private final Task.CombinerRunner<K, V> combinerRunner;
 
   public SortWriteBufferManager(
       long maxMemSize,
@@ -120,7 +121,8 @@ public class SortWriteBufferManager<K, V> {
       int sendThreadNum,
       double sendThreshold,
       long maxBufferSize,
-      RssConf rssConf) {
+      RssConf rssConf,
+      Task.CombinerRunner<K, V> combinerRunner) {
     this.maxMemSize = maxMemSize;
     this.taskAttemptId = taskAttemptId;
     this.batch = batch;
@@ -146,6 +148,7 @@ public class SortWriteBufferManager<K, V> {
     this.sendExecutorService = 
ThreadUtils.getDaemonFixedThreadPool(sendThreadNum, "send-thread");
     this.rssConf = rssConf;
     this.codec = Codec.newInstance(rssConf);
+    this.combinerRunner = combinerRunner;
   }
 
   // todo: Single Buffer should also have its size limit
@@ -231,7 +234,19 @@ public class SortWriteBufferManager<K, V> {
 
   private void prepareBufferForSend(List<ShuffleBlockInfo> shuffleBlocks, 
SortWriteBuffer buffer) {
     buffers.remove(buffer.getPartitionId());
-    ShuffleBlockInfo block = createShuffleBlock(buffer);
+    buffer.sort();
+    ShuffleBlockInfo block;
+    if (combinerRunner != null) {
+      try {
+        buffer = combineBuffer(buffer);
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Successfully finished combining.");
+        }
+      } catch (Exception e) {
+        LOG.error("Error occurred while combining in Map:", e);
+      }
+    }
+    block = createShuffleBlock(buffer);
     buffer.clear();
     shuffleBlocks.add(block);
     allBlockIds.add(block.getBlockId());
@@ -239,6 +254,21 @@ public class SortWriteBufferManager<K, V> {
     partitionToBlocks.get(block.getPartitionId()).add(block.getBlockId());
   }
 
+  public SortWriteBuffer<K, V> combineBuffer(SortWriteBuffer<K, V> buffer)
+      throws IOException, InterruptedException, ClassNotFoundException {
+    RawKeyValueIterator kvIterator = new 
SortWriteBuffer.SortBufferIterator<>(buffer);
+
+    RssCombineOutputCollector<K, V> combineCollector = new 
RssCombineOutputCollector<>();
+
+    SortWriteBuffer<K, V> newBuffer =
+        new SortWriteBuffer<>(
+            buffer.getPartitionId(), comparator, maxSegmentSize, 
keySerializer, valSerializer);
+
+    combineCollector.setWriter(newBuffer);
+    combinerRunner.combine(kvIterator, combineCollector);
+    return newBuffer;
+  }
+
   private void sendShuffleBlocks(List<ShuffleBlockInfo> shuffleBlocks) {
     sendExecutorService.submit(
         new Runnable() {
diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 0cd9e1a06..fbb280391 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -17,7 +17,9 @@
 
 package org.apache.hadoop.mapred;
 
+import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
@@ -25,11 +27,15 @@ import java.util.Set;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
+import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.WritableComparator;
 import org.apache.hadoop.io.serializer.SerializationFactory;
+import org.apache.hadoop.io.serializer.Serializer;
 import org.junit.jupiter.api.Test;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
@@ -49,6 +55,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
 
 public class SortWriteBufferManagerTest {
 
@@ -89,7 +96,8 @@ public class SortWriteBufferManagerTest {
             5,
             0.2f,
             1024000L,
-            new RssConf());
+            new RssConf(),
+            null);
 
     // case 1
     Random random = new Random();
@@ -151,7 +159,8 @@ public class SortWriteBufferManagerTest {
             5,
             0.2f,
             1024000L,
-            new RssConf());
+            new RssConf(),
+            null);
     byte[] key = new byte[20];
     byte[] value = new byte[1024];
     random.nextBytes(key);
@@ -202,7 +211,8 @@ public class SortWriteBufferManagerTest {
             5,
             0.2f,
             100L,
-            new RssConf());
+            new RssConf(),
+            null);
     Random random = new Random();
     for (int i = 0; i < 1000; i++) {
       byte[] key = new byte[20];
@@ -252,7 +262,8 @@ public class SortWriteBufferManagerTest {
             5,
             0.2f,
             1024000L,
-            new RssConf());
+            new RssConf(),
+            null);
     Random random = new Random();
     for (int i = 0; i < 1000; i++) {
       byte[] key = new byte[20];
@@ -317,7 +328,8 @@ public class SortWriteBufferManagerTest {
             5,
             0.2f,
             1024000L,
-            new RssConf());
+            new RssConf(),
+            null);
     Random random = new Random();
     for (int i = 0; i < 1000; i++) {
       byte[] key = new byte[20];
@@ -335,6 +347,86 @@ public class SortWriteBufferManagerTest {
         client.mockedShuffleServer.getFlushBlockSize());
   }
 
+  @Test
+  public void testCombineBuffer() throws Exception {
+    JobConf jobConf = new JobConf(new Configuration());
+    jobConf.setOutputKeyClass(Text.class);
+    jobConf.setOutputValueClass(IntWritable.class);
+    jobConf.setCombinerClass(Reduce.class);
+    SerializationFactory serializationFactory = new 
SerializationFactory(jobConf);
+    Serializer<Text> keySerializer = 
serializationFactory.getSerializer(Text.class);
+    Serializer<IntWritable> valueSerializer = 
serializationFactory.getSerializer(IntWritable.class);
+    WritableComparator comparator = WritableComparator.get(Text.class);
+
+    Task.TaskReporter reporter = mock(Task.TaskReporter.class);
+
+    final Counters.Counter combineInputCounter = new Counters.Counter();
+
+    Task.CombinerRunner<Text, IntWritable> combinerRunner =
+        Task.CombinerRunner.create(
+            jobConf, new TaskAttemptID(), combineInputCounter, reporter, null);
+
+    SortWriteBuffer<Text, IntWritable> buffer =
+        new SortWriteBuffer<Text, IntWritable>(
+            1, comparator, 10000, keySerializer, valueSerializer);
+
+    List<String> wordTable =
+        Lists.newArrayList(
+            "apple", "banana", "fruit", "cherry", "Chinese", "America", 
"Japan", "tomato");
+    Random random = new Random();
+    for (int i = 0; i < 8; i++) {
+      buffer.addRecord(new Text(wordTable.get(i)), new IntWritable(1));
+    }
+    for (int i = 0; i < 100; i++) {
+      int index = random.nextInt(wordTable.size());
+      buffer.addRecord(new Text(wordTable.get(index)), new IntWritable(1));
+    }
+
+    SortWriteBufferManager<Text, IntWritable> manager =
+        new SortWriteBufferManager<Text, IntWritable>(
+            10240,
+            1L,
+            10,
+            keySerializer,
+            valueSerializer,
+            comparator,
+            0.9,
+            "test",
+            null,
+            500,
+            5 * 1000,
+            null,
+            null,
+            null,
+            null,
+            null,
+            1,
+            100,
+            1,
+            true,
+            5,
+            0.2f,
+            1024000L,
+            new RssConf(),
+            combinerRunner);
+
+    buffer.sort();
+    SortWriteBuffer<Text, IntWritable> newBuffer = 
manager.combineBuffer(buffer);
+
+    RawKeyValueIterator kvIterator1 = new 
SortWriteBuffer.SortBufferIterator<>(buffer);
+    RawKeyValueIterator kvIterator2 = new 
SortWriteBuffer.SortBufferIterator<>(newBuffer);
+    int count1 = 0;
+    while (kvIterator1.next()) {
+      count1++;
+    }
+    int count2 = 0;
+    while (kvIterator2.next()) {
+      count2++;
+    }
+    assertEquals(108, count1);
+    assertEquals(8, count2);
+  }
+
   class MockShuffleServer {
 
     // All methods of MockShuffle are thread safe, because send-thread may do 
something in
@@ -500,4 +592,23 @@ public class SortWriteBufferManagerTest {
     @Override
     public void unregisterShuffle(String appId) {}
   }
+
+  static class Reduce extends MapReduceBase
+      implements Reducer<Text, IntWritable, Text, IntWritable> {
+
+    Reduce() {}
+
+    public void reduce(
+        Text key,
+        Iterator<IntWritable> values,
+        OutputCollector<Text, IntWritable> output,
+        Reporter reporter)
+        throws IOException {
+      int sum = 0;
+      while (values.hasNext()) {
+        sum += values.next().get();
+      }
+      output.collect(key, new IntWritable(sum));
+    }
+  }
 }
diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
index 92af295d4..ebfe75dcd 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
@@ -56,6 +56,7 @@ public class SortWriteBufferTest {
             1, WritableComparator.get(BytesWritable.class), 1024L, 
keySerializer, valSerializer);
 
     long recordLength = buffer.addRecord(key, value);
+    buffer.sort();
     assertEquals(20, buffer.getData().length);
     assertEquals(16, recordLength);
     assertEquals(1, buffer.getPartitionId());
@@ -124,6 +125,7 @@ public class SortWriteBufferTest {
     recordLength = buffer.addRecord(key, value);
     recordLenMap.putIfAbsent(keyStr, recordLength);
 
+    buffer.sort();
     result = buffer.getData();
     byteArrayInputStream = new ByteArrayInputStream(result);
     keyDeserializer.open(byteArrayInputStream);
diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index f02975404..6a7d36d68 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -367,7 +367,8 @@ public class FetcherTest {
             5,
             0.2f,
             1024000L,
-            new RssConf());
+            new RssConf(),
+            null);
 
     for (String key : keysToValues.keySet()) {
       String value = keysToValues.get(key);

Reply via email to