This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a97885b  [SPARK-33350][SHUFFLE] Add support to DiskBlockManager to 
create merge directory and to get the local shuffle merged data
a97885b is described below

commit a97885bb2c81563a18c99df233fb9e99ff368c9c
Author: Ye Zhou <yez...@linkedin.com>
AuthorDate: Thu Jun 10 16:57:46 2021 -0500

    [SPARK-33350][SHUFFLE] Add support to DiskBlockManager to create merge 
directory and to get the local shuffle merged data
    
    ### What changes were proposed in this pull request?
    This is one of the patches for SPIP SPARK-30602 which is needed for 
push-based shuffle.
    
    ### Summary of changes:
    Executor will create the merge directories under the application temp 
directory provided by YARN. The access control of the folder will be set to 
770, where Shuffle Service can create merged shuffle files and write merge 
shuffle data in to those files.
    
    Serve the merged shuffle blocks fetch request, read the merged shuffle 
blocks.
    
    ### Why are the changes needed?
    Refer to the SPIP in SPARK-30602.
    
    ### Does this PR introduce any user-facing change?
    No
    
    ### How was this patch tested?
    Added unit tests.
    The reference PR with the consolidated changes covering the complete 
implementation is also provided in SPARK-30602.
    We have already verified the functionality and the improved performance as 
documented in the SPIP doc.
    
    Lead-authored-by: Min Shen mshenlinkedin.com
    Co-authored-by: Chandni Singh chsinghlinkedin.com
    Co-authored-by: Ye Zhou yezhoulinkedin.com
    
    Closes #32007 from zhouyejoe/SPARK-33350.
    
    Lead-authored-by: Ye Zhou <yez...@linkedin.com>
    Co-authored-by: Chandni Singh <chsi...@linkedin.com>
    Co-authored-by: Min Shen <ms...@linkedin.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../network/shuffle/RemoteBlockPushResolver.java   |   8 +-
 .../scala/org/apache/spark/MapOutputTracker.scala  |   2 +-
 .../spark/shuffle/IndexShuffleBlockResolver.scala  |  73 +++++++++++++-
 .../spark/shuffle/ShuffleBlockResolver.scala       |  13 ++-
 .../scala/org/apache/spark/storage/BlockId.scala   |  37 +++++++
 .../org/apache/spark/storage/BlockManager.scala    |  26 ++++-
 .../apache/spark/storage/DiskBlockManager.scala    | 112 +++++++++++++++++++++
 .../main/scala/org/apache/spark/util/Utils.scala   |  28 +++++-
 .../shuffle/HostLocalShuffleReadingSuite.scala     |   9 ++
 .../sort/IndexShuffleBlockResolverSuite.scala      |  93 ++++++++++++++++-
 .../org/apache/spark/storage/BlockIdSuite.scala    |  36 +++++++
 .../spark/storage/DiskBlockManagerSuite.scala      |  39 +++++++
 .../scala/org/apache/spark/util/UtilsSuite.scala   |  12 ++-
 13 files changed, 470 insertions(+), 18 deletions(-)

