This is an automated email from the ASF dual-hosted git repository.

felixybw pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 4dd513bb0 Bucket join support for Iceberg tables (#4859)
4dd513bb0 is described below

commit 4dd513bb0ab47389ee44ae42b2363116c1db1dc2
Author: Ashish Singh <[email protected]>
AuthorDate: Wed Mar 13 16:20:46 2024 -0700

    Bucket join support for Iceberg tables (#4859)
---
 .../execution/BatchScanExecTransformer.scala       |  19 +-
 .../execution/IcebergScanTransformer.scala         |  24 +-
 .../execution/VeloxIcebergSuite.scala              | 261 +++++++++++++++++++--
 .../io/glutenproject/sql/shims/SparkShims.scala    |  14 +-
 .../sql/shims/spark32/Spark32Shims.scala           |   4 +
 .../sql/shims/spark33/Spark33Shims.scala           |   6 +
 .../datasources/v2/BatchScanExecShim.scala         |   1 +
 .../sql/shims/spark34/Spark34Shims.scala           |  45 +++-
 .../datasources/v2/BatchScanExecShim.scala         |  19 +-
 .../sql/shims/spark35/Spark35Shims.scala           |   5 +
 .../datasources/v2/BatchScanExecShim.scala         |   2 +
 11 files changed, 362 insertions(+), 38 deletions(-)

diff --git 
a/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala
 
b/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala
index 693e014df..afa0ce0e2 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/execution/BatchScanExecTransformer.scala
@@ -32,13 +32,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
-/**
- * Columnar Based BatchScanExec. Although keyGroupedPartitioning is not used, 
it cannot be deleted,
- * it can make BatchScanExecTransformer contain a constructor with the same 
parameters as
- * Spark-3.3's BatchScanExec. Otherwise, the corresponding constructor will 
not be found when
- * calling TreeNode.makeCopy and will fail to copy this node during 
transformation.
- */
-
+/** Columnar Based BatchScanExec. */
 case class BatchScanExecTransformer(
     override val output: Seq[AttributeReference],
     @transient override val scan: Scan,
@@ -80,7 +74,16 @@ abstract class BatchScanExecTransformerBase(
     override val commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
     override val applyPartialClustering: Boolean = false,
     override val replicatePartitions: Boolean = false)
-  extends BatchScanExecShim(output, scan, runtimeFilters, table = table)
+  extends BatchScanExecShim(
+    output,
+    scan,
+    runtimeFilters,
+    keyGroupedPartitioning,
+    ordering,
+    table,
+    commonPartitionValues,
+    applyPartialClustering,
+    replicatePartitions)
   with BasicScanExecTransformer {
 
   // Note: "metrics" is made transient to avoid sending driver-side metrics to 
tasks.
diff --git 
a/gluten-iceberg/src/main/scala/io/glutenproject/execution/IcebergScanTransformer.scala
 
b/gluten-iceberg/src/main/scala/io/glutenproject/execution/IcebergScanTransformer.scala
index 1c2c189ed..fdb82f23e 100644
--- 
a/gluten-iceberg/src/main/scala/io/glutenproject/execution/IcebergScanTransformer.scala
+++ 
b/gluten-iceberg/src/main/scala/io/glutenproject/execution/IcebergScanTransformer.scala
@@ -20,6 +20,7 @@ import io.glutenproject.sql.shims.SparkShimLoader
 import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
 import io.glutenproject.substrait.rel.SplitInfo
 
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
DynamicPruningExpression, Expression, Literal}
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.connector.catalog.Table
@@ -33,12 +34,17 @@ case class IcebergScanTransformer(
     override val output: Seq[AttributeReference],
     @transient override val scan: Scan,
     override val runtimeFilters: Seq[Expression],
-    @transient override val table: Table)
+    @transient override val table: Table,
+    override val keyGroupedPartitioning: Option[Seq[Expression]] = None,
+    override val commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None)
   extends BatchScanExecTransformerBase(
     output = output,
     scan = scan,
     runtimeFilters = runtimeFilters,
-    table = table) {
+    table = table,
+    keyGroupedPartitioning = keyGroupedPartitioning,
+    commonPartitionValues = commonPartitionValues
+  ) {
 
   override def filterExprs(): Seq[Expression] = 
pushdownFilters.getOrElse(Seq.empty)
 
@@ -51,7 +57,12 @@ case class IcebergScanTransformer(
   override lazy val fileFormat: ReadFileFormat = 
GlutenIcebergSourceUtil.getFileFormat(scan)
 
   override def getSplitInfos: Seq[SplitInfo] = {
-    getPartitions.zipWithIndex.map {
+    val groupedPartitions = SparkShimLoader.getSparkShims.orderPartitions(
+      scan,
+      keyGroupedPartitioning,
+      filteredPartitions,
+      outputPartitioning)
+    groupedPartitions.zipWithIndex.map {
       case (p, index) => GlutenIcebergSourceUtil.genSplitInfo(p, index)
     }
   }
@@ -64,6 +75,8 @@ case class IcebergScanTransformer(
         output)
     )
   }
+  // Needed for tests
+  private[execution] def getKeyGroupPartitioning: Option[Seq[Expression]] = 
keyGroupedPartitioning
 }
 
 object IcebergScanTransformer {
@@ -74,6 +87,9 @@ object IcebergScanTransformer {
       batchScan.output,
       batchScan.scan,
       newPartitionFilters,
-      table = SparkShimLoader.getSparkShims.getBatchScanExecTable(batchScan))
+      table = SparkShimLoader.getSparkShims.getBatchScanExecTable(batchScan),
+      keyGroupedPartitioning = 
SparkShimLoader.getSparkShims.getKeyGroupedPartitioning(batchScan),
+      commonPartitionValues = 
SparkShimLoader.getSparkShims.getCommonPartitionValues(batchScan)
+    )
   }
 }
