This is an automated email from the ASF dual-hosted git repository.

zhli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 23048fb6a [GLUTEN-4884][VL] Call getPartitions once in 
WholeStageTransformer.
23048fb6a is described below

commit 23048fb6a3d1520dd1336371dc67ec6dc5fbd55a
Author: Ankita Victor <[email protected]>
AuthorDate: Wed Mar 27 13:30:59 2024 +0530

    [GLUTEN-4884][VL] Call getPartitions once in WholeStageTransformer.
    
    Call getPartitions once in WholeStageTransformer.
---
 .../execution/BasicScanExecTransformer.scala             |  7 ++++++-
 .../glutenproject/execution/WholeStageTransformer.scala  | 16 +++++++++++-----
 2 files changed, 17 insertions(+), 6 deletions(-)

diff --git 
a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala
 
b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala
index 997395c0c..30b5b4cd5 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala
@@ -28,6 +28,7 @@ import 
io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.connector.read.InputPartition
 import org.apache.spark.sql.hive.HiveTableScanExecTransformer
 import org.apache.spark.sql.types.{BooleanType, StringType, StructField, 
StructType}
 import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -65,7 +66,11 @@ trait BasicScanExecTransformer extends LeafTransformSupport 
with BaseDataSource
 
   /** Returns the split infos that will be processed by the underlying native 
engine. */
   def getSplitInfos: Seq[SplitInfo] = {
-    getPartitions.map(
+    getSplitInfosFromPartitions(getPartitions)
+  }
+
+  def getSplitInfosFromPartitions(partitions: Seq[InputPartition]): 
Seq[SplitInfo] = {
+    partitions.map(
       BackendsApiManager.getIteratorApiInstance
         .genSplitInfo(_, getPartitionSchema, fileFormat, 
getMetadataColumns.map(_.name)))
   }
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala
 
b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala
index 24957d1c2..b4409d98d 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala
@@ -34,6 +34,7 @@ import org.apache.spark.softaffinity.SoftAffinity
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.connector.read.InputPartition
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources.FilePartition
 import org.apache.spark.sql.execution.metric.SQLMetric
@@ -267,7 +268,9 @@ case class WholeStageTransformer(child: SparkPlan, 
materializeInput: Boolean = f
        * care of SCAN there won't be any other RDD for SCAN. As a result, 
genFirstStageIterator
        * rather than genFinalStageIterator will be invoked
        */
-      val allScanSplitInfos = 
getSplitInfosFromScanTransformer(basicScanExecTransformers)
+      val allScanPartitions = basicScanExecTransformers.map(_.getPartitions)
+      val allScanSplitInfos =
+        getSplitInfosFromPartitions(basicScanExecTransformers, 
allScanPartitions)
 
       val (wsCtx, inputPartitions) = GlutenTimeMetric.withMillisTime {
         val wsCtx = doWholeStageTransform()
@@ -297,7 +300,6 @@ case class WholeStageTransformer(child: SparkPlan, 
materializeInput: Boolean = f
           wsCtx.substraitContext.registeredAggregationParams
         )
       )
-      val allScanPartitions = basicScanExecTransformers.map(_.getPartitions)
       (0 until allScanPartitions.head.size).foreach(
         i => {
           val currentPartitions = allScanPartitions.map(_(i))
@@ -361,8 +363,9 @@ case class WholeStageTransformer(child: SparkPlan, 
materializeInput: Boolean = f
   override protected def withNewChildInternal(newChild: SparkPlan): 
WholeStageTransformer =
     copy(child = newChild, materializeInput = 
materializeInput)(transformStageId)
 
-  private def getSplitInfosFromScanTransformer(
-      basicScanExecTransformers: Seq[BasicScanExecTransformer]): 
Seq[Seq[SplitInfo]] = {
+  private def getSplitInfosFromPartitions(
+      basicScanExecTransformers: Seq[BasicScanExecTransformer],
+      allScanPartitions: Seq[Seq[InputPartition]]): Seq[Seq[SplitInfo]] = {
     // If these are two scan transformers, they must have same partitions,
     // otherwise, exchange will be inserted. We should combine the two scan
     // transformers' partitions with same index, and set them together in
@@ -378,7 +381,10 @@ case class WholeStageTransformer(child: SparkPlan, 
materializeInput: Boolean = f
     //  p14  |  p24
     //      ...
     //  p1n  |  p2n    => substraitContext.setSplitInfo([p1n, p2n])
-    val allScanSplitInfos = basicScanExecTransformers.map(_.getSplitInfos)
+    val allScanSplitInfos =
+      allScanPartitions.zip(basicScanExecTransformers).map {
+        case (partition, transformer) => 
transformer.getSplitInfosFromPartitions(partition)
+      }
     val partitionLength = allScanSplitInfos.head.size
     if (allScanSplitInfos.exists(_.size != partitionLength)) {
       throw new GlutenException(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to