diff --git 
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
index 1ac33cd..47d2547 100644
--- 
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
+++ 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
@@ -75,6 +75,7 @@ public class RemoteBlockPushResolver implements 
MergedShuffleFileManager {
   private static final Logger logger = 
LoggerFactory.getLogger(RemoteBlockPushResolver.class);
   @VisibleForTesting
   static final String MERGE_MANAGER_DIR = "merge_manager";
+  public static final String MERGED_SHUFFLE_FILE_NAME_PREFIX = "shuffleMerged";
 
   private final ConcurrentMap<String, AppPathsInfo> appsPathInfo;
   private final ConcurrentMap<AppShuffleId, Map<Integer, 
AppShufflePartitionInfo>> partitions;
@@ -211,7 +212,8 @@ public class RemoteBlockPushResolver implements 
MergedShuffleFileManager {
 
   /**
    * The logic here is consistent with
-   * org.apache.spark.storage.DiskBlockManager#getMergedShuffleFile
+   * @see [[org.apache.spark.storage.DiskBlockManager#getMergedShuffleFile(
+   *      org.apache.spark.storage.BlockId, scala.Option)]]
    */
   private File getFile(String appId, String filename) {
     // TODO: [SPARK-33236] Change the message when this service is able to 
handle NM restart
@@ -431,8 +433,8 @@ public class RemoteBlockPushResolver implements 
MergedShuffleFileManager {
       executorInfo.subDirsPerLocalDir));
   }
   private static String generateFileName(AppShuffleId appShuffleId, int 
reduceId) {
-    return String.format("mergedShuffle_%s_%d_%d", appShuffleId.appId, 
appShuffleId.shuffleId,
-      reduceId);
+    return String.format("%s_%s_%d_%d", MERGED_SHUFFLE_FILE_NAME_PREFIX, 
appShuffleId.appId,
+      appShuffleId.shuffleId, reduceId);
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala 
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 9f2228b..003b10f 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -1290,7 +1290,7 @@ private[spark] object MapOutputTracker extends Logging {
   private val DIRECT = 0
   private val BROADCAST = 1
 
-  private val SHUFFLE_PUSH_MAP_ID = -1
+  val SHUFFLE_PUSH_MAP_ID = -1
 
   // Serialize an array of map/merge output locations into an efficient byte 
format so that we can
   // send it to reduce tasks. We do this by compressing the serialized bytes 
using Zstd. They will
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala 
b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 1619a5b..5d1da19 100644
--- 
a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -28,7 +28,7 @@ import org.apache.spark.io.NioBufferedFileInputStream
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer}
 import org.apache.spark.network.client.StreamCallbackWithID
 import org.apache.spark.network.netty.SparkTransportConf
-import org.apache.spark.network.shuffle.ExecutorDiskUtils
+import org.apache.spark.network.shuffle.{ExecutorDiskUtils, MergedBlockMeta}
 import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
 import org.apache.spark.storage._
@@ -110,6 +110,33 @@ private[spark] class IndexShuffleBlockResolver(
       .getOrElse(blockManager.diskBlockManager.getFile(blockId))
   }
 
+  private def getMergedBlockDataFile(
+      appId: String,
+      shuffleId: Int,
+      reduceId: Int,
+      dirs: Option[Array[String]] = None): File = {
+    blockManager.diskBlockManager.getMergedShuffleFile(
+      ShuffleMergedDataBlockId(appId, shuffleId, reduceId), dirs)
+  }
+
+  private def getMergedBlockIndexFile(
+      appId: String,
+      shuffleId: Int,
+      reduceId: Int,
+      dirs: Option[Array[String]] = None): File = {
+    blockManager.diskBlockManager.getMergedShuffleFile(
+      ShuffleMergedIndexBlockId(appId, shuffleId, reduceId), dirs)
+  }
+
+  private def getMergedBlockMetaFile(
+      appId: String,
+      shuffleId: Int,
+      reduceId: Int,
+      dirs: Option[Array[String]] = None): File = {
+    blockManager.diskBlockManager.getMergedShuffleFile(
+      ShuffleMergedMetaBlockId(appId, shuffleId, reduceId), dirs)
+  }
+
   /**
    * Remove data file and index file that contain the output data from one map.
    */
@@ -343,6 +370,50 @@ private[spark] class IndexShuffleBlockResolver(
     }
   }
 
+  /**
+   * This is only used for reading local merged block data. In such cases, all 
chunks in the
+   * merged shuffle file need to be identified at once, so the 
ShuffleBlockFetcherIterator
+   * knows how to consume local merged shuffle file as multiple chunks.
+   */
+  override def getMergedBlockData(
+      blockId: ShuffleBlockId,
+      dirs: Option[Array[String]]): Seq[ManagedBuffer] = {
+    val indexFile =
+      getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, 
blockId.reduceId, dirs)
+    val dataFile = getMergedBlockDataFile(conf.getAppId, blockId.shuffleId, 
blockId.reduceId, dirs)
+    // Load all the indexes in order to identify all chunks in the specified 
merged shuffle file.
+    val size = indexFile.length.toInt
+    val offsets = Utils.tryWithResource {
+      new DataInputStream(Files.newInputStream(indexFile.toPath))
+    } { dis =>
+      val buffer = ByteBuffer.allocate(size)
+      dis.readFully(buffer.array)
+      buffer.asLongBuffer
+    }
+    // Number of chunks is number of indexes - 1
+    val numChunks = size / 8 - 1
+    for (index <- 0 until numChunks) yield {
+      new FileSegmentManagedBuffer(transportConf, dataFile,
+        offsets.get(index),
+        offsets.get(index + 1) - offsets.get(index))
+    }
+  }
+
+  /**
+   * This is only used for reading local merged block meta data.
+   */
+  override def getMergedBlockMeta(
+      blockId: ShuffleBlockId,
+      dirs: Option[Array[String]]): MergedBlockMeta = {
+    val indexFile =
+      getMergedBlockIndexFile(conf.getAppId, blockId.shuffleId, 
blockId.reduceId, dirs)
+    val size = indexFile.length.toInt
+    val numChunks = (size / 8) - 1
+    val metaFile = getMergedBlockMetaFile(conf.getAppId, blockId.shuffleId, 
blockId.reduceId, dirs)
+    val chunkBitMaps = new FileSegmentManagedBuffer(transportConf, metaFile, 
0L, metaFile.length)
+    new MergedBlockMeta(numChunks, chunkBitMaps)
+  }
+
   override def getBlockData(
       blockId: BlockId,
       dirs: Option[Array[String]]): ManagedBuffer = {
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala 
b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
index 5485cf9..49e5929 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
@@ -18,7 +18,8 @@
 package org.apache.spark.shuffle
 
 import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.storage.BlockId
+import org.apache.spark.network.shuffle.MergedBlockMeta
+import org.apache.spark.storage.{BlockId, ShuffleBlockId}
 
 private[spark]
 /**
@@ -40,5 +41,15 @@ trait ShuffleBlockResolver {
    */
   def getBlockData(blockId: BlockId, dirs: Option[Array[String]] = None): 
ManagedBuffer
 
+  /**
+   * Retrieve the data for the specified merged shuffle block as multiple 
chunks.
+   */
+  def getMergedBlockData(blockId: ShuffleBlockId, dirs: 
Option[Array[String]]): Seq[ManagedBuffer]
+
+  /**
+   * Retrieve the meta data for the specified merged shuffle block.
+   */
+  def getMergedBlockMeta(blockId: ShuffleBlockId, dirs: 
Option[Array[String]]): MergedBlockMeta
+
   def stop(): Unit
 }
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 73bf809..47c1b96 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -21,6 +21,7 @@ import java.util.UUID
 
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.{DeveloperApi, Since}
+import org.apache.spark.network.shuffle.RemoteBlockPushResolver
 
 /**
  * :: DeveloperApi ::
@@ -87,6 +88,33 @@ case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, 
reduceId: Int) exte
   override def name: String = "shufflePush_" + shuffleId + "_" + mapIndex + 
"_" + reduceId
 }
 
+@Since("3.2.0")
+@DeveloperApi
+case class ShuffleMergedDataBlockId(appId: String, shuffleId: Int, reduceId: 
Int) extends BlockId {
+  override def name: String = 
RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" +
+    appId + "_" + shuffleId + "_" + reduceId + ".data"
+}
+
+@Since("3.2.0")
+@DeveloperApi
+case class ShuffleMergedIndexBlockId(
+    appId: String,
+    shuffleId: Int,
+    reduceId: Int) extends BlockId {
+  override def name: String = 
RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" +
+    appId + "_" + shuffleId + "_" + reduceId + ".index"
+}
+
+@Since("3.2.0")
+@DeveloperApi
+case class ShuffleMergedMetaBlockId(
+    appId: String,
+    shuffleId: Int,
+    reduceId: Int) extends BlockId {
+  override def name: String = 
RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" +
+    appId + "_" + shuffleId + "_" + reduceId + ".meta"
+}
+
 @DeveloperApi
 case class BroadcastBlockId(broadcastId: Long, field: String = "") extends 
BlockId {
   override def name: String = "broadcast_" + broadcastId + (if (field == "") 
"" else "_" + field)
@@ -129,6 +157,9 @@ object BlockId {
   val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r
   val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
   val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE_MERGED_DATA = 
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).data".r
+  val SHUFFLE_MERGED_INDEX = 
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).index".r
+  val SHUFFLE_MERGED_META = 
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).meta".r
   val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
   val TASKRESULT = "taskresult_([0-9]+)".r
   val STREAM = "input-([0-9]+)-([0-9]+)".r
@@ -149,6 +180,12 @@ object BlockId {
       ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
     case SHUFFLE_PUSH(shuffleId, mapIndex, reduceId) =>
       ShufflePushBlockId(shuffleId.toInt, mapIndex.toInt, reduceId.toInt)
+    case SHUFFLE_MERGED_DATA(appId, shuffleId, reduceId) =>
+      ShuffleMergedDataBlockId(appId, shuffleId.toInt, reduceId.toInt)
+    case SHUFFLE_MERGED_INDEX(appId, shuffleId, reduceId) =>
+      ShuffleMergedIndexBlockId(appId, shuffleId.toInt, reduceId.toInt)
+    case SHUFFLE_MERGED_META(appId, shuffleId, reduceId) =>
+      ShuffleMergedMetaBlockId(appId, shuffleId.toInt, reduceId.toInt)
     case BROADCAST(broadcastId, field) =>
       BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
     case TASKRESULT(taskId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 5d1666d..df449fb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -503,8 +503,9 @@ private[spark] class BlockManager(
     }
 
     hostLocalDirManager = {
-      if (conf.get(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED) &&
-          !conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
+      if ((conf.get(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED) &&
+          !conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) ||
+          Utils.isPushBasedShuffleEnabled(conf)) {
         Some(new HostLocalDirManager(
           futureExecutionContext,
           conf.get(config.STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE),
@@ -729,6 +730,27 @@ private[spark] class BlockManager(
   }
 
   /**
+   * Get the local merged shuffle block data for the given block ID as 
multiple chunks.
+   * A merged shuffle file is divided into multiple chunks according to the 
index file.
+   * Instead of reading the entire file as a single block, we split it into 
smaller chunks
+   * which will be memory efficient when performing certain operations.
+   */
+  def getLocalMergedBlockData(
+      blockId: ShuffleBlockId,
+      dirs: Array[String]): Seq[ManagedBuffer] = {
+    shuffleManager.shuffleBlockResolver.getMergedBlockData(blockId, Some(dirs))
+  }
+
+  /**
+   * Get the local merged shuffle block meta data for the given block ID.
+   */
+  def getLocalMergedBlockMeta(
+      blockId: ShuffleBlockId,
+      dirs: Array[String]): MergedBlockMeta = {
+    shuffleManager.shuffleBlockResolver.getMergedBlockMeta(blockId, Some(dirs))
+  }
+
+  /**
    * Get the BlockStatus for the block identified by the given ID, if it 
exists.
    * NOTE: This is mainly for testing.
    */
