This is an automated email from the ASF dual-hosted git repository.
xianjingfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new f7c6d2da2 [#1751] improvement: support gluten again (#1857)
f7c6d2da2 is described below
commit f7c6d2da237bd487d3cd0e21231108df90559cbe
Author: xianjingfeng <[email protected]>
AuthorDate: Thu Jul 4 10:24:59 2024 +0800
[#1751] improvement: support gluten again (#1857)
### What changes were proposed in this pull request?
support gluten
### Why are the changes needed?
Currently, gluten will fail to compile using client from the master branch
of uniffle.
Fix: #1751
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Existing UTs and manual testing
---
.../spark/shuffle/writer/WriteBufferManager.java | 5 ++++
.../shuffle/manager/RssShuffleManagerBase.java | 16 +++++++++++
.../apache/spark/shuffle/RssShuffleManager.java | 15 +----------
.../spark/shuffle/writer/RssShuffleWriter.java | 7 ++---
.../apache/spark/shuffle/RssShuffleManager.java | 31 +++-------------------
.../spark/shuffle/writer/RssShuffleWriter.java | 18 +++++++++----
6 files changed, 42 insertions(+), 50 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 95add5048..bfd929777 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -324,6 +324,11 @@ public class WriteBufferManager extends MemoryConsumer {
return shuffleBlockInfos;
}
+ // Gluten needs this method.
+ public synchronized List<ShuffleBlockInfo> clear() {
+ return clear(bufferSpillRatio);
+ }
+
// transform all [partition, records] to [partition, ShuffleBlockInfo] and
clear cache
public synchronized List<ShuffleBlockInfo> clear(double bufferSpillRatio) {
List<ShuffleBlockInfo> result = Lists.newArrayList();
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index 209ede25c..bbeec90dd 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -44,6 +44,7 @@ import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkException;
+import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.RssStageInfo;
@@ -53,6 +54,7 @@ import org.apache.spark.shuffle.ShuffleManager;
import org.apache.spark.shuffle.SparkVersionUtils;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -576,6 +578,20 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""),
confItems);
}
+ public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?>
rssHandle) {
+ int shuffleId = rssHandle.getShuffleId();
+ if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
+ // In Stage Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
+ return getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
+ } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
+ // In Stage Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
+ return getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
+ } else {
+ return new SimpleShuffleHandleInfo(
+ shuffleId, rssHandle.getPartitionToServers(),
rssHandle.getRemoteStorage());
+ }
+ }
+
/**
* In Stage Retry mode, obtain the Shuffle Server list from the Driver based
on shuffleId.
*
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index de5d4da63..1e5bb4941 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -424,18 +424,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
int shuffleId = rssHandle.getShuffleId();
String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
- ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
- // In Stage Retry mode, Get the ShuffleServer list from the Driver
based on the shuffleId
- shuffleHandleInfo =
getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
- } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
- // In Block Retry mode, Get the ShuffleServer list from the Driver
based on the shuffleId
- shuffleHandleInfo =
getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
- } else {
- shuffleHandleInfo =
- new SimpleShuffleHandleInfo(
- shuffleId, rssHandle.getPartitionToServers(),
rssHandle.getRemoteStorage());
- }
ShuffleWriteMetrics writeMetrics =
context.taskMetrics().shuffleWriteMetrics();
return new RssShuffleWriter<>(
rssHandle.getAppId(),
@@ -448,8 +436,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleWriteClient,
rssHandle,
this::markFailedTask,
- context,
- shuffleHandleInfo);
+ context);
} else {
throw new RssException("Unexpected ShuffleHandle:" +
handle.getClass().getName());
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 65b66df3d..5ac6a7e9e 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -97,6 +97,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private String appId;
private int numMaps;
private int shuffleId;
+ private final ShuffleHandleInfo shuffleHandleInfo;
private int bitmapSplitNum;
private String taskId;
private long taskAttemptId;
@@ -176,6 +177,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.isMemoryShuffleEnabled =
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
this.taskFailureCallback = taskFailureCallback;
+ this.shuffleHandleInfo = shuffleHandleInfo;
this.taskContext = context;
this.sparkConf = sparkConf;
}
@@ -191,8 +193,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
- TaskContext context,
- ShuffleHandleInfo shuffleHandleInfo) {
+ TaskContext context) {
this(
appId,
shuffleId,
@@ -204,7 +205,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleWriteClient,
rssHandle,
taskFailureCallback,
- shuffleHandleInfo,
+ shuffleManager.getShuffleHandleInfo(rssHandle),
context);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
final WriteBufferManager bufferManager =
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 1d5050790..bf42bf361 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -512,18 +512,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
} else {
writeMetrics = context.taskMetrics().shuffleWriteMetrics();
}
- ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
- // In Stage Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
- shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
- } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
- // In Stage Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
- shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
- } else {
- shuffleHandleInfo =
- new SimpleShuffleHandleInfo(
- shuffleId, rssHandle.getPartitionToServers(),
rssHandle.getRemoteStorage());
- }
+
String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
return new RssShuffleWriter<>(
rssHandle.getAppId(),
@@ -536,8 +525,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleWriteClient,
rssHandle,
this::markFailedTask,
- context,
- shuffleHandleInfo);
+ context);
}
@Override
@@ -656,20 +644,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>)
handle;
final int partitionNum =
rssShuffleHandle.getDependency().partitioner().numPartitions();
int shuffleId = rssShuffleHandle.getShuffleId();
- ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
- // In Stage Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
- shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
- } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
- // In Block Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
- shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
- } else {
- shuffleHandleInfo =
- new SimpleShuffleHandleInfo(
- shuffleId,
- rssShuffleHandle.getPartitionToServers(),
- rssShuffleHandle.getRemoteStorage());
- }
+ ShuffleHandleInfo shuffleHandleInfo =
getShuffleHandleInfo(rssShuffleHandle);
Map<ShuffleServerInfo, Set<Integer>> serverToPartitions =
getPartitionDataServers(shuffleHandleInfo, startPartition,
endPartition);
long start = System.currentTimeMillis();
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 50eb47001..6660a5e7b 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -104,6 +104,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final String appId;
private final int shuffleId;
+ private final ShuffleHandleInfo shuffleHandleInfo;
private WriteBufferManager bufferManager;
private String taskId;
private final int numMaps;
@@ -119,7 +120,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final ShuffleWriteClient shuffleWriteClient;
private final Set<ShuffleServerInfo> shuffleServersForData;
private final long[] partitionLengths;
- private final boolean isMemoryShuffleEnabled;
+ // Gluten needs this variable
+ protected final boolean isMemoryShuffleEnabled;
private final Function<String, Boolean> taskFailureCallback;
private final Set<Long> blockIds = Sets.newConcurrentHashSet();
private TaskContext taskContext;
@@ -211,6 +213,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.isMemoryShuffleEnabled =
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
this.taskFailureCallback = taskFailureCallback;
+ this.shuffleHandleInfo = shuffleHandleInfo;
this.taskContext = context;
this.sparkConf = sparkConf;
this.blockFailSentRetryEnabled =
@@ -233,8 +236,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
- TaskContext context,
- ShuffleHandleInfo shuffleHandleInfo) {
+ TaskContext context) {
this(
appId,
shuffleId,
@@ -246,7 +248,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleWriteClient,
rssHandle,
taskFailureCallback,
- shuffleHandleInfo,
+ shuffleManager.getShuffleHandleInfo(rssHandle),
context);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
final WriteBufferManager bufferManager =
@@ -288,7 +290,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
- private void writeImpl(Iterator<Product2<K, V>> records) {
+ // Gluten needs this method.
+ protected void writeImpl(Iterator<Product2<K, V>> records) {
List<ShuffleBlockInfo> shuffleBlockInfos;
boolean isCombine = shuffleDependency.mapSideCombine();
@@ -454,6 +457,11 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
return futures;
}
+ // Gluten needs this method
+ protected void internalCheckBlockSendResult() {
+ this.checkBlockSendResult(this.blockIds);
+ }
+
@VisibleForTesting
protected void checkBlockSendResult(Set<Long> blockIds) {
boolean interrupted = false;