This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new d26da9cb7 [#1165] improvement(tez): Unregister shuffle data after
completing the execution of a DAG. (#1166)
d26da9cb7 is described below
commit d26da9cb7eae5b5eea67932649833d9f9dda9d5b
Author: Fantasy-Jay <[email protected]>
AuthorDate: Fri Aug 25 10:22:18 2023 +0800
[#1165] improvement(tez): Unregister shuffle data after completing the
execution of a DAG. (#1166)
### What changes were proposed in this pull request?
Generally, one application will execute multiple DAGs, and there is no
correlation between the DAGs. Therefore, after completing the execution of a
DAG, you can unregister the relevant shuffle data. Otherwise, when there are
many DAGs, an application will occupy a large amount of resources for a long
period of time.
### Why are the changes needed?
Fix: https://github.com/apache/incubator-uniffle/issues/1165
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing test cases.
---
.../java/org/apache/tez/common/RssTezUtils.java | 9 ++++
.../org/apache/tez/dag/app/RssDAGAppMaster.java | 44 ++++++++++++++++++-
.../tez/dag/app/TezRemoteShuffleManager.java | 49 ++++++++++++++++++----
.../org/apache/tez/common/RssTezUtilsTest.java | 20 ++++++---
.../apache/tez/dag/app/RssDAGAppMasterTest.java | 31 ++++++++++----
5 files changed, 129 insertions(+), 24 deletions(-)
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
index c44080cb9..4b9f055a4 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
@@ -267,6 +267,15 @@ public class RssTezUtils {
return shuffleId;
}
+ public static int parseDagId(int shuffleId) {
+ Preconditions.checkArgument(shuffleId > 0, "shuffleId should be
positive.");
+ int dagId = shuffleId / (SHUFFLE_ID_MAGIC * SHUFFLE_ID_MAGIC);
+ if (dagId == 0) {
+ throw new RssException("Illegal shuffleId: " + shuffleId);
+ }
+ return dagId;
+ }
+
/**
* @param vertexName: vertex name, like "Map 1" or "Reducer 2"
* @return Map vertex name of String type to int type. Split vertex name,
get vertex type and
diff --git
a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
index 281c53aa3..742fe6616 100644
--- a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
+++ b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
@@ -20,6 +20,8 @@ package org.apache.tez.dag.app;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
+import java.util.Arrays;
+import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executors;
@@ -461,6 +463,46 @@ public class RssDAGAppMaster extends DAGAppMaster {
public static void registerStateEnteredCallback(DAGImpl dag, RssDAGAppMaster
appMaster) {
StateMachineTez stateMachine = (StateMachineTez) getPrivateField(dag,
"stateMachine");
stateMachine.registerStateEnteredCallback(DAGState.INITED, new
DagInitialCallback(appMaster));
+ overrideDAGFinalStateCallback(
+ appMaster,
+ (Map) getPrivateField(stateMachine, "callbackMap"),
+ Arrays.asList(DAGState.SUCCEEDED, DAGState.FAILED, DAGState.KILLED,
DAGState.ERROR));
+ }
+
+ private static void overrideDAGFinalStateCallback(
+ RssDAGAppMaster appMaster, Map callbackMap, List<DAGState> finalStates) {
+ finalStates.forEach(
+ finalState ->
+ callbackMap.put(
+ finalState,
+ new DagFinalStateCallback(
+ appMaster, (OnStateChangedCallback)
callbackMap.get(finalState))));
+ }
+
+ static class DagFinalStateCallback implements
OnStateChangedCallback<DAGState, DAGImpl> {
+
+ private RssDAGAppMaster appMaster;
+ private OnStateChangedCallback callback;
+
+ DagFinalStateCallback(RssDAGAppMaster appMaster, OnStateChangedCallback
callback) {
+ this.appMaster = appMaster;
+ this.callback = callback;
+ }
+
+ @Override
+ public void onStateChanged(DAGImpl dag, DAGState dagState) {
+ callback.onStateChanged(dag, dagState);
+ LOG.info("Receive a dag state change event, dagId={}, dagState={}",
dag.getID(), dagState);
+ long startTime = System.currentTimeMillis();
+ // Generally, one application will execute multiple DAGs, and there is
no correlation between
+ // the DAGs.
+ // Therefore, after executing a DAG, you can unregister the relevant
shuffle data.
+
appMaster.getTezRemoteShuffleManager().unregisterShuffleByDagId(dag.getID());
+ LOG.info(
+ "Complete the task of unregister shuffle, dagId={}, cost={}ms ",
+ dag.getID(),
+ System.currentTimeMillis() - startTime);
+ }
}
static class DagInitialCallback implements OnStateChangedCallback<DAGState,
DAGImpl> {
@@ -593,7 +635,7 @@ public class RssDAGAppMaster extends DAGAppMaster {
// Here we only handle TA_NODE_FAILED. TA_KILL_REQUEST and TA_KILLED
also could trigger
// TerminatedAfterSuccessTransition, but the reason is not about bad
node.
LOG.info(
- "We should not recompute the succeeded task attempt, though task
attempt {} recieved envent {}",
+ "We should not recompute the succeeded task attempt, though task
attempt {} received event {}",
attempt,
event);
return;
diff --git
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
index 808810e1c..59bfc8f19 100644
---
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
+++
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
@@ -21,11 +21,12 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.PrivilegedExceptionAction;
import java.util.Arrays;
-import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Collectors;
import com.google.common.collect.Sets;
import org.apache.commons.lang3.StringUtils;
@@ -51,6 +52,7 @@ import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.app.security.authorize.RssTezAMPolicyProvider;
+import org.apache.tez.dag.records.TezDAGID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -109,15 +111,41 @@ public class TezRemoteShuffleManager implements
ServicePluginLifecycle {
@Override
public void shutdown() throws Exception {
+ unregisterShuffle();
+ server.stop();
+ }
+
+ public void unregisterShuffle() {
if (rssClient != null) {
LOG.info("unregister all shuffle for appid {}", appId);
- Map<Integer, ShuffleAssignmentsInfo> infos =
- tezRemoteShuffleUmbilical.getShuffleIdToShuffleAssignsInfo();
- for (Map.Entry<Integer, ShuffleAssignmentsInfo> entry :
infos.entrySet()) {
- rssClient.unregisterShuffle(appId, entry.getKey());
- }
+ rssClient.unregisterShuffle(appId);
}
- server.stop();
+ }
+
+ public boolean unregisterShuffleByDagId(TezDAGID dagId) {
+ try {
+ Set<Integer> shuffleIds =
+
tezRemoteShuffleUmbilical.getShuffleIdToShuffleAssignsInfo().keySet().stream()
+ .filter(shuffleId -> dagId.getId() ==
RssTezUtils.parseDagId(shuffleId))
+ .collect(Collectors.toSet());
+
+ shuffleIds.forEach(
+ shuffleId -> {
+ long startTime = System.currentTimeMillis();
+ rssClient.unregisterShuffle(appId, shuffleId);
+ tezRemoteShuffleUmbilical.removeShuffleInfo(shuffleId);
+ LOG.info(
+ "Unregister shuffle successfully, appId={}, dagId={},
shuffleId={}, cost={}ms",
+ appId,
+ dagId,
+ shuffleId,
+ System.currentTimeMillis() - startTime);
+ });
+ } catch (Exception e) {
+ LOG.info("Failed to unregister shuffle by dagId: {}", dagId, e);
+ return false;
+ }
+ return true;
}
public InetSocketAddress getAddress() {
@@ -125,7 +153,8 @@ public class TezRemoteShuffleManager implements
ServicePluginLifecycle {
}
private class TezRemoteShuffleUmbilicalProtocolImpl implements
TezRemoteShuffleUmbilicalProtocol {
- private Map<Integer, ShuffleAssignmentsInfo> shuffleIdToShuffleAssignsInfo
= new HashMap<>();
+ private Map<Integer, ShuffleAssignmentsInfo> shuffleIdToShuffleAssignsInfo
=
+ new ConcurrentHashMap<>();
@Override
public long getProtocolVersion(String s, long l) throws IOException {
@@ -185,6 +214,10 @@ public class TezRemoteShuffleManager implements
ServicePluginLifecycle {
Map<Integer, ShuffleAssignmentsInfo> getShuffleIdToShuffleAssignsInfo() {
return shuffleIdToShuffleAssignsInfo;
}
+
+ void removeShuffleInfo(int shuffleId) {
+ shuffleIdToShuffleAssignsInfo.remove(shuffleId);
+ }
}
private ShuffleAssignmentsInfo getShuffleWorks(int partitionNum, int
shuffleId) {
diff --git
a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
index 3aa36bd5c..c21173d15 100644
--- a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
+++ b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
@@ -29,7 +29,6 @@ import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.apache.uniffle.common.ShuffleServerInfo;
@@ -40,6 +39,7 @@ import org.apache.uniffle.storage.util.StorageType;
import static org.apache.tez.common.RssTezConfig.RSS_STORAGE_TYPE;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class RssTezUtilsTest {
@@ -207,9 +207,9 @@ public class RssTezUtilsTest {
dynamic.put(RSS_STORAGE_TYPE, StorageType.LOCALFILE.name());
dynamic.put("config2", "value2");
RssTezUtils.applyDynamicClientConf(conf, dynamic);
- Assertions.assertEquals("value1", conf.get("tez.config1"));
- Assertions.assertEquals("value2", conf.get("tez.config2"));
- Assertions.assertEquals(StorageType.LOCALFILE.name(),
conf.get(RSS_STORAGE_TYPE));
+ assertEquals("value1", conf.get("tez.config1"));
+ assertEquals("value2", conf.get("tez.config2"));
+ assertEquals(StorageType.LOCALFILE.name(), conf.get(RSS_STORAGE_TYPE));
}
@Test
@@ -218,7 +218,15 @@ public class RssTezUtilsTest {
conf1.set("tez.config1", "value1");
conf1.set("config2", "value2");
Configuration conf2 = RssTezUtils.filterRssConf(conf1);
- Assertions.assertEquals("value1", conf2.get("tez.config1"));
- Assertions.assertNull(conf2.get("config2"));
+ assertEquals("value1", conf2.get("tez.config1"));
+ assertNull(conf2.get("config2"));
+ }
+
+ @Test
+ public void testParseDagId() {
+ int shuffleId = RssTezUtils.computeShuffleId(1, 2, 3);
+ assertEquals(1, RssTezUtils.parseDagId(shuffleId));
+ assertThrows(IllegalArgumentException.class, () ->
RssTezUtils.parseDagId(-1));
+ assertThrows(RssException.class, () -> RssTezUtils.parseDagId(100));
}
}
diff --git
a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
index 759580a92..44ca49e31 100644
--- a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
+++ b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
@@ -99,8 +99,11 @@ import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
import static org.apache.tez.common.RssTezConfig.RSS_STORAGE_TYPE;
import static
org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_READAHEAD_BYTES;
import static org.awaitility.Awaitility.await;
+import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class RssDAGAppMasterTest {
@@ -112,7 +115,7 @@ public class RssDAGAppMasterTest {
.getAbsoluteFile();
@Test
- public void testHookAfterDagInited() throws Exception {
+ public void testDagStateChangeCallback() throws Exception {
// 1 Init and mock some basic module
AppContext appContext = mock(AppContext.class);
ApplicationAttemptId appAttemptId =
@@ -132,6 +135,7 @@ public class RssDAGAppMasterTest {
TezRemoteShuffleManager shuffleManager =
mock(TezRemoteShuffleManager.class);
InetSocketAddress address = NetUtils.createSocketAddrForHost("host", 0);
when(shuffleManager.getAddress()).thenReturn(address);
+ when(shuffleManager.unregisterShuffleByDagId(any())).thenReturn(true);
when(appMaster.getTezRemoteShuffleManager()).thenReturn(shuffleManager);
Configuration clientConf = new Configuration(false);
clientConf.set(RSS_STORAGE_TYPE, StorageType.MEMORY_LOCALFILE_HDFS.name());
@@ -187,15 +191,24 @@ public class RssDAGAppMasterTest {
await().atMost(2, TimeUnit.SECONDS).until(() ->
dagImpl.getState().equals(DAGState.INITED));
// 8 verify I/O for vertexImpl
- verfiyOutput(dagImpl, "vertex1",
RssOrderedPartitionedKVOutput.class.getName(), 0, 1);
- verfiyInput(dagImpl, "vertex2", RssOrderedGroupedKVInput.class.getName(),
0, 1);
- verfiyOutput(dagImpl, "vertex2", RssUnorderedKVOutput.class.getName(), 1,
2);
- verfiyInput(dagImpl, "vertex3", RssUnorderedKVInput.class.getName(), 1, 2);
- verfiyOutput(dagImpl, "vertex3",
RssUnorderedPartitionedKVOutput.class.getName(), 2, 3);
- verfiyInput(dagImpl, "vertex4", RssUnorderedKVInput.class.getName(), 2, 3);
+ verifyOutput(dagImpl, "vertex1",
RssOrderedPartitionedKVOutput.class.getName(), 0, 1);
+ verifyInput(dagImpl, "vertex2", RssOrderedGroupedKVInput.class.getName(),
0, 1);
+ verifyOutput(dagImpl, "vertex2", RssUnorderedKVOutput.class.getName(), 1,
2);
+ verifyInput(dagImpl, "vertex3", RssUnorderedKVInput.class.getName(), 1, 2);
+ verifyOutput(dagImpl, "vertex3",
RssUnorderedPartitionedKVOutput.class.getName(), 2, 3);
+ verifyInput(dagImpl, "vertex4", RssUnorderedKVInput.class.getName(), 2, 3);
+
+ // 9 send INTERNAL_ERROR to dispatcher
+ dispatcher.getEventHandler().handle(new DAGEvent(dagImpl.getID(),
DAGEventType.INTERNAL_ERROR));
+
+ // 10 wait DAGImpl transient to INITED state
+ await().atMost(2, TimeUnit.SECONDS).until(() ->
dagImpl.getState().equals(DAGState.ERROR));
+
+ // verify
+ verify(shuffleManager, times(1)).unregisterShuffleByDagId(dagId);
}
- public static void verfiyInput(
+ public static void verifyInput(
DAGImpl dag,
String name,
String expectedInputClassName,
@@ -230,7 +243,7 @@ public class RssDAGAppMasterTest {
expectedDestinationVertexId,
conf.getInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, -1));
}
- public static void verfiyOutput(
+ public static void verifyOutput(
DAGImpl dag,
String name,
String expectedOutputClassName,