This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new d26da9cb7 [#1165] improvement(tez): Unregister shuffle data after 
completing the execution of a DAG.  (#1166)
d26da9cb7 is described below

commit d26da9cb7eae5b5eea67932649833d9f9dda9d5b
Author: Fantasy-Jay <[email protected]>
AuthorDate: Fri Aug 25 10:22:18 2023 +0800

    [#1165] improvement(tez): Unregister shuffle data after completing the 
execution of a DAG.  (#1166)
    
    ### What changes were proposed in this pull request?
    
    Generally, one application will execute multiple DAGs, and there is no 
correlation between the DAGs. Therefore, after completing the execution of a 
DAG, you can unregister the relevant shuffle data. Otherwise, when there are 
many DAGs, an application will occupy a large amount of resources for a long 
period of time.
    
    ### Why are the changes needed?
    
    Fix: https://github.com/apache/incubator-uniffle/issues/1165
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing test cases.
---
 .../java/org/apache/tez/common/RssTezUtils.java    |  9 ++++
 .../org/apache/tez/dag/app/RssDAGAppMaster.java    | 44 ++++++++++++++++++-
 .../tez/dag/app/TezRemoteShuffleManager.java       | 49 ++++++++++++++++++----
 .../org/apache/tez/common/RssTezUtilsTest.java     | 20 ++++++---
 .../apache/tez/dag/app/RssDAGAppMasterTest.java    | 31 ++++++++++----
 5 files changed, 129 insertions(+), 24 deletions(-)

diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java 
b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
index c44080cb9..4b9f055a4 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
@@ -267,6 +267,15 @@ public class RssTezUtils {
     return shuffleId;
   }
 
+  public static int parseDagId(int shuffleId) {
+    Preconditions.checkArgument(shuffleId > 0, "shuffleId should be 
positive.");
+    int dagId = shuffleId / (SHUFFLE_ID_MAGIC * SHUFFLE_ID_MAGIC);
+    if (dagId == 0) {
+      throw new RssException("Illegal shuffleId: " + shuffleId);
+    }
+    return dagId;
+  }
+
   /**
    * @param vertexName: vertex name, like "Map 1" or "Reducer 2"
    * @return Map vertex name of String type to int type. Split vertex name, 
get vertex type and
diff --git 
a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java 
b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
index 281c53aa3..742fe6616 100644
--- a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
+++ b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
@@ -20,6 +20,8 @@ package org.apache.tez.dag.app;
 import java.io.IOException;
 import java.lang.reflect.Field;
 import java.net.URL;
+import java.util.Arrays;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.concurrent.Executors;
@@ -461,6 +463,46 @@ public class RssDAGAppMaster extends DAGAppMaster {
   public static void registerStateEnteredCallback(DAGImpl dag, RssDAGAppMaster 
appMaster) {
     StateMachineTez stateMachine = (StateMachineTez) getPrivateField(dag, 
"stateMachine");
     stateMachine.registerStateEnteredCallback(DAGState.INITED, new 
DagInitialCallback(appMaster));
+    overrideDAGFinalStateCallback(
+        appMaster,
+        (Map) getPrivateField(stateMachine, "callbackMap"),
+        Arrays.asList(DAGState.SUCCEEDED, DAGState.FAILED, DAGState.KILLED, 
DAGState.ERROR));
+  }
+
+  private static void overrideDAGFinalStateCallback(
+      RssDAGAppMaster appMaster, Map callbackMap, List<DAGState> finalStates) {
+    finalStates.forEach(
+        finalState ->
+            callbackMap.put(
+                finalState,
+                new DagFinalStateCallback(
+                    appMaster, (OnStateChangedCallback) 
callbackMap.get(finalState))));
+  }
+
+  static class DagFinalStateCallback implements 
OnStateChangedCallback<DAGState, DAGImpl> {
+
+    private RssDAGAppMaster appMaster;
+    private OnStateChangedCallback callback;
+
+    DagFinalStateCallback(RssDAGAppMaster appMaster, OnStateChangedCallback 
callback) {
+      this.appMaster = appMaster;
+      this.callback = callback;
+    }
+
+    @Override
+    public void onStateChanged(DAGImpl dag, DAGState dagState) {
+      callback.onStateChanged(dag, dagState);
+      LOG.info("Receive a dag state change event, dagId={}, dagState={}", 
dag.getID(), dagState);
+      long startTime = System.currentTimeMillis();
+      // Generally, one application will execute multiple DAGs, and there is 
no correlation between
+      // the DAGs.
+      // Therefore, after executing a DAG, you can unregister the relevant 
shuffle data.
+      
appMaster.getTezRemoteShuffleManager().unregisterShuffleByDagId(dag.getID());
+      LOG.info(
+          "Complete the task of unregister shuffle, dagId={}, cost={}ms ",
+          dag.getID(),
+          System.currentTimeMillis() - startTime);
+    }
   }
 
   static class DagInitialCallback implements OnStateChangedCallback<DAGState, 
DAGImpl> {
@@ -593,7 +635,7 @@ public class RssDAGAppMaster extends DAGAppMaster {
         // Here we only handle TA_NODE_FAILED. TA_KILL_REQUEST and TA_KILLED 
also could trigger
         // TerminatedAfterSuccessTransition, but the reason is not about bad 
node.
         LOG.info(
-            "We should not recompute the succeeded task attempt, though task 
attempt {} recieved envent {}",
+            "We should not recompute the succeeded task attempt, though task 
attempt {} received event {}",
             attempt,
             event);
         return;
diff --git 
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java 
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
index 808810e1c..59bfc8f19 100644
--- 
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
+++ 
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
@@ -21,11 +21,12 @@ import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.security.PrivilegedExceptionAction;
 import java.util.Arrays;
-import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Collectors;
 
 import com.google.common.collect.Sets;
 import org.apache.commons.lang3.StringUtils;
@@ -51,6 +52,7 @@ import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.TezException;
 import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.dag.app.security.authorize.RssTezAMPolicyProvider;
+import org.apache.tez.dag.records.TezDAGID;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -109,15 +111,41 @@ public class TezRemoteShuffleManager implements 
ServicePluginLifecycle {
 
   @Override
   public void shutdown() throws Exception {
+    unregisterShuffle();
+    server.stop();
+  }
+
+  public void unregisterShuffle() {
     if (rssClient != null) {
       LOG.info("unregister all shuffle for appid {}", appId);
-      Map<Integer, ShuffleAssignmentsInfo> infos =
-          tezRemoteShuffleUmbilical.getShuffleIdToShuffleAssignsInfo();
-      for (Map.Entry<Integer, ShuffleAssignmentsInfo> entry : 
infos.entrySet()) {
-        rssClient.unregisterShuffle(appId, entry.getKey());
-      }
+      rssClient.unregisterShuffle(appId);
     }
-    server.stop();
+  }
+
+  public boolean unregisterShuffleByDagId(TezDAGID dagId) {
+    try {
+      Set<Integer> shuffleIds =
+          
tezRemoteShuffleUmbilical.getShuffleIdToShuffleAssignsInfo().keySet().stream()
+              .filter(shuffleId -> dagId.getId() == 
RssTezUtils.parseDagId(shuffleId))
+              .collect(Collectors.toSet());
+
+      shuffleIds.forEach(
+          shuffleId -> {
+            long startTime = System.currentTimeMillis();
+            rssClient.unregisterShuffle(appId, shuffleId);
+            tezRemoteShuffleUmbilical.removeShuffleInfo(shuffleId);
+            LOG.info(
+                "Unregister shuffle successfully, appId={}, dagId={}, 
shuffleId={}, cost={}ms",
+                appId,
+                dagId,
+                shuffleId,
+                System.currentTimeMillis() - startTime);
+          });
+    } catch (Exception e) {
+      LOG.info("Failed to unregister shuffle by dagId: {}", dagId, e);
+      return false;
+    }
+    return true;
   }
 
   public InetSocketAddress getAddress() {
@@ -125,7 +153,8 @@ public class TezRemoteShuffleManager implements 
ServicePluginLifecycle {
   }
 
   private class TezRemoteShuffleUmbilicalProtocolImpl implements 
TezRemoteShuffleUmbilicalProtocol {
-    private Map<Integer, ShuffleAssignmentsInfo> shuffleIdToShuffleAssignsInfo 
= new HashMap<>();
+    private Map<Integer, ShuffleAssignmentsInfo> shuffleIdToShuffleAssignsInfo 
=
+        new ConcurrentHashMap<>();
 
     @Override
     public long getProtocolVersion(String s, long l) throws IOException {
@@ -185,6 +214,10 @@ public class TezRemoteShuffleManager implements 
ServicePluginLifecycle {
     Map<Integer, ShuffleAssignmentsInfo> getShuffleIdToShuffleAssignsInfo() {
       return shuffleIdToShuffleAssignsInfo;
     }
+
+    void removeShuffleInfo(int shuffleId) {
+      shuffleIdToShuffleAssignsInfo.remove(shuffleId);
+    }
   }
 
   private ShuffleAssignmentsInfo getShuffleWorks(int partitionNum, int 
shuffleId) {
diff --git 
a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java 
b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
index 3aa36bd5c..c21173d15 100644
--- a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
+++ b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
@@ -29,7 +29,6 @@ import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.dag.records.TezTaskID;
 import org.apache.tez.dag.records.TezVertexID;
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import org.apache.uniffle.common.ShuffleServerInfo;
@@ -40,6 +39,7 @@ import org.apache.uniffle.storage.util.StorageType;
 import static org.apache.tez.common.RssTezConfig.RSS_STORAGE_TYPE;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class RssTezUtilsTest {
@@ -207,9 +207,9 @@ public class RssTezUtilsTest {
     dynamic.put(RSS_STORAGE_TYPE, StorageType.LOCALFILE.name());
     dynamic.put("config2", "value2");
     RssTezUtils.applyDynamicClientConf(conf, dynamic);
-    Assertions.assertEquals("value1", conf.get("tez.config1"));
-    Assertions.assertEquals("value2", conf.get("tez.config2"));
-    Assertions.assertEquals(StorageType.LOCALFILE.name(), 
conf.get(RSS_STORAGE_TYPE));
+    assertEquals("value1", conf.get("tez.config1"));
+    assertEquals("value2", conf.get("tez.config2"));
+    assertEquals(StorageType.LOCALFILE.name(), conf.get(RSS_STORAGE_TYPE));
   }
 
   @Test
@@ -218,7 +218,15 @@ public class RssTezUtilsTest {
     conf1.set("tez.config1", "value1");
     conf1.set("config2", "value2");
     Configuration conf2 = RssTezUtils.filterRssConf(conf1);
-    Assertions.assertEquals("value1", conf2.get("tez.config1"));
-    Assertions.assertNull(conf2.get("config2"));
+    assertEquals("value1", conf2.get("tez.config1"));
+    assertNull(conf2.get("config2"));
+  }
+
+  @Test
+  public void testParseDagId() {
+    int shuffleId = RssTezUtils.computeShuffleId(1, 2, 3);
+    assertEquals(1, RssTezUtils.parseDagId(shuffleId));
+    assertThrows(IllegalArgumentException.class, () -> 
RssTezUtils.parseDagId(-1));
+    assertThrows(RssException.class, () -> RssTezUtils.parseDagId(100));
   }
 }
diff --git 
a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java 
b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
index 759580a92..44ca49e31 100644
--- a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
+++ b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
@@ -99,8 +99,11 @@ import static 
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
 import static org.apache.tez.common.RssTezConfig.RSS_STORAGE_TYPE;
 import static 
org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_IFILE_READAHEAD_BYTES;
 import static org.awaitility.Awaitility.await;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 public class RssDAGAppMasterTest {
@@ -112,7 +115,7 @@ public class RssDAGAppMasterTest {
           .getAbsoluteFile();
 
   @Test
-  public void testHookAfterDagInited() throws Exception {
+  public void testDagStateChangeCallback() throws Exception {
     // 1 Init and mock some basic module
     AppContext appContext = mock(AppContext.class);
     ApplicationAttemptId appAttemptId =
@@ -132,6 +135,7 @@ public class RssDAGAppMasterTest {
     TezRemoteShuffleManager shuffleManager = 
mock(TezRemoteShuffleManager.class);
     InetSocketAddress address = NetUtils.createSocketAddrForHost("host", 0);
     when(shuffleManager.getAddress()).thenReturn(address);
+    when(shuffleManager.unregisterShuffleByDagId(any())).thenReturn(true);
     when(appMaster.getTezRemoteShuffleManager()).thenReturn(shuffleManager);
     Configuration clientConf = new Configuration(false);
     clientConf.set(RSS_STORAGE_TYPE, StorageType.MEMORY_LOCALFILE_HDFS.name());
@@ -187,15 +191,24 @@ public class RssDAGAppMasterTest {
     await().atMost(2, TimeUnit.SECONDS).until(() -> 
dagImpl.getState().equals(DAGState.INITED));
 
     // 8 verify I/O for vertexImpl
-    verfiyOutput(dagImpl, "vertex1", 
RssOrderedPartitionedKVOutput.class.getName(), 0, 1);
-    verfiyInput(dagImpl, "vertex2", RssOrderedGroupedKVInput.class.getName(), 
0, 1);
-    verfiyOutput(dagImpl, "vertex2", RssUnorderedKVOutput.class.getName(), 1, 
2);
-    verfiyInput(dagImpl, "vertex3", RssUnorderedKVInput.class.getName(), 1, 2);
-    verfiyOutput(dagImpl, "vertex3", 
RssUnorderedPartitionedKVOutput.class.getName(), 2, 3);
-    verfiyInput(dagImpl, "vertex4", RssUnorderedKVInput.class.getName(), 2, 3);
+    verifyOutput(dagImpl, "vertex1", 
RssOrderedPartitionedKVOutput.class.getName(), 0, 1);
+    verifyInput(dagImpl, "vertex2", RssOrderedGroupedKVInput.class.getName(), 
0, 1);
+    verifyOutput(dagImpl, "vertex2", RssUnorderedKVOutput.class.getName(), 1, 
2);
+    verifyInput(dagImpl, "vertex3", RssUnorderedKVInput.class.getName(), 1, 2);
+    verifyOutput(dagImpl, "vertex3", 
RssUnorderedPartitionedKVOutput.class.getName(), 2, 3);
+    verifyInput(dagImpl, "vertex4", RssUnorderedKVInput.class.getName(), 2, 3);
+
+    // 9 send INTERNAL_ERROR to dispatcher
+    dispatcher.getEventHandler().handle(new DAGEvent(dagImpl.getID(), 
DAGEventType.INTERNAL_ERROR));
+
+    // 10 wait DAGImpl transient to INITED state
+    await().atMost(2, TimeUnit.SECONDS).until(() -> 
dagImpl.getState().equals(DAGState.ERROR));
+
+    // verify
+    verify(shuffleManager, times(1)).unregisterShuffleByDagId(dagId);
   }
 
-  public static void verfiyInput(
+  public static void verifyInput(
       DAGImpl dag,
       String name,
       String expectedInputClassName,
@@ -230,7 +243,7 @@ public class RssDAGAppMasterTest {
         expectedDestinationVertexId, 
conf.getInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, -1));
   }
 
-  public static void verfiyOutput(
+  public static void verifyOutput(
       DAGImpl dag,
       String name,
       String expectedOutputClassName,

Reply via email to