This is an automated email from the ASF dual-hosted git repository.
ulyssesyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 3ee108f97 [CORE] Pullout pre-project for ExpandExec (#5066)
3ee108f97 is described below
commit 3ee108f97d103364facfde371bdd2af2d5013d7e
Author: Joey <[email protected]>
AuthorDate: Thu Mar 21 19:13:26 2024 +0800
[CORE] Pullout pre-project for ExpandExec (#5066)
---
.../execution/ExpandExecTransformer.scala | 126 ++++-----------------
.../extension/columnar/PullOutPreProject.scala | 12 +-
.../columnar/RewriteSparkPlanRulesManager.scala | 1 +
3 files changed, 35 insertions(+), 104 deletions(-)
diff --git
a/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala
b/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala
index 4d547f771..daa195b68 100644
---
a/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala
@@ -17,12 +17,12 @@
package io.glutenproject.execution
import io.glutenproject.backendsapi.BackendsApiManager
-import io.glutenproject.expression.{ConverterUtils, ExpressionConverter,
LiteralTransformer}
+import io.glutenproject.expression.{ConverterUtils, ExpressionConverter}
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.MetricsUpdater
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.SubstraitContext
-import io.glutenproject.substrait.expression.{ExpressionBuilder,
ExpressionNode}
+import io.glutenproject.substrait.expression.ExpressionNode
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
@@ -32,9 +32,6 @@ import org.apache.spark.sql.execution._
import java.util.{ArrayList => JArrayList, List => JList}
-import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
-
case class ExpandExecTransformer(
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
@@ -66,110 +63,33 @@ case class ExpandExecTransformer(
input: RelNode,
validation: Boolean): RelNode = {
val args = context.registeredFunction
- def needsPreProjection(projections: Seq[Seq[Expression]]): Boolean = {
- projections
- .exists(set => set.exists(p => !p.isInstanceOf[Attribute] &&
!p.isInstanceOf[Literal]))
- }
- if (needsPreProjection(projections)) {
- // if there is not literal and attribute expression in project sets, add
a project op
- // to calculate them before expand op.
- val preExprs = ArrayBuffer.empty[Expression]
- val selectionMaps = ArrayBuffer.empty[Seq[Int]]
- var preExprIndex = 0
- for (i <- projections.indices) {
- val selections = ArrayBuffer.empty[Int]
- for (j <- projections(i).indices) {
- val proj = projections(i)(j)
- if (!proj.isInstanceOf[Literal]) {
- val exprIdx = preExprs.indexWhere(expr =>
expr.semanticEquals(proj))
- if (exprIdx != -1) {
- selections += exprIdx
- } else {
- preExprs += proj
- selections += preExprIndex
- preExprIndex = preExprIndex + 1
- }
- } else {
- selections += -1
- }
- }
- selectionMaps += selections
- }
- // make project
- val preExprNodes = preExprs
- .map(
- ExpressionConverter
- .replaceWithExpressionTransformer(_, originalInputAttributes)
- .doTransform(args))
- .asJava
-
- val emitStartIndex = originalInputAttributes.size
- val inputRel = if (!validation) {
- RelBuilder.makeProjectRel(input, preExprNodes, context, operatorId,
emitStartIndex)
- } else {
- // Use a extension node to send the input types through Substrait plan
for a validation.
- val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
- for (attr <- originalInputAttributes) {
- inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType,
attr.nullable))
- }
- val extensionNode = ExtensionBuilder.makeAdvancedExtension(
- BackendsApiManager.getTransformerApiInstance.packPBMessage(
- TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
- RelBuilder.makeProjectRel(
- input,
- preExprNodes,
- extensionNode,
- context,
- operatorId,
- emitStartIndex)
- }
-
- // make expand
- val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
- for (i <- projections.indices) {
+ val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
+ projections.foreach {
+ projectSet =>
val projectExprNodes = new JArrayList[ExpressionNode]()
- for (j <- projections(i).indices) {
- val projectExprNode = projections(i)(j) match {
- case l: Literal =>
- LiteralTransformer(l).doTransform(args)
- case _ =>
- ExpressionBuilder.makeSelection(selectionMaps(i)(j))
- }
-
- projectExprNodes.add(projectExprNode)
+ projectSet.foreach {
+ project =>
+ val projectExprNode = ExpressionConverter
+ .replaceWithExpressionTransformer(project,
originalInputAttributes)
+ .doTransform(args)
+ projectExprNodes.add(projectExprNode)
}
projectSetExprNodes.add(projectExprNodes)
- }
- RelBuilder.makeExpandRel(inputRel, projectSetExprNodes, context,
operatorId)
+ }
+
+ if (!validation) {
+ RelBuilder.makeExpandRel(input, projectSetExprNodes, context, operatorId)
} else {
- val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
- projections.foreach {
- projectSet =>
- val projectExprNodes = new JArrayList[ExpressionNode]()
- projectSet.foreach {
- project =>
- val projectExprNode = ExpressionConverter
- .replaceWithExpressionTransformer(project,
originalInputAttributes)
- .doTransform(args)
- projectExprNodes.add(projectExprNode)
- }
- projectSetExprNodes.add(projectExprNodes)
+ // Use a extension node to send the input types through Substrait plan
for a validation.
+ val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
+ for (attr <- originalInputAttributes) {
+ inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType,
attr.nullable))
}
- if (!validation) {
- RelBuilder.makeExpandRel(input, projectSetExprNodes, context,
operatorId)
- } else {
- // Use a extension node to send the input types through Substrait plan
for a validation.
- val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
- for (attr <- originalInputAttributes) {
- inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType,
attr.nullable))
- }
-
- val extensionNode = ExtensionBuilder.makeAdvancedExtension(
- BackendsApiManager.getTransformerApiInstance.packPBMessage(
- TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
- RelBuilder.makeExpandRel(input, projectSetExprNodes, extensionNode,
context, operatorId)
- }
+ val extensionNode = ExtensionBuilder.makeAdvancedExtension(
+ BackendsApiManager.getTransformerApiInstance.packPBMessage(
+ TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
+ RelBuilder.makeExpandRel(input, projectSetExprNodes, extensionNode,
context, operatorId)
}
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
index 5bf70597c..440f609de 100644
---
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
@@ -21,7 +21,7 @@ import io.glutenproject.utils.PullOutProjectHelper
import org.apache.spark.sql.catalyst.expressions._
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Complete, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan,
TakeOrderedAndProjectExec}
+import org.apache.spark.sql.execution.{ExpandExec, ProjectExec, SortExec,
SparkPlan, TakeOrderedAndProjectExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
TypedAggregateExpression}
import org.apache.spark.sql.execution.window.WindowExec
@@ -74,6 +74,7 @@ object PullOutPreProject extends Rule[SparkPlan] with
PullOutProjectHelper {
}
case _ => false
}.isDefined)
+ case expand: ExpandExec =>
expand.projections.flatten.exists(isNotAttributeAndLiteral)
case _ => false
}
}
@@ -179,6 +180,15 @@ object PullOutPreProject extends Rule[SparkPlan] with
PullOutProjectHelper {
ProjectExec(window.output, newWindow)
+ case expand: ExpandExec if needsPreProject(expand) =>
+ val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
+ val newProjections =
+ expand.projections.map(_.map(replaceExpressionWithAttribute(_,
expressionMap)))
+ expand.copy(
+ projections = newProjections,
+ child = ProjectExec(
+ eliminateProjectList(expand.child.outputSet,
expressionMap.values.toSeq),
+ expand.child))
case _ => plan
}
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
index 892e5eeef..8f3f01f95 100644
---
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
@@ -53,6 +53,7 @@ class RewriteSparkPlanRulesManager(rewriteRules:
Seq[Rule[SparkPlan]]) extends R
case _: WindowExec => true
case _: FilterExec => true
case _: FileSourceScanExec => true
+ case _: ExpandExec => true
case _ => false
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]