This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 636d0be21 [#1341] fix(mr): Fix MR Combiner
ArrayIndexOutOfBoundsException Bug. (#1666)
636d0be21 is described below
commit 636d0be219b24898fcc7bbd59d74ba2d558acaa9
Author: QI Jiale <[email protected]>
AuthorDate: Tue Apr 30 11:10:11 2024 +0800
[#1341] fix(mr): Fix MR Combiner ArrayIndexOutOfBoundsException Bug. (#1666)
### What changes were proposed in this pull request?
The current implementation of the SortBufferIterator.getKey() and
.getValue() methods in the SortWriteBuffer class assumes that keys and values
are always stored within a single buffer. This assumption can lead to runtime
exceptions, specifically ArrayIndexOutOfBoundsException, when keys or values
span multiple WrappedBuffer instances.
In the current implementation, due to the execution of the `compact()`, the
data of the key will not span buffers. However, the data of the value may be
located on different buffers from the key and may span multiple buffers.
This PR update getKey() to use fetchDataFromBuffers() to retrieve the key
data. And this PR update getValue() to first adjust the starting index and
offset based on the length of the key and then use fetchDataFromBuffers() to
retrieve the value data.
Another solution to this bug is to modify `addRecord()` method by keeping
key and value in the same WrappedBuffer.
### Why are the changes needed?
Fix: #1341
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Tested manually. I'll submit a new integration test later.
---
.../org/apache/hadoop/mapred/SortWriteBuffer.java | 45 ++++++++++++++----
.../hadoop/mapred/SortWriteBufferManagerTest.java | 7 ++-
.../apache/hadoop/mapred/SortWriteBufferTest.java | 53 ++++++++++++++++++++++
3 files changed, 93 insertions(+), 12 deletions(-)
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java
index 37b06871f..20765c53d 100644
--- a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java
+++ b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBuffer.java
@@ -341,22 +341,51 @@ public class SortWriteBuffer<K, V> extends OutputStream {
this.iterator = sortWriteBuffer.records.iterator();
}
+ private byte[] fetchDataFromBuffers(int index, int offset, int length) {
+ // Adjust start index and offset for the start of the value
+ while (offset >= sortWriteBuffer.buffers.get(index).getSize()) {
+ offset -= sortWriteBuffer.buffers.get(index).getSize();
+ index++;
+ }
+
+ byte[] data = new byte[length]; // Create a new array to store the
complete data
+ int copyDestPos = 0;
+
+ while (length > 0) {
+ WrappedBuffer currentBuffer = sortWriteBuffer.buffers.get(index);
+ byte[] currentBufferData = currentBuffer.getBuffer();
+ int currentBufferCapacity = currentBuffer.getSize();
+ int copyLength = Math.min(currentBufferCapacity - offset, length);
+
+ // Copy data from the current buffer to the data array
+ System.arraycopy(currentBufferData, offset, data, copyDestPos,
copyLength);
+ length -= copyLength;
+ copyDestPos += copyLength;
+
+ // Move to the next buffer
+ index++;
+ offset = 0; // Start position in the new buffer is 0
+ }
+ return data;
+ }
+
@Override
public DataInputBuffer getKey() {
- SortWriteBuffer.WrappedBuffer keyWrappedBuffer =
- sortWriteBuffer.buffers.get(currentRecord.getKeyIndex());
- byte[] rawData = keyWrappedBuffer.getBuffer();
- keyBuffer.reset(rawData, currentRecord.getKeyOffSet(),
currentRecord.getKeyLength());
+ int keyIndex = currentRecord.getKeyIndex();
+ int keyOffset = currentRecord.getKeyOffSet();
+ int keyLength = currentRecord.getKeyLength();
+ byte[] keyData = fetchDataFromBuffers(keyIndex, keyOffset, keyLength);
+ keyBuffer.reset(keyData, 0, keyLength);
return keyBuffer;
}
@Override
public DataInputBuffer getValue() {
- SortWriteBuffer.WrappedBuffer valueWrappedBuffer =
- sortWriteBuffer.buffers.get(currentRecord.getKeyIndex());
- byte[] rawData = valueWrappedBuffer.getBuffer();
+ int keyIndex = currentRecord.getKeyIndex();
int valueOffset = currentRecord.getKeyOffSet() +
currentRecord.getKeyLength();
- valueBuffer.reset(rawData, valueOffset, currentRecord.getValueLength());
+ int valueLength = currentRecord.getValueLength();
+ byte[] valueData = fetchDataFromBuffers(keyIndex, valueOffset,
valueLength);
+ valueBuffer.reset(valueData, 0, valueLength);
return valueBuffer;
}
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index e5a25b67f..5c9f401b8 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -373,8 +373,7 @@ public class SortWriteBufferManagerTest {
jobConf, new TaskAttemptID(), combineInputCounter, reporter, null);
SortWriteBuffer<Text, IntWritable> buffer =
- new SortWriteBuffer<Text, IntWritable>(
- 1, comparator, 10000, keySerializer, valueSerializer);
+ new SortWriteBuffer<Text, IntWritable>(1, comparator, 3072,
keySerializer, valueSerializer);
List<String> wordTable =
Lists.newArrayList(
@@ -383,7 +382,7 @@ public class SortWriteBufferManagerTest {
for (int i = 0; i < 8; i++) {
buffer.addRecord(new Text(wordTable.get(i)), new IntWritable(1));
}
- for (int i = 0; i < 100; i++) {
+ for (int i = 0; i < 10000; i++) {
int index = random.nextInt(wordTable.size());
buffer.addRecord(new Text(wordTable.get(index)), new IntWritable(1));
}
@@ -429,7 +428,7 @@ public class SortWriteBufferManagerTest {
while (kvIterator2.next()) {
count2++;
}
- assertEquals(108, count1);
+ assertEquals(10008, count1);
assertEquals(8, count2);
}
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
index ebfe75dcd..e222992fb 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferTest.java
@@ -20,11 +20,15 @@ package org.apache.hadoop.mapred;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
+import java.util.List;
import java.util.Map;
import java.util.Random;
+import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.io.serializer.Deserializer;
@@ -153,6 +157,55 @@ public class SortWriteBufferTest {
assertEquals(bigWritableValue, valueRead);
}
+ @Test
+ public void testSortBufferIterator() throws IOException {
+ SerializationFactory serializationFactory =
+ new SerializationFactory(new JobConf(new Configuration()));
+ Serializer<Text> keySerializer =
serializationFactory.getSerializer(Text.class);
+ Deserializer<Text> keyDeserializer =
serializationFactory.getDeserializer(Text.class);
+ Serializer<IntWritable> valueSerializer =
serializationFactory.getSerializer(IntWritable.class);
+ Deserializer<IntWritable> valueDeserializer =
+ serializationFactory.getDeserializer(IntWritable.class);
+
+ SortWriteBuffer<Text, IntWritable> buffer =
+ new SortWriteBuffer<Text, IntWritable>(1, null, 3072, keySerializer,
valueSerializer);
+
+ List<String> wordTable =
+ Lists.newArrayList(
+ "apple", "banana", "fruit", "cherry", "Chinese", "America",
"Japan", "tomato");
+
+ List<String> keys = Lists.newArrayList();
+
+ Random random = new Random();
+ for (int i = 0; i < 8; i++) {
+ buffer.addRecord(new Text(wordTable.get(i)), new IntWritable(1));
+ keys.add(wordTable.get(i));
+ }
+ for (int i = 0; i < 10000; i++) {
+ int index = random.nextInt(wordTable.size());
+ buffer.addRecord(new Text(wordTable.get(index)), new IntWritable(1));
+ keys.add(wordTable.get(index));
+ }
+
+ SortWriteBuffer.SortBufferIterator<Text, IntWritable> iterator =
+ new SortWriteBuffer.SortBufferIterator<>(buffer);
+
+ int ind = 0;
+
+ Text key = new Text();
+ IntWritable value = new IntWritable();
+ while (iterator.next()) {
+ iterator.getKey().getData();
+ iterator.getValue().getData();
+ keyDeserializer.open(iterator.getKey());
+ valueDeserializer.open(iterator.getValue());
+ keyDeserializer.deserialize(key);
+ valueDeserializer.deserialize(value);
+ assertEquals(keys.get(ind), key.toString());
+ ind++;
+ }
+ }
+
int readInt(DataInputStream dStream) throws IOException {
return WritableUtils.readVInt(dStream);
}