This is an automated email from the ASF dual-hosted git repository.

zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new b9fbb4702b [GLUTEN-7358][CH] Optimize the strategy of the partition 
split according to the files count (#7361)
b9fbb4702b is described below

commit b9fbb4702b3a7eb8cfa1d39595e45b576d0ffa4f
Author: Zhichao Zhang <[email protected]>
AuthorDate: Fri Sep 27 16:45:14 2024 +0800

    [GLUTEN-7358][CH] Optimize the strategy of the partition split according to 
the files count (#7361)
    
    Optimize the strategy of the partition split according to the files count
    
    Close #7358.
---
 .../gluten/utils/CHInputPartitionsUtil.scala       |  10 +-
 .../utils/MergeTreePartsPartitionsUtil.scala       | 182 ++++++++++++++++-----
 .../GlutenClickHouseMergeTreeWriteSuite.scala      |  53 ++++++
 ...kHouseTPCDSParquetColumnarShuffleAQESuite.scala |  24 +++
 4 files changed, 221 insertions(+), 48 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHInputPartitionsUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHInputPartitionsUtil.scala
index 0f35ff66d4..412d8773b9 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHInputPartitionsUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHInputPartitionsUtil.scala
@@ -91,7 +91,7 @@ case class CHInputPartitionsUtil(
       .sortBy(_.length)(implicitly[Ordering[Long]].reverse)
 
     val totalCores = 
SparkResourceUtil.getTotalCores(relation.sparkSession.sessionState.conf)
-    val fileCntPerPartition = math.ceil((splitFiles.size * 1.0) / 
totalCores).toInt
+    val isAllSmallFiles = splitFiles.forall(_.length < maxSplitBytes)
     val fileCntThreshold = relation.sparkSession.sessionState.conf
       .getConfString(
         CHBackendSettings.GLUTEN_CLICKHOUSE_FILES_PER_PARTITION_THRESHOLD,
@@ -99,8 +99,12 @@ case class CHInputPartitionsUtil(
       )
       .toInt
 
-    if (fileCntThreshold > 0 && fileCntPerPartition > fileCntThreshold) {
-      getFilePartitionsByFileCnt(splitFiles, fileCntPerPartition)
+    // calculate the file count for each partition according to the parameter
+    val totalFilesThreshold = totalCores * fileCntThreshold
+    if (fileCntThreshold > 0 && isAllSmallFiles && splitFiles.size <= 
totalFilesThreshold) {
+      var fileCnt = math.round((splitFiles.size * 1.0) / totalCores).toInt
+      if (fileCnt < 1) fileCnt = 1
+      getFilePartitionsByFileCnt(splitFiles, fileCnt)
     } else {
       FilePartition.getFilePartitions(relation.sparkSession, splitFiles, 
maxSplitBytes)
     }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala
index 6fc5e198a6..c2eb338261 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.spark.sql.execution.datasources.utils
 
-import org.apache.gluten.backendsapi.clickhouse.CHConf
+import org.apache.gluten.backendsapi.clickhouse.{CHBackendSettings, CHConf}
 import org.apache.gluten.execution.{GlutenMergeTreePartition, 
MergeTreePartRange, MergeTreePartSplit}
 import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
 import org.apache.gluten.softaffinity.SoftAffinityManager
@@ -39,6 +39,7 @@ import 
org.apache.spark.sql.execution.datasources.clickhouse.{ExtensionTableBuil
 import 
org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts
 import 
org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat
 import org.apache.spark.sql.types.BooleanType
+import org.apache.spark.util.SparkResourceUtil
 import org.apache.spark.util.collection.BitSet
 
 import com.fasterxml.jackson.core.`type`.TypeReference
@@ -114,6 +115,7 @@ object MergeTreePartsPartitionsUtil extends Logging {
       )
     } else {
       genInputPartitionSeq(
+        relation,
         engine,
         database,
         tableName,
@@ -135,6 +137,7 @@ object MergeTreePartsPartitionsUtil extends Logging {
   }
 
   def genInputPartitionSeq(
+      relation: HadoopFsRelation,
       engine: String,
       database: String,
       tableName: String,
@@ -213,53 +216,142 @@ object MergeTreePartsPartitionsUtil extends Logging {
     }
 
     val maxSplitBytes = getMaxSplitBytes(sparkSession, selectRanges)
-    val total_marks = selectRanges.map(p => p.marks).sum
-    val total_Bytes = selectRanges.map(p => p.size).sum
-    // maxSplitBytes / (total_Bytes / total_marks) + 1
-    val markCntPerPartition = maxSplitBytes * total_marks / total_Bytes + 1
-
-    logInfo(s"Planning scan with bin packing, max mark: $markCntPerPartition")
-    val splitFiles = selectRanges
-      .flatMap {
-        part =>
-          val end = part.marks + part.start
-          (part.start until end by markCntPerPartition).map {
-            offset =>
-              val remaining = end - offset
-              val size = if (remaining > markCntPerPartition) 
markCntPerPartition else remaining
-              MergeTreePartSplit(
-                part.name,
-                part.dirName,
-                part.targetNode,
-                offset,
-                size,
-                size * part.size / part.marks)
-          }
-      }
+    val totalCores = 
SparkResourceUtil.getTotalCores(relation.sparkSession.sessionState.conf)
+    val isAllSmallFiles = selectRanges.forall(_.size < maxSplitBytes)
+    val fileCntThreshold = relation.sparkSession.sessionState.conf
+      .getConfString(
+        CHBackendSettings.GLUTEN_CLICKHOUSE_FILES_PER_PARTITION_THRESHOLD,
+        
CHBackendSettings.GLUTEN_CLICKHOUSE_FILES_PER_PARTITION_THRESHOLD_DEFAULT
+      )
+      .toInt
+    val totalMarksThreshold = totalCores * fileCntThreshold
+    if (fileCntThreshold > 0 && isAllSmallFiles && selectRanges.size <= 
totalMarksThreshold) {
+      var fileCnt = math.round((selectRanges.size * 1.0) / totalCores).toInt
+      if (fileCnt < 1) fileCnt = 1
+      val splitFiles = selectRanges
+        .map {
+          part =>
+            MergeTreePartSplit(part.name, part.dirName, part.targetNode, 0, 
part.marks, part.size)
+        }
+      genInputPartitionSeqByFileCnt(
+        engine,
+        database,
+        tableName,
+        snapshotId,
+        relativeTablePath,
+        absoluteTablePath,
+        tableSchemaJson,
+        partitions,
+        table,
+        clickhouseTableConfigs,
+        splitFiles,
+        fileCnt
+      )
+    } else {
+      val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes
+      val totalMarks = selectRanges.map(p => p.marks).sum
+      val totalBytes = selectRanges.map(p => p.size).sum
+      // maxSplitBytes / (total_Bytes / total_marks) + 1
+      val markCntPerPartition = maxSplitBytes * totalMarks / totalBytes + 1
+
+      logInfo(s"Planning scan with bin packing, max mark: 
$markCntPerPartition")
+      val splitFiles = selectRanges
+        .flatMap {
+          part =>
+            val end = part.marks + part.start
+            (part.start until end by markCntPerPartition).map {
+              offset =>
+                val remaining = end - offset
+                val size = if (remaining > markCntPerPartition) 
markCntPerPartition else remaining
+                MergeTreePartSplit(
+                  part.name,
+                  part.dirName,
+                  part.targetNode,
+                  offset,
+                  size,
+                  size * part.size / part.marks)
+            }
+        }
 
-    val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes
-    val (partNameWithLocation, locationDistinct) =
-      calculatedLocationForSoftAffinity(splitFiles, relativeTablePath)
-
-    genInputPartitionSeqBySplitFiles(
-      engine,
-      database,
-      tableName,
-      snapshotId,
-      relativeTablePath,
-      absoluteTablePath,
-      tableSchemaJson,
-      partitions,
-      table,
-      clickhouseTableConfigs,
-      splitFiles,
-      openCostInBytes,
-      maxSplitBytes,
-      partNameWithLocation,
-      locationDistinct
-    )
+      val (partNameWithLocation, locationDistinct) =
+        calculatedLocationForSoftAffinity(splitFiles, relativeTablePath)
+
+      genInputPartitionSeqBySplitFiles(
+        engine,
+        database,
+        tableName,
+        snapshotId,
+        relativeTablePath,
+        absoluteTablePath,
+        tableSchemaJson,
+        partitions,
+        table,
+        clickhouseTableConfigs,
+        splitFiles,
+        openCostInBytes,
+        maxSplitBytes,
+        partNameWithLocation,
+        locationDistinct
+      )
+    }
   }
 
+  def genInputPartitionSeqByFileCnt(
+      engine: String,
+      database: String,
+      tableName: String,
+      snapshotId: String,
+      relativeTablePath: String,
+      absoluteTablePath: String,
+      tableSchemaJson: String,
+      partitions: ArrayBuffer[InputPartition],
+      table: ClickHouseTableV2,
+      clickhouseTableConfigs: Map[String, String],
+      splitFiles: Seq[MergeTreePartSplit],
+      fileCnt: Int): Unit = {
+    val currentFiles = new ArrayBuffer[MergeTreePartSplit]
+    var currentFileCnt = 0L
+
+    /** Close the current partition and move to the next. */
+    def closePartition(): Unit = {
+      if (currentFiles.nonEmpty) {
+        // Copy to a new Array.
+        val newPartition = GlutenMergeTreePartition(
+          partitions.size,
+          engine,
+          database,
+          tableName,
+          snapshotId,
+          relativeTablePath,
+          absoluteTablePath,
+          table.orderByKey(),
+          table.lowCardKey(),
+          table.minmaxIndexKey(),
+          table.bfIndexKey(),
+          table.setIndexKey(),
+          table.primaryKey(),
+          currentFiles.toArray,
+          tableSchemaJson,
+          clickhouseTableConfigs
+        )
+        partitions += newPartition
+      }
+      currentFiles.clear()
+      currentFileCnt = 0L
+    }
+
+    splitFiles.foreach {
+      file =>
+        if (currentFileCnt >= fileCnt) {
+          closePartition()
+        }
+        // Add the given file to the current partition.
+        currentFileCnt += 1L
+        currentFiles += file
+    }
+    closePartition()
+    partitions.toSeq
+  }
   def genInputPartitionSeqBySplitFiles(
       engine: String,
       database: String,
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteSuite.scala
index d09c28f9a2..b4eca622c8 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteSuite.scala
@@ -2051,4 +2051,57 @@ class GlutenClickHouseMergeTreeWriteSuite
          |""".stripMargin
     runSql(sqlStr) { _ => }
   }
+
+  test("GLUTEN-7358: Optimize the strategy of the partition split according to 
the files count") {
+    spark.sql(s"""
+                 |DROP TABLE IF EXISTS lineitem_split;
+                 |""".stripMargin)
+    spark.sql(s"""
+                 |CREATE TABLE IF NOT EXISTS lineitem_split
+                 |(
+                 | l_orderkey      bigint,
+                 | l_partkey       bigint,
+                 | l_suppkey       bigint,
+                 | l_linenumber    bigint,
+                 | l_quantity      double,
+                 | l_extendedprice double,
+                 | l_discount      double,
+                 | l_tax           double,
+                 | l_returnflag    string,
+                 | l_linestatus    string,
+                 | l_shipdate      date,
+                 | l_commitdate    date,
+                 | l_receiptdate   date,
+                 | l_shipinstruct  string,
+                 | l_shipmode      string,
+                 | l_comment       string
+                 |)
+                 |USING clickhouse
+                 |LOCATION '$basePath/lineitem_split'
+                 |""".stripMargin)
+    spark.sql(s"""
+                 | insert into table lineitem_split
+                 | select * from lineitem
+                 |""".stripMargin)
+    Seq(("-1", 3), ("3", 3), ("6", 1)).foreach(
+      conf => {
+        withSQLConf(
+          
("spark.gluten.sql.columnar.backend.ch.files.per.partition.threshold" -> 
conf._1)) {
+          val sql =
+            s"""
+               |select count(1), min(l_returnflag) from lineitem_split
+               |""".stripMargin
+          runSql(sql) {
+            df =>
+              val result = df.collect()
+              assertResult(1)(result.length)
+              assertResult("600572")(result(0).getLong(0).toString)
+              val scanExec = collect(df.queryExecution.executedPlan) {
+                case f: FileSourceScanExecTransformer => f
+              }
+              assert(scanExec(0).getPartitions.size == conf._2)
+          }
+        }
+      })
+  }
 }
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetColumnarShuffleAQESuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetColumnarShuffleAQESuite.scala
index 3e1507bf17..3e965c67ea 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetColumnarShuffleAQESuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetColumnarShuffleAQESuite.scala
@@ -241,4 +241,28 @@ class GlutenClickHouseTPCDSParquetColumnarShuffleAQESuite
     // There are some BroadcastHashJoin with NOT condition
     compareResultsAgainstVanillaSpark(sql, true, { df => })
   }
+
+  test("GLUTEN-7358: Optimize the strategy of the partition split according to 
the files count") {
+    Seq(("-1", 8), ("100", 8), ("2000", 1)).foreach(
+      conf => {
+        withSQLConf(
+          
("spark.gluten.sql.columnar.backend.ch.files.per.partition.threshold" -> 
conf._1)) {
+          val sql =
+            s"""
+               |select count(1) from store_sales
+               |""".stripMargin
+          compareResultsAgainstVanillaSpark(
+            sql,
+            true,
+            {
+              df =>
+                val scanExec = collect(df.queryExecution.executedPlan) {
+                  case f: FileSourceScanExecTransformer => f
+                }
+                assert(scanExec(0).getPartitions.size == conf._2)
+            }
+          )
+        }
+      })
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to