marin-ma commented on code in PR #53040:
URL: https://github.com/apache/spark/pull/53040#discussion_r2555858014


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala:
##########
@@ -87,28 +87,116 @@ object FilePartition extends SessionStateHelper with 
Logging {
     partitions.toSeq
   }
 
+  private def getFilePartitionsByFileNum(
+      partitionedFiles: Seq[PartitionedFile],
+      outputPartitions: Int,
+      smallFileThreshold: Double): Seq[FilePartition] = {
+    // Flatten and sort descending by file size.
+    val filesSorted: Seq[(PartitionedFile, Long)] =
+      partitionedFiles
+        .map(f => (f, f.length))
+        .sortBy(_._2)(Ordering.Long.reverse)
+
+    val partitions = 
Seq.fill(outputPartitions)(mutable.ArrayBuffer.empty[PartitionedFile])
+
+    def addToBucket(
+        heap: mutable.PriorityQueue[(Long, Int, Int)],
+        file: PartitionedFile,
+        sz: Long): Unit = {
+      val (load, numFiles, idx) = heap.dequeue()
+      partitions(idx) += file
+      heap.enqueue((load + sz, numFiles + 1, idx))
+    }
+
+    // First by load, then by numFiles.
+    val heapByFileSize =
+      mutable.PriorityQueue.empty[(Long, Int, Int)](
+        Ordering
+          .by[(Long, Int, Int), (Long, Int)] {
+            case (load, numFiles, _) =>
+              (load, numFiles)
+          }
+          .reverse
+      )
+
+    if (smallFileThreshold > 0) {
+      val smallFileTotalSize = filesSorted.map(_._2).sum * smallFileThreshold
+      // First by numFiles, then by load.
+      val heapByFileNum =
+        mutable.PriorityQueue.empty[(Long, Int, Int)](
+          Ordering
+            .by[(Long, Int, Int), (Int, Long)] {
+              case (load, numFiles, _) =>
+                (numFiles, load)
+            }
+            .reverse
+        )
+
+      (0 until outputPartitions).foreach(i => heapByFileNum.enqueue((0L, 0, 
i)))
+
+      var numSmallFiles = 0
+      var smallFileSize = 0L
+      // Enqueue small files to the least number of files and the least load.
+      filesSorted.reverse.takeWhile(f => f._2 + smallFileSize <= 
smallFileTotalSize).foreach {
+        case (file, sz) =>
+          addToBucket(heapByFileNum, file, sz)
+          numSmallFiles += 1
+          smallFileSize += sz
+      }
+
+      // Move buckets from heapByFileNum to heapByFileSize.
+      while (heapByFileNum.nonEmpty) {
+        heapByFileSize.enqueue(heapByFileNum.dequeue())
+      }
+
+      // Finally, enqueue remaining files.
+      filesSorted.take(filesSorted.size - numSmallFiles).foreach {
+        case (file, sz) =>
+          addToBucket(heapByFileSize, file, sz)
+      }
+    } else {
+      (0 until outputPartitions).foreach(i => heapByFileSize.enqueue((0L, 0, 
i)))
+
+      filesSorted.foreach {
+        case (file, sz) =>
+          addToBucket(heapByFileSize, file, sz)
+      }
+    }
+
+    partitions.zipWithIndex.map {
+      case (p, idx) => FilePartition(idx, p.toArray)
+    }
+  }
+
   def getFilePartitions(
       sparkSession: SparkSession,
       partitionedFiles: Seq[PartitionedFile],
       maxSplitBytes: Long): Seq[FilePartition] = {
     val conf = getSqlConf(sparkSession)
     val openCostBytes = conf.filesOpenCostInBytes
     val maxPartNum = conf.filesMaxPartitionNum
-    val partitions = getFilePartitions(partitionedFiles, maxSplitBytes, 
openCostBytes)
-    if (maxPartNum.exists(partitions.size > _)) {
-      val totalSizeInBytes =
-        partitionedFiles.map(_.length + 
openCostBytes).map(BigDecimal(_)).sum[BigDecimal]
-      val desiredSplitBytes =
-        (totalSizeInBytes / BigDecimal(maxPartNum.get)).setScale(0, 
RoundingMode.UP).longValue
-      val desiredPartitions = getFilePartitions(partitionedFiles, 
desiredSplitBytes, openCostBytes)
-      logWarning(log"The number of partitions is ${MDC(NUM_PARTITIONS, 
partitions.size)}, " +
-        log"which exceeds the maximum number configured: " +
-        log"${MDC(MAX_NUM_PARTITIONS, maxPartNum.get)}. Spark rescales it to " 
+
-        log"${MDC(DESIRED_NUM_PARTITIONS, desiredPartitions.size)} by ignoring 
the " +
-        log"configuration of ${MDC(CONFIG, 
SQLConf.FILES_MAX_PARTITION_BYTES.key)}.")
-      desiredPartitions
-    } else {
-      partitions
+    val partitions = getFilePartitionsBySize(partitionedFiles, maxSplitBytes, 
openCostBytes)

Review Comment:
   In very extreme cases it might happen. If it does, the user should not 
enable this algorithm. This algorithm is particularly beneficial in the common 
scenario we observed in the user's production environment, where the input 
contains many small files that are much smaller than most of the input splits. 
In that case, there won’t be a wide mismatch.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to