mridulm commented on code in PR #2921:
URL: https://github.com/apache/celeborn/pull/2921#discussion_r1894340973


##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -43,11 +54,12 @@
 
 import org.apache.celeborn.client.ShuffleClient;
 import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.util.JavaUtils;
 import org.apache.celeborn.common.util.Utils;
 import org.apache.celeborn.reflect.DynFields;
 
 public class SparkUtils {
-  private static final Logger logger = 
LoggerFactory.getLogger(SparkUtils.class);
+  private static final Logger LOG = LoggerFactory.getLogger(SparkUtils.class);

Review Comment:
   super nit: revert
   While I would prefer `LOG` or `LOGGER` (given it is immutable constant) - 
let us do it in a separate PR for entire codebase, if we want to go down that 
path.



##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String 
reason) {
         scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
       }
     } else {
-      logger.error("Can not get active SparkContext, skip cancelShuffle.");
+      LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+    }
+  }
+
+  private static final DynFields.UnboundField<ConcurrentHashMap<Long, 
TaskSetManager>>
+      TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+              .defaultAlwaysNull()
+              .build();
+  private static final 
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+      TASK_INFOS_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSetManager.class, "taskInfos")
+              .defaultAlwaysNull()
+              .build();
+
+  protected static TaskSetManager getTaskSetManager(long taskId) {
+    if (SparkContext$.MODULE$.getActive().nonEmpty()) {
+      TaskSchedulerImpl taskScheduler =
+          (TaskSchedulerImpl) 
SparkContext$.MODULE$.getActive().get().taskScheduler();
+      ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
+          TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
+      return taskIdToTaskSetManager.get(taskId);
+    } else {
+      LOG.error("Can not get active SparkContext.");
+      return null;
+    }
+  }
+
+  protected static List<TaskInfo> getTaskAttempts(TaskSetManager 
taskSetManager, long taskId) {
+    if (taskSetManager != null) {
+      scala.Option<TaskInfo> taskInfoOption =
+          TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
+      if (taskInfoOption.isDefined()) {
+        int taskIndex = taskInfoOption.get().index();
+        return scala.collection.JavaConverters.asJavaCollectionConverter(
+                taskSetManager.taskAttempts()[taskIndex])
+            .asJavaCollection().stream()
+            .collect(Collectors.toList());
+      } else {
+        LOG.error("Can not get TaskInfo for taskId: {}", taskId);
+        return Collections.emptyList();
+      }
+    } else {
+      LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
+      return Collections.emptyList();
+    }
+  }
+
+  public static Map<Integer, Set<Long>> 
reportedStageShuffleFetchFailureTaskIds =
+      JavaUtils.newConcurrentHashMap();
+
+  /**
+   * Only check for the shuffle fetch failure task whether another attempt is 
running or successful.
+   * If another attempt(excluding the reported shuffle fetch failure tasks in 
current stage) is
+   * running or successful, return true. Otherwise, return false.
+   */
+  public static synchronized boolean 
taskAnotherAttemptRunningOrSuccessful(long taskId) {
+    TaskSetManager taskSetManager = getTaskSetManager(taskId);
+    if (taskSetManager != null) {
+      int stageId = taskSetManager.stageId();
+      Set<Long> reportedStageTaskIds =
+          reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(stageId, k 
-> new HashSet<>());

Review Comment:
   We need to index by stage id + stage attempt id - will make it easier to 
introduce cleanup.



##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String 
reason) {
         scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
       }
     } else {
-      logger.error("Can not get active SparkContext, skip cancelShuffle.");
+      LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+    }
+  }
+
+  private static final DynFields.UnboundField<ConcurrentHashMap<Long, 
TaskSetManager>>
+      TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+              .defaultAlwaysNull()
+              .build();
+  private static final 
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+      TASK_INFOS_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSetManager.class, "taskInfos")
+              .defaultAlwaysNull()
+              .build();
+
+  protected static TaskSetManager getTaskSetManager(long taskId) {
+    if (SparkContext$.MODULE$.getActive().nonEmpty()) {
+      TaskSchedulerImpl taskScheduler =
+          (TaskSchedulerImpl) 
SparkContext$.MODULE$.getActive().get().taskScheduler();
+      ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
+          TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
+      return taskIdToTaskSetManager.get(taskId);
+    } else {
+      LOG.error("Can not get active SparkContext.");
+      return null;
+    }
+  }
+
+  protected static List<TaskInfo> getTaskAttempts(TaskSetManager 
taskSetManager, long taskId) {
+    if (taskSetManager != null) {
+      scala.Option<TaskInfo> taskInfoOption =
+          TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
+      if (taskInfoOption.isDefined()) {
+        int taskIndex = taskInfoOption.get().index();
+        return scala.collection.JavaConverters.asJavaCollectionConverter(
+                taskSetManager.taskAttempts()[taskIndex])
+            .asJavaCollection().stream()
+            .collect(Collectors.toList());
+      } else {
+        LOG.error("Can not get TaskInfo for taskId: {}", taskId);
+        return Collections.emptyList();
+      }
+    } else {
+      LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
+      return Collections.emptyList();
+    }
+  }
+
+  public static Map<Integer, Set<Long>> 
reportedStageShuffleFetchFailureTaskIds =
+      JavaUtils.newConcurrentHashMap();
+
+  /**
+   * Only check for the shuffle fetch failure task whether another attempt is 
running or successful.
+   * If another attempt(excluding the reported shuffle fetch failure tasks in 
current stage) is
+   * running or successful, return true. Otherwise, return false.
+   */
+  public static synchronized boolean 
taskAnotherAttemptRunningOrSuccessful(long taskId) {
+    TaskSetManager taskSetManager = getTaskSetManager(taskId);
+    if (taskSetManager != null) {
+      int stageId = taskSetManager.stageId();
+      Set<Long> reportedStageTaskIds =
+          reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(stageId, k 
-> new HashSet<>());
+      reportedStageTaskIds.add(taskId);
+
+      List<TaskInfo> taskAttempts = getTaskAttempts(taskSetManager, taskId);
+      Optional<TaskInfo> taskInfoOpt =
+          taskAttempts.stream().filter(ti -> ti.taskId() == 
taskId).findFirst();

Review Comment:
   Return `Tuple2<TaskInfo, List<TaskInfo>>` from `getTaskAttempts` ? We fetch 
taskInfo for taskId at start of `getTaskAttempts`.



##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String 
reason) {
         scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
       }
     } else {
-      logger.error("Can not get active SparkContext, skip cancelShuffle.");
+      LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+    }
+  }
+
+  private static final DynFields.UnboundField<ConcurrentHashMap<Long, 
TaskSetManager>>
+      TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+              .defaultAlwaysNull()
+              .build();
+  private static final 
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+      TASK_INFOS_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSetManager.class, "taskInfos")
+              .defaultAlwaysNull()
+              .build();
+
+  protected static TaskSetManager getTaskSetManager(long taskId) {
+    if (SparkContext$.MODULE$.getActive().nonEmpty()) {
+      TaskSchedulerImpl taskScheduler =
+          (TaskSchedulerImpl) 
SparkContext$.MODULE$.getActive().get().taskScheduler();
+      ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
+          TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
+      return taskIdToTaskSetManager.get(taskId);
+    } else {
+      LOG.error("Can not get active SparkContext.");
+      return null;
+    }
+  }
+
+  protected static List<TaskInfo> getTaskAttempts(TaskSetManager 
taskSetManager, long taskId) {
+    if (taskSetManager != null) {
+      scala.Option<TaskInfo> taskInfoOption =
+          TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
+      if (taskInfoOption.isDefined()) {
+        int taskIndex = taskInfoOption.get().index();
+        return scala.collection.JavaConverters.asJavaCollectionConverter(
+                taskSetManager.taskAttempts()[taskIndex])
+            .asJavaCollection().stream()
+            .collect(Collectors.toList());
+      } else {
+        LOG.error("Can not get TaskInfo for taskId: {}", taskId);
+        return Collections.emptyList();
+      }
+    } else {
+      LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
+      return Collections.emptyList();
+    }
+  }
+
+  public static Map<Integer, Set<Long>> 
reportedStageShuffleFetchFailureTaskIds =
+      JavaUtils.newConcurrentHashMap();
+
+  /**
+   * Only check for the shuffle fetch failure task whether another attempt is 
running or successful.
+   * If another attempt(excluding the reported shuffle fetch failure tasks in 
current stage) is
+   * running or successful, return true. Otherwise, return false.
+   */
+  public static synchronized boolean 
taskAnotherAttemptRunningOrSuccessful(long taskId) {

Review Comment:
   This has an inherent race condition - where both tasks end up not raising a 
fetch failure.
   The lock here is on `SparkUtils` - which does not prevent tsm state from 
changing from under us ... 
   Since we have to lock TaskSchedulerImpl anyway for `getTaskAttempts` - I 
would instead suggest locking `taskAnotherAttemptRunningOrSuccessful` on 
`TaskSchedulerImpl` - which prevent TSM from getting mutated under us.



##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String 
reason) {
         scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
       }
     } else {
-      logger.error("Can not get active SparkContext, skip cancelShuffle.");
+      LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+    }
+  }
+
+  private static final DynFields.UnboundField<ConcurrentHashMap<Long, 
TaskSetManager>>
+      TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+              .defaultAlwaysNull()
+              .build();
+  private static final 
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+      TASK_INFOS_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSetManager.class, "taskInfos")
+              .defaultAlwaysNull()
+              .build();
+
+  protected static TaskSetManager getTaskSetManager(long taskId) {
+    if (SparkContext$.MODULE$.getActive().nonEmpty()) {
+      TaskSchedulerImpl taskScheduler =
+          (TaskSchedulerImpl) 
SparkContext$.MODULE$.getActive().get().taskScheduler();
+      ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
+          TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
+      return taskIdToTaskSetManager.get(taskId);
+    } else {
+      LOG.error("Can not get active SparkContext.");
+      return null;
+    }
+  }
+
+  protected static List<TaskInfo> getTaskAttempts(TaskSetManager 
taskSetManager, long taskId) {
+    if (taskSetManager != null) {
+      scala.Option<TaskInfo> taskInfoOption =
+          TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);

Review Comment:
   This should be accessed within the TaskSchedulerImpl instance lock



##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String 
reason) {
         scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
       }
     } else {
-      logger.error("Can not get active SparkContext, skip cancelShuffle.");
+      LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+    }
+  }
+
+  private static final DynFields.UnboundField<ConcurrentHashMap<Long, 
TaskSetManager>>
+      TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+              .defaultAlwaysNull()
+              .build();
+  private static final 
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+      TASK_INFOS_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSetManager.class, "taskInfos")
+              .defaultAlwaysNull()
+              .build();
+
+  protected static TaskSetManager getTaskSetManager(long taskId) {
+    if (SparkContext$.MODULE$.getActive().nonEmpty()) {
+      TaskSchedulerImpl taskScheduler =
+          (TaskSchedulerImpl) 
SparkContext$.MODULE$.getActive().get().taskScheduler();
+      ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
+          TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
+      return taskIdToTaskSetManager.get(taskId);
+    } else {
+      LOG.error("Can not get active SparkContext.");
+      return null;
+    }
+  }
+
+  protected static List<TaskInfo> getTaskAttempts(TaskSetManager 
taskSetManager, long taskId) {
+    if (taskSetManager != null) {
+      scala.Option<TaskInfo> taskInfoOption =
+          TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
+      if (taskInfoOption.isDefined()) {
+        int taskIndex = taskInfoOption.get().index();
+        return scala.collection.JavaConverters.asJavaCollectionConverter(
+                taskSetManager.taskAttempts()[taskIndex])
+            .asJavaCollection().stream()
+            .collect(Collectors.toList());
+      } else {
+        LOG.error("Can not get TaskInfo for taskId: {}", taskId);
+        return Collections.emptyList();
+      }
+    } else {
+      LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
+      return Collections.emptyList();
+    }
+  }
+
+  public static Map<Integer, Set<Long>> 
reportedStageShuffleFetchFailureTaskIds =

