This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 1701d06c [#1011] feat(tez): Avoid recompute succeeded task. (#1033)
1701d06c is described below

commit 1701d06c56f185c4e659e2774eb7be67dcd549fc
Author: zhengchenyu <[email protected]>
AuthorDate: Mon Jul 31 09:55:25 2023 +0800

    [#1011] feat(tez): Avoid recompute succeeded task. (#1033)
    
    ### What changes were proposed in this pull request?
    
    Avoid recompute succeeded task. Detailed information see #1011
    Here only 2.a, 2.b, 2.c is solved. 2.d will not be solved in this PR.
    
    ### Why are the changes needed?
    
    Fix: #1011
    
    ### Does this PR introduce _any_ user-facing change?
    
    New config 'rss.avoid.recompute.succeeded.task' was introduced, default 
value is false. If set to true, we won't recompute the succeed task when the 
reason of recompute is about node failed.
    
    >Note: I suggest change default value to true after running on production 
cluster some months.
    
    ### How was this patch tested?
    
    integration test and test in yarn cluster.
---
 .../java/org/apache/tez/common/RssTezConfig.java   |   4 +
 .../org/apache/tez/dag/app/RssDAGAppMaster.java    |  65 +++-
 docs/client_guide.md                               |   7 +
 .../RssDAGAppMasterForWordCountWithFailures.java   | 371 +++++++++++++++++++++
 .../uniffle/test/TezIntegrationTestBase.java       |   2 +-
 .../uniffle/test/TezWordCountWithFailuresTest.java | 370 ++++++++++++++++++++
 6 files changed, 817 insertions(+), 2 deletions(-)

diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java 
b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
index 3b5e38b5..f9111e04 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
@@ -201,6 +201,10 @@ public class RssTezConfig {
   public static final String RSS_SHUFFLE_DESTINATION_VERTEX_ID =
       TEZ_RSS_CONFIG_PREFIX + "rss.shuffle.destination.vertex.id";
 
+  public static final String RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK =
+      TEZ_RSS_CONFIG_PREFIX + "rss.avoid.recompute.succeeded.task";
+  public static final boolean RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT = 
false;
+
   public static RssConf toRssConf(Configuration jobConf) {
     RssConf rssConf = new RssConf();
     for (Map.Entry<String, String> entry : jobConf) {
diff --git 
a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java 
b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
index d996ea7c..608f26d0 100644
--- a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
+++ b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
@@ -37,12 +37,14 @@ import org.apache.hadoop.yarn.api.ApplicationConstants;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.event.EventHandler;
 import org.apache.hadoop.yarn.util.Clock;
 import org.apache.hadoop.yarn.util.ConverterUtils;
 import org.apache.hadoop.yarn.util.SystemClock;
 import org.apache.log4j.LogManager;
 import org.apache.log4j.helpers.Loader;
 import org.apache.log4j.helpers.OptionConverter;
+import org.apache.tez.common.AsyncDispatcher;
 import org.apache.tez.common.RssTezConfig;
 import org.apache.tez.common.RssTezUtils;
 import org.apache.tez.common.TezClassLoader;
@@ -56,10 +58,15 @@ import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.OutputDescriptor;
 import org.apache.tez.dag.api.TezConstants;
 import org.apache.tez.dag.api.TezUncheckedException;
+import org.apache.tez.dag.api.oldrecords.TaskAttemptState;
 import org.apache.tez.dag.api.records.DAGProtos;
 import org.apache.tez.dag.api.records.DAGProtos.AMPluginDescriptorProto;
 import org.apache.tez.dag.app.dag.DAG;
 import org.apache.tez.dag.app.dag.DAGState;
+import org.apache.tez.dag.app.dag.Task;
+import org.apache.tez.dag.app.dag.TaskAttempt;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEvent;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEventType;
 import org.apache.tez.dag.app.dag.impl.DAGImpl;
 import org.apache.tez.dag.app.dag.impl.Edge;
 import org.apache.tez.dag.library.vertexmanager.ShuffleVertexManager;
@@ -76,8 +83,12 @@ import static 
org.apache.log4j.LogManager.CONFIGURATOR_CLASS_KEY;
 import static org.apache.log4j.LogManager.DEFAULT_CONFIGURATION_KEY;
 import static 
org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS;
 import static org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT;
+import static 
org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK;
+import static 
org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT;
 import static 
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
 import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
+import static 
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS;
+import static 
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT;
 
 public class RssDAGAppMaster extends DAGAppMaster {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssDAGAppMaster.class);
@@ -125,6 +136,10 @@ public class RssDAGAppMaster extends DAGAppMaster {
   @Override
   public synchronized void serviceInit(Configuration conf) throws Exception {
     super.serviceInit(conf);
+    if (conf.getBoolean(
+        RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, 
RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT)) {
+      overrideTaskAttemptEventDispatcher();
+    }
     initAndStartRSSClient(this, conf);
   }
 
@@ -336,6 +351,16 @@ public class RssDAGAppMaster extends DAGAppMaster {
         }
       }
 
+      if (conf.getBoolean(
+              RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, 
RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT)
+          && conf.getBoolean(
+              TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS,
+              TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT)) {
+        LOG.info(
+            "When rss.avoid.recompute.succeeded.task is enable, "
+                + "we can not rescheduler succeeded task on unhealthy node");
+        conf.setBoolean(TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS, false);
+      }
       initAndStartAppMaster(appMaster, conf);
     } catch (Throwable t) {
       LOG.error("Error starting RssDAGAppMaster", t);
@@ -476,7 +501,7 @@ public class RssDAGAppMaster extends DAGAppMaster {
     }
   }
 
-  private static void reconfigureLog4j() {
+  static void reconfigureLog4j() {
     String configuratorClassName = 
OptionConverter.getSystemProperty(CONFIGURATOR_CLASS_KEY, null);
     String configurationOptionStr =
         OptionConverter.getSystemProperty(DEFAULT_CONFIGURATION_KEY, null);
@@ -484,4 +509,42 @@ public class RssDAGAppMaster extends DAGAppMaster {
     OptionConverter.selectAndConfigure(
         url, configuratorClassName, LogManager.getLoggerRepository());
   }
+
+  protected void overrideTaskAttemptEventDispatcher()
+      throws NoSuchFieldException, IllegalAccessException {
+    AsyncDispatcher dispatcher = (AsyncDispatcher) this.getDispatcher();
+    Field field = dispatcher.getClass().getDeclaredField("eventHandlers");
+    field.setAccessible(true);
+    Map<Class<? extends Enum>, EventHandler> eventHandlers =
+        (Map<Class<? extends Enum>, EventHandler>) field.get(dispatcher);
+    eventHandlers.put(TaskAttemptEventType.class, new 
RssTaskAttemptEventDispatcher());
+  }
+
+  private class RssTaskAttemptEventDispatcher implements 
EventHandler<TaskAttemptEvent> {
+    @SuppressWarnings("unchecked")
+    @Override
+    public void handle(TaskAttemptEvent event) {
+      DAG dag = getContext().getCurrentDAG();
+      int eventDagIndex = 
event.getTaskAttemptID().getTaskID().getVertexID().getDAGId().getId();
+      if (dag == null || eventDagIndex != dag.getID().getId()) {
+        return; // event not relevant any more
+      }
+      Task task =
+          dag.getVertex(event.getTaskAttemptID().getTaskID().getVertexID())
+              .getTask(event.getTaskAttemptID().getTaskID());
+      TaskAttempt attempt = task.getAttempt(event.getTaskAttemptID());
+
+      if (attempt.getState() == TaskAttemptState.SUCCEEDED
+          && event.getType() == TaskAttemptEventType.TA_NODE_FAILED) {
+        // Here we only handle TA_NODE_FAILED. TA_KILL_REQUEST and TA_KILLED 
also could trigger
+        // TerminatedAfterSuccessTransition, but the reason is not about bad 
node.
+        LOG.info(
+            "We should not recompute the succeeded task attempt, though task 
attempt {} recieved envent {}",
+            attempt,
+            event);
+        return;
+      }
+      ((EventHandler<TaskAttemptEvent>) attempt).handle(event);
+    }
+  }
 }
diff --git a/docs/client_guide.md b/docs/client_guide.md
index e34e651b..74415260 100644
--- a/docs/client_guide.md
+++ b/docs/client_guide.md
@@ -234,4 +234,11 @@ This experimental feature allows reduce tasks to spill 
data to remote storage (e
 |mapreduce.rss.reduce.remote.spill.retries|5| The retry number to spill data 
to Hadoop FS                            |
 
 Notice: this feature requires the MEMORY_LOCAL_HADOOP mode.
+
+
+### Tez Specialized Setting
+
+| Property Name                  | Default | Description                       
                                      |
+|--------------------------------|---------|-------------------------------------------------------------------------|
+| tez.rss.avoid.recompute.succeeded.task | false   | Whether to avoid 
recompute succeeded task when node is unhealthy or black-listed |
  
\ No newline at end of file
diff --git 
a/integration-test/tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterForWordCountWithFailures.java
 
b/integration-test/tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterForWordCountWithFailures.java
new file mode 100644
index 00000000..9637a2f8
--- /dev/null
+++ 
b/integration-test/tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterForWordCountWithFailures.java
@@ -0,0 +1,371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app;
+
+import java.lang.reflect.Field;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.util.ShutdownHookManager;
+import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler;
+import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.NodeId;
+import org.apache.hadoop.yarn.api.records.NodeReport;
+import org.apache.hadoop.yarn.api.records.NodeState;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.event.EventHandler;
+import org.apache.hadoop.yarn.util.Clock;
+import org.apache.hadoop.yarn.util.ConverterUtils;
+import org.apache.hadoop.yarn.util.SystemClock;
+import org.apache.tez.common.AsyncDispatcher;
+import org.apache.tez.common.TezClassLoader;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.common.VersionInfo;
+import org.apache.tez.dag.api.TezConstants;
+import org.apache.tez.dag.api.oldrecords.TaskAttemptState;
+import org.apache.tez.dag.api.records.DAGProtos;
+import org.apache.tez.dag.app.dag.DAG;
+import org.apache.tez.dag.app.dag.Task;
+import org.apache.tez.dag.app.dag.TaskAttempt;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEvent;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEventAttemptFailed;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEventType;
+import org.apache.tez.dag.app.rm.node.AMNodeEventStateChanged;
+import org.apache.tez.dag.records.TaskAttemptTerminationCause;
+import org.apache.tez.runtime.api.TaskFailureType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.exception.RssException;
+
+import static 
org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK;
+import static 
org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT;
+import static 
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS;
+import static 
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/*
+ * RssDAGAppMasterForWordCountWithFailures is only used for 
TezWordCountWithFailuresTest.
+ * We want to simulate that some task have succeeded, but the node which these 
task have run is label as black list.
+ * Then we will verify whether these task is recompute or not.
+ *
+ * Two test mode are supported:
+ * (a) testMode 0
+ * The test example is WordCount. The parallelism of Tokenizer is 5, it means 
at lease one node run more than one
+ * container. Here if a task succeeded in node1, will kill the next container 
runs on node1. Because
+ * maxtaskfailures.per.node is set to 1, so the node1 will labeled as black 
list, then verify whether the succeeded
+ * task is recomputed or not.
+ *
+ * (b) testMode 1
+ * The test example is WordCount. The parallelism of Tokenizer is 5, it means 
at lease one node run more than one
+ * container. Here if a task succeeded in node1, then will label node1 as 
DECOMMISSIONED. Then verify whether
+ * the succeeded task is recomputed or not.
+ * */
+public class RssDAGAppMasterForWordCountWithFailures extends RssDAGAppMaster {
+
+  private static final Logger LOG =
+      LoggerFactory.getLogger(RssDAGAppMasterForWordCountWithFailures.class);
+
+  private final int testMode;
+
+  public RssDAGAppMasterForWordCountWithFailures(
+      ApplicationAttemptId applicationAttemptId,
+      ContainerId containerId,
+      String nmHost,
+      int nmPort,
+      int nmHttpPort,
+      Clock clock,
+      long appSubmitTime,
+      boolean isSession,
+      String workingDirectory,
+      String[] localDirs,
+      String[] logDirs,
+      String clientVersion,
+      Credentials credentials,
+      String jobUserName,
+      DAGProtos.AMPluginDescriptorProto pluginDescriptorProto,
+      int testMode) {
+    super(
+        applicationAttemptId,
+        containerId,
+        nmHost,
+        nmPort,
+        nmHttpPort,
+        clock,
+        appSubmitTime,
+        isSession,
+        workingDirectory,
+        localDirs,
+        logDirs,
+        clientVersion,
+        credentials,
+        jobUserName,
+        pluginDescriptorProto);
+    this.testMode = testMode;
+  }
+
+  @Override
+  public synchronized void serviceInit(Configuration conf) throws Exception {
+    super.serviceInit(conf);
+    overrideTaskAttemptEventDispatcher();
+  }
+
+  public static void main(String[] args) {
+    int testMode = 0;
+    try {
+      // We use trick way to introduce RssDAGAppMaster by the config 
tez.am.launch.cmd-opts.
+      // It means some property which is set by command line will be ingored, 
so we must reload it.
+      boolean sessionModeCliOption = false;
+      for (int i = 0; i < args.length; i++) {
+        if (args[i].startsWith("-D")) {
+          String[] property = args[i].split("=");
+          if (property.length < 2) {
+            System.setProperty(property[0].substring(2), "");
+          } else {
+            System.setProperty(property[0].substring(2), property[1]);
+          }
+        } else if (args[i].contains("--session") || args[i].contains("-s")) {
+          sessionModeCliOption = true;
+        } else if (args[i].startsWith("--testMode")) {
+          testMode = Integer.parseInt(args[i].substring(10));
+        }
+      }
+      // Load the log4j config is only init in static code block of 
LogManager, so we must
+      // reconfigure.
+      reconfigureLog4j();
+
+      // Install the tez class loader, which can be used add new resources
+      TezClassLoader.setupTezClassLoader();
+      Thread.setDefaultUncaughtExceptionHandler(new 
YarnUncaughtExceptionHandler());
+      final String pid = System.getenv().get("JVM_PID");
+      String containerIdStr = 
System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name());
+      String appSubmitTimeStr = 
System.getenv(ApplicationConstants.APP_SUBMIT_TIME_ENV);
+      String clientVersion = 
System.getenv(TezConstants.TEZ_CLIENT_VERSION_ENV);
+      if (clientVersion == null) {
+        clientVersion = VersionInfo.UNKNOWN;
+      }
+
+      Objects.requireNonNull(
+          appSubmitTimeStr, ApplicationConstants.APP_SUBMIT_TIME_ENV + " is 
null");
+
+      ContainerId containerId = ConverterUtils.toContainerId(containerIdStr);
+      ApplicationAttemptId applicationAttemptId = 
containerId.getApplicationAttemptId();
+
+      String jobUserName = 
System.getenv(ApplicationConstants.Environment.USER.name());
+
+      LOG.info(
+          "Creating RssDAGAppMaster for "
+              + "applicationId="
+              + applicationAttemptId.getApplicationId()
+              + ", attemptNum="
+              + applicationAttemptId.getAttemptId()
+              + ", AMContainerId="
+              + containerId
+              + ", jvmPid="
+              + pid
+              + ", userFromEnv="
+              + jobUserName
+              + ", cliSessionOption="
+              + sessionModeCliOption
+              + ", pwd="
+              + System.getenv(ApplicationConstants.Environment.PWD.name())
+              + ", localDirs="
+              + 
System.getenv(ApplicationConstants.Environment.LOCAL_DIRS.name())
+              + ", logDirs="
+              + 
System.getenv(ApplicationConstants.Environment.LOG_DIRS.name()));
+
+      Configuration conf = new Configuration(new YarnConfiguration());
+
+      DAGProtos.ConfigurationProto confProto =
+          TezUtilsInternal.readUserSpecifiedTezConfiguration(
+              System.getenv(ApplicationConstants.Environment.PWD.name()));
+      TezUtilsInternal.addUserSpecifiedTezConfiguration(conf, 
confProto.getConfKeyValuesList());
+
+      DAGProtos.AMPluginDescriptorProto amPluginDescriptorProto = null;
+      if (confProto.hasAmPluginDescriptor()) {
+        amPluginDescriptorProto = confProto.getAmPluginDescriptor();
+      }
+
+      UserGroupInformation.setConfiguration(conf);
+      Credentials credentials = 
UserGroupInformation.getCurrentUser().getCredentials();
+
+      TezUtilsInternal.setSecurityUtilConfigration(LOG, conf);
+
+      String nodeHostString = 
System.getenv(ApplicationConstants.Environment.NM_HOST.name());
+      String nodePortString = 
System.getenv(ApplicationConstants.Environment.NM_PORT.name());
+      String nodeHttpPortString =
+          System.getenv(ApplicationConstants.Environment.NM_HTTP_PORT.name());
+      long appSubmitTime = Long.parseLong(appSubmitTimeStr);
+      RssDAGAppMasterForWordCountWithFailures appMaster =
+          new RssDAGAppMasterForWordCountWithFailures(
+              applicationAttemptId,
+              containerId,
+              nodeHostString,
+              Integer.parseInt(nodePortString),
+              Integer.parseInt(nodeHttpPortString),
+              new SystemClock(),
+              appSubmitTime,
+              sessionModeCliOption,
+              System.getenv(ApplicationConstants.Environment.PWD.name()),
+              TezCommonUtils.getTrimmedStrings(
+                  
System.getenv(ApplicationConstants.Environment.LOCAL_DIRS.name())),
+              TezCommonUtils.getTrimmedStrings(
+                  
System.getenv(ApplicationConstants.Environment.LOG_DIRS.name())),
+              clientVersion,
+              credentials,
+              jobUserName,
+              amPluginDescriptorProto,
+              testMode);
+      ShutdownHookManager.get()
+          .addShutdownHook(
+              new RssDAGAppMaster.RssDAGAppMasterShutdownHook(appMaster), 
SHUTDOWN_HOOK_PRIORITY);
+
+      // log the system properties
+      if (LOG.isInfoEnabled()) {
+        String systemPropsToLog = 
TezCommonUtils.getSystemPropertiesToLog(conf);
+        if (systemPropsToLog != null) {
+          LOG.info(systemPropsToLog);
+        }
+      }
+
+      LOG.info(
+          "recompute is {}, reschedule is {}",
+          conf.getBoolean(
+              RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, 
RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT),
+          conf.getBoolean(
+              TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS,
+              TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT));
+      if (conf.getBoolean(
+              RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, 
RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT)
+          && conf.getBoolean(
+              TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS,
+              TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT)) {
+        LOG.info(
+            "When rss.avoid.recompute.succeeded.task is enable, "
+                + "we can not rescheduler succeeded task on unhealthy node");
+        conf.setBoolean(TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS, false);
+      }
+      initAndStartAppMaster(appMaster, conf);
+    } catch (Throwable t) {
+      LOG.error("Error starting RssDAGAppMaster", t);
+      System.exit(1);
+    }
+  }
+
+  public void overrideTaskAttemptEventDispatcher()
+      throws NoSuchFieldException, IllegalAccessException {
+    AsyncDispatcher dispatcher = (AsyncDispatcher) this.getDispatcher();
+    Field field = dispatcher.getClass().getDeclaredField("eventHandlers");
+    field.setAccessible(true);
+    Map<Class<? extends Enum>, EventHandler> eventHandlers =
+        (Map<Class<? extends Enum>, EventHandler>) field.get(dispatcher);
+    eventHandlers.put(
+        TaskAttemptEventType.class, new 
RssTaskAttemptEventDispatcher(this.getConfig()));
+  }
+
+  private class RssTaskAttemptEventDispatcher implements 
EventHandler<TaskAttemptEvent> {
+
+    Map<NodeId, Integer> succeed = new HashMap<>();
+    boolean killed = false;
+    boolean avoidRecompute;
+
+    RssTaskAttemptEventDispatcher(Configuration conf) {
+      avoidRecompute =
+          conf.getBoolean(
+              RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, 
RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT);
+    }
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public void handle(TaskAttemptEvent event) {
+      DAG dag = getContext().getCurrentDAG();
+      int eventDagIndex = 
event.getTaskAttemptID().getTaskID().getVertexID().getDAGId().getId();
+      if (dag == null || eventDagIndex != dag.getID().getId()) {
+        return; // event not relevant any more
+      }
+      Task task =
+          dag.getVertex(event.getTaskAttemptID().getTaskID().getVertexID())
+              .getTask(event.getTaskAttemptID().getTaskID());
+      TaskAttempt attempt = task.getAttempt(event.getTaskAttemptID());
+
+      LOG.info("handle task attempt event: {}", event);
+      if (avoidRecompute) {
+        if (attempt.getState() == TaskAttemptState.SUCCEEDED
+            && event.getType() == TaskAttemptEventType.TA_NODE_FAILED) {
+          LOG.info(
+              "We should not recompute the succeeded task attempt, though 
taskattempt {} recieved event {}",
+              attempt,
+              event);
+          return;
+        }
+      }
+      ((EventHandler<TaskAttemptEvent>) attempt).handle(event);
+      // For Tokenizer, record the first succeeded task and its node. When 
next task runs on this
+      // node, will kill this task or label this node as unhealthy.
+      int vertexId = attempt.getVertexID().getId();
+      if (vertexId == 0) {
+        if (attempt.getState() == TaskAttemptState.SUCCEEDED) {
+          NodeId nodeId = attempt.getAssignedContainer().getNodeId();
+          if (!succeed.containsKey(nodeId)) {
+            succeed.put(nodeId, 1);
+          } else {
+            succeed.put(nodeId, succeed.get(nodeId) + 1);
+          }
+        } else if (attempt.getState() == TaskAttemptState.RUNNING) {
+          NodeId nodeId = attempt.getAssignedContainer().getNodeId();
+          if (succeed.getOrDefault(nodeId, 0) == 1 && !killed) {
+            if (testMode == 0) {
+              TaskAttemptEventAttemptFailed eventAttemptFailed =
+                  new TaskAttemptEventAttemptFailed(
+                      attempt.getID(),
+                      TaskAttemptEventType.TA_FAILED,
+                      TaskFailureType.NON_FATAL,
+                      "Triggerd by " + this.getClass().getName(),
+                      TaskAttemptTerminationCause.CONTAINER_LAUNCH_FAILED);
+              LOG.info(
+                  "Killing running task attempt: {} at node: {}",
+                  attempt,
+                  attempt.getAssignedContainer().getNodeId());
+              ((EventHandler<TaskAttemptEvent>) 
attempt).handle(eventAttemptFailed);
+              dag.getEventHandler().handle(eventAttemptFailed);
+            } else if (testMode == 1) {
+              NodeReport nodeReport = mock(NodeReport.class);
+              
when(nodeReport.getNodeState()).thenReturn(NodeState.DECOMMISSIONED);
+              when(nodeReport.getNodeId()).thenReturn(nodeId);
+              LOG.info(
+                  "Label the node {} as DECOMMISSIONED."
+                      + attempt.getAssignedContainer().getNodeId());
+              dag.getEventHandler().handle(new 
AMNodeEventStateChanged(nodeReport, 0));
+            } else {
+              throw new RssException("testMode " + testMode + " is not 
supported!");
+            }
+            killed = true;
+          }
+        }
+      }
+    }
+  }
+}
diff --git 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
index f9584efd..b5219efe 100644
--- 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
+++ 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
@@ -61,7 +61,7 @@ public class TezIntegrationTestBase extends 
IntegrationTestBase {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(TezIntegrationTestBase.class);
   private static String TEST_ROOT_DIR =
-      "target" + Path.SEPARATOR + TezWordCountTest.class.getName() + "-tmpDir";
+      "target" + Path.SEPARATOR + TezIntegrationTestBase.class.getName() + 
"-tmpDir";
 
   private Path remoteStagingDir = null;
   protected static MiniTezCluster miniTezCluster;
diff --git 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
new file mode 100644
index 00000000..14957df4
--- /dev/null
+++ 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
@@ -0,0 +1,370 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.io.IOException;
+import java.lang.reflect.Method;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.client.CallerContext;
+import org.apache.tez.client.TezClient;
+import org.apache.tez.client.TezClientUtils;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.dag.api.DAG;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TezException;
+import org.apache.tez.dag.api.client.DAGClient;
+import org.apache.tez.dag.api.client.DAGStatus;
+import org.apache.tez.dag.api.client.Progress;
+import org.apache.tez.dag.api.client.StatusGetOpts;
+import org.apache.tez.dag.app.RssDAGAppMasterForWordCountWithFailures;
+import org.apache.tez.examples.WordCount;
+import org.apache.tez.hadoop.shim.HadoopShim;
+import org.apache.tez.hadoop.shim.HadoopShimsLoader;
+import org.apache.tez.test.MiniTezCluster;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.storage.util.StorageType;
+
+import static 
org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK;
+import static 
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_MAX_TASK_FAILURES_PER_NODE;
+import static 
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_BLACKLISTING_ENABLED;
+import static 
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_BLACKLISTING_IGNORE_THRESHOLD;
+import static 
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class TezWordCountWithFailuresTest extends IntegrationTestBase {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(TezIntegrationTestBase.class);
+  private static String TEST_ROOT_DIR =
+      "target" + Path.SEPARATOR + TezWordCountWithFailuresTest.class.getName() 
+ "-tmpDir";
+
+  private Path remoteStagingDir = null;
+  private String inputPath = "word_count_input";
+  private String outputPath = "word_count_output";
+  private List<String> wordTable =
+      Lists.newArrayList(
+          "apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan", 
"tomato");
+
+  protected static MiniTezCluster miniTezCluster;
+
+  @BeforeAll
+  public static void beforeClass() throws Exception {
+    LOG.info("Starting mini tez clusters");
+    if (miniTezCluster == null) {
+      miniTezCluster = new 
MiniTezCluster(TezIntegrationTestBase.class.getName(), 3, 1, 1);
+      miniTezCluster.init(conf);
+      miniTezCluster.start();
+    }
+    LOG.info("Starting corrdinators and shuffer servers");
+    CoordinatorConf coordinatorConf = getCoordinatorConf();
+    Map<String, String> dynamicConf = new HashMap();
+    dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), 
HDFS_URI + "rss/test");
+    dynamicConf.put(RssTezConfig.RSS_STORAGE_TYPE, 
StorageType.MEMORY_LOCALFILE_HDFS.name());
+    addDynamicConf(coordinatorConf, dynamicConf);
+    createCoordinatorServer(coordinatorConf);
+    ShuffleServerConf shuffleServerConf = getShuffleServerConf();
+    createShuffleServer(shuffleServerConf);
+    startServers();
+  }
+
+  @AfterAll
+  public static void tearDown() throws Exception {
+    if (miniTezCluster != null) {
+      LOG.info("Stopping MiniTezCluster");
+      miniTezCluster.stop();
+      miniTezCluster = null;
+    }
+  }
+
+  @BeforeEach
+  public void setup() throws Exception {
+    remoteStagingDir =
+        fs.makeQualified(new Path(TEST_ROOT_DIR, String.valueOf(new 
Random().nextInt(100000))));
+    TezClientUtils.ensureStagingDirExists(conf, remoteStagingDir);
+    generateInputFile();
+  }
+
+  private void generateInputFile() throws Exception {
+    assertTrue(fs.mkdirs(new Path(inputPath)));
+    for (int j = 0; j < 5; j++) {
+      FSDataOutputStream outputStream = fs.create(new Path(inputPath + 
"/file." + j));
+      Random random = new Random();
+      for (int i = 0; i < 100; i++) {
+        int index = random.nextInt(wordTable.size());
+        String str = wordTable.get(index) + "\n";
+        outputStream.writeBytes(str);
+      }
+      outputStream.close();
+    }
+    FileStatus[] fileStatus = fs.listStatus(new Path(inputPath));
+    for (FileStatus status : fileStatus) {
+      System.out.println("status is " + status);
+    }
+  }
+
+  @AfterEach
+  public void tearDownEach() throws Exception {
+    if (this.remoteStagingDir != null) {
+      fs.delete(this.remoteStagingDir, true);
+    }
+    for (int j = 0; j < 5; j++) {
+      fs.delete(new Path(inputPath + "/file." + j), true);
+    }
+  }
+
+  @Test
+  public void wordCountTestWithTaskFailureWhenAvoidRecomputeEnable() throws 
Exception {
+    // 1 Run Tez examples based on rss
+    TezConfiguration appConf = new 
TezConfiguration(miniTezCluster.getConfig());
+    updateRssConfiguration(appConf, 0, true, false, 1);
+    TezIntegrationTestBase.appendAndUploadRssJars(appConf);
+    runTezApp(appConf, getTestArgs("rss"), 0);
+    final String rssPath = getOutputDir("rss");
+
+    // 2 Run original Tez examples
+    appConf = new TezConfiguration(miniTezCluster.getConfig());
+    updateCommonConfiguration(appConf);
+    runTezApp(appConf, getTestArgs("origin"), -1);
+    final String originPath = getOutputDir("origin");
+
+    // 3 verify the results
+    TezIntegrationTestBase.verifyResultEqual(originPath, rssPath);
+  }
+
+  @Test
+  public void wordCountTestWithTaskFailureWhenAvoidRecomputeDisable() throws 
Exception {
+    // 1 Run Tez examples based on rss
+    TezConfiguration appConf = new 
TezConfiguration(miniTezCluster.getConfig());
+    updateRssConfiguration(appConf, 0, false, false, 1);
+    TezIntegrationTestBase.appendAndUploadRssJars(appConf);
+    runTezApp(appConf, getTestArgs("rss"), 1);
+    final String rssPath = getOutputDir("rss");
+
+    // 2 Run original Tez examples
+    appConf = new TezConfiguration(miniTezCluster.getConfig());
+    updateCommonConfiguration(appConf);
+    runTezApp(appConf, getTestArgs("origin"), -1);
+    final String originPath = getOutputDir("origin");
+
+    // 3 verify the results
+    TezIntegrationTestBase.verifyResultEqual(originPath, rssPath);
+  }
+
+  @Test
+  public void wordCountTestWithNodeUnhealthyWhenAvoidRecomputeEnable() throws 
Exception {
+    // 1 Run Tez examples based on rss
+    TezConfiguration appConf = new 
TezConfiguration(miniTezCluster.getConfig());
+    updateRssConfiguration(appConf, 1, true, true, 100);
+    TezIntegrationTestBase.appendAndUploadRssJars(appConf);
+    runTezApp(appConf, getTestArgs("rss"), 0);
+    final String rssPath = getOutputDir("rss");
+
+    // 2 Run original Tez examples
+    appConf = new TezConfiguration(miniTezCluster.getConfig());
+    updateCommonConfiguration(appConf);
+    runTezApp(appConf, getTestArgs("origin"), -1);
+    final String originPath = getOutputDir("origin");
+
+    // 3 verify the results
+    TezIntegrationTestBase.verifyResultEqual(originPath, rssPath);
+  }
+
+  @Test
+  public void wordCountTestWithNodeUnhealthyWhenAvoidRecomputeDisable() throws 
Exception {
+    // 1 Run Tez examples based on rss
+    TezConfiguration appConf = new 
TezConfiguration(miniTezCluster.getConfig());
+    updateRssConfiguration(appConf, 1, false, true, 100);
+    TezIntegrationTestBase.appendAndUploadRssJars(appConf);
+    runTezApp(appConf, getTestArgs("rss"), 1);
+    final String rssPath = getOutputDir("rss");
+
+    // 2 Run original Tez examples
+    appConf = new TezConfiguration(miniTezCluster.getConfig());
+    updateCommonConfiguration(appConf);
+    runTezApp(appConf, getTestArgs("origin"), -1);
+    final String originPath = getOutputDir("origin");
+
+    // 3 verify the results
+    TezIntegrationTestBase.verifyResultEqual(originPath, rssPath);
+  }
+
+  /*
+   * Two verify mode are supported:
+   * (a) verifyMode 0
+   *     tez.rss.avoid.recompute.succeeded.task is enable, should not 
recompute the task when this node is
+   *     blacke-listed for unhealthy.
+   *
+   * (b) verifyMode 1
+   *     tez.rss.avoid.recompute.succeeded.task is disable, will recompute the 
task when this node is
+   *     blacke-listed for unhealthy.
+   * */
+  protected void runTezApp(TezConfiguration tezConf, String[] args, int 
verifyMode)
+      throws Exception {
+    assertEquals(
+        0,
+        ToolRunner.run(tezConf, new WordCountWithFailures(verifyMode), args),
+        "WordCountWithFailures failed");
+  }
+
+  public String[] getTestArgs(String uniqueOutputName) {
+    return new String[] {
+      "-disableSplitGrouping", inputPath, outputPath + "/" + uniqueOutputName, 
"2"
+    };
+  }
+
+  public String getOutputDir(String uniqueOutputName) {
+    return outputPath + "/" + uniqueOutputName;
+  }
+
+  /*
+   * In this integration test, mini cluster have three NM with 4G
+   * (YarnConfiguration.DEFAULT_YARN_MINICLUSTER_NM_PMEM_MB). The request of 
am is 4G, the request of task is 2G.
+   * It means that one node only runs one am container so that won't lable the 
node which am container runs as
+   * black-list or uhealthy node.
+   * */
+  public void updateRssConfiguration(
+      Configuration appConf,
+      int testMode,
+      boolean avoidRecompute,
+      boolean rescheduleWhenUnhealthy,
+      int maxFailures)
+      throws Exception {
+    appConf.set(TezConfiguration.TEZ_AM_STAGING_DIR, 
remoteStagingDir.toString());
+    appConf.setInt(TezConfiguration.TEZ_AM_RESOURCE_MEMORY_MB, 4096);
+    appConf.setInt(TezConfiguration.TEZ_TASK_RESOURCE_MEMORY_MB, 4096);
+    appConf.setBoolean(TEZ_AM_NODE_BLACKLISTING_ENABLED, true);
+    appConf.setInt(TEZ_AM_NODE_BLACKLISTING_IGNORE_THRESHOLD, 99);
+    appConf.setInt(TEZ_AM_MAX_TASK_FAILURES_PER_NODE, maxFailures);
+    appConf.set(RssTezConfig.RSS_COORDINATOR_QUORUM, COORDINATOR_QUORUM);
+    appConf.set(RssTezConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name());
+    appConf.set(
+        TezConfiguration.TEZ_AM_LAUNCH_CMD_OPTS,
+        TezConfiguration.TEZ_AM_LAUNCH_CMD_OPTS_DEFAULT
+            + " "
+            + RssDAGAppMasterForWordCountWithFailures.class.getName()
+            + " --testMode"
+            + testMode);
+    appConf.setBoolean(RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, avoidRecompute);
+    appConf.setBoolean(TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS, 
rescheduleWhenUnhealthy);
+  }
+
+  public void updateCommonConfiguration(Configuration appConf) {
+    appConf.set(TezConfiguration.TEZ_AM_STAGING_DIR, 
remoteStagingDir.toString());
+    appConf.setInt(TezConfiguration.TEZ_AM_RESOURCE_MEMORY_MB, 512);
+    appConf.set(TezConfiguration.TEZ_AM_LAUNCH_CMD_OPTS, " -Xmx384m");
+    appConf.setInt(TezConfiguration.TEZ_TASK_RESOURCE_MEMORY_MB, 512);
+    appConf.set(TezConfiguration.TEZ_TASK_LAUNCH_CMD_OPTS, " -Xmx384m");
+  }
+
+  public class WordCountWithFailures extends WordCount {
+
+    TezClient tezClientInternal = null;
+    private HadoopShim hadoopShim;
+    int verifyMode = -1;
+
+    WordCountWithFailures(int assertMode) {
+      this.verifyMode = assertMode;
+    }
+
+    @Override
+    protected int runJob(String[] args, TezConfiguration tezConf, TezClient 
tezClient)
+        throws Exception {
+      this.tezClientInternal = tezClient;
+      Method method =
+          WordCount.class.getDeclaredMethod(
+              "createDAG", TezConfiguration.class, String.class, String.class, 
int.class);
+      method.setAccessible(true);
+      DAG dag =
+          (DAG)
+              method.invoke(
+                  this,
+                  tezConf,
+                  args[0],
+                  args[1],
+                  args.length == 3 ? Integer.parseInt(args[2]) : 1);
+      LOG.info("Running WordCountWithFailures");
+      return runDag(dag, isCountersLog(), LOG);
+    }
+
+    public int runDag(DAG dag, boolean printCounters, Logger logger)
+        throws TezException, InterruptedException, IOException {
+      tezClientInternal.waitTillReady();
+
+      CallerContext callerContext =
+          CallerContext.create("TezExamples", "Tez Example DAG: " + 
dag.getName());
+      ApplicationId appId = tezClientInternal.getAppMasterApplicationId();
+      if (hadoopShim == null) {
+        Configuration conf = (getConf() == null ? new Configuration(false) : 
getConf());
+        hadoopShim = new HadoopShimsLoader(conf).getHadoopShim();
+      }
+
+      if (appId != null) {
+        TezUtilsInternal.setHadoopCallerContext(hadoopShim, appId);
+        callerContext.setCallerIdAndType(appId.toString(), 
"TezExampleApplication");
+      }
+      dag.setCallerContext(callerContext);
+
+      DAGClient dagClient = tezClientInternal.submitDAG(dag);
+      Set<StatusGetOpts> getOpts = Sets.newHashSet();
+      if (printCounters) {
+        getOpts.add(StatusGetOpts.GET_COUNTERS);
+      }
+
+      DAGStatus dagStatus = 
dagClient.waitForCompletionWithStatusUpdates(getOpts);
+      if (dagStatus.getState() != DAGStatus.State.SUCCEEDED) {
+        logger.info("DAG diagnostics: " + dagStatus.getDiagnostics());
+        return -1;
+      }
+
+      Map<String, Progress> progressMap = dagStatus.getVertexProgress();
+      if (verifyMode == 0) {
+        // verifyMode is 0: avoid recompute succeeded task is true
+        Assertions.assertEquals(0, 
progressMap.get("Tokenizer").getKilledTaskAttemptCount());
+      } else if (verifyMode == 1) {
+        // verifyMode is 1: avoid recompute succeeded task is true
+        
Assertions.assertTrue(progressMap.get("Tokenizer").getKilledTaskAttemptCount() 
> 0);
+      }
+      return 0;
+    }
+  }
+}


Reply via email to