This is an automated email from the ASF dual-hosted git repository.
xianjingfeng pushed a commit to branch branch-0.9
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/branch-0.9 by this push:
new 4944d5481 [#1751][0.9] improvement: support gluten (#1753)
4944d5481 is described below
commit 4944d5481e7b64e75ddf8bf6eee03b27490a3667
Author: xianjingfeng <[email protected]>
AuthorDate: Tue Jun 18 09:13:24 2024 +0800
[#1751][0.9] improvement: support gluten (#1753)
* support gluten
* optimize
* fix bug
* nit
* fix spotless
* nit
* nit
* fix bug
* optimize
* optimize
* nit
* nit
* nit
* nit
* nit
* Update RssShuffleWriter.java
---
.../apache/spark/shuffle/RssShuffleManager.java | 24 +++++++-------
.../spark/shuffle/writer/RssShuffleWriter.java | 6 ++--
.../apache/spark/shuffle/RssShuffleManager.java | 37 ++++++++--------------
.../spark/shuffle/writer/RssShuffleWriter.java | 19 ++++++++---
4 files changed, 44 insertions(+), 42 deletions(-)
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 78bcc2c17..45d338e39 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -475,15 +475,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
int shuffleId = rssHandle.getShuffleId();
String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
- ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled) {
- // Get the ShuffleServer list from the Driver based on the shuffleId
- shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
- } else {
- shuffleHandleInfo =
- new ShuffleHandleInfo(
- shuffleId, rssHandle.getPartitionToServers(),
rssHandle.getRemoteStorage());
- }
ShuffleWriteMetrics writeMetrics =
context.taskMetrics().shuffleWriteMetrics();
return new RssShuffleWriter<>(
rssHandle.getAppId(),
@@ -496,8 +487,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleWriteClient,
rssHandle,
this::markFailedTask,
- context,
- shuffleHandleInfo);
+ context);
} else {
throw new RssException("Unexpected ShuffleHandle:" +
handle.getClass().getName());
}
@@ -806,6 +796,18 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
.createShuffleManagerClient(ClientType.GRPC, host, port);
}
+ public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?>
rssHandle) {
+ if (shuffleManagerRpcServiceEnabled) {
+ // Get the ShuffleServer list from the Driver based on the shuffleId
+ return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
+ } else {
+ return new ShuffleHandleInfo(
+ rssHandle.getShuffleId(),
+ rssHandle.getPartitionToServers(),
+ rssHandle.getRemoteStorage());
+ }
+ }
+
/**
* Get the ShuffleServer list from the Driver based on the shuffleId
*
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 9e64b2fd5..37576c1c9 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -188,8 +188,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
- TaskContext context,
- ShuffleHandleInfo shuffleHandleInfo) {
+ TaskContext context) {
this(
appId,
shuffleId,
@@ -201,9 +200,10 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleWriteClient,
rssHandle,
taskFailureCallback,
- shuffleHandleInfo,
+ shuffleManager.getShuffleHandleInfo(rssHandle),
context);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
+ ShuffleHandleInfo shuffleHandleInfo =
shuffleManager.getShuffleHandleInfo(rssHandle);
final WriteBufferManager bufferManager =
new WriteBufferManager(
shuffleId,
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 6d9487ca4..700b7691b 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -141,7 +141,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
private boolean rssResubmitStage;
private boolean taskBlockSendFailureRetryEnabled;
-
private boolean shuffleManagerRpcServiceEnabled;
/** A list of shuffleServer for Write failures */
private Set<String> failuresShuffleServerIds;
@@ -514,15 +513,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
} else {
writeMetrics = context.taskMetrics().shuffleWriteMetrics();
}
- ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled) {
- // Get the ShuffleServer list from the Driver based on the shuffleId
- shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
- } else {
- shuffleHandleInfo =
- new ShuffleHandleInfo(
- shuffleId, rssHandle.getPartitionToServers(),
rssHandle.getRemoteStorage());
- }
String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(),
rssHandle.getShuffleId());
return new RssShuffleWriter<>(
@@ -536,8 +526,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
shuffleWriteClient,
rssHandle,
this::markFailedTask,
- context,
- shuffleHandleInfo);
+ context);
}
@Override
@@ -656,17 +645,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>)
handle;
final int partitionNum =
rssShuffleHandle.getDependency().partitioner().numPartitions();
int shuffleId = rssShuffleHandle.getShuffleId();
- ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled) {
- // Get the ShuffleServer list from the Driver based on the shuffleId
- shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
- } else {
- shuffleHandleInfo =
- new ShuffleHandleInfo(
- shuffleId,
- rssShuffleHandle.getPartitionToServers(),
- rssShuffleHandle.getRemoteStorage());
- }
+ ShuffleHandleInfo shuffleHandleInfo =
getShuffleHandleInfo(rssShuffleHandle);
Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
shuffleHandleInfo.getPartitionToServers();
Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
@@ -1101,6 +1080,18 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
.createShuffleManagerClient(ClientType.GRPC, host, port);
}
+ public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?>
rssHandle) {
+ if (shuffleManagerRpcServiceEnabled) {
+ // Get the ShuffleServer list from the Driver based on the shuffleId
+ return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
+ } else {
+ return new ShuffleHandleInfo(
+ rssHandle.getShuffleId(),
+ rssHandle.getPartitionToServers(),
+ rssHandle.getRemoteStorage());
+ }
+ }
+
/**
* Get the ShuffleServer list from the Driver based on the shuffleId
*
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 8a22b73ba..70ae3d8f6 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -95,6 +95,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final String appId;
private final int shuffleId;
+ private final ShuffleHandleInfo shuffleHandleInfo;
private WriteBufferManager bufferManager;
private String taskId;
private final int numMaps;
@@ -110,7 +111,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final ShuffleWriteClient shuffleWriteClient;
private final Set<ShuffleServerInfo> shuffleServersForData;
private final long[] partitionLengths;
- private final boolean isMemoryShuffleEnabled;
+ // Gluten needs this variable
+ protected final boolean isMemoryShuffleEnabled;
private final Function<String, Boolean> taskFailureCallback;
private final Set<Long> blockIds = Sets.newConcurrentHashSet();
private TaskContext taskContext;
@@ -195,6 +197,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.isMemoryShuffleEnabled =
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
this.taskFailureCallback = taskFailureCallback;
+ this.shuffleHandleInfo = shuffleHandleInfo;
this.taskContext = context;
this.sparkConf = sparkConf;
this.blockFailSentRetryEnabled =
@@ -204,6 +207,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.defaultValue());
}
+ // Gluten needs this constructor
public RssShuffleWriter(
String appId,
int shuffleId,
@@ -215,8 +219,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
- TaskContext context,
- ShuffleHandleInfo shuffleHandleInfo) {
+ TaskContext context) {
this(
appId,
shuffleId,
@@ -228,7 +231,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleWriteClient,
rssHandle,
taskFailureCallback,
- shuffleHandleInfo,
+ shuffleManager.getShuffleHandleInfo(rssHandle),
context);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
final WriteBufferManager bufferManager =
@@ -264,7 +267,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
- private void writeImpl(Iterator<Product2<K, V>> records) {
+ // Gluten needs this method.
+ protected void writeImpl(Iterator<Product2<K, V>> records) {
List<ShuffleBlockInfo> shuffleBlockInfos;
boolean isCombine = shuffleDependency.mapSideCombine();
Function1<V, C> createCombiner = null;
@@ -322,6 +326,11 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
+ bufferManager.getManagerCostInfo());
}
+ // Gluten needs this method
+ protected void internalCheckBlockSendResult() {
+ this.checkBlockSendResult(this.blockIds);
+ }
+
private void checkSentRecordCount(long recordCount) {
if (recordCount != bufferManager.getRecordCount()) {
String errorMsg =