This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new de5044215 [#1554] feat(spark): Fetch dynamic client conf as early as
possible (#1557)
de5044215 is described below
commit de50442158fec331769d471c40042f5c26623096
Author: Enrico Minack <[email protected]>
AuthorDate: Thu Mar 14 02:37:20 2024 +0100
[#1554] feat(spark): Fetch dynamic client conf as early as possible (#1557)
### What changes were proposed in this pull request?
Fetches dynamic client config as early as possible, to be able to use
dynamic client config to create the shuffle client with updated config.
### Why are the changes needed?
Providing config or the shuffle client via coordinator is operationally
useful as cluster-wide settings can be deployed through the cluster and changed
over time. Clients and apps do not need to change configs.
Fixes Spark part of #1554
### Does this PR introduce _any_ user-facing change?
More configs can be provided via coordinators.
### How was this patch tested?
Existing and
[follow-up](https://github.com/apache/incubator-uniffle/pull/1528/files#diff-ea644edb1c0bf0e80f9a960adbc1615c99cb6a3a0d5fe24f788307f1daf22f46R127-R131)
unit tests.
---
.../apache/spark/shuffle/RssSparkShuffleUtils.java | 10 ++-
.../shuffle/manager/RssShuffleManagerBase.java | 33 +++++++++
.../shuffle/manager/RssShuffleManagerBaseTest.java | 84 +++++++++++++++++++++-
.../apache/spark/shuffle/RssShuffleManager.java | 24 ++++---
.../shuffle/DelegationRssShuffleManagerTest.java | 1 +
.../apache/spark/shuffle/RssShuffleManager.java | 21 +++---
.../shuffle/DelegationRssShuffleManagerTest.java | 1 +
.../apache/uniffle/test/RssShuffleManagerTest.java | 31 +++++---
8 files changed, 169 insertions(+), 36 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
index e846eb7f4..cf49d3ed8 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
@@ -130,9 +130,13 @@ public class RssSparkShuffleUtils {
sparkConfKey = RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + sparkConfKey;
}
String confVal = kv.getValue();
- if (!sparkConf.contains(sparkConfKey)
- || RssSparkConfig.RSS_MANDATORY_CLUSTER_CONF.contains(sparkConfKey))
{
- LOG.warn("Use conf dynamic conf {} = {}", sparkConfKey, confVal);
+ boolean isMandatory =
RssSparkConfig.RSS_MANDATORY_CLUSTER_CONF.contains(sparkConfKey);
+ if (!sparkConf.contains(sparkConfKey) || isMandatory) {
+ if (sparkConf.contains(sparkConfKey) && isMandatory) {
+ LOG.warn("Override with mandatory dynamic conf {} = {}",
sparkConfKey, confVal);
+ } else {
+ LOG.info("Use dynamic conf {} = {}", sparkConfKey, confVal);
+ }
sparkConf.set(sparkConfKey, confVal);
}
}
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index c75207b89..23922644a 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -41,11 +41,17 @@ import org.apache.spark.shuffle.SparkVersionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.client.api.CoordinatorClient;
+import org.apache.uniffle.client.factory.CoordinatorClientFactory;
+import org.apache.uniffle.client.request.RssFetchClientConfRequest;
+import org.apache.uniffle.client.response.RssFetchClientConfResponse;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.config.ConfigOption;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.rpc.StatusCode;
import static
org.apache.uniffle.common.config.RssClientConf.HADOOP_CONFIG_KEY_PREFIX;
import static
org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_STORAGE_USE_LOCAL_CONF_ENABLED;
@@ -315,6 +321,33 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
return (long) mapIndex << attemptBits | attemptNo;
}
+ protected static void fetchAndApplyDynamicConf(SparkConf sparkConf) {
+ String clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
+ String coordinators =
sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
+ CoordinatorClientFactory coordinatorClientFactory =
CoordinatorClientFactory.getInstance();
+ List<CoordinatorClient> coordinatorClients =
+ coordinatorClientFactory.createCoordinatorClient(
+ ClientType.valueOf(clientType), coordinators);
+
+ int timeoutMs =
+ sparkConf.getInt(
+ RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.key(),
+ RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.defaultValue().get());
+ for (CoordinatorClient client : coordinatorClients) {
+ RssFetchClientConfResponse response =
+ client.fetchClientConf(new RssFetchClientConfRequest(timeoutMs));
+ if (response.getStatusCode() == StatusCode.SUCCESS) {
+ LOG.info("Success to get conf from {}", client.getDesc());
+ RssSparkShuffleUtils.applyDynamicClientConf(sparkConf,
response.getClientConf());
+ break;
+ } else {
+ LOG.warn("Fail to get conf from {}", client.getDesc());
+ }
+ }
+
+ coordinatorClients.forEach(CoordinatorClient::close);
+ }
+
@Override
public void unregisterAllMapOutput(int shuffleId) throws SparkException {
if (!RssSparkShuffleUtils.isStageResubmitSupported()) {
diff --git
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
index 440c8fb8c..da45c82a8 100644
---
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
+++
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
@@ -18,8 +18,11 @@
package org.apache.uniffle.shuffle.manager;
import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
import java.util.stream.Stream;
+import com.google.common.collect.ImmutableMap;
import org.apache.spark.SparkConf;
import org.apache.spark.shuffle.RssSparkConfig;
import org.junit.jupiter.api.Test;
@@ -27,17 +30,33 @@ import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
-
+import org.mockito.ArgumentCaptor;
+import org.mockito.MockedStatic;
+
+import org.apache.uniffle.client.api.CoordinatorClient;
+import org.apache.uniffle.client.factory.CoordinatorClientFactory;
+import org.apache.uniffle.client.request.RssFetchClientConfRequest;
+import org.apache.uniffle.client.response.RssFetchClientConfResponse;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import static org.apache.uniffle.common.rpc.StatusCode.INVALID_REQUEST;
+import static org.apache.uniffle.common.rpc.StatusCode.SUCCESS;
import static
org.apache.uniffle.shuffle.manager.RssShuffleManagerBase.getTaskAttemptIdForBlockId;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockStatic;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
public class RssShuffleManagerBaseTest {
@@ -628,4 +647,67 @@ public class RssShuffleManagerBaseTest {
// check that a lower mapIndex works as expected
assertEquals(bits("11111111|00"), getTaskAttemptIdForBlockId(255, 0, 4,
false, 10));
}
+
+ @Test
+ void testFetchAndApplyDynamicConf() {
+ ClientType clientType = ClientType.GRPC;
+ String coordinators = "host1,host2,host3";
+ int timeout = RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.defaultValue().get() /
10;
+
+ SparkConf conf = new SparkConf();
+ conf.set(RssSparkConfig.RSS_CLIENT_TYPE, clientType.toString());
+ conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM, coordinators);
+ conf.set(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS, timeout);
+
+ CoordinatorClientFactory mockFactoryInstance =
mock(CoordinatorClientFactory.class);
+ CoordinatorClient mockClient1 = mock(CoordinatorClient.class);
+ CoordinatorClient mockClient2 = mock(CoordinatorClient.class);
+ CoordinatorClient mockClient3 = mock(CoordinatorClient.class);
+
+ Map<String, String> clientConf1 = ImmutableMap.of("rss.config.from",
"client1");
+ Map<String, String> clientConf2 = ImmutableMap.of("rss.config.from",
"client2");
+ Map<String, String> clientConf3 = ImmutableMap.of("rss.config.from",
"client3");
+
+ when(mockClient1.fetchClientConf(any(RssFetchClientConfRequest.class)))
+ .thenReturn(new RssFetchClientConfResponse(INVALID_REQUEST, "error",
clientConf1));
+ when(mockClient2.fetchClientConf(any(RssFetchClientConfRequest.class)))
+ .thenReturn(new RssFetchClientConfResponse(SUCCESS, "success",
clientConf2));
+ when(mockClient3.fetchClientConf(any(RssFetchClientConfRequest.class)))
+ .thenReturn(new RssFetchClientConfResponse(SUCCESS, "success",
clientConf3));
+
+ List<CoordinatorClient> mockClients = Arrays.asList(mockClient1,
mockClient2, mockClient3);
+ when(mockFactoryInstance.createCoordinatorClient(clientType, coordinators))
+ .thenReturn(mockClients);
+
+ assertFalse(conf.contains("rss.config.from"));
+ assertFalse(conf.contains("spark.rss.config.from"));
+
+ try (MockedStatic<CoordinatorClientFactory> mockFactoryStatic =
+ mockStatic(CoordinatorClientFactory.class)) {
+
mockFactoryStatic.when(CoordinatorClientFactory::getInstance).thenReturn(mockFactoryInstance);
+ RssShuffleManagerBase.fetchAndApplyDynamicConf(conf);
+ }
+
+ assertFalse(conf.contains("rss.config.from"));
+ assertTrue(conf.contains("spark.rss.config.from"));
+ assertEquals("client2", conf.get("spark.rss.config.from"));
+
+ ArgumentCaptor<RssFetchClientConfRequest> request1 =
+ ArgumentCaptor.forClass(RssFetchClientConfRequest.class);
+ ArgumentCaptor<RssFetchClientConfRequest> request2 =
+ ArgumentCaptor.forClass(RssFetchClientConfRequest.class);
+ ArgumentCaptor<RssFetchClientConfRequest> request3 =
+ ArgumentCaptor.forClass(RssFetchClientConfRequest.class);
+
+ verify(mockClient1, times(1)).fetchClientConf(request1.capture());
+ verify(mockClient2, times(1)).fetchClientConf(request2.capture());
+ verify(mockClient3, never()).fetchClientConf(request3.capture());
+
+ assertEquals(timeout, request1.getValue().getTimeoutMs());
+ assertEquals(timeout, request2.getValue().getTimeoutMs());
+
+ verify(mockClient1, times(1)).close();
+ verify(mockClient2, times(1)).close();
+ verify(mockClient3, times(1)).close();
+ }
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 75cda7810..07e1a180b 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -106,7 +106,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
private final int dataCommitPoolSize;
private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
private boolean heartbeatStarted = false;
- private boolean dynamicConfEnabled = false;
+ private boolean dynamicConfEnabled;
private final int maxFailures;
private final boolean speculation;
private final BlockIdLayout blockIdLayout;
@@ -144,14 +144,24 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
"Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be
false.");
}
this.sparkConf = sparkConf;
+ this.user = sparkConf.get("spark.rss.quota.user", "user");
+ this.uuid = sparkConf.get("spark.rss.quota.uuid",
Long.toString(System.currentTimeMillis()));
+ this.dynamicConfEnabled =
sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
+
+ // fetch client conf and apply them if necessary
+ if (isDriver && this.dynamicConfEnabled) {
+ fetchAndApplyDynamicConf(sparkConf);
+ }
+ RssSparkShuffleUtils.validateRssClientConf(sparkConf);
+
+ // configure block id layout
this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
this.speculation = sparkConf.getBoolean("spark.speculation", false);
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
// configureBlockIdLayout requires maxFailures and speculation to be
initialized
configureBlockIdLayout(sparkConf, rssConf);
this.blockIdLayout = BlockIdLayout.from(rssConf);
- this.user = sparkConf.get("spark.rss.quota.user", "user");
- this.uuid = sparkConf.get("spark.rss.quota.uuid",
Long.toString(System.currentTimeMillis()));
+
// set & check replica config
this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
this.dataReplicaWrite =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
@@ -176,7 +186,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.heartbeatInterval =
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
this.heartbeatTimeout =
sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(),
heartbeatInterval / 2);
- this.dynamicConfEnabled =
sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
long retryIntervalMax =
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
int heartBeatThreadNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
@@ -203,13 +212,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
.unregisterRequestTimeSec(unregisterRequestTimeoutSec)
.rssConf(rssConf));
registerCoordinator();
- // fetch client conf and apply them if necessary and disable ESS
- if (isDriver && dynamicConfEnabled) {
- Map<String, String> clusterClientConf =
-
shuffleWriteClient.fetchClientConf(sparkConf.get(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS));
- RssSparkShuffleUtils.applyDynamicClientConf(sparkConf,
clusterClientConf);
- }
- RssSparkShuffleUtils.validateRssClientConf(sparkConf);
// External shuffle service is not supported when using remote shuffle
service
sparkConf.set("spark.shuffle.service.enabled", "false");
LOG.info("Disable external shuffle service in RssShuffleManager.");
diff --git
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java
index 62cabc55d..8f24c5e5c 100644
---
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java
+++
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java
@@ -125,6 +125,7 @@ public class DelegationRssShuffleManagerTest {
conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false");
conf.set(RssSparkConfig.RSS_ACCESS_ID.key(), "mockId");
conf.set(RssSparkConfig.RSS_ENABLED.key(), "true");
+ conf.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), "MEMORY_LOCALFILE");
// fall back to SortShuffleManager in driver
assertCreateSortShuffleManager(conf);
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 891ccd65e..d0995734b 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -108,7 +108,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
private final Map<String, FailedBlockSendTracker>
taskToFailedBlockSendTracker;
private ScheduledExecutorService heartBeatScheduledExecutorService;
private boolean heartbeatStarted = false;
- private boolean dynamicConfEnabled = false;
+ private boolean dynamicConfEnabled;
private final ShuffleDataDistributionType dataDistributionType;
private final BlockIdLayout blockIdLayout;
private final int maxConcurrencyPerPartitionToWrite;
@@ -160,6 +160,14 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
this.user = sparkConf.get("spark.rss.quota.user", "user");
this.uuid = sparkConf.get("spark.rss.quota.uuid",
Long.toString(System.currentTimeMillis()));
+ this.dynamicConfEnabled =
sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
+
+ // fetch client conf and apply them if necessary
+ if (isDriver && this.dynamicConfEnabled) {
+ fetchAndApplyDynamicConf(sparkConf);
+ }
+ RssSparkShuffleUtils.validateRssClientConf(sparkConf);
+
// set & check replica config
this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
this.dataReplicaWrite =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
@@ -182,7 +190,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(),
heartbeatInterval / 2);
final int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
- this.dynamicConfEnabled =
sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
this.dataDistributionType = getDataDistributionType(sparkConf);
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
this.maxConcurrencyPerPartitionToWrite =
rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
@@ -217,16 +224,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
.unregisterRequestTimeSec(unregisterRequestTimeoutSec)
.rssConf(rssConf));
registerCoordinator();
- // fetch client conf and apply them if necessary and disable ESS
- if (isDriver && dynamicConfEnabled) {
- Map<String, String> clusterClientConf =
- shuffleWriteClient.fetchClientConf(
- sparkConf.getInt(
- RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.key(),
- RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.defaultValue().get()));
- RssSparkShuffleUtils.applyDynamicClientConf(sparkConf,
clusterClientConf);
- }
- RssSparkShuffleUtils.validateRssClientConf(sparkConf);
// External shuffle service is not supported when using remote shuffle
service
sparkConf.set("spark.shuffle.service.enabled", "false");
sparkConf.set("spark.dynamicAllocation.shuffleTracking.enabled", "false");
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java
index 243492d8d..c7011f27e 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java
@@ -78,6 +78,7 @@ public class DelegationRssShuffleManagerTest extends
RssShuffleManagerTestBase {
conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false");
conf.set(RssSparkConfig.RSS_ACCESS_ID.key(), "mockId");
conf.set(RssSparkConfig.RSS_ENABLED.key(), "true");
+ conf.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), "MEMORY_LOCALFILE");
// fall back to SortShuffleManager in driver
assertCreateSortShuffleManager(conf);
diff --git
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
index ac6d739dd..8414cd0b9 100644
---
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
+++
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
@@ -34,11 +34,11 @@ import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.ShuffleHandleInfo;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
@@ -74,7 +74,7 @@ public class RssShuffleManagerTest extends
SparkIntegrationTestBase {
shutdownServers();
}
- public static void startServers(BlockIdLayout dynamicConfLayout) throws
Exception {
+ public static Map<String, String> startServers(BlockIdLayout
dynamicConfLayout) throws Exception {
Map<String, String> dynamicConf = Maps.newHashMap();
dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(),
HDFS_URI + "rss/test");
dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(),
StorageType.MEMORY_LOCALFILE.name());
@@ -95,6 +95,7 @@ public class RssShuffleManagerTest extends
SparkIntegrationTestBase {
createCoordinatorServer(coordinatorConf);
createShuffleServer(getShuffleServerConf(ServerType.GRPC));
startServers();
+ return dynamicConf;
}
@Override
@@ -108,7 +109,7 @@ public class RssShuffleManagerTest extends
SparkIntegrationTestBase {
private static final BlockIdLayout CUSTOM2 = BlockIdLayout.from(22, 18, 23);
public static Stream<Arguments> testBlockIdLayouts() {
- return Stream.of(Arguments.of(CUSTOM1), Arguments.of(CUSTOM2));
+ return Stream.of(Arguments.of(DEFAULT), Arguments.of(CUSTOM1),
Arguments.of(CUSTOM2));
}
@ParameterizedTest
@@ -125,8 +126,6 @@ public class RssShuffleManagerTest extends
SparkIntegrationTestBase {
@ParameterizedTest
@MethodSource("testBlockIdLayouts")
- @Disabled(
- "Dynamic client conf not working for arguments used to create
ShuffleWriteClient: issue #1554")
public void testRssShuffleManagerDynamicClientConf(BlockIdLayout layout)
throws Exception {
doTestRssShuffleManager(null, layout, layout, true);
}
@@ -144,7 +143,7 @@ public class RssShuffleManagerTest extends
SparkIntegrationTestBase {
BlockIdLayout expectedLayout,
boolean enableDynamicCLientConf)
throws Exception {
- startServers(dynamicConfLayout);
+ Map<String, String> dynamicConf = startServers(dynamicConfLayout);
SparkConf conf = createSparkConf();
updateSparkConfWithRss(conf);
@@ -159,7 +158,7 @@ public class RssShuffleManagerTest extends
SparkIntegrationTestBase {
conf.set("spark." + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL,
"1000");
conf.set("spark." + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES,
"10");
- // configure block id layout (if set)
+ // configure client conf block id layout (if set)
if (clientConfLayout != null) {
conf.set(
"spark." + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(),
@@ -182,12 +181,26 @@ public class RssShuffleManagerTest extends
SparkIntegrationTestBase {
// create a rdd that triggers shuffle registration
long count = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2).groupBy(x
-> x).count();
assertEquals(5, count);
- RssShuffleManagerBase shuffleManager =
- (RssShuffleManagerBase) SparkEnv.get().shuffleManager();
+
+ // configure expected block id layout
+ conf.set(
+ "spark." + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(),
+ String.valueOf(expectedLayout.sequenceNoBits));
+ conf.set(
+ "spark." + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(),
+ String.valueOf(expectedLayout.partitionIdBits));
+ conf.set(
+ "spark." + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(),
+ String.valueOf(expectedLayout.taskAttemptIdBits));
// get written block ids (we know there is one shuffle where two task
attempts wrote two
// partitions)
RssConf rssConf = RssSparkConfig.toRssConf(conf);
+ if (clientConfLayout == null && dynamicConfLayout != null) {
+ RssSparkShuffleUtils.applyDynamicClientConf(conf, dynamicConf);
+ }
+ RssShuffleManagerBase shuffleManager =
+ (RssShuffleManagerBase) SparkEnv.get().shuffleManager();
shuffleManager.configureBlockIdLayout(conf, rssConf);
ShuffleWriteClient shuffleWriteClient =
ShuffleClientFactory.newWriteBuilder()