This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 735650e63 [CELEBORN-1211] Add extension for celeborn shuffle handler
735650e63 is described below
commit 735650e634508edda1e335751fbadb8f6054170a
Author: mingji <[email protected]>
AuthorDate: Fri Jan 5 15:56:29 2024 +0800
[CELEBORN-1211] Add extension for celeborn shuffle handler
### What changes were proposed in this pull request?
1. Add extension API to CelebornShuffleHandler.
### Why are the changes needed?
### Does this PR introduce _any_ user-facing change?
NO.
### How was this patch tested?
GA.
Closes #2206 from FMX/b1211.
Authored-by: mingji <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
---
.../shuffle/celeborn/SparkShuffleManager.java | 3 ++-
.../shuffle/celeborn/CelebornShuffleHandle.scala | 24 ++++++++++++++++++++--
.../shuffle/celeborn/CelebornShuffleReader.scala | 3 ++-
.../shuffle/celeborn/SparkShuffleManager.java | 3 ++-
.../shuffle/celeborn/CelebornShuffleHandle.scala | 24 ++++++++++++++++++++--
.../shuffle/celeborn/CelebornShuffleReader.scala | 3 ++-
.../org/apache/celeborn/client/ShuffleClient.java | 18 ++++++++++++++++
.../apache/celeborn/client/ShuffleClientImpl.java | 7 +++++++
.../apache/celeborn/client/DummyShuffleClient.java | 3 +++
.../tests/spark/CelebornFetchFailureSuite.scala | 3 ++-
10 files changed, 82 insertions(+), 9 deletions(-)
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index df2dee534..cda992b29 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -185,7 +185,8 @@ public class SparkShuffleManager implements ShuffleManager {
h.lifecycleManagerHost(),
h.lifecycleManagerPort(),
celebornConf,
- h.userIdentifier());
+ h.userIdentifier(),
+ h.extension());
int shuffleId = SparkUtils.celebornShuffleId(client, h, context, true);
shuffleIdTracker.track(h.shuffleId(), shuffleId);
diff --git
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
index 4f67edaf3..4ae52720c 100644
---
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
+++
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
@@ -30,5 +30,25 @@ class CelebornShuffleHandle[K, V, C](
shuffleId: Int,
val throwsFetchFailure: Boolean,
numMappers: Int,
- dependency: ShuffleDependency[K, V, C])
- extends BaseShuffleHandle(shuffleId, numMappers, dependency)
+ dependency: ShuffleDependency[K, V, C],
+ val extension: Array[Byte])
+ extends BaseShuffleHandle(shuffleId, numMappers, dependency) {
+ def this(
+ appUniqueId: String,
+ lifecycleManagerHost: String,
+ lifecycleManagerPort: Int,
+ userIdentifier: UserIdentifier,
+ shuffleId: Int,
+ throwsFetchFailure: Boolean,
+ numMappers: Int,
+ dependency: ShuffleDependency[K, V, C]) = this(
+ appUniqueId,
+ lifecycleManagerHost,
+ lifecycleManagerPort,
+ userIdentifier,
+ shuffleId,
+ throwsFetchFailure,
+ numMappers,
+ dependency,
+ null)
+}
diff --git
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index dec305225..ad1a620bd 100644
---
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -52,7 +52,8 @@ class CelebornShuffleReader[K, C](
handle.lifecycleManagerHost,
handle.lifecycleManagerPort,
conf,
- handle.userIdentifier)
+ handle.userIdentifier,
+ handle.extension)
private val exceptionRef = new AtomicReference[IOException]
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 7cf7b0979..264bf7766 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -233,7 +233,8 @@ public class SparkShuffleManager implements ShuffleManager {
h.lifecycleManagerHost(),
h.lifecycleManagerPort(),
celebornConf,
- h.userIdentifier());
+ h.userIdentifier(),
+ h.extension());
int shuffleId = SparkUtils.celebornShuffleId(shuffleClient, h,
context, true);
shuffleIdTracker.track(h.shuffleId(), shuffleId);
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
index 18a3053e0..3d0180a26 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
@@ -30,5 +30,25 @@ class CelebornShuffleHandle[K, V, C](
shuffleId: Int,
val throwsFetchFailure: Boolean,
val numMappers: Int,
- dependency: ShuffleDependency[K, V, C])
- extends BaseShuffleHandle(shuffleId, dependency)
+ dependency: ShuffleDependency[K, V, C],
+ val extension: Array[Byte])
+ extends BaseShuffleHandle(shuffleId, dependency) {
+ def this(
+ appUniqueId: String,
+ lifecycleManagerHost: String,
+ lifecycleManagerPort: Int,
+ userIdentifier: UserIdentifier,
+ shuffleId: Int,
+ throwsFetchFailure: Boolean,
+ numMappers: Int,
+ dependency: ShuffleDependency[K, V, C]) = this(
+ appUniqueId,
+ lifecycleManagerHost,
+ lifecycleManagerPort,
+ userIdentifier,
+ shuffleId,
+ throwsFetchFailure,
+ numMappers,
+ dependency,
+ null)
+}
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index fe7af8309..663983b58 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -53,7 +53,8 @@ class CelebornShuffleReader[K, C](
handle.lifecycleManagerHost,
handle.lifecycleManagerPort,
conf,
- handle.userIdentifier)
+ handle.userIdentifier,
+ handle.extension)
private val exceptionRef = new AtomicReference[IOException]
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index 72230a536..af7b4e3d5 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -61,6 +61,16 @@ public abstract class ShuffleClient {
int port,
CelebornConf conf,
UserIdentifier userIdentifier) {
+ return ShuffleClient.get(appUniqueId, driverHost, port, conf,
userIdentifier, null);
+ }
+
+ public static ShuffleClient get(
+ String appUniqueId,
+ String driverHost,
+ int port,
+ CelebornConf conf,
+ UserIdentifier userIdentifier,
+ byte[] extension) {
if (null == _instance || !initialized) {
synchronized (ShuffleClient.class) {
if (null == _instance) {
@@ -72,11 +82,13 @@ public abstract class ShuffleClient {
// when communicating with LifecycleManager, it will cause a
NullPointerException.
_instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier);
_instance.setupLifecycleManagerRef(driverHost, port);
+ _instance.setExtension(extension);
initialized = true;
} else if (!initialized) {
_instance.shutdown();
_instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier);
_instance.setupLifecycleManagerRef(driverHost, port);
+ _instance.setExtension(extension);
initialized = true;
}
}
@@ -122,6 +134,12 @@ public abstract class ShuffleClient {
public abstract void setupLifecycleManagerRef(RpcEndpointRef endpointRef);
+ /**
+ * @param extension Extension for shuffle client, it's a byte array. Used in
derived shuffle
+ * client implementation.
+ */
+ public abstract void setExtension(byte[] extension);
+
/**
* Write data to a specific reduce partition
*
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 27aaf1817..a7b4e7083 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -86,6 +86,8 @@ public class ShuffleClientImpl extends ShuffleClient {
protected final int BATCH_HEADER_SIZE = 4 * 4;
+ protected byte[] extension;
+
// key: appShuffleIdentifier, value: shuffleId
protected Map<String, Integer> shuffleIdCache =
JavaUtils.newConcurrentHashMap();
@@ -1703,6 +1705,11 @@ public class ShuffleClientImpl extends ShuffleClient {
lifecycleManagerRef = endpointRef;
}
+ @Override
+ public void setExtension(byte[] extension) {
+ this.extension = extension;
+ }
+
boolean mapperEnded(int shuffleId, int mapId) {
return (mapperEndMap.containsKey(shuffleId) &&
mapperEndMap.get(shuffleId).contains(mapId))
|| isStageEnded(shuffleId);
diff --git
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index bb4f4fe41..dda76b1da 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -62,6 +62,9 @@ public class DummyShuffleClient extends ShuffleClient {
@Override
public void setupLifecycleManagerRef(RpcEndpointRef endpointRef) {}
+ @Override
+ public void setExtension(byte[] extension) {}
+
@Override
public int pushData(
int shuffleId,
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
index ebb916d1a..8983f6bd6 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
@@ -72,7 +72,8 @@ class CelebornFetchFailureSuite extends AnyFunSuite
h.lifecycleManagerHost,
h.lifecycleManagerPort,
conf,
- h.userIdentifier)
+ h.userIdentifier,
+ h.extension)
val celebornShuffleId =
SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
val datafile =
workerDirs.map(dir => {