mridulm commented on code in PR #46805:
URL: https://github.com/apache/spark/pull/46805#discussion_r1643750605
##########
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java:
##########
@@ -190,6 +217,52 @@ public void onFailure(Throwable t) {
}
}
+ private <T> void retry(
+ final int retryCountValue,
+ final int saslRetryCountValue,
+ final int maxRetries,
+ final boolean enableSaslRetries,
+ int delayMs,
+ Supplier<CompletableFuture<T>> action,
+ CompletableFuture<T> future) {
+ action.get()
+ .thenAccept(future::complete)
+ .exceptionally(e -> {
+ int retryCount = retryCountValue;
+ int saslRetryCount = saslRetryCountValue;
+ boolean isIOException = e instanceof IOException
+ || (e.getCause() != null && e.getCause() instanceof
IOException);
+ boolean isSaslTimeout = enableSaslRetries && e instanceof
SaslTimeoutException;
+ if (!isSaslTimeout && saslRetryCount > 0) {
+ Preconditions.checkState(retryCount >= saslRetryCount,
+ "retryCount must be greater than or equal to
saslRetryCount");
+ retryCount -= saslRetryCount;
+ saslRetryCount = 0;
+ }
+ boolean hasRemainingRetries = retryCount < maxRetries;
+ boolean shouldRetry = (isSaslTimeout || isIOException) &&
+ hasRemainingRetries;
+ if (!shouldRetry) {
+ future.completeExceptionally(e);
+ } else {
+ if (enableSaslRetries && e instanceof SaslTimeoutException) {
+ saslRetryCount += 1;
+ }
+ retryCount += 1;
+ int finalRetryCount = retryCount;
+ int finalSaslRetryCount = saslRetryCount;
Review Comment:
super nit:
```suggestion
final int finalRetryCount = retryCount;
final int finalSaslRetryCount = saslRetryCount;
```
This does get inferred to be `final` - so no longer strictly necessary - but
more aligned with the intent and the var name
##########
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java:
##########
@@ -190,6 +217,52 @@ public void onFailure(Throwable t) {
}
}
+ private <T> void retry(
+ final int retryCountValue,
+ final int saslRetryCountValue,
+ final int maxRetries,
+ final boolean enableSaslRetries,
+ int delayMs,
+ Supplier<CompletableFuture<T>> action,
+ CompletableFuture<T> future) {
+ action.get()
+ .thenAccept(future::complete)
+ .exceptionally(e -> {
+ int retryCount = retryCountValue;
+ int saslRetryCount = saslRetryCountValue;
+ boolean isIOException = e instanceof IOException
+ || (e.getCause() != null && e.getCause() instanceof
IOException);
+ boolean isSaslTimeout = enableSaslRetries && e instanceof
SaslTimeoutException;
+ if (!isSaslTimeout && saslRetryCount > 0) {
+ Preconditions.checkState(retryCount >= saslRetryCount,
+ "retryCount must be greater than or equal to
saslRetryCount");
+ retryCount -= saslRetryCount;
+ saslRetryCount = 0;
+ }
+ boolean hasRemainingRetries = retryCount < maxRetries;
+ boolean shouldRetry = (isSaslTimeout || isIOException) &&
+ hasRemainingRetries;
+ if (!shouldRetry) {
+ future.completeExceptionally(e);
+ } else {
+ if (enableSaslRetries && e instanceof SaslTimeoutException) {
Review Comment:
```suggestion
if (isSaslTimeout) {
```
##########
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala:
##########
@@ -364,6 +368,72 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size
=== 1)
}
+ test("BlockStoreClient getHostLocalDirs RPC supports IOException retry") {
+ val mockClientFactory = mock(classOf[TransportClientFactory])
+ val mockTransportClient = mock(classOf[TransportClient])
+ val execToDirs = Map("exec-1" ->
+ Array("loc2.1", "loc2.2"))
+
+ var sendRpcThrowIOExceptionIdx = 0
+ var sendRpcThrowIOExceptionMaxCnt = 2
Review Comment:
The initial value is a bit confusing - it is always reset below.
```suggestion
var sendRpcThrowIOExceptionMaxCnt = 0
```
##########
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala:
##########
@@ -364,6 +368,72 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size
=== 1)
}
+ test("BlockStoreClient getHostLocalDirs RPC supports IOException retry") {
+ val mockClientFactory = mock(classOf[TransportClientFactory])
+ val mockTransportClient = mock(classOf[TransportClient])
+ val execToDirs = Map("exec-1" ->
+ Array("loc2.1", "loc2.2"))
+
+ var sendRpcThrowIOExceptionIdx = 0
+ var sendRpcThrowIOExceptionMaxCnt = 2
+ when(mockClientFactory.createClient(any(), any())).thenAnswer(_ => {
+ mockTransportClient
+ })
+ when(mockTransportClient.sendRpc(any(),
any())).thenAnswer(invocationOnMock => {
+ if (sendRpcThrowIOExceptionIdx < sendRpcThrowIOExceptionMaxCnt) {
+ sendRpcThrowIOExceptionIdx += 1
+ throw new IOException("sendRpc failed " + sendRpcThrowIOExceptionIdx +
" times")
+ }
+ val callback =
invocationOnMock.getArgument(1).asInstanceOf[RpcResponseCallback]
+ callback.onSuccess(new
LocalDirsForExecutors(execToDirs.asJava).toByteBuffer)
+ null
+ })
+
+ class TestExternalBlockStoreClient(
+ conf: TransportConf,
+ secretKeyHolder: SecretKeyHolder,
+ authEnabled: Boolean,
+ registrationTimeoutMs: Long) extends ExternalBlockStoreClient(
+ conf, secretKeyHolder, authEnabled, registrationTimeoutMs) {
+ override def init(appId: String): Unit = {
+ super.init(appId)
+ this.clientFactory = mockClientFactory
+ }
+ }
+
+ val config = new java.util.HashMap[String, String]
+ config.put("spark.shuffle.io.maxRetries", "3")
+ val clientConf: TransportConf = new TransportConf("shuffle", new
MapConfigProvider(config))
+ val testExternalBlockStoreClient = new TestExternalBlockStoreClient(
+ clientConf, null, false, 5000)
+ try {
+ testExternalBlockStoreClient.init("APP_ID")
+ Seq((0, true), (2, true), (3, true), (4, false)).foreach { case (maxCnt,
success) =>
+ sendRpcThrowIOExceptionIdx = 0
+ sendRpcThrowIOExceptionMaxCnt = maxCnt
+ val hostLocalDirsCompletable = new
CompletableFuture[java.util.Map[String, Array[String]]]
+ testExternalBlockStoreClient.getHostLocalDirs("exec-1", 1,
+ Array("exec-1"), hostLocalDirsCompletable)
+ try {
+ val result = hostLocalDirsCompletable.get()
+ assert(success)
+ assert(result.size() == 1)
+ assert(result.keySet() == execToDirs.asJava.keySet())
+ assert(result.values().iterator().next() sameElements
+ execToDirs.asJava.values().iterator().next())
+ } catch {
+ case e: ExecutionException =>
+ assert(e.getCause.isInstanceOf[IOException])
+ assert(!success)
Review Comment:
`assert(sendRpcThrowIOExceptionIdx == maxRetries)` should hold here, right ?
##########
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java:
##########
@@ -161,6 +172,22 @@ public void getHostLocalDirs(
String[] execIds,
CompletableFuture<Map<String, String[]>> hostLocalDirsCompletable) {
checkInit();
+ int maxRetries = transportConf.maxIORetries();
+ int retryWaitTime = transportConf.ioRetryWaitTimeMs();
+ boolean enableSaslRetries = transportConf.enableSaslRetries();
+ retry(0, 0, maxRetries, enableSaslRetries, retryWaitTime, () -> {
+ CompletableFuture<Map<String, String[]>> tempHostLocalDirsCompletable =
+ new CompletableFuture<>();
+ getHostLocalDirsInternal(host, port, execIds,
tempHostLocalDirsCompletable);
+ return tempHostLocalDirsCompletable;
+ }, hostLocalDirsCompletable);
+ }
+
+ private void getHostLocalDirsInternal(
+ String host,
+ int port,
+ String[] execIds,
+ CompletableFuture<Map<String, String[]>> hostLocalDirsCompletable) {
Review Comment:
While it does minimize the diff, it can make it slightly confusing as things
evolve - `hostLocalDirsCompletable` which is passed into `retry`, and keeps
propagating to subsequent executions of `retry` is different from the
`hostLocalDirsCompletable` passed in as input here.
I would suggest refactoring this method to return the future - and not have
it passed in as parameter.
```suggestion
private CompletableFuture<Map<String, String[]>> getHostLocalDirsInternal(
String host,
int port,
String[] execIds) {
CompletableFuture<Map<String, String[]>> result = new
CompletableFuture<>();
```
##########
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java:
##########
@@ -190,6 +217,52 @@ public void onFailure(Throwable t) {
}
}
+ private <T> void retry(
+ final int retryCountValue,
+ final int saslRetryCountValue,
+ final int maxRetries,
+ final boolean enableSaslRetries,
+ int delayMs,
+ Supplier<CompletableFuture<T>> action,
+ CompletableFuture<T> future) {
+ action.get()
+ .thenAccept(future::complete)
+ .exceptionally(e -> {
+ int retryCount = retryCountValue;
+ int saslRetryCount = saslRetryCountValue;
+ boolean isIOException = e instanceof IOException
+ || (e.getCause() != null && e.getCause() instanceof
IOException);
Review Comment:
`null` check is not required.
```suggestion
boolean isIOException = e instanceof IOException ||
e.getCause() instanceof IOException;
```
##########
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala:
##########
@@ -364,6 +368,72 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size
=== 1)
}
+ test("BlockStoreClient getHostLocalDirs RPC supports IOException retry") {
+ val mockClientFactory = mock(classOf[TransportClientFactory])
+ val mockTransportClient = mock(classOf[TransportClient])
+ val execToDirs = Map("exec-1" ->
+ Array("loc2.1", "loc2.2"))
+
+ var sendRpcThrowIOExceptionIdx = 0
+ var sendRpcThrowIOExceptionMaxCnt = 2
+ when(mockClientFactory.createClient(any(), any())).thenAnswer(_ => {
+ mockTransportClient
+ })
+ when(mockTransportClient.sendRpc(any(),
any())).thenAnswer(invocationOnMock => {
+ if (sendRpcThrowIOExceptionIdx < sendRpcThrowIOExceptionMaxCnt) {
+ sendRpcThrowIOExceptionIdx += 1
+ throw new IOException("sendRpc failed " + sendRpcThrowIOExceptionIdx +
" times")
+ }
+ val callback =
invocationOnMock.getArgument(1).asInstanceOf[RpcResponseCallback]
+ callback.onSuccess(new
LocalDirsForExecutors(execToDirs.asJava).toByteBuffer)
+ null
+ })
+
+ class TestExternalBlockStoreClient(
+ conf: TransportConf,
+ secretKeyHolder: SecretKeyHolder,
+ authEnabled: Boolean,
+ registrationTimeoutMs: Long) extends ExternalBlockStoreClient(
+ conf, secretKeyHolder, authEnabled, registrationTimeoutMs) {
+ override def init(appId: String): Unit = {
+ super.init(appId)
+ this.clientFactory = mockClientFactory
+ }
+ }
+
+ val config = new java.util.HashMap[String, String]
+ config.put("spark.shuffle.io.maxRetries", "3")
+ val clientConf: TransportConf = new TransportConf("shuffle", new
MapConfigProvider(config))
+ val testExternalBlockStoreClient = new TestExternalBlockStoreClient(
+ clientConf, null, false, 5000)
+ try {
+ testExternalBlockStoreClient.init("APP_ID")
+ Seq((0, true), (2, true), (3, true), (4, false)).foreach { case (maxCnt,
success) =>
Review Comment:
nit:
```suggestion
(0 to maxRetries).map((_, true)) ++ Seq((maxRetries + 1,
false)).foreach { case (maxCnt, success) =>
```
##########
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala:
##########
@@ -364,6 +368,72 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size
=== 1)
}
+ test("BlockStoreClient getHostLocalDirs RPC supports IOException retry") {
+ val mockClientFactory = mock(classOf[TransportClientFactory])
+ val mockTransportClient = mock(classOf[TransportClient])
+ val execToDirs = Map("exec-1" ->
+ Array("loc2.1", "loc2.2"))
+
+ var sendRpcThrowIOExceptionIdx = 0
+ var sendRpcThrowIOExceptionMaxCnt = 2
+ when(mockClientFactory.createClient(any(), any())).thenAnswer(_ => {
+ mockTransportClient
+ })
+ when(mockTransportClient.sendRpc(any(),
any())).thenAnswer(invocationOnMock => {
+ if (sendRpcThrowIOExceptionIdx < sendRpcThrowIOExceptionMaxCnt) {
+ sendRpcThrowIOExceptionIdx += 1
+ throw new IOException("sendRpc failed " + sendRpcThrowIOExceptionIdx +
" times")
+ }
+ val callback =
invocationOnMock.getArgument(1).asInstanceOf[RpcResponseCallback]
+ callback.onSuccess(new
LocalDirsForExecutors(execToDirs.asJava).toByteBuffer)
+ null
+ })
+
+ class TestExternalBlockStoreClient(
+ conf: TransportConf,
+ secretKeyHolder: SecretKeyHolder,
+ authEnabled: Boolean,
+ registrationTimeoutMs: Long) extends ExternalBlockStoreClient(
+ conf, secretKeyHolder, authEnabled, registrationTimeoutMs) {
+ override def init(appId: String): Unit = {
+ super.init(appId)
+ this.clientFactory = mockClientFactory
+ }
+ }
+
+ val config = new java.util.HashMap[String, String]
+ config.put("spark.shuffle.io.maxRetries", "3")
Review Comment:
```suggestion
val maxRetries = 3
config.put("spark.shuffle.io.maxRetries", maxRetries.toString)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]