mridulm commented on code in PR #2921:
URL: https://github.com/apache/celeborn/pull/2921#discussion_r1894340973
##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -43,11 +54,12 @@
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynFields;
public class SparkUtils {
- private static final Logger logger =
LoggerFactory.getLogger(SparkUtils.class);
+ private static final Logger LOG = LoggerFactory.getLogger(SparkUtils.class);
Review Comment:
super nit: revert
While I would prefer `LOG` or `LOGGER` (given it is immutable constant) -
let us do it in a separate PR for entire codebase, if we want to go down that
path.
##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String
reason) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
- logger.error("Can not get active SparkContext, skip cancelShuffle.");
+ LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+ }
+ }
+
+ private static final DynFields.UnboundField<ConcurrentHashMap<Long,
TaskSetManager>>
+ TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+ .defaultAlwaysNull()
+ .build();
+ private static final
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+ TASK_INFOS_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSetManager.class, "taskInfos")
+ .defaultAlwaysNull()
+ .build();
+
+ protected static TaskSetManager getTaskSetManager(long taskId) {
+ if (SparkContext$.MODULE$.getActive().nonEmpty()) {
+ TaskSchedulerImpl taskScheduler =
+ (TaskSchedulerImpl)
SparkContext$.MODULE$.getActive().get().taskScheduler();
+ ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
+ TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
+ return taskIdToTaskSetManager.get(taskId);
+ } else {
+ LOG.error("Can not get active SparkContext.");
+ return null;
+ }
+ }
+
+ protected static List<TaskInfo> getTaskAttempts(TaskSetManager
taskSetManager, long taskId) {
+ if (taskSetManager != null) {
+ scala.Option<TaskInfo> taskInfoOption =
+ TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
+ if (taskInfoOption.isDefined()) {
+ int taskIndex = taskInfoOption.get().index();
+ return scala.collection.JavaConverters.asJavaCollectionConverter(
+ taskSetManager.taskAttempts()[taskIndex])
+ .asJavaCollection().stream()
+ .collect(Collectors.toList());
+ } else {
+ LOG.error("Can not get TaskInfo for taskId: {}", taskId);
+ return Collections.emptyList();
+ }
+ } else {
+ LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
+ return Collections.emptyList();
+ }
+ }
+
+ public static Map<Integer, Set<Long>>
reportedStageShuffleFetchFailureTaskIds =
+ JavaUtils.newConcurrentHashMap();
+
+ /**
+ * Only check for the shuffle fetch failure task whether another attempt is
running or successful.
+ * If another attempt(excluding the reported shuffle fetch failure tasks in
current stage) is
+ * running or successful, return true. Otherwise, return false.
+ */
+ public static synchronized boolean
taskAnotherAttemptRunningOrSuccessful(long taskId) {
+ TaskSetManager taskSetManager = getTaskSetManager(taskId);
+ if (taskSetManager != null) {
+ int stageId = taskSetManager.stageId();
+ Set<Long> reportedStageTaskIds =
+ reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(stageId, k
-> new HashSet<>());
Review Comment:
We need to index by stage id + stage attempt id - will make it easier to
introduce cleanup.
##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String
reason) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
- logger.error("Can not get active SparkContext, skip cancelShuffle.");
+ LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+ }
+ }
+
+ private static final DynFields.UnboundField<ConcurrentHashMap<Long,
TaskSetManager>>
+ TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+ .defaultAlwaysNull()
+ .build();
+ private static final
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+ TASK_INFOS_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSetManager.class, "taskInfos")
+ .defaultAlwaysNull()
+ .build();
+
+ protected static TaskSetManager getTaskSetManager(long taskId) {
+ if (SparkContext$.MODULE$.getActive().nonEmpty()) {
+ TaskSchedulerImpl taskScheduler =
+ (TaskSchedulerImpl)
SparkContext$.MODULE$.getActive().get().taskScheduler();
+ ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
+ TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
+ return taskIdToTaskSetManager.get(taskId);
+ } else {
+ LOG.error("Can not get active SparkContext.");
+ return null;
+ }
+ }
+
+ protected static List<TaskInfo> getTaskAttempts(TaskSetManager
taskSetManager, long taskId) {
+ if (taskSetManager != null) {
+ scala.Option<TaskInfo> taskInfoOption =
+ TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
+ if (taskInfoOption.isDefined()) {
+ int taskIndex = taskInfoOption.get().index();
+ return scala.collection.JavaConverters.asJavaCollectionConverter(
+ taskSetManager.taskAttempts()[taskIndex])
+ .asJavaCollection().stream()
+ .collect(Collectors.toList());
+ } else {
+ LOG.error("Can not get TaskInfo for taskId: {}", taskId);
+ return Collections.emptyList();
+ }
+ } else {
+ LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
+ return Collections.emptyList();
+ }
+ }
+
+ public static Map<Integer, Set<Long>>
reportedStageShuffleFetchFailureTaskIds =
+ JavaUtils.newConcurrentHashMap();
+
+ /**
+ * Only check for the shuffle fetch failure task whether another attempt is
running or successful.
+ * If another attempt(excluding the reported shuffle fetch failure tasks in
current stage) is
+ * running or successful, return true. Otherwise, return false.
+ */
+ public static synchronized boolean
taskAnotherAttemptRunningOrSuccessful(long taskId) {
+ TaskSetManager taskSetManager = getTaskSetManager(taskId);
+ if (taskSetManager != null) {
+ int stageId = taskSetManager.stageId();
+ Set<Long> reportedStageTaskIds =
+ reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(stageId, k
-> new HashSet<>());
+ reportedStageTaskIds.add(taskId);
+
+ List<TaskInfo> taskAttempts = getTaskAttempts(taskSetManager, taskId);
+ Optional<TaskInfo> taskInfoOpt =
+ taskAttempts.stream().filter(ti -> ti.taskId() ==
taskId).findFirst();
Review Comment:
Return `Tuple2<TaskInfo, List<TaskInfo>>` from `getTaskAttempts` ? We fetch
taskInfo for taskId at start of `getTaskAttempts`.
##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String
reason) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
- logger.error("Can not get active SparkContext, skip cancelShuffle.");
+ LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+ }
+ }
+
+ private static final DynFields.UnboundField<ConcurrentHashMap<Long,
TaskSetManager>>
+ TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+ .defaultAlwaysNull()
+ .build();
+ private static final
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+ TASK_INFOS_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSetManager.class, "taskInfos")
+ .defaultAlwaysNull()
+ .build();
+
+ protected static TaskSetManager getTaskSetManager(long taskId) {
+ if (SparkContext$.MODULE$.getActive().nonEmpty()) {
+ TaskSchedulerImpl taskScheduler =
+ (TaskSchedulerImpl)
SparkContext$.MODULE$.getActive().get().taskScheduler();
+ ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
+ TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
+ return taskIdToTaskSetManager.get(taskId);
+ } else {
+ LOG.error("Can not get active SparkContext.");
+ return null;
+ }
+ }
+
+ protected static List<TaskInfo> getTaskAttempts(TaskSetManager
taskSetManager, long taskId) {
+ if (taskSetManager != null) {
+ scala.Option<TaskInfo> taskInfoOption =
+ TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
+ if (taskInfoOption.isDefined()) {
+ int taskIndex = taskInfoOption.get().index();
+ return scala.collection.JavaConverters.asJavaCollectionConverter(
+ taskSetManager.taskAttempts()[taskIndex])
+ .asJavaCollection().stream()
+ .collect(Collectors.toList());
+ } else {
+ LOG.error("Can not get TaskInfo for taskId: {}", taskId);
+ return Collections.emptyList();
+ }
+ } else {
+ LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
+ return Collections.emptyList();
+ }
+ }
+
+ public static Map<Integer, Set<Long>>
reportedStageShuffleFetchFailureTaskIds =
+ JavaUtils.newConcurrentHashMap();
+
+ /**
+ * Only check for the shuffle fetch failure task whether another attempt is
running or successful.
+ * If another attempt(excluding the reported shuffle fetch failure tasks in
current stage) is
+ * running or successful, return true. Otherwise, return false.
+ */
+ public static synchronized boolean
taskAnotherAttemptRunningOrSuccessful(long taskId) {
Review Comment:
This has an inherent race condition - where both tasks end up not raising a
fetch failure.
The lock here is on `SparkUtils` - which does not prevent tsm state from
changing from under us ...
Since we have to lock TaskSchedulerImpl anyway for `getTaskAttempts` - I
would instead suggest locking `taskAnotherAttemptRunningOrSuccessful` on
`TaskSchedulerImpl` - which prevent TSM from getting mutated under us.
##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String
reason) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
- logger.error("Can not get active SparkContext, skip cancelShuffle.");
+ LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+ }
+ }
+
+ private static final DynFields.UnboundField<ConcurrentHashMap<Long,
TaskSetManager>>
+ TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+ .defaultAlwaysNull()
+ .build();
+ private static final
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+ TASK_INFOS_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSetManager.class, "taskInfos")
+ .defaultAlwaysNull()
+ .build();
+
+ protected static TaskSetManager getTaskSetManager(long taskId) {
+ if (SparkContext$.MODULE$.getActive().nonEmpty()) {
+ TaskSchedulerImpl taskScheduler =
+ (TaskSchedulerImpl)
SparkContext$.MODULE$.getActive().get().taskScheduler();
+ ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
+ TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
+ return taskIdToTaskSetManager.get(taskId);
+ } else {
+ LOG.error("Can not get active SparkContext.");
+ return null;
+ }
+ }
+
+ protected static List<TaskInfo> getTaskAttempts(TaskSetManager
taskSetManager, long taskId) {
+ if (taskSetManager != null) {
+ scala.Option<TaskInfo> taskInfoOption =
+ TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
Review Comment:
This should be accessed within the TaskSchedulerImpl instance lock
##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String
reason) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
- logger.error("Can not get active SparkContext, skip cancelShuffle.");
+ LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+ }
+ }
+
+ private static final DynFields.UnboundField<ConcurrentHashMap<Long,
TaskSetManager>>
+ TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+ .defaultAlwaysNull()
+ .build();
+ private static final
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+ TASK_INFOS_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSetManager.class, "taskInfos")
+ .defaultAlwaysNull()
+ .build();
+
+ protected static TaskSetManager getTaskSetManager(long taskId) {
+ if (SparkContext$.MODULE$.getActive().nonEmpty()) {
+ TaskSchedulerImpl taskScheduler =
+ (TaskSchedulerImpl)
SparkContext$.MODULE$.getActive().get().taskScheduler();
+ ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
+ TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
+ return taskIdToTaskSetManager.get(taskId);
+ } else {
+ LOG.error("Can not get active SparkContext.");
+ return null;
+ }
+ }
+
+ protected static List<TaskInfo> getTaskAttempts(TaskSetManager
taskSetManager, long taskId) {
+ if (taskSetManager != null) {
+ scala.Option<TaskInfo> taskInfoOption =
+ TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
+ if (taskInfoOption.isDefined()) {
+ int taskIndex = taskInfoOption.get().index();
+ return scala.collection.JavaConverters.asJavaCollectionConverter(
+ taskSetManager.taskAttempts()[taskIndex])
+ .asJavaCollection().stream()
+ .collect(Collectors.toList());
+ } else {
+ LOG.error("Can not get TaskInfo for taskId: {}", taskId);
+ return Collections.emptyList();
+ }
+ } else {
+ LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
+ return Collections.emptyList();
+ }
+ }
+
+ public static Map<Integer, Set<Long>>
reportedStageShuffleFetchFailureTaskIds =
Review Comment:
`private final` ? or atleast package private if we want this to be
accessible from test ?
Also, we are not cleaning up from this `Map` - will lead to memory leak
##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String
reason) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
- logger.error("Can not get active SparkContext, skip cancelShuffle.");
+ LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+ }
+ }
+
+ private static final DynFields.UnboundField<ConcurrentHashMap<Long,
TaskSetManager>>
+ TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+ .defaultAlwaysNull()
+ .build();
+ private static final
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+ TASK_INFOS_FIELD =
+ DynFields.builder()
+ .hiddenImpl(TaskSetManager.class, "taskInfos")
+ .defaultAlwaysNull()
+ .build();
+
+ protected static TaskSetManager getTaskSetManager(long taskId) {
Review Comment:
We should be very careful with use of `TaskSetManager` - it is not designed
to be used outside of the spark scheduler. I have added a specific comment
below to mitigate the immediate issue - but within Celeborn we should be
careful if we are using this method in future.
Please add a note that there are concerns with use of TSM which need to be
analyzed carefully if code evolves and other usages come up.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]