diff --git 
a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index f5ad4f9..d49f43f 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -24,6 +24,8 @@ import java.util.UUID
 import org.apache.spark.SparkConf
 import org.apache.spark.executor.ExecutorExitCode
 import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.network.shuffle.ExecutorDiskUtils
+import org.apache.spark.storage.DiskBlockManager.MERGE_MANAGER_DIR
 import org.apache.spark.util.{ShutdownHookManager, Utils}
 
 /**
@@ -55,6 +57,9 @@ private[spark] class DiskBlockManager(conf: SparkConf, var 
deleteFilesOnStop: Bo
   // of subDirs(i) is protected by the lock of subDirs(i)
   private val subDirs = Array.fill(localDirs.length)(new 
Array[File](subDirsPerLocalDir))
 
+  // Create merge directories
+  createLocalDirsForMergedShuffleBlocks()
+
   private val shutdownHook = addShutdownHook()
 
   /** Looks up a file by hashing it into one of our local subdirectories. */
@@ -86,6 +91,33 @@ private[spark] class DiskBlockManager(conf: SparkConf, var 
deleteFilesOnStop: Bo
 
   def getFile(blockId: BlockId): File = getFile(blockId.name)
 
+  /**
+   * This should be in sync with
+   * @see [[org.apache.spark.network.shuffle.RemoteBlockPushResolver#getFile(
+   *     java.lang.String, java.lang.String)]]
+   */
+  def getMergedShuffleFile(blockId: BlockId, dirs: Option[Array[String]]): 
File = {
+    blockId match {
+      case mergedBlockId: ShuffleMergedDataBlockId =>
+        getMergedShuffleFile(mergedBlockId.name, dirs)
+      case mergedIndexBlockId: ShuffleMergedIndexBlockId =>
+        getMergedShuffleFile(mergedIndexBlockId.name, dirs)
+      case mergedMetaBlockId: ShuffleMergedMetaBlockId =>
+        getMergedShuffleFile(mergedMetaBlockId.name, dirs)
+      case _ =>
+        throw new IllegalArgumentException(
+          s"Only merged block ID is supported, but got $blockId")
+    }
+  }
+
+  private def getMergedShuffleFile(filename: String, dirs: 
Option[Array[String]]): File = {
+    if (!dirs.exists(_.nonEmpty)) {
+      throw new IllegalArgumentException(
+        s"Cannot read $filename because merged shuffle dirs is empty")
+    }
+    ExecutorDiskUtils.getFile(dirs.get, subDirsPerLocalDir, filename)
+  }
+
   /** Check if disk block manager has a block. */
   def containsBlock(blockId: BlockId): Boolean = {
     getFile(blockId.name).exists()
@@ -156,6 +188,82 @@ private[spark] class DiskBlockManager(conf: SparkConf, var 
deleteFilesOnStop: Bo
     }
   }
 
+  /**
+   * Get the list of configured local dirs storing merged shuffle blocks 
created by executors
+   * if push based shuffle is enabled. Note that the files in this directory 
will be created
+   * by the external shuffle services. We only create the merge_manager 
directories and
+   * subdirectories here because currently the external shuffle service 
doesn't have
+   * permission to create directories under application local directories.
+   */
+  private def createLocalDirsForMergedShuffleBlocks(): Unit = {
+    if (Utils.isPushBasedShuffleEnabled(conf)) {
+      // Will create the merge_manager directory only if it doesn't exist 
under the local dir.
+      Utils.getConfiguredLocalDirs(conf).foreach { rootDir =>
+        try {
+          val mergeDir = new File(rootDir, MERGE_MANAGER_DIR)
+          if (!mergeDir.exists()) {
+            // This executor does not find merge_manager directory, it will 
try to create
+            // the merge_manager directory and the sub directories.
+            logDebug(s"Try to create $mergeDir and its sub dirs since the " +
+              s"$MERGE_MANAGER_DIR dir does not exist")
+            for (dirNum <- 0 until subDirsPerLocalDir) {
+              val subDir = new File(mergeDir, "%02x".format(dirNum))
+              if (!subDir.exists()) {
+                // Only one container will create this directory. The 
filesystem will handle
+                // any race conditions.
+                createDirWithPermission770(subDir)
+              }
+            }
+          }
+          logInfo(s"Merge directory and its sub dirs get created at $mergeDir")
+        } catch {
+          case e: IOException =>
+            logError(
+              s"Failed to create $MERGE_MANAGER_DIR dir in $rootDir. Ignoring 
this directory.", e)
+        }
+      }
+    }
+  }
+
+  /**
+   * Create a directory that is writable by the group.
+   * Grant the permission 770 "rwxrwx---" to the directory so the shuffle 
server can
+   * create subdirs/files within the merge folder.
+   * TODO: Find out why can't we create a dir using java api with permission 
770
+   *  Files.createDirectories(mergeDir.toPath, 
PosixFilePermissions.asFileAttribute(
+   *  PosixFilePermissions.fromString("rwxrwx---")))
+   */
+  def createDirWithPermission770(dirToCreate: File): Unit = {
+    var attempts = 0
+    val maxAttempts = Utils.MAX_DIR_CREATION_ATTEMPTS
+    var created: File = null
+    while (created == null) {
+      attempts += 1
+      if (attempts > maxAttempts) {
+        throw new IOException(
+          s"Failed to create directory ${dirToCreate.getAbsolutePath} with 
permission " +
+            s"770 after $maxAttempts attempts!")
+      }
+      try {
+        val builder = new ProcessBuilder().command(
+          "mkdir", "-p", "-m770", dirToCreate.getAbsolutePath)
+        val proc = builder.start()
+        val exitCode = proc.waitFor()
+        if (dirToCreate.exists()) {
+          created = dirToCreate
+        }
+        logDebug(
+          s"Created directory at ${dirToCreate.getAbsolutePath} with 
permission " +
+            s"770 and exitCode $exitCode")
+      } catch {
+        case e: SecurityException =>
+          logWarning(s"Failed to create directory 
${dirToCreate.getAbsolutePath} " +
+            s"with permission 770", e)
+          created = null;
+      }
+    }
+  }
+
   private def addShutdownHook(): AnyRef = {
     logDebug("Adding shutdown hook") // force eager creation of logger
     
ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY
 + 1) { () =>
@@ -193,3 +301,7 @@ private[spark] class DiskBlockManager(conf: SparkConf, var 
deleteFilesOnStop: Bo
     }
   }
 }
+
+private[spark] object DiskBlockManager {
+  private[spark] val MERGE_MANAGER_DIR = "merge_manager"
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala 
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index a082442..98565a1 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -96,7 +96,7 @@ private[spark] object Utils extends Logging {
    */
   val DEFAULT_DRIVER_MEM_MB = JavaUtils.DEFAULT_DRIVER_MEM_MB.toInt
 
-  private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+  val MAX_DIR_CREATION_ATTEMPTS: Int = 10
   @volatile private var localRootDirs: Array[String] = null
 
   /** Scheme used for files that are locally available on worker nodes in the 
cluster. */
@@ -2582,11 +2582,33 @@ private[spark] object Utils extends Logging {
   }
 
   /**
-   * Push based shuffle can only be enabled when external shuffle service is 
enabled.
+   * Push based shuffle can only be enabled when the application is submitted
+   * to run in YARN mode, with external shuffle service enabled and
+   * spark.yarn.maxAttempts or the yarn cluster default max attempts is set to 
1.
+   * TODO: Remove the requirement on spark.yarn.maxAttempts after SPARK-35546
+   * Support push based shuffle with multiple app attempts
    */
   def isPushBasedShuffleEnabled(conf: SparkConf): Boolean = {
     conf.get(PUSH_BASED_SHUFFLE_ENABLED) &&
-      (conf.get(IS_TESTING).getOrElse(false) || 
conf.get(SHUFFLE_SERVICE_ENABLED))
+      (conf.get(IS_TESTING).getOrElse(false) ||
+        (conf.get(SHUFFLE_SERVICE_ENABLED) &&
+          conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn" &&
+          getYarnMaxAttempts(conf) == 1))
+  }
+
+  /**
+   * Returns the maximum number of attempts to register the AM in YARN mode.
+   * TODO: Remove this method after SPARK-35546 Support push based shuffle
+   * with multiple app attempts
+   */
+  def getYarnMaxAttempts(conf: SparkConf): Int = {
+    val sparkMaxAttempts = 
conf.getOption("spark.yarn.maxAttempts").map(_.toInt)
+    val yarnMaxAttempts = getSparkOrYarnConfig(conf, 
YarnConfiguration.RM_AM_MAX_ATTEMPTS,
+      YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS.toString).toInt
+    sparkMaxAttempts match {
+      case Some(x) => if (x <= yarnMaxAttempts) x else yarnMaxAttempts
+      case None => yarnMaxAttempts
+    }
   }
 
   /**
diff --git 
a/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala
 
b/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala
index 8f0c4da..33f544a 100644
--- 
a/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala
@@ -133,4 +133,13 @@ class HostLocalShuffleReadingSuite extends SparkFunSuite 
with Matchers with Loca
       assert(remoteBytesRead.sum === 0 && remoteBlocksFetched.sum === 0)
     }
   }
+
+  test("Enable host local shuffle reading when push based shuffle is enabled") 
{
+    val conf = new SparkConf()
+      .set(SHUFFLE_SERVICE_ENABLED, true)
+      .set("spark.yarn.maxAttempts", "1")
+      .set(PUSH_BASED_SHUFFLE_ENABLED, true)
+    sc = new SparkContext("local-cluster[2, 1, 1024]", 
"test-host-local-shuffle-reading", conf)
+    sc.env.blockManager.hostLocalDirManager.isDefined should equal(true)
+  }
 }
diff --git 
a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
 
b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
index da98ad3..5955d44 100644
--- 
a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
@@ -17,21 +17,21 @@
 
 package org.apache.spark.shuffle.sort
 
-import java.io.{DataInputStream, File, FileInputStream, FileOutputStream}
+import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, File, 
FileInputStream, FileOutputStream}
 
 import org.mockito.{Mock, MockitoAnnotations}
 import org.mockito.Answers.RETURNS_SMART_NULLS
 import org.mockito.ArgumentMatchers.any
 import org.mockito.Mockito._
 import org.mockito.invocation.InvocationOnMock
+import org.roaringbitmap.RoaringBitmap
 import org.scalatest.BeforeAndAfterEach
 
-import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.{MapOutputTracker, SparkConf, SparkFunSuite}
 import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo}
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
 
-
 class IndexShuffleBlockResolverSuite extends SparkFunSuite with 
BeforeAndAfterEach {
 
   @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = 
_
@@ -39,6 +39,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite 
with BeforeAndAfterEa
 
   private var tempDir: File = _
   private val conf: SparkConf = new SparkConf(loadDefaults = false)
+  private val appId = "TESTAPP"
 
   override def beforeEach(): Unit = {
     super.beforeEach()
@@ -48,7 +49,11 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite 
with BeforeAndAfterEa
     when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
     when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
       (invocation: InvocationOnMock) => new File(tempDir, 
invocation.getArguments.head.toString))
+    when(diskBlockManager.getMergedShuffleFile(
+      any[BlockId], any[Option[Array[String]]])).thenAnswer(
+      (invocation: InvocationOnMock) => new File(tempDir, 
invocation.getArguments.head.toString))
     when(diskBlockManager.localDirs).thenReturn(Array(tempDir))
+    conf.set("spark.app.id", appId)
   }
 
   override def afterEach(): Unit = {
@@ -161,4 +166,86 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite 
with BeforeAndAfterEa
     val resolver = new IndexShuffleBlockResolver(conf, blockManager)
     assert(resolver.getMigrationBlocks(ShuffleBlockInfo(Int.MaxValue, 
Long.MaxValue)).isEmpty)
   }
+
+  test("getMergedBlockData should return expected FileSegmentManagedBuffer 
list") {
+    val shuffleId = 1
+    val reduceId = 1
+    val dataFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.data"
+    val dataFile = new File(tempDir.getAbsolutePath, dataFileName)
+    val out = new FileOutputStream(dataFile)
+    Utils.tryWithSafeFinally {
+      out.write(new Array[Byte](30))
+    } {
+      out.close()
+    }
+    val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.index"
+    generateMergedShuffleIndexFile(indexFileName)
+    val resolver = new IndexShuffleBlockResolver(conf, blockManager)
+    val dirs = Some(Array[String](tempDir.getAbsolutePath))
+    val managedBufferList =
+      resolver.getMergedBlockData(ShuffleBlockId(shuffleId, -1, reduceId), 
dirs)
+    assert(managedBufferList.size === 3)
+    assert(managedBufferList(0).size === 10)
+    assert(managedBufferList(1).size === 0)
+    assert(managedBufferList(2).size === 20)
+  }
+
+  test("getMergedBlockMeta should return expected MergedBlockMeta") {
+    val shuffleId = 1
+    val reduceId = 1
+    val metaFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.meta"
+    val metaFile = new File(tempDir.getAbsolutePath, metaFileName)
+    val chunkTracker = new RoaringBitmap()
+    val metaFileOutputStream = new FileOutputStream(metaFile)
+    val outMeta = new DataOutputStream(metaFileOutputStream)
+    Utils.tryWithSafeFinally {
+      chunkTracker.add(1)
+      chunkTracker.add(2)
+      chunkTracker.serialize(outMeta)
+      chunkTracker.clear()
+      chunkTracker.add(3)
+      chunkTracker.add(4)
+      chunkTracker.serialize(outMeta)
+      chunkTracker.clear()
+      chunkTracker.add(5)
+      chunkTracker.add(6)
+      chunkTracker.serialize(outMeta)
+    }{
+      outMeta.close()
+    }
+    val indexFileName = s"shuffleMerged_${appId}_${shuffleId}_$reduceId.index"
+    generateMergedShuffleIndexFile(indexFileName)
+    val resolver = new IndexShuffleBlockResolver(conf, blockManager)
+    val dirs = Some(Array[String](tempDir.getAbsolutePath))
+    val mergedBlockMeta =
+      resolver.getMergedBlockMeta(
+        ShuffleBlockId(shuffleId, MapOutputTracker.SHUFFLE_PUSH_MAP_ID, 
reduceId), dirs)
+    assert(mergedBlockMeta.getNumChunks === 3)
+    assert(mergedBlockMeta.readChunkBitmaps().size === 3)
+    assert(mergedBlockMeta.readChunkBitmaps()(0).contains(1))
+    assert(mergedBlockMeta.readChunkBitmaps()(0).contains(2))
+    assert(!mergedBlockMeta.readChunkBitmaps()(0).contains(3))
+    assert(mergedBlockMeta.readChunkBitmaps()(1).contains(3))
+    assert(mergedBlockMeta.readChunkBitmaps()(1).contains(4))
+    assert(!mergedBlockMeta.readChunkBitmaps()(1).contains(5))
+    assert(mergedBlockMeta.readChunkBitmaps()(2).contains(5))
+    assert(mergedBlockMeta.readChunkBitmaps()(2).contains(6))
+    assert(!mergedBlockMeta.readChunkBitmaps()(2).contains(1))
+  }
+
+  private def generateMergedShuffleIndexFile(indexFileName: String): Unit = {
+    val lengths = Array[Long](10, 0, 20)
+    val indexFile = new File(tempDir.getAbsolutePath, indexFileName)
+    val outIndex = new DataOutputStream(new BufferedOutputStream(new 
FileOutputStream(indexFile)))
+    Utils.tryWithSafeFinally {
+      var offset = 0L
+      outIndex.writeLong(offset)
+      for (length <- lengths) {
+        offset += length
+        outIndex.writeLong(offset)
+      }
+    } {
+      outIndex.close()
+    }
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
index d7009e6..b3138d7 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
@@ -104,6 +104,42 @@ class BlockIdSuite extends SparkFunSuite {
     assertSame(id, BlockId(id.toString))
   }
 
+  test("shuffle merged data") {
+    val id = ShuffleMergedDataBlockId("app_000", 8, 9)
+    assertSame(id, ShuffleMergedDataBlockId("app_000", 8, 9))
+    assertDifferent(id, ShuffleMergedDataBlockId("app_000", 9, 9))
+    assert(id.name === "shuffleMerged_app_000_8_9.data")
+    assert(id.asRDDId === None)
+    assert(id.appId === "app_000")
+    assert(id.shuffleId=== 8)
+    assert(id.reduceId === 9)
+    assertSame(id, BlockId(id.toString))
+  }
+
+  test("shuffle merged index") {
+    val id = ShuffleMergedIndexBlockId("app_000", 8, 9)
+    assertSame(id, ShuffleMergedIndexBlockId("app_000", 8, 9))
+    assertDifferent(id, ShuffleMergedIndexBlockId("app_000", 9, 9))
+    assert(id.name === "shuffleMerged_app_000_8_9.index")
+    assert(id.asRDDId === None)
+    assert(id.appId === "app_000")
+    assert(id.shuffleId=== 8)
+    assert(id.reduceId === 9)
+    assertSame(id, BlockId(id.toString))
+  }
+
+  test("shuffle merged meta") {
+    val id = ShuffleMergedMetaBlockId("app_000", 8, 9)
+    assertSame(id, ShuffleMergedMetaBlockId("app_000", 8, 9))
+    assertDifferent(id, ShuffleMergedMetaBlockId("app_000", 9, 9))
+    assert(id.name === "shuffleMerged_app_000_8_9.meta")
+    assert(id.asRDDId === None)
+    assert(id.appId === "app_000")
+    assert(id.shuffleId=== 8)
+    assert(id.reduceId === 9)
+    assertSame(id, BlockId(id.toString))
+  }
+
   test("broadcast") {
     val id = BroadcastBlockId(42)
     assertSame(id, BroadcastBlockId(42))
diff --git 
a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index c757dee..6397c96 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -18,12 +18,17 @@
 package org.apache.spark.storage
 
 import java.io.{File, FileWriter}
+import java.nio.file.{Files, Paths}
+import java.nio.file.attribute.PosixFilePermissions
 
+import org.apache.commons.io.FileUtils
 import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.internal.config
 import org.apache.spark.util.Utils
 
+
 class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with 
BeforeAndAfterAll {
   private val testConf = new SparkConf(false)
   private var rootDir0: File = _
@@ -85,6 +90,40 @@ class DiskBlockManagerSuite extends SparkFunSuite with 
BeforeAndAfterEach with B
     assert(diskBlockManager.getAllBlocks().isEmpty)
   }
 
+  test("should still create merge directories if one already exists under a 
local dir") {
+    val mergeDir0 = new File(rootDir0, DiskBlockManager.MERGE_MANAGER_DIR)
+    if (!mergeDir0.exists()) {
+      Files.createDirectories(mergeDir0.toPath)
+    }
+    val mergeDir1 = new File(rootDir1, DiskBlockManager.MERGE_MANAGER_DIR)
+    if (mergeDir1.exists()) {
+      Utils.deleteRecursively(mergeDir1)
+    }
+    testConf.set("spark.local.dir", rootDirs)
+    testConf.set("spark.shuffle.push.enabled", "true")
+    testConf.set(config.Tests.IS_TESTING, true)
+    diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true)
+    assert(Utils.getConfiguredLocalDirs(testConf).map(
+      rootDir => new File(rootDir, DiskBlockManager.MERGE_MANAGER_DIR))
+      .filter(mergeDir => mergeDir.exists()).length === 2)
+    // mergeDir0 will be skipped as it already exists
+    assert(mergeDir0.list().length === 0)
+    // Sub directories get created under mergeDir1
+    assert(mergeDir1.list().length === 
testConf.get(config.DISKSTORE_SUB_DIRECTORIES))
+  }
+
+  test("Test dir creation with permission 770") {
+    val testDir = new File("target/testDir");
+    FileUtils.deleteQuietly(testDir)
+    diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true)
+    diskBlockManager.createDirWithPermission770(testDir)
+    assert(testDir.exists && testDir.isDirectory)
+    val permission = PosixFilePermissions.toString(
+      Files.getPosixFilePermissions(Paths.get("target/testDir")))
+    assert(permission.equals("rwxrwx---"))
+    FileUtils.deleteQuietly(testDir)
+  }
+
   def writeToFile(file: File, numBytes: Int): Unit = {
     val writer = new FileWriter(file, true)
     for (i <- 0 until numBytes) writer.write(i)
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala 
b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 404cc34..dba7e39 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -17,8 +17,7 @@
 
 package org.apache.spark.util
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataOutput, 
DataOutputStream, File,
-  FileOutputStream, PrintStream, SequenceInputStream}
+import java.io._
 import java.lang.reflect.Field
 import java.net.{BindException, ServerSocket, URI}
 import java.nio.{ByteBuffer, ByteOrder}
@@ -42,6 +41,7 @@ import org.apache.spark.{SparkConf, SparkException, 
SparkFunSuite, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
 import org.apache.spark.internal.config.Tests.IS_TESTING
+import org.apache.spark.launcher.SparkLauncher
 import org.apache.spark.network.util.ByteUnit
 import org.apache.spark.scheduler.SparkListener
 import org.apache.spark.util.io.ChunkedByteBufferInputStream
@@ -1438,15 +1438,19 @@ class UtilsSuite extends SparkFunSuite with 
ResetSystemProperties with Logging {
     assert(message.contains(expected))
   }
 
-  test("isPushBasedShuffleEnabled when both PUSH_BASED_SHUFFLE_ENABLED" +
-    " and SHUFFLE_SERVICE_ENABLED are true") {
+  test("isPushBasedShuffleEnabled when PUSH_BASED_SHUFFLE_ENABLED " +
+    "and SHUFFLE_SERVICE_ENABLED are both set to true in YARN mode with 
maxAttempts set to 1") {
     val conf = new SparkConf()
     assert(Utils.isPushBasedShuffleEnabled(conf) === false)
     conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
     conf.set(IS_TESTING, false)
     assert(Utils.isPushBasedShuffleEnabled(conf) === false)
     conf.set(SHUFFLE_SERVICE_ENABLED, true)
+    conf.set(SparkLauncher.SPARK_MASTER, "yarn")
+    conf.set("spark.yarn.maxAttempts", "1")
     assert(Utils.isPushBasedShuffleEnabled(conf) === true)
+    conf.set("spark.yarn.maxAttempts", "2")
+    assert(Utils.isPushBasedShuffleEnabled(conf) === false)
   }
 }
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to