This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 1701d06c [#1011] feat(tez): Avoid recompute succeeded task. (#1033)
1701d06c is described below
commit 1701d06c56f185c4e659e2774eb7be67dcd549fc
Author: zhengchenyu <[email protected]>
AuthorDate: Mon Jul 31 09:55:25 2023 +0800
[#1011] feat(tez): Avoid recompute succeeded task. (#1033)
### What changes were proposed in this pull request?
Avoid recompute succeeded task. Detailed information see #1011
Here only 2.a, 2.b, 2.c is solved. 2.d will not be solved in this PR.
### Why are the changes needed?
Fix: #1011
### Does this PR introduce _any_ user-facing change?
New config 'rss.avoid.recompute.succeeded.task' was introduced, default
value is false. If set to true, we won't recompute the succeed task when the
reason of recompute is about node failed.
>Note: I suggest change default value to true after running on production
cluster some months.
### How was this patch tested?
integration test and test in yarn cluster.
---
.../java/org/apache/tez/common/RssTezConfig.java | 4 +
.../org/apache/tez/dag/app/RssDAGAppMaster.java | 65 +++-
docs/client_guide.md | 7 +
.../RssDAGAppMasterForWordCountWithFailures.java | 371 +++++++++++++++++++++
.../uniffle/test/TezIntegrationTestBase.java | 2 +-
.../uniffle/test/TezWordCountWithFailuresTest.java | 370 ++++++++++++++++++++
6 files changed, 817 insertions(+), 2 deletions(-)
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
index 3b5e38b5..f9111e04 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
@@ -201,6 +201,10 @@ public class RssTezConfig {
public static final String RSS_SHUFFLE_DESTINATION_VERTEX_ID =
TEZ_RSS_CONFIG_PREFIX + "rss.shuffle.destination.vertex.id";
+ public static final String RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK =
+ TEZ_RSS_CONFIG_PREFIX + "rss.avoid.recompute.succeeded.task";
+ public static final boolean RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT =
false;
+
public static RssConf toRssConf(Configuration jobConf) {
RssConf rssConf = new RssConf();
for (Map.Entry<String, String> entry : jobConf) {
diff --git
a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
index d996ea7c..608f26d0 100644
--- a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
+++ b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
@@ -37,12 +37,14 @@ import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.hadoop.yarn.util.Clock;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.hadoop.yarn.util.SystemClock;
import org.apache.log4j.LogManager;
import org.apache.log4j.helpers.Loader;
import org.apache.log4j.helpers.OptionConverter;
+import org.apache.tez.common.AsyncDispatcher;
import org.apache.tez.common.RssTezConfig;
import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezClassLoader;
@@ -56,10 +58,15 @@ import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.api.TezConstants;
import org.apache.tez.dag.api.TezUncheckedException;
+import org.apache.tez.dag.api.oldrecords.TaskAttemptState;
import org.apache.tez.dag.api.records.DAGProtos;
import org.apache.tez.dag.api.records.DAGProtos.AMPluginDescriptorProto;
import org.apache.tez.dag.app.dag.DAG;
import org.apache.tez.dag.app.dag.DAGState;
+import org.apache.tez.dag.app.dag.Task;
+import org.apache.tez.dag.app.dag.TaskAttempt;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEvent;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEventType;
import org.apache.tez.dag.app.dag.impl.DAGImpl;
import org.apache.tez.dag.app.dag.impl.Edge;
import org.apache.tez.dag.library.vertexmanager.ShuffleVertexManager;
@@ -76,8 +83,12 @@ import static
org.apache.log4j.LogManager.CONFIGURATOR_CLASS_KEY;
import static org.apache.log4j.LogManager.DEFAULT_CONFIGURATION_KEY;
import static
org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS;
import static org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT;
+import static
org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK;
+import static
org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT;
import static
org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
+import static
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS;
+import static
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT;
public class RssDAGAppMaster extends DAGAppMaster {
private static final Logger LOG =
LoggerFactory.getLogger(RssDAGAppMaster.class);
@@ -125,6 +136,10 @@ public class RssDAGAppMaster extends DAGAppMaster {
@Override
public synchronized void serviceInit(Configuration conf) throws Exception {
super.serviceInit(conf);
+ if (conf.getBoolean(
+ RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK,
RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT)) {
+ overrideTaskAttemptEventDispatcher();
+ }
initAndStartRSSClient(this, conf);
}
@@ -336,6 +351,16 @@ public class RssDAGAppMaster extends DAGAppMaster {
}
}
+ if (conf.getBoolean(
+ RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK,
RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT)
+ && conf.getBoolean(
+ TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS,
+ TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT)) {
+ LOG.info(
+ "When rss.avoid.recompute.succeeded.task is enable, "
+ + "we can not rescheduler succeeded task on unhealthy node");
+ conf.setBoolean(TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS, false);
+ }
initAndStartAppMaster(appMaster, conf);
} catch (Throwable t) {
LOG.error("Error starting RssDAGAppMaster", t);
@@ -476,7 +501,7 @@ public class RssDAGAppMaster extends DAGAppMaster {
}
}
- private static void reconfigureLog4j() {
+ static void reconfigureLog4j() {
String configuratorClassName =
OptionConverter.getSystemProperty(CONFIGURATOR_CLASS_KEY, null);
String configurationOptionStr =
OptionConverter.getSystemProperty(DEFAULT_CONFIGURATION_KEY, null);
@@ -484,4 +509,42 @@ public class RssDAGAppMaster extends DAGAppMaster {
OptionConverter.selectAndConfigure(
url, configuratorClassName, LogManager.getLoggerRepository());
}
+
+ protected void overrideTaskAttemptEventDispatcher()
+ throws NoSuchFieldException, IllegalAccessException {
+ AsyncDispatcher dispatcher = (AsyncDispatcher) this.getDispatcher();
+ Field field = dispatcher.getClass().getDeclaredField("eventHandlers");
+ field.setAccessible(true);
+ Map<Class<? extends Enum>, EventHandler> eventHandlers =
+ (Map<Class<? extends Enum>, EventHandler>) field.get(dispatcher);
+ eventHandlers.put(TaskAttemptEventType.class, new
RssTaskAttemptEventDispatcher());
+ }
+
+ private class RssTaskAttemptEventDispatcher implements
EventHandler<TaskAttemptEvent> {
+ @SuppressWarnings("unchecked")
+ @Override
+ public void handle(TaskAttemptEvent event) {
+ DAG dag = getContext().getCurrentDAG();
+ int eventDagIndex =
event.getTaskAttemptID().getTaskID().getVertexID().getDAGId().getId();
+ if (dag == null || eventDagIndex != dag.getID().getId()) {
+ return; // event not relevant any more
+ }
+ Task task =
+ dag.getVertex(event.getTaskAttemptID().getTaskID().getVertexID())
+ .getTask(event.getTaskAttemptID().getTaskID());
+ TaskAttempt attempt = task.getAttempt(event.getTaskAttemptID());
+
+ if (attempt.getState() == TaskAttemptState.SUCCEEDED
+ && event.getType() == TaskAttemptEventType.TA_NODE_FAILED) {
+ // Here we only handle TA_NODE_FAILED. TA_KILL_REQUEST and TA_KILLED
also could trigger
+ // TerminatedAfterSuccessTransition, but the reason is not about bad
node.
+ LOG.info(
+ "We should not recompute the succeeded task attempt, though task
attempt {} recieved envent {}",
+ attempt,
+ event);
+ return;
+ }
+ ((EventHandler<TaskAttemptEvent>) attempt).handle(event);
+ }
+ }
}
diff --git a/docs/client_guide.md b/docs/client_guide.md
index e34e651b..74415260 100644
--- a/docs/client_guide.md
+++ b/docs/client_guide.md
@@ -234,4 +234,11 @@ This experimental feature allows reduce tasks to spill
data to remote storage (e
|mapreduce.rss.reduce.remote.spill.retries|5| The retry number to spill data
to Hadoop FS |
Notice: this feature requires the MEMORY_LOCAL_HADOOP mode.
+
+
+### Tez Specialized Setting
+
+| Property Name | Default | Description
|
+|--------------------------------|---------|-------------------------------------------------------------------------|
+| tez.rss.avoid.recompute.succeeded.task | false | Whether to avoid
recompute succeeded task when node is unhealthy or black-listed |
\ No newline at end of file
diff --git
a/integration-test/tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterForWordCountWithFailures.java
b/integration-test/tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterForWordCountWithFailures.java
new file mode 100644
index 00000000..9637a2f8
--- /dev/null
+++
b/integration-test/tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterForWordCountWithFailures.java
@@ -0,0 +1,371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app;
+
+import java.lang.reflect.Field;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.util.ShutdownHookManager;
+import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler;
+import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.NodeId;
+import org.apache.hadoop.yarn.api.records.NodeReport;
+import org.apache.hadoop.yarn.api.records.NodeState;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.event.EventHandler;
+import org.apache.hadoop.yarn.util.Clock;
+import org.apache.hadoop.yarn.util.ConverterUtils;
+import org.apache.hadoop.yarn.util.SystemClock;
+import org.apache.tez.common.AsyncDispatcher;
+import org.apache.tez.common.TezClassLoader;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.common.VersionInfo;
+import org.apache.tez.dag.api.TezConstants;
+import org.apache.tez.dag.api.oldrecords.TaskAttemptState;
+import org.apache.tez.dag.api.records.DAGProtos;
+import org.apache.tez.dag.app.dag.DAG;
+import org.apache.tez.dag.app.dag.Task;
+import org.apache.tez.dag.app.dag.TaskAttempt;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEvent;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEventAttemptFailed;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEventType;
+import org.apache.tez.dag.app.rm.node.AMNodeEventStateChanged;
+import org.apache.tez.dag.records.TaskAttemptTerminationCause;
+import org.apache.tez.runtime.api.TaskFailureType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.exception.RssException;
+
+import static
org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK;
+import static
org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT;
+import static
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS;
+import static
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/*
+ * RssDAGAppMasterForWordCountWithFailures is only used for
TezWordCountWithFailuresTest.
+ * We want to simulate that some task have succeeded, but the node which these
task have run is label as black list.
+ * Then we will verify whether these task is recompute or not.
+ *
+ * Two test mode are supported:
+ * (a) testMode 0
+ * The test example is WordCount. The parallelism of Tokenizer is 5, it means
at lease one node run more than one
+ * container. Here if a task succeeded in node1, will kill the next container
runs on node1. Because
+ * maxtaskfailures.per.node is set to 1, so the node1 will labeled as black
list, then verify whether the succeeded
+ * task is recomputed or not.
+ *
+ * (b) testMode 1
+ * The test example is WordCount. The parallelism of Tokenizer is 5, it means
at lease one node run more than one
+ * container. Here if a task succeeded in node1, then will label node1 as
DECOMMISSIONED. Then verify whether
+ * the succeeded task is recomputed or not.
+ * */
+public class RssDAGAppMasterForWordCountWithFailures extends RssDAGAppMaster {
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(RssDAGAppMasterForWordCountWithFailures.class);
+
+ private final int testMode;
+
+ public RssDAGAppMasterForWordCountWithFailures(
+ ApplicationAttemptId applicationAttemptId,
+ ContainerId containerId,
+ String nmHost,
+ int nmPort,
+ int nmHttpPort,
+ Clock clock,
+ long appSubmitTime,
+ boolean isSession,
+ String workingDirectory,
+ String[] localDirs,
+ String[] logDirs,
+ String clientVersion,
+ Credentials credentials,
+ String jobUserName,
+ DAGProtos.AMPluginDescriptorProto pluginDescriptorProto,
+ int testMode) {
+ super(
+ applicationAttemptId,
+ containerId,
+ nmHost,
+ nmPort,
+ nmHttpPort,
+ clock,
+ appSubmitTime,
+ isSession,
+ workingDirectory,
+ localDirs,
+ logDirs,
+ clientVersion,
+ credentials,
+ jobUserName,
+ pluginDescriptorProto);
+ this.testMode = testMode;
+ }
+
+ @Override
+ public synchronized void serviceInit(Configuration conf) throws Exception {
+ super.serviceInit(conf);
+ overrideTaskAttemptEventDispatcher();
+ }
+
+ public static void main(String[] args) {
+ int testMode = 0;
+ try {
+ // We use trick way to introduce RssDAGAppMaster by the config
tez.am.launch.cmd-opts.
+ // It means some property which is set by command line will be ingored,
so we must reload it.
+ boolean sessionModeCliOption = false;
+ for (int i = 0; i < args.length; i++) {
+ if (args[i].startsWith("-D")) {
+ String[] property = args[i].split("=");
+ if (property.length < 2) {
+ System.setProperty(property[0].substring(2), "");
+ } else {
+ System.setProperty(property[0].substring(2), property[1]);
+ }
+ } else if (args[i].contains("--session") || args[i].contains("-s")) {
+ sessionModeCliOption = true;
+ } else if (args[i].startsWith("--testMode")) {
+ testMode = Integer.parseInt(args[i].substring(10));
+ }
+ }
+ // Load the log4j config is only init in static code block of
LogManager, so we must
+ // reconfigure.
+ reconfigureLog4j();
+
+ // Install the tez class loader, which can be used add new resources
+ TezClassLoader.setupTezClassLoader();
+ Thread.setDefaultUncaughtExceptionHandler(new
YarnUncaughtExceptionHandler());
+ final String pid = System.getenv().get("JVM_PID");
+ String containerIdStr =
System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name());
+ String appSubmitTimeStr =
System.getenv(ApplicationConstants.APP_SUBMIT_TIME_ENV);
+ String clientVersion =
System.getenv(TezConstants.TEZ_CLIENT_VERSION_ENV);
+ if (clientVersion == null) {
+ clientVersion = VersionInfo.UNKNOWN;
+ }
+
+ Objects.requireNonNull(
+ appSubmitTimeStr, ApplicationConstants.APP_SUBMIT_TIME_ENV + " is
null");
+
+ ContainerId containerId = ConverterUtils.toContainerId(containerIdStr);
+ ApplicationAttemptId applicationAttemptId =
containerId.getApplicationAttemptId();
+
+ String jobUserName =
System.getenv(ApplicationConstants.Environment.USER.name());
+
+ LOG.info(
+ "Creating RssDAGAppMaster for "
+ + "applicationId="
+ + applicationAttemptId.getApplicationId()
+ + ", attemptNum="
+ + applicationAttemptId.getAttemptId()
+ + ", AMContainerId="
+ + containerId
+ + ", jvmPid="
+ + pid
+ + ", userFromEnv="
+ + jobUserName
+ + ", cliSessionOption="
+ + sessionModeCliOption
+ + ", pwd="
+ + System.getenv(ApplicationConstants.Environment.PWD.name())
+ + ", localDirs="
+ +
System.getenv(ApplicationConstants.Environment.LOCAL_DIRS.name())
+ + ", logDirs="
+ +
System.getenv(ApplicationConstants.Environment.LOG_DIRS.name()));
+
+ Configuration conf = new Configuration(new YarnConfiguration());
+
+ DAGProtos.ConfigurationProto confProto =
+ TezUtilsInternal.readUserSpecifiedTezConfiguration(
+ System.getenv(ApplicationConstants.Environment.PWD.name()));
+ TezUtilsInternal.addUserSpecifiedTezConfiguration(conf,
confProto.getConfKeyValuesList());
+
+ DAGProtos.AMPluginDescriptorProto amPluginDescriptorProto = null;
+ if (confProto.hasAmPluginDescriptor()) {
+ amPluginDescriptorProto = confProto.getAmPluginDescriptor();
+ }
+
+ UserGroupInformation.setConfiguration(conf);
+ Credentials credentials =
UserGroupInformation.getCurrentUser().getCredentials();
+
+ TezUtilsInternal.setSecurityUtilConfigration(LOG, conf);
+
+ String nodeHostString =
System.getenv(ApplicationConstants.Environment.NM_HOST.name());
+ String nodePortString =
System.getenv(ApplicationConstants.Environment.NM_PORT.name());
+ String nodeHttpPortString =
+ System.getenv(ApplicationConstants.Environment.NM_HTTP_PORT.name());
+ long appSubmitTime = Long.parseLong(appSubmitTimeStr);
+ RssDAGAppMasterForWordCountWithFailures appMaster =
+ new RssDAGAppMasterForWordCountWithFailures(
+ applicationAttemptId,
+ containerId,
+ nodeHostString,
+ Integer.parseInt(nodePortString),
+ Integer.parseInt(nodeHttpPortString),
+ new SystemClock(),
+ appSubmitTime,
+ sessionModeCliOption,
+ System.getenv(ApplicationConstants.Environment.PWD.name()),
+ TezCommonUtils.getTrimmedStrings(
+
System.getenv(ApplicationConstants.Environment.LOCAL_DIRS.name())),
+ TezCommonUtils.getTrimmedStrings(
+
System.getenv(ApplicationConstants.Environment.LOG_DIRS.name())),
+ clientVersion,
+ credentials,
+ jobUserName,
+ amPluginDescriptorProto,
+ testMode);
+ ShutdownHookManager.get()
+ .addShutdownHook(
+ new RssDAGAppMaster.RssDAGAppMasterShutdownHook(appMaster),
SHUTDOWN_HOOK_PRIORITY);
+
+ // log the system properties
+ if (LOG.isInfoEnabled()) {
+ String systemPropsToLog =
TezCommonUtils.getSystemPropertiesToLog(conf);
+ if (systemPropsToLog != null) {
+ LOG.info(systemPropsToLog);
+ }
+ }
+
+ LOG.info(
+ "recompute is {}, reschedule is {}",
+ conf.getBoolean(
+ RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK,
RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT),
+ conf.getBoolean(
+ TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS,
+ TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT));
+ if (conf.getBoolean(
+ RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK,
RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT)
+ && conf.getBoolean(
+ TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS,
+ TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS_DEFAULT)) {
+ LOG.info(
+ "When rss.avoid.recompute.succeeded.task is enable, "
+ + "we can not rescheduler succeeded task on unhealthy node");
+ conf.setBoolean(TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS, false);
+ }
+ initAndStartAppMaster(appMaster, conf);
+ } catch (Throwable t) {
+ LOG.error("Error starting RssDAGAppMaster", t);
+ System.exit(1);
+ }
+ }
+
+ public void overrideTaskAttemptEventDispatcher()
+ throws NoSuchFieldException, IllegalAccessException {
+ AsyncDispatcher dispatcher = (AsyncDispatcher) this.getDispatcher();
+ Field field = dispatcher.getClass().getDeclaredField("eventHandlers");
+ field.setAccessible(true);
+ Map<Class<? extends Enum>, EventHandler> eventHandlers =
+ (Map<Class<? extends Enum>, EventHandler>) field.get(dispatcher);
+ eventHandlers.put(
+ TaskAttemptEventType.class, new
RssTaskAttemptEventDispatcher(this.getConfig()));
+ }
+
+ private class RssTaskAttemptEventDispatcher implements
EventHandler<TaskAttemptEvent> {
+
+ Map<NodeId, Integer> succeed = new HashMap<>();
+ boolean killed = false;
+ boolean avoidRecompute;
+
+ RssTaskAttemptEventDispatcher(Configuration conf) {
+ avoidRecompute =
+ conf.getBoolean(
+ RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK,
RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK_DEFAULT);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void handle(TaskAttemptEvent event) {
+ DAG dag = getContext().getCurrentDAG();
+ int eventDagIndex =
event.getTaskAttemptID().getTaskID().getVertexID().getDAGId().getId();
+ if (dag == null || eventDagIndex != dag.getID().getId()) {
+ return; // event not relevant any more
+ }
+ Task task =
+ dag.getVertex(event.getTaskAttemptID().getTaskID().getVertexID())
+ .getTask(event.getTaskAttemptID().getTaskID());
+ TaskAttempt attempt = task.getAttempt(event.getTaskAttemptID());
+
+ LOG.info("handle task attempt event: {}", event);
+ if (avoidRecompute) {
+ if (attempt.getState() == TaskAttemptState.SUCCEEDED
+ && event.getType() == TaskAttemptEventType.TA_NODE_FAILED) {
+ LOG.info(
+ "We should not recompute the succeeded task attempt, though
taskattempt {} recieved event {}",
+ attempt,
+ event);
+ return;
+ }
+ }
+ ((EventHandler<TaskAttemptEvent>) attempt).handle(event);
+ // For Tokenizer, record the first succeeded task and its node. When
next task runs on this
+ // node, will kill this task or label this node as unhealthy.
+ int vertexId = attempt.getVertexID().getId();
+ if (vertexId == 0) {
+ if (attempt.getState() == TaskAttemptState.SUCCEEDED) {
+ NodeId nodeId = attempt.getAssignedContainer().getNodeId();
+ if (!succeed.containsKey(nodeId)) {
+ succeed.put(nodeId, 1);
+ } else {
+ succeed.put(nodeId, succeed.get(nodeId) + 1);
+ }
+ } else if (attempt.getState() == TaskAttemptState.RUNNING) {
+ NodeId nodeId = attempt.getAssignedContainer().getNodeId();
+ if (succeed.getOrDefault(nodeId, 0) == 1 && !killed) {
+ if (testMode == 0) {
+ TaskAttemptEventAttemptFailed eventAttemptFailed =
+ new TaskAttemptEventAttemptFailed(
+ attempt.getID(),
+ TaskAttemptEventType.TA_FAILED,
+ TaskFailureType.NON_FATAL,
+ "Triggerd by " + this.getClass().getName(),
+ TaskAttemptTerminationCause.CONTAINER_LAUNCH_FAILED);
+ LOG.info(
+ "Killing running task attempt: {} at node: {}",
+ attempt,
+ attempt.getAssignedContainer().getNodeId());
+ ((EventHandler<TaskAttemptEvent>)
attempt).handle(eventAttemptFailed);
+ dag.getEventHandler().handle(eventAttemptFailed);
+ } else if (testMode == 1) {
+ NodeReport nodeReport = mock(NodeReport.class);
+
when(nodeReport.getNodeState()).thenReturn(NodeState.DECOMMISSIONED);
+ when(nodeReport.getNodeId()).thenReturn(nodeId);
+ LOG.info(
+ "Label the node {} as DECOMMISSIONED."
+ + attempt.getAssignedContainer().getNodeId());
+ dag.getEventHandler().handle(new
AMNodeEventStateChanged(nodeReport, 0));
+ } else {
+ throw new RssException("testMode " + testMode + " is not
supported!");
+ }
+ killed = true;
+ }
+ }
+ }
+ }
+ }
+}
diff --git
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
index f9584efd..b5219efe 100644
---
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
+++
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezIntegrationTestBase.java
@@ -61,7 +61,7 @@ public class TezIntegrationTestBase extends
IntegrationTestBase {
private static final Logger LOG =
LoggerFactory.getLogger(TezIntegrationTestBase.class);
private static String TEST_ROOT_DIR =
- "target" + Path.SEPARATOR + TezWordCountTest.class.getName() + "-tmpDir";
+ "target" + Path.SEPARATOR + TezIntegrationTestBase.class.getName() +
"-tmpDir";
private Path remoteStagingDir = null;
protected static MiniTezCluster miniTezCluster;
diff --git
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
new file mode 100644
index 00000000..14957df4
--- /dev/null
+++
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
@@ -0,0 +1,370 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.io.IOException;
+import java.lang.reflect.Method;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.client.CallerContext;
+import org.apache.tez.client.TezClient;
+import org.apache.tez.client.TezClientUtils;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.dag.api.DAG;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TezException;
+import org.apache.tez.dag.api.client.DAGClient;
+import org.apache.tez.dag.api.client.DAGStatus;
+import org.apache.tez.dag.api.client.Progress;
+import org.apache.tez.dag.api.client.StatusGetOpts;
+import org.apache.tez.dag.app.RssDAGAppMasterForWordCountWithFailures;
+import org.apache.tez.examples.WordCount;
+import org.apache.tez.hadoop.shim.HadoopShim;
+import org.apache.tez.hadoop.shim.HadoopShimsLoader;
+import org.apache.tez.test.MiniTezCluster;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.storage.util.StorageType;
+
+import static
org.apache.tez.common.RssTezConfig.RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK;
+import static
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_MAX_TASK_FAILURES_PER_NODE;
+import static
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_BLACKLISTING_ENABLED;
+import static
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_BLACKLISTING_IGNORE_THRESHOLD;
+import static
org.apache.tez.dag.api.TezConfiguration.TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class TezWordCountWithFailuresTest extends IntegrationTestBase {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(TezIntegrationTestBase.class);
+ private static String TEST_ROOT_DIR =
+ "target" + Path.SEPARATOR + TezWordCountWithFailuresTest.class.getName()
+ "-tmpDir";
+
+ private Path remoteStagingDir = null;
+ private String inputPath = "word_count_input";
+ private String outputPath = "word_count_output";
+ private List<String> wordTable =
+ Lists.newArrayList(
+ "apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan",
"tomato");
+
+ protected static MiniTezCluster miniTezCluster;
+
+ @BeforeAll
+ public static void beforeClass() throws Exception {
+ LOG.info("Starting mini tez clusters");
+ if (miniTezCluster == null) {
+ miniTezCluster = new
MiniTezCluster(TezIntegrationTestBase.class.getName(), 3, 1, 1);
+ miniTezCluster.init(conf);
+ miniTezCluster.start();
+ }
+ LOG.info("Starting corrdinators and shuffer servers");
+ CoordinatorConf coordinatorConf = getCoordinatorConf();
+ Map<String, String> dynamicConf = new HashMap();
+ dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(),
HDFS_URI + "rss/test");
+ dynamicConf.put(RssTezConfig.RSS_STORAGE_TYPE,
StorageType.MEMORY_LOCALFILE_HDFS.name());
+ addDynamicConf(coordinatorConf, dynamicConf);
+ createCoordinatorServer(coordinatorConf);
+ ShuffleServerConf shuffleServerConf = getShuffleServerConf();
+ createShuffleServer(shuffleServerConf);
+ startServers();
+ }
+
+ @AfterAll
+ public static void tearDown() throws Exception {
+ if (miniTezCluster != null) {
+ LOG.info("Stopping MiniTezCluster");
+ miniTezCluster.stop();
+ miniTezCluster = null;
+ }
+ }
+
+ @BeforeEach
+ public void setup() throws Exception {
+ remoteStagingDir =
+ fs.makeQualified(new Path(TEST_ROOT_DIR, String.valueOf(new
Random().nextInt(100000))));
+ TezClientUtils.ensureStagingDirExists(conf, remoteStagingDir);
+ generateInputFile();
+ }
+
+ private void generateInputFile() throws Exception {
+ assertTrue(fs.mkdirs(new Path(inputPath)));
+ for (int j = 0; j < 5; j++) {
+ FSDataOutputStream outputStream = fs.create(new Path(inputPath +
"/file." + j));
+ Random random = new Random();
+ for (int i = 0; i < 100; i++) {
+ int index = random.nextInt(wordTable.size());
+ String str = wordTable.get(index) + "\n";
+ outputStream.writeBytes(str);
+ }
+ outputStream.close();
+ }
+ FileStatus[] fileStatus = fs.listStatus(new Path(inputPath));
+ for (FileStatus status : fileStatus) {
+ System.out.println("status is " + status);
+ }
+ }
+
+ @AfterEach
+ public void tearDownEach() throws Exception {
+ if (this.remoteStagingDir != null) {
+ fs.delete(this.remoteStagingDir, true);
+ }
+ for (int j = 0; j < 5; j++) {
+ fs.delete(new Path(inputPath + "/file." + j), true);
+ }
+ }
+
+ @Test
+ public void wordCountTestWithTaskFailureWhenAvoidRecomputeEnable() throws
Exception {
+ // 1 Run Tez examples based on rss
+ TezConfiguration appConf = new
TezConfiguration(miniTezCluster.getConfig());
+ updateRssConfiguration(appConf, 0, true, false, 1);
+ TezIntegrationTestBase.appendAndUploadRssJars(appConf);
+ runTezApp(appConf, getTestArgs("rss"), 0);
+ final String rssPath = getOutputDir("rss");
+
+ // 2 Run original Tez examples
+ appConf = new TezConfiguration(miniTezCluster.getConfig());
+ updateCommonConfiguration(appConf);
+ runTezApp(appConf, getTestArgs("origin"), -1);
+ final String originPath = getOutputDir("origin");
+
+ // 3 verify the results
+ TezIntegrationTestBase.verifyResultEqual(originPath, rssPath);
+ }
+
+ @Test
+ public void wordCountTestWithTaskFailureWhenAvoidRecomputeDisable() throws
Exception {
+ // 1 Run Tez examples based on rss
+ TezConfiguration appConf = new
TezConfiguration(miniTezCluster.getConfig());
+ updateRssConfiguration(appConf, 0, false, false, 1);
+ TezIntegrationTestBase.appendAndUploadRssJars(appConf);
+ runTezApp(appConf, getTestArgs("rss"), 1);
+ final String rssPath = getOutputDir("rss");
+
+ // 2 Run original Tez examples
+ appConf = new TezConfiguration(miniTezCluster.getConfig());
+ updateCommonConfiguration(appConf);
+ runTezApp(appConf, getTestArgs("origin"), -1);
+ final String originPath = getOutputDir("origin");
+
+ // 3 verify the results
+ TezIntegrationTestBase.verifyResultEqual(originPath, rssPath);
+ }
+
+ @Test
+ public void wordCountTestWithNodeUnhealthyWhenAvoidRecomputeEnable() throws
Exception {
+ // 1 Run Tez examples based on rss
+ TezConfiguration appConf = new
TezConfiguration(miniTezCluster.getConfig());
+ updateRssConfiguration(appConf, 1, true, true, 100);
+ TezIntegrationTestBase.appendAndUploadRssJars(appConf);
+ runTezApp(appConf, getTestArgs("rss"), 0);
+ final String rssPath = getOutputDir("rss");
+
+ // 2 Run original Tez examples
+ appConf = new TezConfiguration(miniTezCluster.getConfig());
+ updateCommonConfiguration(appConf);
+ runTezApp(appConf, getTestArgs("origin"), -1);
+ final String originPath = getOutputDir("origin");
+
+ // 3 verify the results
+ TezIntegrationTestBase.verifyResultEqual(originPath, rssPath);
+ }
+
+ @Test
+ public void wordCountTestWithNodeUnhealthyWhenAvoidRecomputeDisable() throws
Exception {
+ // 1 Run Tez examples based on rss
+ TezConfiguration appConf = new
TezConfiguration(miniTezCluster.getConfig());
+ updateRssConfiguration(appConf, 1, false, true, 100);
+ TezIntegrationTestBase.appendAndUploadRssJars(appConf);
+ runTezApp(appConf, getTestArgs("rss"), 1);
+ final String rssPath = getOutputDir("rss");
+
+ // 2 Run original Tez examples
+ appConf = new TezConfiguration(miniTezCluster.getConfig());
+ updateCommonConfiguration(appConf);
+ runTezApp(appConf, getTestArgs("origin"), -1);
+ final String originPath = getOutputDir("origin");
+
+ // 3 verify the results
+ TezIntegrationTestBase.verifyResultEqual(originPath, rssPath);
+ }
+
+ /*
+ * Two verify mode are supported:
+ * (a) verifyMode 0
+ * tez.rss.avoid.recompute.succeeded.task is enable, should not
recompute the task when this node is
+ * blacke-listed for unhealthy.
+ *
+ * (b) verifyMode 1
+ * tez.rss.avoid.recompute.succeeded.task is disable, will recompute the
task when this node is
+ * blacke-listed for unhealthy.
+ * */
+ protected void runTezApp(TezConfiguration tezConf, String[] args, int
verifyMode)
+ throws Exception {
+ assertEquals(
+ 0,
+ ToolRunner.run(tezConf, new WordCountWithFailures(verifyMode), args),
+ "WordCountWithFailures failed");
+ }
+
+ public String[] getTestArgs(String uniqueOutputName) {
+ return new String[] {
+ "-disableSplitGrouping", inputPath, outputPath + "/" + uniqueOutputName,
"2"
+ };
+ }
+
+ public String getOutputDir(String uniqueOutputName) {
+ return outputPath + "/" + uniqueOutputName;
+ }
+
+ /*
+ * In this integration test, mini cluster have three NM with 4G
+ * (YarnConfiguration.DEFAULT_YARN_MINICLUSTER_NM_PMEM_MB). The request of
am is 4G, the request of task is 2G.
+ * It means that one node only runs one am container so that won't lable the
node which am container runs as
+ * black-list or uhealthy node.
+ * */
+ public void updateRssConfiguration(
+ Configuration appConf,
+ int testMode,
+ boolean avoidRecompute,
+ boolean rescheduleWhenUnhealthy,
+ int maxFailures)
+ throws Exception {
+ appConf.set(TezConfiguration.TEZ_AM_STAGING_DIR,
remoteStagingDir.toString());
+ appConf.setInt(TezConfiguration.TEZ_AM_RESOURCE_MEMORY_MB, 4096);
+ appConf.setInt(TezConfiguration.TEZ_TASK_RESOURCE_MEMORY_MB, 4096);
+ appConf.setBoolean(TEZ_AM_NODE_BLACKLISTING_ENABLED, true);
+ appConf.setInt(TEZ_AM_NODE_BLACKLISTING_IGNORE_THRESHOLD, 99);
+ appConf.setInt(TEZ_AM_MAX_TASK_FAILURES_PER_NODE, maxFailures);
+ appConf.set(RssTezConfig.RSS_COORDINATOR_QUORUM, COORDINATOR_QUORUM);
+ appConf.set(RssTezConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name());
+ appConf.set(
+ TezConfiguration.TEZ_AM_LAUNCH_CMD_OPTS,
+ TezConfiguration.TEZ_AM_LAUNCH_CMD_OPTS_DEFAULT
+ + " "
+ + RssDAGAppMasterForWordCountWithFailures.class.getName()
+ + " --testMode"
+ + testMode);
+ appConf.setBoolean(RSS_AVOID_RECOMPUTE_SUCCEEDED_TASK, avoidRecompute);
+ appConf.setBoolean(TEZ_AM_NODE_UNHEALTHY_RESCHEDULE_TASKS,
rescheduleWhenUnhealthy);
+ }
+
+ public void updateCommonConfiguration(Configuration appConf) {
+ appConf.set(TezConfiguration.TEZ_AM_STAGING_DIR,
remoteStagingDir.toString());
+ appConf.setInt(TezConfiguration.TEZ_AM_RESOURCE_MEMORY_MB, 512);
+ appConf.set(TezConfiguration.TEZ_AM_LAUNCH_CMD_OPTS, " -Xmx384m");
+ appConf.setInt(TezConfiguration.TEZ_TASK_RESOURCE_MEMORY_MB, 512);
+ appConf.set(TezConfiguration.TEZ_TASK_LAUNCH_CMD_OPTS, " -Xmx384m");
+ }
+
+ public class WordCountWithFailures extends WordCount {
+
+ TezClient tezClientInternal = null;
+ private HadoopShim hadoopShim;
+ int verifyMode = -1;
+
+ WordCountWithFailures(int assertMode) {
+ this.verifyMode = assertMode;
+ }
+
+ @Override
+ protected int runJob(String[] args, TezConfiguration tezConf, TezClient
tezClient)
+ throws Exception {
+ this.tezClientInternal = tezClient;
+ Method method =
+ WordCount.class.getDeclaredMethod(
+ "createDAG", TezConfiguration.class, String.class, String.class,
int.class);
+ method.setAccessible(true);
+ DAG dag =
+ (DAG)
+ method.invoke(
+ this,
+ tezConf,
+ args[0],
+ args[1],
+ args.length == 3 ? Integer.parseInt(args[2]) : 1);
+ LOG.info("Running WordCountWithFailures");
+ return runDag(dag, isCountersLog(), LOG);
+ }
+
+ public int runDag(DAG dag, boolean printCounters, Logger logger)
+ throws TezException, InterruptedException, IOException {
+ tezClientInternal.waitTillReady();
+
+ CallerContext callerContext =
+ CallerContext.create("TezExamples", "Tez Example DAG: " +
dag.getName());
+ ApplicationId appId = tezClientInternal.getAppMasterApplicationId();
+ if (hadoopShim == null) {
+ Configuration conf = (getConf() == null ? new Configuration(false) :
getConf());
+ hadoopShim = new HadoopShimsLoader(conf).getHadoopShim();
+ }
+
+ if (appId != null) {
+ TezUtilsInternal.setHadoopCallerContext(hadoopShim, appId);
+ callerContext.setCallerIdAndType(appId.toString(),
"TezExampleApplication");
+ }
+ dag.setCallerContext(callerContext);
+
+ DAGClient dagClient = tezClientInternal.submitDAG(dag);
+ Set<StatusGetOpts> getOpts = Sets.newHashSet();
+ if (printCounters) {
+ getOpts.add(StatusGetOpts.GET_COUNTERS);
+ }
+
+ DAGStatus dagStatus =
dagClient.waitForCompletionWithStatusUpdates(getOpts);
+ if (dagStatus.getState() != DAGStatus.State.SUCCEEDED) {
+ logger.info("DAG diagnostics: " + dagStatus.getDiagnostics());
+ return -1;
+ }
+
+ Map<String, Progress> progressMap = dagStatus.getVertexProgress();
+ if (verifyMode == 0) {
+ // verifyMode is 0: avoid recompute succeeded task is true
+ Assertions.assertEquals(0,
progressMap.get("Tokenizer").getKilledTaskAttemptCount());
+ } else if (verifyMode == 1) {
+ // verifyMode is 1: avoid recompute succeeded task is true
+
Assertions.assertTrue(progressMap.get("Tokenizer").getKilledTaskAttemptCount()
> 0);
+ }
+ return 0;
+ }
+ }
+}