Review Comment:
   `private final` ? or atleast package private if we want this to be 
accessible from test ?
   
   Also, we are not cleaning up from this `Map` - will lead to memory leak



##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -200,7 +212,116 @@ public static void cancelShuffle(int shuffleId, String 
reason) {
         scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
       }
     } else {
-      logger.error("Can not get active SparkContext, skip cancelShuffle.");
+      LOG.error("Can not get active SparkContext, skip cancelShuffle.");
+    }
+  }
+
+  private static final DynFields.UnboundField<ConcurrentHashMap<Long, 
TaskSetManager>>
+      TASK_ID_TO_TASK_SET_MANAGER_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
+              .defaultAlwaysNull()
+              .build();
+  private static final 
DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
+      TASK_INFOS_FIELD =
+          DynFields.builder()
+              .hiddenImpl(TaskSetManager.class, "taskInfos")
+              .defaultAlwaysNull()
+              .build();
+
+  protected static TaskSetManager getTaskSetManager(long taskId) {

Review Comment:
   We should be very careful with use of `TaskSetManager` - it is not designed 
to be used outside of the spark scheduler. I have added a specific comment 
below to mitigate the immediate issue - but within Celeborn we should be 
careful if we are using this method in future.
   Please add a note that there are concerns with use of TSM which need to be 
analyzed carefully if code evolves and other usages come up.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to