This is an automated email from the ASF dual-hosted git repository.
chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 98a624f016 [GLUTEN-8811][VL] Fix bucket scan when some partitionValue
is empty (#8834)
98a624f016 is described below
commit 98a624f0167b07b769efe5f4884d9f05eb168c2c
Author: Jin Chengcheng <[email protected]>
AuthorDate: Fri Feb 28 08:41:35 2025 +0000
[GLUTEN-8811][VL] Fix bucket scan when some partitionValue is empty (#8834)
---
.../backendsapi/velox/VeloxIteratorApi.scala | 34 ++++
.../gluten/execution/IcebergScanTransformer.scala | 14 +-
.../spark/source/GlutenIcebergSourceUtil.scala | 57 +++++-
.../org/apache/gluten/execution/IcebergSuite.scala | 156 +++++++++++++++
.../apache/gluten/backendsapi/IteratorApi.scala | 8 +
.../execution/BatchScanExecTransformer.scala | 30 +++
.../gluten/execution/WholeStageTransformer.scala | 214 +++++++++++++++------
.../org/apache/gluten/sql/shims/SparkShims.scala | 13 +-
.../gluten/sql/shims/spark32/Spark32Shims.scala | 5 -
.../gluten/sql/shims/spark33/Spark33Shims.scala | 2 -
.../gluten/sql/shims/spark34/Spark34Shims.scala | 119 ++++++++++--
.../gluten/sql/shims/spark35/Spark35Shims.scala | 120 ++++++++++--
12 files changed, 661 insertions(+), 111 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala
index 27b721d1cd..1be7b8d735 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala
@@ -87,6 +87,40 @@ class VeloxIteratorApi extends IteratorApi with Logging {
}
}
+ override def genSplitInfoForPartitions(
+ partitionIndex: Int,
+ partitions: Seq[InputPartition],
+ partitionSchema: StructType,
+ fileFormat: ReadFileFormat,
+ metadataColumnNames: Seq[String],
+ properties: Map[String, String]): SplitInfo = {
+ val partitionFiles = partitions.flatMap {
+ p =>
+ if (!p.isInstanceOf[FilePartition]) {
+ throw new UnsupportedOperationException(
+ s"Unsupported input partition ${p.getClass.getName}.")
+ }
+ p.asInstanceOf[FilePartition].files
+ }.toArray
+ val locations =
+ partitions.flatMap(p =>
SoftAffinity.getFilePartitionLocations(p.asInstanceOf[FilePartition]))
+ val (paths, starts, lengths, fileSizes, modificationTimes,
partitionColumns, metadataColumns) =
+ constructSplitInfo(partitionSchema, partitionFiles, metadataColumnNames)
+ LocalFilesBuilder.makeLocalFiles(
+ partitionIndex,
+ paths,
+ starts,
+ lengths,
+ fileSizes,
+ modificationTimes,
+ partitionColumns,
+ metadataColumns,
+ fileFormat,
+ locations.toList.asJava,
+ mapAsJavaMap(properties)
+ )
+ }
+
/** Generate native row partition. */
override def genPartitions(
wsCtx: WholeStageTransformContext,
diff --git
a/gluten-iceberg/src-iceberg/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala
b/gluten-iceberg/src-iceberg/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala
index 63ab9eb206..5cb095bb7d 100644
---
a/gluten-iceberg/src-iceberg/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala
+++
b/gluten-iceberg/src-iceberg/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala
@@ -19,7 +19,6 @@ package org.apache.gluten.execution
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
import org.apache.gluten.substrait.rel.SplitInfo
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
DynamicPruningExpression, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.QueryPlan
@@ -27,7 +26,6 @@ import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.read.{InputPartition, Scan}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.types.StructType
-
import org.apache.iceberg.spark.source.GlutenIcebergSourceUtil
case class IcebergScanTransformer(
@@ -60,14 +58,22 @@ case class IcebergScanTransformer(
override lazy val fileFormat: ReadFileFormat =
GlutenIcebergSourceUtil.getFileFormat(scan)
+ override def getSplitInfosWithIndex: Seq[SplitInfo] = {
+ getPartitionsWithIndex.zipWithIndex.map {
+ case (partitions, index) =>
+ GlutenIcebergSourceUtil.genSplitInfo(partitions, index,
getPartitionSchema)
+ }
+ }
+
override def getSplitInfosFromPartitions(partitions: Seq[InputPartition]):
Seq[SplitInfo] = {
val groupedPartitions = SparkShimLoader.getSparkShims.orderPartitions(
+ this,
scan,
keyGroupedPartitioning,
filteredPartitions,
- outputPartitioning)
+ outputPartitioning, commonPartitionValues, applyPartialClustering,
replicatePartitions).flatten
groupedPartitions.zipWithIndex.map {
- case (p, index) => GlutenIcebergSourceUtil.genSplitInfo(p, index,
getPartitionSchema)
+ case (p, index) => GlutenIcebergSourceUtil.genSplitInfoForPartition(p,
index, getPartitionSchema)
}
}
diff --git
a/gluten-iceberg/src-iceberg/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala
b/gluten-iceberg/src-iceberg/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala
index a53c464b65..b423dee5b7 100644
---
a/gluten-iceberg/src-iceberg/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala
+++
b/gluten-iceberg/src-iceberg/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala
@@ -36,10 +36,9 @@ import scala.collection.JavaConverters._
object GlutenIcebergSourceUtil {
- def genSplitInfo(
- inputPartition: InputPartition,
- index: Int,
- readPartitionSchema: StructType): SplitInfo = inputPartition match {
+ def genSplitInfoForPartition(inputPartition: InputPartition,
+ index: Int,
+ readPartitionSchema: StructType): SplitInfo =
inputPartition match {
case partition: SparkInputPartition =>
val paths = new JArrayList[String]()
val starts = new JArrayList[JLong]()
@@ -84,6 +83,54 @@ object GlutenIcebergSourceUtil {
throw new UnsupportedOperationException("Only support iceberg
SparkInputPartition.")
}
+ def genSplitInfo(inputPartitions: Seq[InputPartition],
+ index: Int,
+ readPartitionSchema: StructType): SplitInfo = {
+ val paths = new JArrayList[String]()
+ val starts = new JArrayList[JLong]()
+ val lengths = new JArrayList[JLong]()
+ val partitionColumns = new JArrayList[JMap[String, String]]()
+ val deleteFilesList = new JArrayList[JList[DeleteFile]]()
+ val preferredLocs = new JArrayList[String]()
+ var fileFormat = ReadFileFormat.UnknownFormat
+
+ inputPartitions.foreach {
+ case partition: SparkInputPartition =>
+ val tasks = partition.taskGroup[ScanTask]().tasks().asScala
+ asFileScanTask(tasks.toList).foreach {
+ task =>
+ paths.add(
+ BackendsApiManager.getTransformerApiInstance
+ .encodeFilePathIfNeed(task.file().path().toString))
+ starts.add(task.start())
+ lengths.add(task.length())
+ partitionColumns.add(getPartitionColumns(task,
readPartitionSchema))
+ deleteFilesList.add(task.deletes())
+ val currentFileFormat = convertFileFormat(task.file().format())
+ if (fileFormat == ReadFileFormat.UnknownFormat) {
+ fileFormat = currentFileFormat
+ } else if (fileFormat != currentFileFormat) {
+ throw new UnsupportedOperationException(
+ s"Only one file format is supported, " +
+ s"find different file format $fileFormat and
$currentFileFormat")
+ }
+ }
+ preferredLocs.addAll(partition.preferredLocations().toList.asJava)
+ }
+ IcebergLocalFilesBuilder.makeIcebergLocalFiles(
+ index,
+ paths,
+ starts,
+ lengths,
+ partitionColumns,
+ fileFormat,
+ SoftAffinity.getFilePartitionLocations(
+ paths.asScala.toArray,
+ preferredLocs.asScala.toArray).toList.asJava,
+ deleteFilesList
+ )
+ }
+
def getFileFormat(sparkScan: Scan): ReadFileFormat = sparkScan match {
case scan: SparkBatchQueryScan =>
val tasks = scan.tasks().asScala
@@ -186,7 +233,7 @@ object GlutenIcebergSourceUtil {
partitionColumns
}
- def convertFileFormat(icebergFileFormat: FileFormat): ReadFileFormat =
+ private def convertFileFormat(icebergFileFormat: FileFormat): ReadFileFormat
=
icebergFileFormat match {
case FileFormat.PARQUET => ReadFileFormat.ParquetReadFormat
case FileFormat.ORC => ReadFileFormat.OrcReadFormat
diff --git
a/gluten-iceberg/src-iceberg/test/scala/org/apache/gluten/execution/IcebergSuite.scala
b/gluten-iceberg/src-iceberg/test/scala/org/apache/gluten/execution/IcebergSuite.scala
index 459b332c7e..82aebad5f4 100644
---
a/gluten-iceberg/src-iceberg/test/scala/org/apache/gluten/execution/IcebergSuite.scala
+++
b/gluten-iceberg/src-iceberg/test/scala/org/apache/gluten/execution/IcebergSuite.scala
@@ -218,6 +218,162 @@ abstract class IcebergSuite extends
WholeStageTransformerSuite {
}
}
+ testWithSpecifiedSparkVersion("iceberg bucketed join partition value not
exists",
+ Array("3.4", "3.5")) {
+ val leftTable = "p_str_tb"
+ val rightTable = "p_int_tb"
+ withTable(leftTable, rightTable) {
+ withSQLConf(GlutenConfig.GLUTEN_ENABLED.key -> "false") {
+ // Gluten does not support write iceberg table.
+ spark.sql(s"""
+ |create table $leftTable(id int, name string, p string)
+ |using iceberg
+ |partitioned by (bucket(4, id));
+ |""".stripMargin)
+ spark.sql(
+ s"""
+ |insert into table $leftTable values
+ |(4, 'a5', 'p4'),
+ |(1, 'a1', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(2, 'a3', 'p2'),
+ |(1, 'a2', 'p1'),
+ |(3, 'a4', 'p3'),
+ |(10, 'a4', 'p3');
+ |""".stripMargin
+ )
+ spark.sql(s"""
+ |create table $rightTable(id int, name string, p int)
+ |using iceberg
+ |partitioned by (bucket(4, id));
+ |""".stripMargin)
+ spark.sql(
+ s"""
+ |insert into table $rightTable values
+ |(3, 'b4', 23),
+ |(1, 'b1', 21);
+ |""".stripMargin
+ )
+ }
+
+ withSQLConf(
+ "spark.sql.sources.v2.bucketing.enabled" -> "true",
+ "spark.sql.requireAllClusterKeysForCoPartition" -> "false",
+ "spark.sql.adaptive.enabled" -> "false",
+ "spark.sql.iceberg.planning.preserve-data-grouping" -> "true",
+ "spark.sql.autoBroadcastJoinThreshold" -> "-1",
+ "spark.sql.sources.v2.bucketing.pushPartValues.enabled" -> "true",
+
"spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled" ->
"false"
+ ) {
+ runQueryAndCompare(s"""
+ |select s.id, s.name, i.name, i.p
+ | from $leftTable s inner join $rightTable i
+ | on s.id = i.id;
+ |""".stripMargin) {
+ df =>
+ {
+ assert(
+ getExecutedPlan(df).count(
+ plan => {
+ plan.isInstanceOf[IcebergScanTransformer]
+ }) == 2)
+ getExecutedPlan(df).map {
+ case plan : IcebergScanTransformer =>
+ assert(plan.getKeyGroupPartitioning.isDefined)
+ assert(plan.getSplitInfosWithIndex.length == 3)
+ case _ => // do nothing
+ }
+ }
+ }
+ }
+ }
+ }
+
+ testWithSpecifiedSparkVersion("iceberg bucketed join partition value not
exists partial cluster",
+ Array("3.4", "3.5")) {
+ val leftTable = "p_str_tb"
+ val rightTable = "p_int_tb"
+ withTable(leftTable, rightTable) {
+ withSQLConf(GlutenConfig.GLUTEN_ENABLED.key -> "false") {
+ // Gluten does not support write iceberg table.
+ spark.sql(s"""
+ |create table $leftTable(id int, name string, p string)
+ |using iceberg
+ |partitioned by (bucket(4, id));
+ |""".stripMargin)
+ spark.sql(
+ s"""
+ |insert into table $leftTable values
+ |(4, 'a5', 'p4'),
+ |(1, 'a1', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(1, 'a2', 'p1'),
+ |(2, 'a3', 'p2'),
+ |(1, 'a2', 'p1'),
+ |(3, 'a4', 'p3'),
+ |(10, 'a4', 'p3');
+ |""".stripMargin
+ )
+ spark.sql(s"""
+ |create table $rightTable(id int, name string, p int)
+ |using iceberg
+ |partitioned by (bucket(4, id));
+ |""".stripMargin)
+ spark.sql(
+ s"""
+ |insert into table $rightTable values
+ |(3, 'b4', 23),
+ |(1, 'b1', 21);
+ |""".stripMargin
+ )
+ }
+
+ withSQLConf(
+ "spark.sql.sources.v2.bucketing.enabled" -> "true",
+ "spark.sql.requireAllClusterKeysForCoPartition" -> "false",
+ "spark.sql.adaptive.enabled" -> "false",
+ "spark.sql.iceberg.planning.preserve-data-grouping" -> "true",
+ "spark.sql.autoBroadcastJoinThreshold" -> "-1",
+ "spark.sql.sources.v2.bucketing.pushPartValues.enabled" -> "true",
+
"spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled" ->
"true"
+ ) {
+ runQueryAndCompare(s"""
+ |select s.id, s.name, i.name, i.p
+ | from $leftTable s inner join $rightTable i
+ | on s.id = i.id;
+ |""".stripMargin) {
+ df =>
+ {
+ assert(
+ getExecutedPlan(df).count(
+ plan => {
+ plan.isInstanceOf[IcebergScanTransformer]
+ }) == 2)
+ getExecutedPlan(df).map {
+ case plan : IcebergScanTransformer =>
+ assert(plan.getKeyGroupPartitioning.isDefined)
+ assert(plan.getSplitInfosWithIndex.length == 3)
+ case _ => // do nothing
+ }
+ }
+ }
+ }
+ }
+ }
+
testWithSpecifiedSparkVersion("iceberg bucketed join with partition filter",
Some("3.4")) {
val leftTable = "p_str_tb"
val rightTable = "p_int_tb"
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/IteratorApi.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/IteratorApi.scala
index 11c86b11f0..efb793839f 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/IteratorApi.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/IteratorApi.scala
@@ -39,6 +39,14 @@ trait IteratorApi {
metadataColumnNames: Seq[String],
properties: Map[String, String]): SplitInfo
+ def genSplitInfoForPartitions(
+ partitionIndex: Int,
+ partition: Seq[InputPartition],
+ partitionSchema: StructType,
+ fileFormat: ReadFileFormat,
+ metadataColumnNames: Seq[String],
+ properties: Map[String, String]): SplitInfo = throw new
UnsupportedOperationException()
+
/** Generate native row partition. */
def genPartitions(
wsCtx: WholeStageTransformContext,
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala
index 958f4ee6c8..2b1f90b726 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala
@@ -22,6 +22,7 @@ import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
+import org.apache.gluten.substrait.rel.SplitInfo
import org.apache.gluten.utils.FileIndexUtil
import org.apache.spark.sql.catalyst.InternalRow
@@ -124,6 +125,24 @@ abstract class BatchScanExecTransformerBase(
override def outputAttributes(): Seq[Attribute] = output
+ // With storage partition join, the return partition type is changed, so as
SplitInfo
+ def getPartitionsWithIndex: Seq[Seq[InputPartition]] = finalPartitions
+
+ def getSplitInfosWithIndex: Seq[SplitInfo] = {
+ getPartitionsWithIndex.zipWithIndex.map {
+ case (partitions, index) =>
+ BackendsApiManager.getIteratorApiInstance
+ .genSplitInfoForPartitions(
+ index,
+ partitions,
+ getPartitionSchema,
+ fileFormat,
+ getMetadataColumns().map(_.name),
+ getProperties)
+ }
+ }
+
+ // May cannot call for bucket scan
override def getPartitions: Seq[InputPartition] = filteredFlattenPartitions
override def getPartitionSchema: StructType = scan match {
@@ -175,6 +194,17 @@ abstract class BatchScanExecTransformerBase(
@transient protected lazy val filteredFlattenPartitions: Seq[InputPartition]
=
filteredPartitions.flatten
+ @transient protected lazy val finalPartitions: Seq[Seq[InputPartition]] =
+ SparkShimLoader.getSparkShims.orderPartitions(
+ this,
+ scan,
+ keyGroupedPartitioning,
+ filteredPartitions,
+ outputPartitioning,
+ commonPartitionValues,
+ applyPartialClustering,
+ replicatePartitions)
+
@transient override lazy val fileFormat: ReadFileFormat =
BackendsApiManager.getSettings.getSubstraitReadFileFormatV2(scan)
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala
index fa99a418e9..1865a38c52 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala
@@ -372,6 +372,153 @@ case class WholeStageTransformer(child: SparkPlan,
materializeInput: Boolean = f
allLeafTransformers.toSeq
}
+ private def generateWholeStageRDD(
+ leafTransformers: Seq[LeafTransformSupport],
+ wsCtx: WholeStageTransformContext,
+ inputRDDs: ColumnarInputRDDsWrapper,
+ pipelineTime: SQLMetric): RDD[ColumnarBatch] = {
+
+ /**
+ * If containing leaf exec transformer this "whole stage" generates a RDD
which itself takes
+ * care of [[LeafTransformSupport]] there won't be any other RDD for leaf
operator. As a result,
+ * genFirstStageIterator rather than genFinalStageIterator will be invoked
+ */
+ val allInputPartitions = leafTransformers.map(_.getPartitions.toIndexedSeq)
+ val allSplitInfos = getSplitInfosFromPartitions(leafTransformers)
+
+ if (GlutenConfig.get.enableHdfsViewfs) {
+ val viewfsToHdfsCache: mutable.Map[String, String] = mutable.Map.empty
+ allSplitInfos.foreach {
+ splitInfos =>
+ splitInfos.foreach {
+ case splitInfo: LocalFilesNode =>
+ val newPaths = ViewFileSystemUtils.convertViewfsToHdfs(
+ splitInfo.getPaths.asScala.toSeq,
+ viewfsToHdfsCache,
+ serializableHadoopConf.value)
+ splitInfo.setPaths(newPaths.asJava)
+ }
+ }
+ }
+
+ val inputPartitions =
+ BackendsApiManager.getIteratorApiInstance.genPartitions(
+ wsCtx,
+ allSplitInfos,
+ leafTransformers)
+
+ val rdd = new GlutenWholeStageColumnarRDD(
+ sparkContext,
+ inputPartitions,
+ inputRDDs,
+ pipelineTime,
+ leafInputMetricsUpdater(),
+ BackendsApiManager.getMetricsApiInstance.metricsUpdatingFunction(
+ child,
+ wsCtx.substraitContext.registeredRelMap,
+ wsCtx.substraitContext.registeredJoinParams,
+ wsCtx.substraitContext.registeredAggregationParams
+ )
+ )
+
+ allInputPartitions.head.indices.foreach(
+ i => {
+ val currentPartitions = allInputPartitions.map(_(i))
+ currentPartitions.indices.foreach(
+ i =>
+ currentPartitions(i) match {
+ case f: FilePartition =>
+ SoftAffinity.updateFilePartitionLocations(f, rdd.id)
+ case _ =>
+ })
+ })
+
+ rdd
+ }
+
+ private def getSplitInfosFromPartitionSeqs(
+ leafTransformers: Seq[BatchScanExecTransformerBase]):
Seq[Seq[SplitInfo]] = {
+ // If these are two batch scan transformer with keyGroupPartitioning,
+ // they have same partitionValues,
+ // but some partitions maybe empty for those partition values that are not
present,
+ // otherwise, exchange will be inserted. We should combine the two leaf
+ // transformers' partitions with same index, and set them together in
+ // the substraitContext. We use transpose to do that, You can refer to
+ // the diagram below.
+ // leaf1 Seq(p11) Seq(p12, p13) Seq(p14) ... Seq(p1n)
+ // leaf2 Seq(p21) Seq(p22) Seq() ... Seq(p2n)
+ // transpose =>
+ // leaf1 | leaf2
+ // Seq(p11) | Seq(p21) =>
substraitContext.setSplitInfo([Seq(p11), Seq(p21)])
+ // Seq(p12, p13) | Seq(p22) =>
substraitContext.setSplitInfo([Seq(p12, p13), Seq(p22)])
+ // Seq(p14) | Seq() ...
+ // ...
+ // Seq(p1n) | Seq(p2n) =>
substraitContext.setSplitInfo([Seq(p1n), Seq(p2n)])
+
+ val allSplitInfos = leafTransformers.map(_.getSplitInfosWithIndex)
+ val partitionLength = allSplitInfos.head.size
+ if (allSplitInfos.exists(_.size != partitionLength)) {
+ throw new GlutenException(
+ "The partition length of all the leaf transformer are not the same.")
+ }
+ if (GlutenConfig.get.enableHdfsViewfs) {
+ val viewfsToHdfsCache: mutable.Map[String, String] = mutable.Map.empty
+ allSplitInfos.foreach {
+ case splitInfo: LocalFilesNode =>
+ val newPaths = ViewFileSystemUtils.convertViewfsToHdfs(
+ splitInfo.getPaths.asScala.toSeq,
+ viewfsToHdfsCache,
+ serializableHadoopConf.value)
+ splitInfo.setPaths(newPaths.asJava)
+ }
+ }
+
+ allSplitInfos.transpose
+ }
+
+ private def generateWholeStageDatasourceRDD(
+ leafTransformers: Seq[BatchScanExecTransformerBase],
+ wsCtx: WholeStageTransformContext,
+ inputRDDs: ColumnarInputRDDsWrapper,
+ pipelineTime: SQLMetric): RDD[ColumnarBatch] = {
+
+ /**
+ * If containing leaf exec transformer this "whole stage" generates a RDD
which itself takes
+ * care of [[LeafTransformSupport]] there won't be any other RDD for leaf
operator. As a result,
+ * genFirstStageIterator rather than genFinalStageIterator will be invoked
+ */
+ val allInputPartitions = leafTransformers.map(_.getPartitionsWithIndex)
+ val allSplitInfos = getSplitInfosFromPartitionSeqs(leafTransformers)
+
+ val inputPartitions =
+ BackendsApiManager.getIteratorApiInstance.genPartitions(
+ wsCtx,
+ allSplitInfos,
+ leafTransformers)
+
+ val rdd = new GlutenWholeStageColumnarRDD(
+ sparkContext,
+ inputPartitions,
+ inputRDDs,
+ pipelineTime,
+ leafInputMetricsUpdater(),
+ BackendsApiManager.getMetricsApiInstance.metricsUpdatingFunction(
+ child,
+ wsCtx.substraitContext.registeredRelMap,
+ wsCtx.substraitContext.registeredJoinParams,
+ wsCtx.substraitContext.registeredAggregationParams
+ )
+ )
+
+ allInputPartitions.foreach(_.foreach(_.foreach {
+ case f: FilePartition =>
+ SoftAffinity.updateFilePartitionLocations(f, rdd.id)
+ case _ =>
+ }))
+
+ rdd
+ }
+
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
assert(child.isInstanceOf[TransformSupport])
val pipelineTime: SQLMetric = longMetric("pipelineTime")
@@ -387,63 +534,20 @@ case class WholeStageTransformer(child: SparkPlan,
materializeInput: Boolean = f
val leafTransformers = findAllLeafTransformers()
if (leafTransformers.nonEmpty) {
-
- /**
- * If containing leaf exec transformer this "whole stage" generates a
RDD which itself takes
- * care of [[LeafTransformSupport]] there won't be any other RDD for
leaf operator. As a
- * result, genFirstStageIterator rather than genFinalStageIterator will
be invoked
- */
- val allInputPartitions =
leafTransformers.map(_.getPartitions.toIndexedSeq)
- val allSplitInfos = getSplitInfosFromPartitions(leafTransformers)
-
- if (GlutenConfig.get.enableHdfsViewfs) {
- val viewfsToHdfsCache: mutable.Map[String, String] = mutable.Map.empty
- allSplitInfos.foreach {
- splitInfos =>
- splitInfos.foreach {
- case splitInfo: LocalFilesNode =>
- val newPaths = ViewFileSystemUtils.convertViewfsToHdfs(
- splitInfo.getPaths.asScala.toSeq,
- viewfsToHdfsCache,
- serializableHadoopConf.value)
- splitInfo.setPaths(newPaths.asJava)
- }
- }
+ val isKeyGroupPartition: Boolean = leafTransformers.exists {
+ // TODO: May can apply to BatchScanExecTransformer without key group
partitioning
+ case b: BatchScanExecTransformerBase if
b.keyGroupedPartitioning.isDefined => true
+ case _ => false
}
-
- val inputPartitions =
- BackendsApiManager.getIteratorApiInstance.genPartitions(
+ if (!isKeyGroupPartition) {
+ generateWholeStageRDD(leafTransformers, wsCtx, inputRDDs, pipelineTime)
+ } else {
+ generateWholeStageDatasourceRDD(
+ leafTransformers.map(_.asInstanceOf[BatchScanExecTransformerBase]),
wsCtx,
- allSplitInfos,
- leafTransformers)
-
- val rdd = new GlutenWholeStageColumnarRDD(
- sparkContext,
- inputPartitions,
- inputRDDs,
- pipelineTime,
- leafInputMetricsUpdater(),
- BackendsApiManager.getMetricsApiInstance.metricsUpdatingFunction(
- child,
- wsCtx.substraitContext.registeredRelMap,
- wsCtx.substraitContext.registeredJoinParams,
- wsCtx.substraitContext.registeredAggregationParams
- )
- )
-
- allInputPartitions.head.indices.foreach(
- i => {
- val currentPartitions = allInputPartitions.map(_(i))
- currentPartitions.indices.foreach(
- i =>
- currentPartitions(i) match {
- case f: FilePartition =>
- SoftAffinity.updateFilePartitionLocations(f, rdd.id)
- case _ =>
- })
- })
-
- rdd
+ inputRDDs,
+ pipelineTime)
+ }
} else {
/**
diff --git
a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
index 681e0f583d..2fe6dddeb3 100644
--- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.{FileSourceScanExec,
GlobalLimitExec, Spar
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
-import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec,
DataSourceV2ScanExecBase}
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike,
ShuffleExchangeLike}
import org.apache.spark.sql.internal.SQLConf
@@ -223,15 +223,20 @@ trait SparkShims {
// For compatibility with Spark-3.5.
def getAnalysisExceptionPlan(ae: AnalysisException): Option[LogicalPlan]
- def getKeyGroupedPartitioning(batchScan: BatchScanExec):
Option[Seq[Expression]]
+ def getKeyGroupedPartitioning(batchScan: BatchScanExec):
Option[Seq[Expression]] = Option(Seq())
- def getCommonPartitionValues(batchScan: BatchScanExec):
Option[Seq[(InternalRow, Int)]]
+ def getCommonPartitionValues(batchScan: BatchScanExec):
Option[Seq[(InternalRow, Int)]] =
+ Option(Seq())
def orderPartitions(
+ batchScan: DataSourceV2ScanExecBase,
scan: Scan,
keyGroupedPartitioning: Option[Seq[Expression]],
filteredPartitions: Seq[Seq[InputPartition]],
- outputPartitioning: Partitioning): Seq[InputPartition] =
filteredPartitions.flatten
+ outputPartitioning: Partitioning,
+ commonPartitionValues: Option[Seq[(InternalRow, Int)]],
+ applyPartialClustering: Boolean,
+ replicatePartitions: Boolean): Seq[Seq[InputPartition]] =
filteredPartitions
def extractExpressionTimestampAddUnit(timestampAdd: Expression):
Option[Seq[String]] =
Option.empty
diff --git
a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
index 123f74770b..1f94b47e22 100644
---
a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
+++
b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
@@ -252,11 +252,6 @@ class Spark32Shims extends SparkShims {
ae.plan
}
- override def getKeyGroupedPartitioning(batchScan: BatchScanExec):
Option[Seq[Expression]] = null
-
- override def getCommonPartitionValues(batchScan: BatchScanExec):
Option[Seq[(InternalRow, Int)]] =
- null
-
override def dateTimestampFormatInReadIsDefaultValue(
csvOptions: CSVOptions,
timeZone: String): Boolean = {
diff --git
a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
index aba794a161..694c719c9b 100644
---
a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
+++
b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
@@ -332,8 +332,6 @@ class Spark33Shims extends SparkShims {
override def getKeyGroupedPartitioning(batchScan: BatchScanExec):
Option[Seq[Expression]] = {
batchScan.keyGroupedPartitioning
}
- override def getCommonPartitionValues(batchScan: BatchScanExec):
Option[Seq[(InternalRow, Int)]] =
- null
override def extractExpressionTimestampAddUnit(exp: Expression):
Option[Seq[String]] = {
exp match {
diff --git
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
index 62a4afe106..5087b02ff8 100644
---
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
+++
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
@@ -46,7 +46,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
-import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec,
DataSourceV2ScanExecBase}
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
@@ -418,34 +418,117 @@ class Spark34Shims extends SparkShims {
}
override def orderPartitions(
+ batchScan: DataSourceV2ScanExecBase,
scan: Scan,
keyGroupedPartitioning: Option[Seq[Expression]],
filteredPartitions: Seq[Seq[InputPartition]],
- outputPartitioning: Partitioning): Seq[InputPartition] = {
+ outputPartitioning: Partitioning,
+ commonPartitionValues: Option[Seq[(InternalRow, Int)]],
+ applyPartialClustering: Boolean,
+ replicatePartitions: Boolean): Seq[Seq[InputPartition]] = {
scan match {
case _ if keyGroupedPartitioning.isDefined =>
- var newPartitions = filteredPartitions
+ var finalPartitions = filteredPartitions
+
outputPartitioning match {
case p: KeyGroupedPartitioning =>
- val partitionMapping = newPartitions
- .map(
- s =>
- InternalRowComparableWrapper(
- s.head.asInstanceOf[HasPartitionKey],
- p.expressions) -> s)
- .toMap
- newPartitions = p.partitionValues.map {
- partValue =>
- // Use empty partition for those partition values that are not
present
- partitionMapping.getOrElse(
- InternalRowComparableWrapper(partValue, p.expressions),
- Seq.empty)
+ if (
+ SQLConf.get.v2BucketingPushPartValuesEnabled &&
+ SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled
+ ) {
+ assert(
+ filteredPartitions.forall(_.size == 1),
+ "Expect partitions to be not grouped when " +
+
s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
+ "is enabled"
+ )
+
+ val groupedPartitions = batchScan
+ .groupPartitions(finalPartitions.map(_.head), true)
+ .getOrElse(Seq.empty)
+
+ // This means the input partitions are not grouped by partition
values. We'll need to
+ // check `groupByPartitionValues` and decide whether to group
and replicate splits
+ // within a partition.
+ if (commonPartitionValues.isDefined && applyPartialClustering) {
+ // A mapping from the common partition values to how many
splits the partition
+ // should contain. Note this no longer maintain the partition
key ordering.
+ val commonPartValuesMap = commonPartitionValues.get
+ .map(t => (InternalRowComparableWrapper(t._1,
p.expressions), t._2))
+ .toMap
+ val nestGroupedPartitions = groupedPartitions.map {
+ case (partValue, splits) =>
+ // `commonPartValuesMap` should contain the part value
since it's the super set.
+ val numSplits = commonPartValuesMap
+ .get(InternalRowComparableWrapper(partValue,
p.expressions))
+ assert(
+ numSplits.isDefined,
+ s"Partition value $partValue does not exist in " +
+ "common partition values from Spark plan")
+
+ val newSplits = if (replicatePartitions) {
+ // We need to also replicate partitions according to the
other side of join
+ Seq.fill(numSplits.get)(splits)
+ } else {
+ // Not grouping by partition values: this could be the
side with partially
+ // clustered distribution. Because of dynamic filtering,
we'll need to check
+ // if the final number of splits of a partition is
smaller than the original
+ // number, and fill with empty splits if so. This is
necessary so that both
+ // sides of a join will have the same number of
partitions & splits.
+ splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
+ }
+ (InternalRowComparableWrapper(partValue, p.expressions),
newSplits)
+ }
+
+ // Now fill missing partition keys with empty partitions
+ val partitionMapping = nestGroupedPartitions.toMap
+ finalPartitions = commonPartitionValues.get.flatMap {
+ case (partValue, numSplits) =>
+ // Use empty partition for those partition values that are
not present.
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions),
+ Seq.fill(numSplits)(Seq.empty))
+ }
+ } else {
+ // either `commonPartitionValues` is not defined, or it is
defined but
+ // `applyPartialClustering` is false.
+ val partitionMapping = groupedPartitions.map {
+ case (row, parts) =>
+ InternalRowComparableWrapper(row, p.expressions) -> parts
+ }.toMap
+
+ // In case `commonPartitionValues` is not defined (e.g., SPJ
is not used), there
+ // could exist duplicated partition values, as partition
grouping is not done
+ // at the beginning and postponed to this method. It is
important to use unique
+ // partition values here so that grouped partitions won't get
duplicated.
+ finalPartitions = p.uniquePartitionValues.map {
+ partValue =>
+ // Use empty partition for those partition values that are
not present
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions),
+ Seq.empty)
+ }
+ }
+ } else {
+ val partitionMapping = finalPartitions.map {
+ parts =>
+ val row =
parts.head.asInstanceOf[HasPartitionKey].partitionKey()
+ InternalRowComparableWrapper(row, p.expressions) -> parts
+ }.toMap
+ finalPartitions = p.partitionValues.map {
+ partValue =>
+ // Use empty partition for those partition values that are
not present
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions),
+ Seq.empty)
+ }
}
+
case _ =>
}
- newPartitions.flatten
+ finalPartitions
case _ =>
- filteredPartitions.flatten
+ filteredPartitions
}
}
diff --git
a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
index b4960cf5f2..a5898b0a96 100644
---
a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
+++
b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
@@ -46,7 +46,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat,
ParquetFilters}
-import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec,
DataSourceV2ScanExecBase}
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike,
ShuffleExchangeLike}
@@ -446,36 +446,120 @@ class Spark35Shims extends SparkShims {
}
override def orderPartitions(
+ batchScan: DataSourceV2ScanExecBase,
scan: Scan,
keyGroupedPartitioning: Option[Seq[Expression]],
filteredPartitions: Seq[Seq[InputPartition]],
- outputPartitioning: Partitioning): Seq[InputPartition] = {
+ outputPartitioning: Partitioning,
+ commonPartitionValues: Option[Seq[(InternalRow, Int)]],
+ applyPartialClustering: Boolean,
+ replicatePartitions: Boolean): Seq[Seq[InputPartition]] = {
scan match {
case _ if keyGroupedPartitioning.isDefined =>
- var newPartitions = filteredPartitions
+ var finalPartitions = filteredPartitions
+
outputPartitioning match {
case p: KeyGroupedPartitioning =>
- val partitionMapping = newPartitions
- .map(
- s =>
- InternalRowComparableWrapper(
- s.head.asInstanceOf[HasPartitionKey],
- p.expressions) -> s)
- .toMap
- newPartitions = p.partitionValues.map {
- partValue =>
- // Use empty partition for those partition values that are not
present
- partitionMapping.getOrElse(
- InternalRowComparableWrapper(partValue, p.expressions),
- Seq.empty)
+ if (
+ SQLConf.get.v2BucketingPushPartValuesEnabled &&
+ SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled
+ ) {
+ assert(
+ filteredPartitions.forall(_.size == 1),
+ "Expect partitions to be not grouped when " +
+
s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
+ "is enabled"
+ )
+
+ val groupedPartitions = batchScan
+ .groupPartitions(finalPartitions.map(_.head), true)
+ .getOrElse(Seq.empty)
+
+ // This means the input partitions are not grouped by partition
values. We'll need to
+ // check `groupByPartitionValues` and decide whether to group
and replicate splits
+ // within a partition.
+ if (commonPartitionValues.isDefined && applyPartialClustering) {
+ // A mapping from the common partition values to how many
splits the partition
+ // should contain. Note this no longer maintain the partition
key ordering.
+ val commonPartValuesMap = commonPartitionValues.get
+ .map(t => (InternalRowComparableWrapper(t._1,
p.expressions), t._2))
+ .toMap
+ val nestGroupedPartitions = groupedPartitions.map {
+ case (partValue, splits) =>
+ // `commonPartValuesMap` should contain the part value
since it's the super set.
+ val numSplits = commonPartValuesMap
+ .get(InternalRowComparableWrapper(partValue,
p.expressions))
+ assert(
+ numSplits.isDefined,
+ s"Partition value $partValue does not exist in " +
+ "common partition values from Spark plan")
+
+ val newSplits = if (replicatePartitions) {
+ // We need to also replicate partitions according to the
other side of join
+ Seq.fill(numSplits.get)(splits)
+ } else {
+ // Not grouping by partition values: this could be the
side with partially
+ // clustered distribution. Because of dynamic filtering,
we'll need to check
+ // if the final number of splits of a partition is
smaller than the original
+ // number, and fill with empty splits if so. This is
necessary so that both
+ // sides of a join will have the same number of
partitions & splits.
+ splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
+ }
+ (InternalRowComparableWrapper(partValue, p.expressions),
newSplits)
+ }
+
+ // Now fill missing partition keys with empty partitions
+ val partitionMapping = nestGroupedPartitions.toMap
+ finalPartitions = commonPartitionValues.get.flatMap {
+ case (partValue, numSplits) =>
+ // Use empty partition for those partition values that are
not present.
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions),
+ Seq.fill(numSplits)(Seq.empty))
+ }
+ } else {
+ // either `commonPartitionValues` is not defined, or it is
defined but
+ // `applyPartialClustering` is false.
+ val partitionMapping = groupedPartitions.map {
+ case (row, parts) =>
+ InternalRowComparableWrapper(row, p.expressions) -> parts
+ }.toMap
+
+ // In case `commonPartitionValues` is not defined (e.g., SPJ
is not used), there
+ // could exist duplicated partition values, as partition
grouping is not done
+ // at the beginning and postponed to this method. It is
important to use unique
+ // partition values here so that grouped partitions won't get
duplicated.
+ finalPartitions = p.uniquePartitionValues.map {
+ partValue =>
+ // Use empty partition for those partition values that are
not present
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions),
+ Seq.empty)
+ }
+ }
+ } else {
+ val partitionMapping = finalPartitions.map {
+ parts =>
+ val row =
parts.head.asInstanceOf[HasPartitionKey].partitionKey()
+ InternalRowComparableWrapper(row, p.expressions) -> parts
+ }.toMap
+ finalPartitions = p.partitionValues.map {
+ partValue =>
+ // Use empty partition for those partition values that are
not present
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions),
+ Seq.empty)
+ }
}
+
case _ =>
}
- newPartitions.flatten
+ finalPartitions
case _ =>
- filteredPartitions.flatten
+ filteredPartitions
}
}
+
override def supportsRowBased(plan: SparkPlan): Boolean =
plan.supportsRowBased
override def withTryEvalMode(expr: Expression): Boolean = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]