This is an automated email from the ASF dual-hosted git repository.

marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 5dd884e15 [GLUTEN-5757][CORE] Remove unnecessary 
ProjectExecTransformer for Generate (#5782)
5dd884e15 is described below

commit 5dd884e150839a6377087d302cfb8f242948bd9c
Author: James Xu <[email protected]>
AuthorDate: Thu May 23 14:43:03 2024 +0800

    [GLUTEN-5757][CORE] Remove unnecessary ProjectExecTransformer for Generate 
(#5782)
    
    If generator function's input is already Attribute reference, we omit the
    introduction of the ProjectExec.
    
    Previous implementation always assume there is Project under Generate. In
    the new implementation we added a metadata(injectedProject) in Substrait
    plan to tell us whether there is a dedicated Project under Generate
---
 .../gluten/execution/GenerateExecTransformer.scala | 69 +++++++++++++++-------
 .../org/apache/gluten/execution/TestOperator.scala | 30 ++++++++--
 cpp/velox/substrait/SubstraitToVeloxPlan.cc        | 13 ++--
 .../execution/WholeStageTransformerSuite.scala     | 30 +++++++++-
 4 files changed, 112 insertions(+), 30 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala
index 23addb89e..c9b0abd6f 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala
@@ -95,6 +95,21 @@ case class GenerateExecTransformer(
       operatorId)
   }
 
+  /**
+   * Is the specified expression an Attribute reference?
+   * @param expr
+   * @param replaceBoundReference
+   * @return
+   */
+  private def isAttributeReference(
+      expr: Expression,
+      replaceBoundReference: Boolean = false): Boolean =
+    expr match {
+      case _: Attribute => true
+      case _: BoundReference if !replaceBoundReference => true
+      case _ => false
+    }
+
   private def getExtensionNode(validation: Boolean): AdvancedExtensionNode = {
     if (!validation) {
       // Start with "GenerateParameters:"
@@ -111,14 +126,26 @@ case class GenerateExecTransformer(
         .append("\n")
 
       // isStack: 1 for Stack, 0 for others.
-      val isStack = if (generator.isInstanceOf[Stack]) {
-        "1"
+      val isStack = generator.isInstanceOf[Stack]
+      parametersStr
+        .append("isStack=")
+        .append(if (isStack) "1" else "0")
+        .append("\n")
+
+      val injectProject = if (isStack) {
+        // We always need to inject a Project for stack because we organize
+        // stack's flat params into arrays, e.g. stack(2, 1, 2, 3) is
+        // organized into two arrays: [1, 2] and [3, null].
+        true
       } else {
-        "0"
+        // Other generator function only have one param, so we just check 
whether
+        // the only param(generator.children.head) is attribute reference or 
not.
+        !isAttributeReference(generator.children.head, true);
       }
+
       parametersStr
-        .append("isStack=")
-        .append(isStack)
+        .append("injectedProject=")
+        .append(if (injectProject) "1" else "0")
         .append("\n")
 
       val message = StringValue
@@ -158,27 +185,27 @@ object PullOutGenerateProjectHelper extends 
PullOutProjectHelper {
           val expressionMap = new mutable.HashMap[Expression, 
NamedExpression]()
           // The new child should be either the original Attribute,
           // or an Alias to other expressions.
-          val generatorAttr = replaceExpressionWithAttribute(
+          replaceExpressionWithAttribute(
             generate.generator.asInstanceOf[UnaryExpression].child,
             expressionMap,
             replaceBoundReference = true)
-          val newGeneratorChild = if (expressionMap.isEmpty) {
-            // generator.child is Attribute
-            generatorAttr.asInstanceOf[Attribute]
+
+          if (!expressionMap.isEmpty) {
+            // generator.child is not an Attribute reference, e.g 
Literal/CreateArray/CreateMap.
+            // We plug in a Project to make it an Attribute reference.
+            // NOTE: DO NOT use eliminateProjectList to create the project 
list because
+            // newGeneratorChild can be a duplicated Attribute in 
generate.child.output. The native
+            // side identifies the last field of projection as generator's 
input.
+            val newGeneratorChildren = Seq(expressionMap.values.head)
+            generate.copy(
+              generator =
+                
generate.generator.withNewChildren(newGeneratorChildren).asInstanceOf[Generator],
+              child = ProjectExec(generate.child.output ++ 
newGeneratorChildren, generate.child)
+            )
           } else {
-            // generator.child is other expression, e.g 
Literal/CreateArray/CreateMap
-            expressionMap.values.head
+            // generator.child is Attribute, no need to introduce a Project.
+            generate
           }
-          val newGeneratorChildren = Seq(newGeneratorChild)
-
-          // Avoid using eliminateProjectList to create the project list
-          // because newGeneratorChild can be a duplicated Attribute in 
generate.child.output.
-          // The native side identifies the last field of projection as 
generator's input.
-          generate.copy(
-            generator =
-              
generate.generator.withNewChildren(newGeneratorChildren).asInstanceOf[Generator],
-            child = ProjectExec(generate.child.output ++ newGeneratorChildren, 
generate.child)
-          )
         case stack: Stack =>
           val numRows = stack.children.head.eval().asInstanceOf[Int]
           val numFields = Math.ceil((stack.children.size - 1.0) / 
numRows).toInt
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala 
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
index 287bf1e9b..657039572 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
@@ -23,9 +23,9 @@ import org.apache.gluten.sql.shims.SparkShimLoader
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.{AnalysisException, Row}
-import org.apache.spark.sql.execution.{ArrowFileSourceScanExec, 
ColumnarToRowExec, FilterExec, GenerateExec, ProjectExec, RDDScanExec}
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.window.WindowExec
-import org.apache.spark.sql.functions.{avg, col, lit, to_date, udf}
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DecimalType, StringType, StructField, 
StructType}
 
@@ -787,7 +787,8 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
           runQueryAndCompare(s"""
                                 |SELECT $func(a) from t2;
                                 |""".stripMargin) {
-            checkGlutenOperatorMatch[GenerateExecTransformer]
+            // No ProjectExecTransformer is introduced.
+            checkSparkOperatorChainMatch[GenerateExecTransformer, 
FilterExecTransformer]
           }
           sql("""select * from values
                 |  map(1, 'a', 2, 'b', 3, null),
@@ -797,7 +798,8 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
           runQueryAndCompare(s"""
                                 |SELECT $func(a) from t2;
                                 |""".stripMargin) {
-            checkGlutenOperatorMatch[GenerateExecTransformer]
+            // No ProjectExecTransformer is introduced.
+            checkSparkOperatorChainMatch[GenerateExecTransformer, 
FilterExecTransformer]
           }
         }
     }
@@ -908,6 +910,26 @@ class TestOperator extends VeloxWholeStageTransformerSuite 
{
         checkGlutenOperatorMatch[GenerateExecTransformer]
       }
     }
+
+    // More complex case which might cause projection name conflict.
+    withTempView("script_trans") {
+      sql("""SELECT * FROM VALUES
+            |(1, 2, 3),
+            |(4, 5, 6),
+            |(7, 8, 9)
+            |AS script_trans(a, b, c)
+         """.stripMargin).createOrReplaceTempView("script_trans")
+      runQueryAndCompare(s"""SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS 
STRING), myCol, myCol2)
+                            |  USING 'cat' AS (a STRING, b STRING, c STRING, d 
ARRAY<INT>, e STRING)
+                            |FROM script_trans
+                            |LATERAL VIEW explode(array(array(1,2,3))) myTable 
AS myCol
+                            |LATERAL VIEW explode(myTable.myCol) myTable2 AS 
myCol2
+                            |WHERE a <= 4
+                            |GROUP BY b, myCol, myCol2
+                            |HAVING max(a) > 1""".stripMargin) {
+        checkSparkOperatorChainMatch[GenerateExecTransformer, 
FilterExecTransformer]
+      }
+    }
   }
 
   test("test array functions") {
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc 
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 236203066..b82eead2c 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -747,12 +747,17 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
     replicated.emplace_back(std::dynamic_pointer_cast<const 
core::FieldAccessTypedExpr>(expression));
   }
 
-  auto projNode = std::dynamic_pointer_cast<const 
core::ProjectNode>(childNode);
+  auto injectedProject = generateRel.has_advanced_extension() &&
+      
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), 
"injectedProject=");
 
-  bool isStack = generateRel.has_advanced_extension() &&
-      
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), 
"isStack=");
+  if (injectedProject) {
+    auto projNode = std::dynamic_pointer_cast<const 
core::ProjectNode>(childNode);
+    VELOX_CHECK(
+        projNode != nullptr && projNode->names().size() > 
requiredChildOutput.size(),
+        "injectedProject is true, but the Project is missing or does not have 
the corresponding projection field")
 
-  if (projNode != nullptr && projNode->names().size() > 
requiredChildOutput.size()) {
+    bool isStack = generateRel.has_advanced_extension() &&
+        
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), 
"isStack=");
     // Generator function's input is NOT a field reference.
     if (!isStack) {
       // For generator function which is not stack, e.g. 
explode(array(1,2,3)), a sample
diff --git 
a/gluten-core/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
 
b/gluten-core/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
index 5f60de27b..7d2d48828 100644
--- 
a/gluten-core/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
+++ 
b/gluten-core/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
@@ -24,7 +24,7 @@ import org.apache.gluten.utils.Arm
 import org.apache.spark.SparkConf
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, GlutenQueryTest, Row}
-import org.apache.spark.sql.execution.{CommandResultExec, SparkPlan}
+import org.apache.spark.sql.execution.{CommandResultExec, SparkPlan, 
UnaryExecNode}
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
AdaptiveSparkPlanHelper, ShuffleQueryStageExec}
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.DoubleType
@@ -254,6 +254,34 @@ abstract class WholeStageTransformerSuite
     assert(executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)))
   }
 
+  /**
+   * Check whether the executed plan of a dataframe contains the expected plan 
chain.
+   *
+   * @param df
+   *   : the input dataframe.
+   * @param tag
+   *   : class of the expected plan.
+   * @param childTag
+   *   : class of the expected plan's child.
+   * @tparam T
+   *   : type of the expected plan.
+   * @tparam PT
+   *   : type of the expected plan's child.
+   */
+  def checkSparkOperatorChainMatch[T <: UnaryExecNode, PT <: UnaryExecNode](
+      df: DataFrame)(implicit tag: ClassTag[T], childTag: ClassTag[PT]): Unit 
= {
+    val executedPlan = getExecutedPlan(df)
+    assert(
+      executedPlan.exists(
+        plan =>
+          tag.runtimeClass.isInstance(plan)
+            && childTag.runtimeClass.isInstance(plan.children.head)),
+      s"Expect an operator chain of [${tag.runtimeClass.getSimpleName} ->"
+        + s"${childTag.runtimeClass.getSimpleName}] exists in executedPlan: \n"
+        + s"${executedPlan.last}"
+    )
+  }
+
   /**
    * run a query with native engine as well as vanilla spark then compare the 
result set for
    * correctness check


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to