diff --git 
a/gluten-iceberg/src/test/scala/io/glutenproject/execution/VeloxIcebergSuite.scala
 
b/gluten-iceberg/src/test/scala/io/glutenproject/execution/VeloxIcebergSuite.scala
index fcb95ace9..019c60295 100644
--- 
a/gluten-iceberg/src/test/scala/io/glutenproject/execution/VeloxIcebergSuite.scala
+++ 
b/gluten-iceberg/src/test/scala/io/glutenproject/execution/VeloxIcebergSuite.scala
@@ -43,6 +43,21 @@ class VeloxIcebergSuite extends WholeStageTransformerSuite {
       .set("spark.sql.catalog.spark_catalog.warehouse", 
s"file://$rootPath/tpch-data-iceberg-velox")
   }
 
+  private def isSparkVersionAtleast(version: String): Boolean = {
+    val currentVersion = spark.version
+    val currentVersionSplit = currentVersion.split("\\.")
+    val versionSplit = version.split("\\.")
+    currentVersionSplit.zip(versionSplit).foreach {
+      case (current, required) =>
+        if (current.toInt > required.toInt) {
+          return true
+        } else if (current.toInt < required.toInt) {
+          return false
+        }
+    }
+    true
+  }
+
   test("iceberg transformer exists") {
     spark.sql("""
                 |create table iceberg_tb using iceberg as
@@ -56,44 +71,246 @@ class VeloxIcebergSuite extends WholeStageTransformerSuite 
{
     }
   }
 
-  test("iceberg partitioned table") {
-    withTable("p_str_tb", "p_int_tb") {
+  test("iceberg bucketed join") {
+    assume(isSparkVersionAtleast("3.4"))
+    val leftTable = "p_str_tb"
+    val rightTable = "p_int_tb"
+    withTable(leftTable, rightTable) {
       // Partition key of string type.
       withSQLConf(GlutenConfig.GLUTEN_ENABLE_KEY -> "false") {
         // Gluten does not support write iceberg table.
+        spark.sql(s"""
+                     |create table $leftTable(id int, name string, p string)
+                     |using iceberg
+                     |partitioned by (bucket(4, id));
+                     |""".stripMargin)
         spark.sql(
-          """
-            |create table p_str_tb(id int, name string, p string) using 
iceberg partitioned by (p);
-            |""".stripMargin)
+          s"""
+             |insert into table $leftTable values
+             |(4, 'a5', 'p4'),
+             |(1, 'a1', 'p1'),
+             |(2, 'a3', 'p2'),
+             |(1, 'a2', 'p1'),
+             |(3, 'a4', 'p3');
+             |""".stripMargin
+        )
+      }
+
+      // Partition key of integer type.
+      withSQLConf(
+        GlutenConfig.GLUTEN_ENABLE_KEY -> "false"
+      ) {
+        // Gluten does not support write iceberg table.
+        spark.sql(s"""
+                     |create table $rightTable(id int, name string, p int)
+                     |using iceberg
+                     |partitioned by (bucket(4, id));
+                     |""".stripMargin)
         spark.sql(
-          """
-            |insert into table p_str_tb values(1, 'a1', 'p1'), (2, 'a2', 
'p1'), (3, 'a3', 'p2');
-            |""".stripMargin
+          s"""
+             |insert into table $rightTable values
+             |(3, 'b4', 23),
+             |(1, 'b2', 21),
+             |(4, 'b5', 24),
+             |(2, 'b3', 22),
+             |(1, 'b1', 21);
+             |""".stripMargin
         )
       }
-      runQueryAndCompare("""
-                           |select * from p_str_tb;
-                           |""".stripMargin) {
-        checkOperatorMatch[IcebergScanTransformer]
+
+      withSQLConf(
+        "spark.sql.sources.v2.bucketing.enabled" -> "true",
+        "spark.sql.requireAllClusterKeysForCoPartition" -> "false",
+        "spark.sql.adaptive.enabled" -> "false",
+        "spark.sql.iceberg.planning.preserve-data-grouping" -> "true",
+        "spark.sql.autoBroadcastJoinThreshold" -> "-1",
+        "spark.sql.sources.v2.bucketing.pushPartValues.enabled" -> "true"
+      ) {
+        runQueryAndCompare(s"""
+                              |select s.id, s.name, i.name, i.p
+                              | from $leftTable s inner join $rightTable i
+                              | on s.id = i.id;
+                              |""".stripMargin) {
+          df =>
+            {
+              assert(
+                getExecutedPlan(df).count(
+                  plan => {
+                    plan.isInstanceOf[IcebergScanTransformer]
+                  }) == 2)
+              getExecutedPlan(df).map {
+                case plan if plan.isInstanceOf[IcebergScanTransformer] =>
+                  assert(
+                    
plan.asInstanceOf[IcebergScanTransformer].getKeyGroupPartitioning.isDefined)
+                  
assert(plan.asInstanceOf[IcebergScanTransformer].getSplitInfos.length == 3)
+                case _ => // do nothing
+              }
+              checkLengthAndPlan(df, 7)
+            }
+        }
+      }
+    }
+  }
+
+  test("iceberg bucketed join with partition") {
+    assume(isSparkVersionAtleast("3.4"))
+    val leftTable = "p_str_tb"
+    val rightTable = "p_int_tb"
+    withTable(leftTable, rightTable) {
+      // Partition key of string type.
+      withSQLConf(GlutenConfig.GLUTEN_ENABLE_KEY -> "false") {
+        // Gluten does not support write iceberg table.
+        spark.sql(s"""
+                     |create table $leftTable(id int, name string, p int)
+                     |using iceberg
+                     |partitioned by (bucket(4, id), p);
+                     |""".stripMargin)
+        spark.sql(
+          s"""
+             |insert into table $leftTable values
+             |(4, 'a5', 2),
+             |(1, 'a1', 1),
+             |(2, 'a3', 1),
+             |(1, 'a2', 1),
+             |(3, 'a4', 2);
+             |""".stripMargin
+        )
       }
 
       // Partition key of integer type.
+      withSQLConf(
+        GlutenConfig.GLUTEN_ENABLE_KEY -> "false"
+      ) {
+        // Gluten does not support write iceberg table.
+        spark.sql(s"""
+                     |create table $rightTable(id int, name string, p int)
+                     |using iceberg
+                     |partitioned by (bucket(4, id), p);
+                     |""".stripMargin)
+        spark.sql(
+          s"""
+             |insert into table $rightTable values
+             |(3, 'b4', 2),
+             |(1, 'b2', 1),
+             |(4, 'b5', 2),
+             |(2, 'b3', 1),
+             |(1, 'b1', 1);
+             |""".stripMargin
+        )
+      }
+
+      withSQLConf(
+        "spark.sql.sources.v2.bucketing.enabled" -> "true",
+        "spark.sql.requireAllClusterKeysForCoPartition" -> "false",
+        "spark.sql.adaptive.enabled" -> "false",
+        "spark.sql.iceberg.planning.preserve-data-grouping" -> "true",
+        "spark.sql.autoBroadcastJoinThreshold" -> "-1",
+        "spark.sql.sources.v2.bucketing.pushPartValues.enabled" -> "true"
+      ) {
+        runQueryAndCompare(s"""
+                              |select s.id, s.name, i.name, i.p
+                              | from $leftTable s inner join $rightTable i
+                              | on s.id = i.id and s.p = i.p;
+                              |""".stripMargin) {
+          df =>
+            {
+              assert(
+                getExecutedPlan(df).count(
+                  plan => {
+                    plan.isInstanceOf[IcebergScanTransformer]
+                  }) == 2)
+              getExecutedPlan(df).map {
+                case plan if plan.isInstanceOf[IcebergScanTransformer] =>
+                  assert(
+                    
plan.asInstanceOf[IcebergScanTransformer].getKeyGroupPartitioning.isDefined)
+                  
assert(plan.asInstanceOf[IcebergScanTransformer].getSplitInfos.length == 3)
+                case _ => // do nothing
+              }
+              checkLengthAndPlan(df, 7)
+            }
+        }
+      }
+    }
+  }
+
+  test("iceberg bucketed join with partition filter") {
+    assume(isSparkVersionAtleast("3.4"))
+    val leftTable = "p_str_tb"
+    val rightTable = "p_int_tb"
+    withTable(leftTable, rightTable) {
+      // Partition key of string type.
       withSQLConf(GlutenConfig.GLUTEN_ENABLE_KEY -> "false") {
         // Gluten does not support write iceberg table.
+        spark.sql(s"""
+                     |create table $leftTable(id int, name string, p int)
+                     |using iceberg
+                     |partitioned by (bucket(4, id), p);
+                     |""".stripMargin)
         spark.sql(
-          """
-            |create table p_int_tb(id int, name string, p int) using iceberg 
partitioned by (p);
-            |""".stripMargin)
+          s"""
+             |insert into table $leftTable values
+             |(4, 'a5', 2),
+             |(1, 'a1', 1),
+             |(2, 'a3', 1),
+             |(1, 'a2', 1),
+             |(3, 'a4', 2);
+             |""".stripMargin
+        )
+      }
+
+      // Partition key of integer type.
+      withSQLConf(
+        GlutenConfig.GLUTEN_ENABLE_KEY -> "false"
+      ) {
+        // Gluten does not support write iceberg table.
+        spark.sql(s"""
+                     |create table $rightTable(id int, name string, p int)
+                     |using iceberg
+                     |partitioned by (bucket(4, id), p);
+                     |""".stripMargin)
         spark.sql(
-          """
-            |insert into table p_int_tb values(1, 'a1', 1), (2, 'a2', 1), (3, 
'a3', 2);
-            |""".stripMargin
+          s"""
+             |insert into table $rightTable values
+             |(3, 'b4', 2),
+             |(1, 'b2', 1),
+             |(4, 'b5', 2),
+             |(2, 'b3', 1),
+             |(1, 'b1', 1);
+             |""".stripMargin
         )
       }
-      runQueryAndCompare("""
-                           |select * from p_int_tb;
-                           |""".stripMargin) {
-        checkOperatorMatch[IcebergScanTransformer]
+
+      withSQLConf(
+        "spark.sql.sources.v2.bucketing.enabled" -> "true",
+        "spark.sql.requireAllClusterKeysForCoPartition" -> "false",
+        "spark.sql.adaptive.enabled" -> "false",
+        "spark.sql.iceberg.planning.preserve-data-grouping" -> "true",
+        "spark.sql.autoBroadcastJoinThreshold" -> "-1",
+        "spark.sql.sources.v2.bucketing.pushPartValues.enabled" -> "true"
+      ) {
+        runQueryAndCompare(s"""
+                              |select s.id, s.name, i.name, i.p
+                              | from $leftTable s inner join $rightTable i
+                              | on s.id = i.id
+                              | where s.p = 1 and i.p = 1;
+                              |""".stripMargin) {
+          df =>
+            {
+              assert(
+                getExecutedPlan(df).count(
+                  plan => {
+                    plan.isInstanceOf[IcebergScanTransformer]
+                  }) == 2)
+              getExecutedPlan(df).map {
+                case plan if plan.isInstanceOf[IcebergScanTransformer] =>
+                  assert(
+                    
plan.asInstanceOf[IcebergScanTransformer].getKeyGroupPartitioning.isDefined)
+                  
assert(plan.asInstanceOf[IcebergScanTransformer].getSplitInfos.length == 1)
+                case _ => // do nothing
+              }
+              checkLengthAndPlan(df, 5)
+            }
+        }
       }
     }
   }
diff --git 
a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala 
b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala
index 4894ce34e..ab560a060 100644
--- a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala
+++ b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala
@@ -21,16 +21,18 @@ import io.glutenproject.expression.Sig
 import org.apache.spark.{SparkContext, TaskContext}
 import org.apache.spark.internal.io.FileCommitProtocol
 import org.apache.spark.scheduler.TaskInfo
-import org.apache.spark.shuffle.{ShuffleHandle, ShuffleReader}
+import org.apache.spark.shuffle.ShuffleHandle
 import org.apache.spark.sql.{AnalysisException, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.catalog.Table
 import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.connector.read.{InputPartition, Scan}
 import org.apache.spark.sql.execution.{FileSourceScanExec, GlobalLimitExec, 
SparkPlan, TakeOrderedAndProjectExec}
 import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, 
PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex, 
WriteJobDescription, WriteTaskResult}
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -159,4 +161,14 @@ trait SparkShims {
 
   // For compatibility with Spark-3.5.
   def getAnalysisExceptionPlan(ae: AnalysisException): Option[LogicalPlan]
+
+  def getKeyGroupedPartitioning(batchScan: BatchScanExec): 
Option[Seq[Expression]]
+
+  def getCommonPartitionValues(batchScan: BatchScanExec): 
Option[Seq[(InternalRow, Int)]]
+
+  def orderPartitions(
+      scan: Scan,
+      keyGroupedPartitioning: Option[Seq[Expression]],
+      filteredPartitions: Seq[Seq[InputPartition]],
+      outputPartitioning: Partitioning): Seq[InputPartition] = 
filteredPartitions.flatten
 }
diff --git 
a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala
 
b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala
index 622335046..5a3579b25 100644
--- 
a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala
+++ 
b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala
@@ -189,4 +189,8 @@ class Spark32Shims extends SparkShims {
     ae.plan
   }
 
+  override def getKeyGroupedPartitioning(batchScan: BatchScanExec): 
Option[Seq[Expression]] = null
+
+  override def getCommonPartitionValues(batchScan: BatchScanExec): 
Option[Seq[(InternalRow, Int)]] =
+    null
 }
diff --git 
a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala
 
b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala
index b580e792b..5f8134f7e 100644
--- 
a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala
+++ 
b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala
@@ -231,4 +231,10 @@ class Spark33Shims extends SparkShims {
   def getAnalysisExceptionPlan(ae: AnalysisException): Option[LogicalPlan] = {
     ae.plan
   }
+
+  override def getKeyGroupedPartitioning(batchScan: BatchScanExec): 
Option[Seq[Expression]] = {
+    batchScan.keyGroupedPartitioning
+  }
+  override def getCommonPartitionValues(batchScan: BatchScanExec): 
Option[Seq[(InternalRow, Int)]] =
+    null
 }
diff --git 
a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
 
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
index cdd571fd0..dcfb5c950 100644
--- 
a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
+++ 
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
@@ -35,6 +35,7 @@ abstract class BatchScanExecShim(
     override val output: Seq[AttributeReference],
     @transient override val scan: Scan,
     override val runtimeFilters: Seq[Expression],
+    val keyGroupedPartitioning: Option[Seq[Expression]] = None,
     val ordering: Option[Seq[SortOrder]] = None,
     @transient val table: Table,
     val commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
diff --git 
a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala
 
b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala
index a1851b276..cd8449bb3 100644
--- 
a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala
+++ 
b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala
@@ -31,10 +31,12 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution, KeyGroupedPartitioning, Partitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
 import org.apache.spark.sql.connector.catalog.Table
 import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, 
Scan}
 import org.apache.spark.sql.execution.{FileSourceScanExec, GlobalLimitExec, 
GlutenFileFormatWriter, PartitionedFileUtil, SparkPlan, 
TakeOrderedAndProjectExec}
 import org.apache.spark.sql.execution.datasources.{BucketingUtils, 
FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, 
PartitioningAwareFileIndex, WriteJobDescription, WriteTaskResult}
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -269,4 +271,45 @@ class Spark34Shims extends SparkShims {
   def getAnalysisExceptionPlan(ae: AnalysisException): Option[LogicalPlan] = {
     ae.plan
   }
+
+  override def getKeyGroupedPartitioning(batchScan: BatchScanExec): 
Option[Seq[Expression]] = {
+    batchScan.keyGroupedPartitioning
+  }
+
+  override def getCommonPartitionValues(
+      batchScan: BatchScanExec): Option[Seq[(InternalRow, Int)]] = {
+    batchScan.commonPartitionValues
+  }
+
+  override def orderPartitions(
+      scan: Scan,
+      keyGroupedPartitioning: Option[Seq[Expression]],
+      filteredPartitions: Seq[Seq[InputPartition]],
+      outputPartitioning: Partitioning): Seq[InputPartition] = {
+    scan match {
+      case _ if keyGroupedPartitioning.isDefined =>
+        var newPartitions = filteredPartitions
+        outputPartitioning match {
+          case p: KeyGroupedPartitioning =>
+            val partitionMapping = newPartitions
+              .map(
+                s =>
+                  InternalRowComparableWrapper(
+                    s.head.asInstanceOf[HasPartitionKey],
+                    p.expressions) -> s)
+              .toMap
+            newPartitions = p.partitionValues.map {
+              partValue =>
+                // Use empty partition for those partition values that are not 
present
+                partitionMapping.getOrElse(
+                  InternalRowComparableWrapper(partValue, p.expressions),
+                  Seq.empty)
+            }
+          case _ =>
+        }
+        newPartitions.flatten
+      case _ =>
+        filteredPartitions.flatten
+    }
+  }
 }
diff --git 
a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
 
b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
index 3cc09068e..4c12356d6 100644
--- 
a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
+++ 
b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import org.apache.spark.SparkException
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
 import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
@@ -34,8 +35,22 @@ abstract class BatchScanExecShim(
     output: Seq[AttributeReference],
     @transient scan: Scan,
     runtimeFilters: Seq[Expression],
-    @transient val table: Table)
-  extends AbstractBatchScanExec(output, scan, runtimeFilters, table = table) {
+    keyGroupedPartitioning: Option[Seq[Expression]] = None,
+    ordering: Option[Seq[SortOrder]] = None,
+    @transient val table: Table,
+    commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
+    applyPartialClustering: Boolean = false,
+    replicatePartitions: Boolean = false)
+  extends AbstractBatchScanExec(
+    output,
+    scan,
+    runtimeFilters,
+    keyGroupedPartitioning,
+    ordering,
+    table,
+    commonPartitionValues,
+    applyPartialClustering,
+    replicatePartitions) {
 
   // Note: "metrics" is made transient to avoid sending driver-side metrics to 
tasks.
   @transient override lazy val metrics: Map[String, SQLMetric] = Map()
diff --git 
a/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala
 
b/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala
index 8468b7a81..a33801653 100644
--- 
a/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala
+++ 
b/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala
@@ -272,4 +272,9 @@ class Spark35Shims extends SparkShims {
         None
     }
   }
+
+  override def getKeyGroupedPartitioning(batchScan: BatchScanExec): 
Option[Seq[Expression]] = null
+
+  override def getCommonPartitionValues(batchScan: BatchScanExec): 
Option[Seq[(InternalRow, Int)]] =
+    null
 }
diff --git 
a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
 
b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
index 768957cf2..ec12fd33a 100644
--- 
a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
+++ 
b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
@@ -35,6 +35,8 @@ abstract class BatchScanExecShim(
     output: Seq[AttributeReference],
     @transient scan: Scan,
     runtimeFilters: Seq[Expression],
+    keyGroupedPartitioning: Option[Seq[Expression]] = None,
+    ordering: Option[Seq[SortOrder]] = None,
     @transient val table: Table,
     val commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
     val applyPartialClustering: Boolean = false,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to