This is an automated email from the ASF dual-hosted git repository.
marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 5dd884e15 [GLUTEN-5757][CORE] Remove unnecessary
ProjectExecTransformer for Generate (#5782)
5dd884e15 is described below
commit 5dd884e150839a6377087d302cfb8f242948bd9c
Author: James Xu <[email protected]>
AuthorDate: Thu May 23 14:43:03 2024 +0800
[GLUTEN-5757][CORE] Remove unnecessary ProjectExecTransformer for Generate
(#5782)
If generator function's input is already Attribute reference, we omit the
introduction of the ProjectExec.
Previous implementation always assume there is Project under Generate. In
the new implementation we added a metadata(injectedProject) in Substrait
plan to tell us whether there is a dedicated Project under Generate
---
.../gluten/execution/GenerateExecTransformer.scala | 69 +++++++++++++++-------
.../org/apache/gluten/execution/TestOperator.scala | 30 ++++++++--
cpp/velox/substrait/SubstraitToVeloxPlan.cc | 13 ++--
.../execution/WholeStageTransformerSuite.scala | 30 +++++++++-
4 files changed, 112 insertions(+), 30 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala
index 23addb89e..c9b0abd6f 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala
@@ -95,6 +95,21 @@ case class GenerateExecTransformer(
operatorId)
}
+ /**
+ * Is the specified expression an Attribute reference?
+ * @param expr
+ * @param replaceBoundReference
+ * @return
+ */
+ private def isAttributeReference(
+ expr: Expression,
+ replaceBoundReference: Boolean = false): Boolean =
+ expr match {
+ case _: Attribute => true
+ case _: BoundReference if !replaceBoundReference => true
+ case _ => false
+ }
+
private def getExtensionNode(validation: Boolean): AdvancedExtensionNode = {
if (!validation) {
// Start with "GenerateParameters:"
@@ -111,14 +126,26 @@ case class GenerateExecTransformer(
.append("\n")
// isStack: 1 for Stack, 0 for others.
- val isStack = if (generator.isInstanceOf[Stack]) {
- "1"
+ val isStack = generator.isInstanceOf[Stack]
+ parametersStr
+ .append("isStack=")
+ .append(if (isStack) "1" else "0")
+ .append("\n")
+
+ val injectProject = if (isStack) {
+ // We always need to inject a Project for stack because we organize
+ // stack's flat params into arrays, e.g. stack(2, 1, 2, 3) is
+ // organized into two arrays: [1, 2] and [3, null].
+ true
} else {
- "0"
+ // Other generator function only have one param, so we just check
whether
+ // the only param(generator.children.head) is attribute reference or
not.
+ !isAttributeReference(generator.children.head, true);
}
+
parametersStr
- .append("isStack=")
- .append(isStack)
+ .append("injectedProject=")
+ .append(if (injectProject) "1" else "0")
.append("\n")
val message = StringValue
@@ -158,27 +185,27 @@ object PullOutGenerateProjectHelper extends
PullOutProjectHelper {
val expressionMap = new mutable.HashMap[Expression,
NamedExpression]()
// The new child should be either the original Attribute,
// or an Alias to other expressions.
- val generatorAttr = replaceExpressionWithAttribute(
+ replaceExpressionWithAttribute(
generate.generator.asInstanceOf[UnaryExpression].child,
expressionMap,
replaceBoundReference = true)
- val newGeneratorChild = if (expressionMap.isEmpty) {
- // generator.child is Attribute
- generatorAttr.asInstanceOf[Attribute]
+
+ if (!expressionMap.isEmpty) {
+ // generator.child is not an Attribute reference, e.g
Literal/CreateArray/CreateMap.
+ // We plug in a Project to make it an Attribute reference.
+ // NOTE: DO NOT use eliminateProjectList to create the project
list because
+ // newGeneratorChild can be a duplicated Attribute in
generate.child.output. The native
+ // side identifies the last field of projection as generator's
input.
+ val newGeneratorChildren = Seq(expressionMap.values.head)
+ generate.copy(
+ generator =
+
generate.generator.withNewChildren(newGeneratorChildren).asInstanceOf[Generator],
+ child = ProjectExec(generate.child.output ++
newGeneratorChildren, generate.child)
+ )
} else {
- // generator.child is other expression, e.g
Literal/CreateArray/CreateMap
- expressionMap.values.head
+ // generator.child is Attribute, no need to introduce a Project.
+ generate
}
- val newGeneratorChildren = Seq(newGeneratorChild)
-
- // Avoid using eliminateProjectList to create the project list
- // because newGeneratorChild can be a duplicated Attribute in
generate.child.output.
- // The native side identifies the last field of projection as
generator's input.
- generate.copy(
- generator =
-
generate.generator.withNewChildren(newGeneratorChildren).asInstanceOf[Generator],
- child = ProjectExec(generate.child.output ++ newGeneratorChildren,
generate.child)
- )
case stack: Stack =>
val numRows = stack.children.head.eval().asInstanceOf[Int]
val numFields = Math.ceil((stack.children.size - 1.0) /
numRows).toInt
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
index 287bf1e9b..657039572 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
@@ -23,9 +23,9 @@ import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.SparkConf
import org.apache.spark.sql.{AnalysisException, Row}
-import org.apache.spark.sql.execution.{ArrowFileSourceScanExec,
ColumnarToRowExec, FilterExec, GenerateExec, ProjectExec, RDDScanExec}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.window.WindowExec
-import org.apache.spark.sql.functions.{avg, col, lit, to_date, udf}
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, StringType, StructField,
StructType}
@@ -787,7 +787,8 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
runQueryAndCompare(s"""
|SELECT $func(a) from t2;
|""".stripMargin) {
- checkGlutenOperatorMatch[GenerateExecTransformer]
+ // No ProjectExecTransformer is introduced.
+ checkSparkOperatorChainMatch[GenerateExecTransformer,
FilterExecTransformer]
}
sql("""select * from values
| map(1, 'a', 2, 'b', 3, null),
@@ -797,7 +798,8 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
runQueryAndCompare(s"""
|SELECT $func(a) from t2;
|""".stripMargin) {
- checkGlutenOperatorMatch[GenerateExecTransformer]
+ // No ProjectExecTransformer is introduced.
+ checkSparkOperatorChainMatch[GenerateExecTransformer,
FilterExecTransformer]
}
}
}
@@ -908,6 +910,26 @@ class TestOperator extends VeloxWholeStageTransformerSuite
{
checkGlutenOperatorMatch[GenerateExecTransformer]
}
}
+
+ // More complex case which might cause projection name conflict.
+ withTempView("script_trans") {
+ sql("""SELECT * FROM VALUES
+ |(1, 2, 3),
+ |(4, 5, 6),
+ |(7, 8, 9)
+ |AS script_trans(a, b, c)
+ """.stripMargin).createOrReplaceTempView("script_trans")
+ runQueryAndCompare(s"""SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS
STRING), myCol, myCol2)
+ | USING 'cat' AS (a STRING, b STRING, c STRING, d
ARRAY<INT>, e STRING)
+ |FROM script_trans
+ |LATERAL VIEW explode(array(array(1,2,3))) myTable
AS myCol
+ |LATERAL VIEW explode(myTable.myCol) myTable2 AS
myCol2
+ |WHERE a <= 4
+ |GROUP BY b, myCol, myCol2
+ |HAVING max(a) > 1""".stripMargin) {
+ checkSparkOperatorChainMatch[GenerateExecTransformer,
FilterExecTransformer]
+ }
+ }
}
test("test array functions") {
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 236203066..b82eead2c 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -747,12 +747,17 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
replicated.emplace_back(std::dynamic_pointer_cast<const
core::FieldAccessTypedExpr>(expression));
}
- auto projNode = std::dynamic_pointer_cast<const
core::ProjectNode>(childNode);
+ auto injectedProject = generateRel.has_advanced_extension() &&
+
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(),
"injectedProject=");
- bool isStack = generateRel.has_advanced_extension() &&
-
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(),
"isStack=");
+ if (injectedProject) {
+ auto projNode = std::dynamic_pointer_cast<const
core::ProjectNode>(childNode);
+ VELOX_CHECK(
+ projNode != nullptr && projNode->names().size() >
requiredChildOutput.size(),
+ "injectedProject is true, but the Project is missing or does not have
the corresponding projection field")
- if (projNode != nullptr && projNode->names().size() >
requiredChildOutput.size()) {
+ bool isStack = generateRel.has_advanced_extension() &&
+
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(),
"isStack=");
// Generator function's input is NOT a field reference.
if (!isStack) {
// For generator function which is not stack, e.g.
explode(array(1,2,3)), a sample
diff --git
a/gluten-core/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
b/gluten-core/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
index 5f60de27b..7d2d48828 100644
---
a/gluten-core/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
+++
b/gluten-core/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
@@ -24,7 +24,7 @@ import org.apache.gluten.utils.Arm
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, GlutenQueryTest, Row}
-import org.apache.spark.sql.execution.{CommandResultExec, SparkPlan}
+import org.apache.spark.sql.execution.{CommandResultExec, SparkPlan,
UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
AdaptiveSparkPlanHelper, ShuffleQueryStageExec}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.DoubleType
@@ -254,6 +254,34 @@ abstract class WholeStageTransformerSuite
assert(executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)))
}
+ /**
+ * Check whether the executed plan of a dataframe contains the expected plan
chain.
+ *
+ * @param df
+ * : the input dataframe.
+ * @param tag
+ * : class of the expected plan.
+ * @param childTag
+ * : class of the expected plan's child.
+ * @tparam T
+ * : type of the expected plan.
+ * @tparam PT
+ * : type of the expected plan's child.
+ */
+ def checkSparkOperatorChainMatch[T <: UnaryExecNode, PT <: UnaryExecNode](
+ df: DataFrame)(implicit tag: ClassTag[T], childTag: ClassTag[PT]): Unit
= {
+ val executedPlan = getExecutedPlan(df)
+ assert(
+ executedPlan.exists(
+ plan =>
+ tag.runtimeClass.isInstance(plan)
+ && childTag.runtimeClass.isInstance(plan.children.head)),
+ s"Expect an operator chain of [${tag.runtimeClass.getSimpleName} ->"
+ + s"${childTag.runtimeClass.getSimpleName}] exists in executedPlan: \n"
+ + s"${executedPlan.last}"
+ )
+ }
+
/**
* run a query with native engine as well as vanilla spark then compare the
result set for
* correctness check
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]