This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new a1e0d56c8 [#1543] improvement(spark): Speed up reading when both sort
and combine are used. (#1640)
a1e0d56c8 is described below
commit a1e0d56c8d60af37d13010db130e86a014c09514
Author: QI Jiale <[email protected]>
AuthorDate: Wed May 8 07:59:26 2024 +0800
[#1543] improvement(spark): Speed up reading when both sort and combine are
used. (#1640)
### What changes were proposed in this pull request?
Backport SPARK-46512 to Uniffle.
After the shuffle reader obtains the block, it will first perform a combine
operation, and then perform a sort operation. It is known that both combine and
sort may generate temporary files, so the performance may be poor when both
sort and combine are used. In fact, combine operations can be performed during
the sort process, and we can avoid the combine spill file.
See https://issues.apache.org/jira/browse/SPARK-46512 for details.
### Why are the changes needed?
Fix: #1543
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing UT and Integration Test.
---
.../spark/shuffle/reader/RssShuffleReader.java | 60 ++++++++++++++--------
.../spark/shuffle/reader/RssShuffleReader.java | 53 ++++++++++---------
2 files changed, 67 insertions(+), 46 deletions(-)
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 76bfed608..8f5118e68 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -21,13 +21,16 @@ import java.util.List;
import java.util.Map;
import scala.Function0;
+import scala.Function2;
import scala.Option;
import scala.Product2;
import scala.collection.Iterator;
import scala.runtime.AbstractFunction0;
+import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;
import org.apache.hadoop.conf.Configuration;
+import org.apache.spark.Aggregator;
import org.apache.spark.InterruptibleIterator;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.TaskContext;
@@ -155,33 +158,38 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
completionIterator.completion();
});
- Iterator<Product2<K, C>> resultIter = null;
- Iterator<Product2<K, C>> aggregatedIter = null;
-
- if (shuffleDependency.aggregator().isDefined()) {
- if (shuffleDependency.mapSideCombine()) {
- // We are reading values that are already combined
- aggregatedIter =
-
shuffleDependency.aggregator().get().combineCombinersByKey(completionIterator,
context);
- } else {
- // We don't know the value type, but also don't care -- the dependency
*should*
- // have made sure its compatible w/ this aggregator, which will
convert the value
- // type to the combined type C
- aggregatedIter =
-
shuffleDependency.aggregator().get().combineValuesByKey(completionIterator,
context);
- }
- } else {
- aggregatedIter = completionIterator;
- }
+ Iterator<Product2<K, C>> resultIter;
if (shuffleDependency.keyOrdering().isDefined()) {
// Create an ExternalSorter to sort the data
- ExternalSorter<K, C, C> sorter =
+ Option<Aggregator<K, Object, C>> aggregator = Option.empty();
+ if (shuffleDependency.aggregator().isDefined()) {
+ if (shuffleDependency.mapSideCombine()) {
+ aggregator =
+ Option.apply(
+ (Aggregator<K, Object, C>)
+ new Aggregator<K, C, C>(
+ new AbstractFunction1<C, C>() {
+ @Override
+ public C apply(C x) {
+ return x;
+ }
+ },
+ (Function2<C, C, C>)
+
shuffleDependency.aggregator().get().mergeCombiners(),
+ (Function2<C, C, C>)
+
shuffleDependency.aggregator().get().mergeCombiners()));
+ } else {
+ aggregator =
+ Option.apply((Aggregator<K, Object, C>)
shuffleDependency.aggregator().get());
+ }
+ }
+ ExternalSorter<K, Object, C> sorter =
new ExternalSorter<>(
- context, Option.empty(), Option.empty(),
shuffleDependency.keyOrdering(), serializer);
+ context, aggregator, Option.empty(),
shuffleDependency.keyOrdering(), serializer);
LOG.info("Inserting aggregated records to sorter");
long startTime = System.currentTimeMillis();
- sorter.insertAll(aggregatedIter);
+ sorter.insertAll(rssShuffleDataIterator);
LOG.info(
"Inserted aggregated records to sorter: millis:"
+ (System.currentTimeMillis() - startTime));
@@ -206,8 +214,16 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
}
};
resultIter = CompletionIterator$.MODULE$.apply(sorter.iterator(), fn0);
+ } else if (shuffleDependency.aggregator().isDefined()) {
+ Aggregator<K, Object, C> aggregator =
+ (Aggregator<K, Object, C>) shuffleDependency.aggregator().get();
+ if (shuffleDependency.mapSideCombine()) {
+ resultIter = aggregator.combineCombinersByKey(rssShuffleDataIterator,
context);
+ } else {
+ resultIter = aggregator.combineValuesByKey(rssShuffleDataIterator,
context);
+ }
} else {
- resultIter = aggregatedIter;
+ resultIter = rssShuffleDataIterator;
}
if (!(resultIter instanceof InterruptibleIterator)) {
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 0c7f3be9e..9b176340b 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -33,6 +33,7 @@ import scala.runtime.BoxedUnit;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
+import org.apache.spark.Aggregator;
import org.apache.spark.InterruptibleIterator;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.TaskContext;
@@ -123,36 +124,32 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
public Iterator<Product2<K, C>> read() {
LOG.info("Shuffle read started:" + getReadInfo());
- Iterator<Product2<K, C>> aggrIter = null;
- Iterator<Product2<K, C>> resultIter = null;
+ Iterator<Product2<K, C>> resultIter;
MultiPartitionIterator rssShuffleDataIterator = new
MultiPartitionIterator<K, C>();
- if (shuffleDependency.aggregator().isDefined()) {
- if (shuffleDependency.mapSideCombine()) {
- aggrIter =
- shuffleDependency
- .aggregator()
- .get()
- .combineCombinersByKey(rssShuffleDataIterator, context);
- } else {
- aggrIter =
- shuffleDependency
- .aggregator()
- .get()
- .combineValuesByKey(rssShuffleDataIterator, context);
- }
- } else {
- aggrIter = rssShuffleDataIterator;
- }
-
if (shuffleDependency.keyOrdering().isDefined()) {
// Create an ExternalSorter to sort the data
- ExternalSorter<K, C, C> sorter =
+ Option<Aggregator<K, Object, C>> aggregator = Option.empty();
+ if (shuffleDependency.aggregator().isDefined()) {
+ if (shuffleDependency.mapSideCombine()) {
+ aggregator =
+ Option.apply(
+ (Aggregator<K, Object, C>)
+ new Aggregator<K, C, C>(
+ x -> x,
+
shuffleDependency.aggregator().get().mergeCombiners(),
+
shuffleDependency.aggregator().get().mergeCombiners()));
+ } else {
+ aggregator =
+ Option.apply((Aggregator<K, Object, C>)
shuffleDependency.aggregator().get());
+ }
+ }
+ ExternalSorter<K, Object, C> sorter =
new ExternalSorter<>(
- context, Option.empty(), Option.empty(),
shuffleDependency.keyOrdering(), serializer);
+ context, aggregator, Option.empty(),
shuffleDependency.keyOrdering(), serializer);
LOG.info("Inserting aggregated records to sorter");
long startTime = System.currentTimeMillis();
- sorter.insertAll(aggrIter);
+ sorter.insertAll(rssShuffleDataIterator);
LOG.info(
"Inserted aggregated records to sorter: millis:"
+ (System.currentTimeMillis() - startTime));
@@ -176,8 +173,16 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
};
context.addTaskCompletionListener(fn1);
resultIter = CompletionIterator$.MODULE$.apply(sorter.iterator(), fn0);
+ } else if (shuffleDependency.aggregator().isDefined()) {
+ Aggregator<K, Object, C> aggregator =
+ (Aggregator<K, Object, C>) shuffleDependency.aggregator().get();
+ if (shuffleDependency.mapSideCombine()) {
+ resultIter = aggregator.combineCombinersByKey(rssShuffleDataIterator,
context);
+ } else {
+ resultIter = aggregator.combineValuesByKey(rssShuffleDataIterator,
context);
+ }
} else {
- resultIter = aggrIter;
+ resultIter = rssShuffleDataIterator;
}
if (!(resultIter instanceof InterruptibleIterator)) {