This is an automated email from the ASF dual-hosted git repository.
zhli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 23048fb6a [GLUTEN-4884][VL] Call getPartitions once in
WholeStageTransformer.
23048fb6a is described below
commit 23048fb6a3d1520dd1336371dc67ec6dc5fbd55a
Author: Ankita Victor <[email protected]>
AuthorDate: Wed Mar 27 13:30:59 2024 +0530
[GLUTEN-4884][VL] Call getPartitions once in WholeStageTransformer.
Call getPartitions once in WholeStageTransformer.
---
.../execution/BasicScanExecTransformer.scala | 7 ++++++-
.../glutenproject/execution/WholeStageTransformer.scala | 16 +++++++++++-----
2 files changed, 17 insertions(+), 6 deletions(-)
diff --git
a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala
b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala
index 997395c0c..30b5b4cd5 100644
---
a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala
@@ -28,6 +28,7 @@ import
io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.types.{BooleanType, StringType, StructField,
StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -65,7 +66,11 @@ trait BasicScanExecTransformer extends LeafTransformSupport
with BaseDataSource
/** Returns the split infos that will be processed by the underlying native
engine. */
def getSplitInfos: Seq[SplitInfo] = {
- getPartitions.map(
+ getSplitInfosFromPartitions(getPartitions)
+ }
+
+ def getSplitInfosFromPartitions(partitions: Seq[InputPartition]):
Seq[SplitInfo] = {
+ partitions.map(
BackendsApiManager.getIteratorApiInstance
.genSplitInfo(_, getPartitionSchema, fileFormat,
getMetadataColumns.map(_.name)))
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala
b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala
index 24957d1c2..b4409d98d 100644
---
a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala
@@ -34,6 +34,7 @@ import org.apache.spark.softaffinity.SoftAffinity
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -267,7 +268,9 @@ case class WholeStageTransformer(child: SparkPlan,
materializeInput: Boolean = f
* care of SCAN there won't be any other RDD for SCAN. As a result,
genFirstStageIterator
* rather than genFinalStageIterator will be invoked
*/
- val allScanSplitInfos =
getSplitInfosFromScanTransformer(basicScanExecTransformers)
+ val allScanPartitions = basicScanExecTransformers.map(_.getPartitions)
+ val allScanSplitInfos =
+ getSplitInfosFromPartitions(basicScanExecTransformers,
allScanPartitions)
val (wsCtx, inputPartitions) = GlutenTimeMetric.withMillisTime {
val wsCtx = doWholeStageTransform()
@@ -297,7 +300,6 @@ case class WholeStageTransformer(child: SparkPlan,
materializeInput: Boolean = f
wsCtx.substraitContext.registeredAggregationParams
)
)
- val allScanPartitions = basicScanExecTransformers.map(_.getPartitions)
(0 until allScanPartitions.head.size).foreach(
i => {
val currentPartitions = allScanPartitions.map(_(i))
@@ -361,8 +363,9 @@ case class WholeStageTransformer(child: SparkPlan,
materializeInput: Boolean = f
override protected def withNewChildInternal(newChild: SparkPlan):
WholeStageTransformer =
copy(child = newChild, materializeInput =
materializeInput)(transformStageId)
- private def getSplitInfosFromScanTransformer(
- basicScanExecTransformers: Seq[BasicScanExecTransformer]):
Seq[Seq[SplitInfo]] = {
+ private def getSplitInfosFromPartitions(
+ basicScanExecTransformers: Seq[BasicScanExecTransformer],
+ allScanPartitions: Seq[Seq[InputPartition]]): Seq[Seq[SplitInfo]] = {
// If these are two scan transformers, they must have same partitions,
// otherwise, exchange will be inserted. We should combine the two scan
// transformers' partitions with same index, and set them together in
@@ -378,7 +381,10 @@ case class WholeStageTransformer(child: SparkPlan,
materializeInput: Boolean = f
// p14 | p24
// ...
// p1n | p2n => substraitContext.setSplitInfo([p1n, p2n])
- val allScanSplitInfos = basicScanExecTransformers.map(_.getSplitInfos)
+ val allScanSplitInfos =
+ allScanPartitions.zip(basicScanExecTransformers).map {
+ case (partition, transformer) =>
transformer.getSplitInfosFromPartitions(partition)
+ }
val partitionLength = allScanSplitInfos.head.size
if (allScanSplitInfos.exists(_.size != partitionLength)) {
throw new GlutenException(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]