This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 74c1ec0a7 [CELEBORN-1670] Avoid swallowing InterruptedException in
ShuffleClientImpl
74c1ec0a7 is described below
commit 74c1ec0a7fcc4d9efb26d9b96901234eb76e22cd
Author: jiang13021 <[email protected]>
AuthorDate: Mon Dec 16 11:26:17 2024 +0800
[CELEBORN-1670] Avoid swallowing InterruptedException in ShuffleClientImpl
### What changes were proposed in this pull request?
As title.
### Why are the changes needed?
In the ShuffleClientImpl, methods such as pushData and pushMergedData might
encounter interruptions during message transmission via the TransportClient.
However, the InterruptedException may be ignored, as it is handled as a
standard exception. As a result, the ShuffleClientImpl continues its
operation(retry or revive) even when an InterruptedException occurs.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Add UT:
org.apache.celeborn.client.ShuffleClientSuiteJ#testPushDataAndInterrupted
Closes #2849 from jiang13021/interrupt_in_shuffle_client.
Authored-by: jiang13021 <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../apache/celeborn/client/ShuffleClientImpl.java | 43 +++++++++++++++++---
.../celeborn/client/ShuffleClientSuiteJ.java | 47 ++++++++++++++++++++--
2 files changed, 80 insertions(+), 10 deletions(-)
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index f1e6a57e3..e50e6093c 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -360,8 +360,12 @@ public class ShuffleClientImpl extends ShuffleClient {
batchId,
newLoc,
e);
- pushDataRpcResponseCallback.onFailure(
- new
CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY, e));
+ if (e instanceof InterruptedException) {
+ pushDataRpcResponseCallback.onFailure(e);
+ } else {
+ pushDataRpcResponseCallback.onFailure(
+ new
CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY, e));
+ }
}
}
}
@@ -698,6 +702,9 @@ public class ShuffleClientImpl extends ShuffleClient {
numRetries - 1);
}
} catch (Exception e) {
+ if (e instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ }
logger.error(
"Exception raised while registering shuffle {} with {} mapper and
{} partitions.",
shuffleId,
@@ -896,6 +903,9 @@ public class ShuffleClientImpl extends ShuffleClient {
partitionIds,
epochs,
e);
+ if (e instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ }
return null;
}
}
@@ -1158,6 +1168,11 @@ public class ShuffleClientImpl extends ShuffleClient {
if (pushState.exception.get() != null) {
return;
}
+ if (e instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ callback.onFailure(e);
+ return;
+ }
StatusCode cause = getPushDataFailCause(e.getMessage());
if (remainReviveTimes <= 0) {
if (e instanceof CelebornIOException) {
@@ -1239,8 +1254,12 @@ public class ShuffleClientImpl extends ShuffleClient {
nextBatchId,
loc,
e);
- wrappedCallback.onFailure(
- new
CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY, e));
+ if (e instanceof InterruptedException) {
+ wrappedCallback.onFailure(e);
+ } else {
+ wrappedCallback.onFailure(
+ new
CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY, e));
+ }
}
} else {
// add batch data
@@ -1570,6 +1589,11 @@ public class ShuffleClientImpl extends ShuffleClient {
if (pushState.exception.get() != null) {
return;
}
+ if (e instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ callback.onFailure(e);
+ return;
+ }
StatusCode cause = getPushDataFailCause(e.getMessage());
if (remainReviveTimes <= 0) {
if (e instanceof CelebornIOException) {
@@ -1650,8 +1674,12 @@ public class ShuffleClientImpl extends ShuffleClient {
Arrays.toString(batchIds),
addressPair,
e);
- wrappedCallback.onFailure(
- new
CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY, e));
+ if (e instanceof InterruptedException) {
+ wrappedCallback.onFailure(e);
+ } else {
+ wrappedCallback.onFailure(
+ new
CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY, e));
+ }
}
}
@@ -1764,6 +1792,9 @@ public class ShuffleClientImpl extends ShuffleClient {
}
}
} catch (Exception e) {
+ if (e instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ }
logger.error("Exception raised while call GetReducerFileGroup for
{}.", shuffleId, e);
exceptionMsg = e.getMessage();
}
diff --git
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
index 7f24db71c..5256ae0fb 100644
--- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
+++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
@@ -17,11 +17,10 @@
package org.apache.celeborn.client;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThrows;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@@ -120,6 +119,31 @@ public class ShuffleClientSuiteJ {
}
}
+ @Test
+ public void testPushDataAndInterrupted() throws IOException,
InterruptedException {
+ CelebornConf conf = setupEnv(CompressionCodec.NONE, StatusCode.SUCCESS,
true);
+ try {
+ shuffleClient.pushData(
+ TEST_SHUFFLE_ID,
+ TEST_ATTEMPT_ID,
+ TEST_ATTEMPT_ID,
+ TEST_REDUCRE_ID,
+ TEST_BUF1,
+ 0,
+ TEST_BUF1.length,
+ 1,
+ 1);
+ Thread.sleep(10 * 1000); // waiting for interrupt
+ fail();
+ } catch (Exception e) {
+ if (e instanceof InterruptedException) {
+ assertTrue(true);
+ } else {
+ fail();
+ }
+ }
+ }
+
@Test
public void testMergeData() throws IOException, InterruptedException {
for (CompressionCodec codec : CompressionCodec.values()) {
@@ -205,6 +229,12 @@ public class ShuffleClientSuiteJ {
private CelebornConf setupEnv(CompressionCodec codec, StatusCode statusCode)
throws IOException, InterruptedException {
+ return setupEnv(codec, statusCode, false);
+ }
+
+ private CelebornConf setupEnv(
+ CompressionCodec codec, StatusCode statusCode, boolean
interruptWhenPushData)
+ throws IOException, InterruptedException {
CelebornConf conf = new CelebornConf();
conf.set(CelebornConf.SHUFFLE_COMPRESSION_CODEC().key(), codec.name());
conf.set(CelebornConf.CLIENT_PUSH_RETRY_THREADS().key(), "1");
@@ -346,7 +376,16 @@ public class ShuffleClientSuiteJ {
}
};
- when(client.pushData(any(), anyLong(), any())).thenAnswer(t ->
mockedFuture);
+ if (interruptWhenPushData) {
+ willAnswer(
+ invocation -> {
+ throw new InterruptedException("test");
+ })
+ .given(client)
+ .pushData(any(), anyLong(), any());
+ } else {
+ when(client.pushData(any(), anyLong(), any())).thenAnswer(t ->
mockedFuture);
+ }
when(clientFactory.createClient(
primaryLocation.getHost(), primaryLocation.getPushPort(),
TEST_REDUCRE_ID))
.thenAnswer(t -> client);