VindhyaG commented on code in PR #53040:
URL: https://github.com/apache/spark/pull/53040#discussion_r2555454615


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala:
##########
@@ -87,28 +87,116 @@ object FilePartition extends SessionStateHelper with 
Logging {
     partitions.toSeq
   }
 
+  private def getFilePartitionsByFileNum(
+      partitionedFiles: Seq[PartitionedFile],
+      outputPartitions: Int,
+      smallFileThreshold: Double): Seq[FilePartition] = {
+    // Flatten and sort descending by file size.
+    val filesSorted: Seq[(PartitionedFile, Long)] =
+      partitionedFiles
+        .map(f => (f, f.length))
+        .sortBy(_._2)(Ordering.Long.reverse)
+
+    val partitions = 
Seq.fill(outputPartitions)(mutable.ArrayBuffer.empty[PartitionedFile])
+
+    def addToBucket(
+        heap: mutable.PriorityQueue[(Long, Int, Int)],
+        file: PartitionedFile,
+        sz: Long): Unit = {
+      val (load, numFiles, idx) = heap.dequeue()
+      partitions(idx) += file
+      heap.enqueue((load + sz, numFiles + 1, idx))
+    }
+
+    // First by load, then by numFiles.
+    val heapByFileSize =
+      mutable.PriorityQueue.empty[(Long, Int, Int)](
+        Ordering
+          .by[(Long, Int, Int), (Long, Int)] {
+            case (load, numFiles, _) =>
+              (load, numFiles)
+          }
+          .reverse
+      )
+
+    if (smallFileThreshold > 0) {
+      val smallFileTotalSize = filesSorted.map(_._2).sum * smallFileThreshold
+      // First by numFiles, then by load.
+      val heapByFileNum =
+        mutable.PriorityQueue.empty[(Long, Int, Int)](
+          Ordering
+            .by[(Long, Int, Int), (Int, Long)] {
+              case (load, numFiles, _) =>
+                (numFiles, load)
+            }
+            .reverse
+        )
+
+      (0 until outputPartitions).foreach(i => heapByFileNum.enqueue((0L, 0, 
i)))
+
+      var numSmallFiles = 0
+      var smallFileSize = 0L
+      // Enqueue small files to the least number of files and the least load.
+      filesSorted.reverse.takeWhile(f => f._2 + smallFileSize <= 
smallFileTotalSize).foreach {
+        case (file, sz) =>
+          addToBucket(heapByFileNum, file, sz)
+          numSmallFiles += 1
+          smallFileSize += sz
+      }
+
+      // Move buckets from heapByFileNum to heapByFileSize.
+      while (heapByFileNum.nonEmpty) {
+        heapByFileSize.enqueue(heapByFileNum.dequeue())
+      }
+
+      // Finally, enqueue remaining files.
+      filesSorted.take(filesSorted.size - numSmallFiles).foreach {
+        case (file, sz) =>
+          addToBucket(heapByFileSize, file, sz)
+      }
+    } else {
+      (0 until outputPartitions).foreach(i => heapByFileSize.enqueue((0L, 0, 
i)))
+
+      filesSorted.foreach {
+        case (file, sz) =>
+          addToBucket(heapByFileSize, file, sz)
+      }
+    }
+
+    partitions.zipWithIndex.map {
+      case (p, idx) => FilePartition(idx, p.toArray)
+    }
+  }
+
   def getFilePartitions(
       sparkSession: SparkSession,
       partitionedFiles: Seq[PartitionedFile],
       maxSplitBytes: Long): Seq[FilePartition] = {
     val conf = getSqlConf(sparkSession)
     val openCostBytes = conf.filesOpenCostInBytes
     val maxPartNum = conf.filesMaxPartitionNum
-    val partitions = getFilePartitions(partitionedFiles, maxSplitBytes, 
openCostBytes)
-    if (maxPartNum.exists(partitions.size > _)) {
-      val totalSizeInBytes =
-        partitionedFiles.map(_.length + 
openCostBytes).map(BigDecimal(_)).sum[BigDecimal]
-      val desiredSplitBytes =
-        (totalSizeInBytes / BigDecimal(maxPartNum.get)).setScale(0, 
RoundingMode.UP).longValue
-      val desiredPartitions = getFilePartitions(partitionedFiles, 
desiredSplitBytes, openCostBytes)
-      logWarning(log"The number of partitions is ${MDC(NUM_PARTITIONS, 
partitions.size)}, " +
-        log"which exceeds the maximum number configured: " +
-        log"${MDC(MAX_NUM_PARTITIONS, maxPartNum.get)}. Spark rescales it to " 
+
-        log"${MDC(DESIRED_NUM_PARTITIONS, desiredPartitions.size)} by ignoring 
the " +
-        log"configuration of ${MDC(CONFIG, 
SQLConf.FILES_MAX_PARTITION_BYTES.key)}.")
-      desiredPartitions
-    } else {
-      partitions
+    val partitions = getFilePartitionsBySize(partitionedFiles, maxSplitBytes, 
openCostBytes)

Review Comment:
   since we are first getting the original partition size  although file per 
partition could be well distributed , can it  end up with wide mismatch between 
size of each partition considering no.of partitions provided by original algo 
will be way above than ideal no. of partitions for the kind of uneven 
distribution this new algo is trying to solve?  



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to