This is an automated email from the ASF dual-hosted git repository.

marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 59ff500b5  [CORE] Pullout pre/post project for generate (#4952)
59ff500b5 is described below

commit 59ff500b572ddccf5e563f599fc2d7b0dd1fa8d0
Author: Rong Ma <[email protected]>
AuthorDate: Mon Mar 25 17:07:03 2024 +0800

     [CORE] Pullout pre/post project for generate (#4952)
---
 .../clickhouse/CHSparkPlanExecApi.scala            |   4 +
 .../backendsapi/velox/SparkPlanExecApiImpl.scala   |  10 +-
 .../execution/GenerateExecTransformer.scala        | 227 +++++++++------------
 .../io/glutenproject/execution/TestOperator.scala  |  38 +++-
 .../execution/VeloxLiteralSuite.scala              |   1 +
 cpp/velox/substrait/SubstraitToVeloxExpr.cc        |  30 ++-
 .../backendsapi/SparkPlanExecApi.scala             |   6 +-
 .../execution/GenerateExecTransformerBase.scala    |   1 +
 .../extension/columnar/PullOutPostProject.scala    |   5 +-
 .../extension/columnar/PullOutPreProject.scala     |   7 +-
 .../columnar/RewriteSparkPlanRulesManager.scala    |   3 +-
 .../glutenproject/utils/PullOutProjectHelper.scala |   9 +-
 12 files changed, 191 insertions(+), 150 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 4b6ee1909..781884ad5 100644
--- 
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -766,4 +766,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
   ): GenerateExecTransformerBase = {
     CHGenerateExecTransformer(generator, requiredChildOutput, outer, 
generatorOutput, child)
   }
+
+  override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = 
generate
+
+  override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = 
generate
 }
diff --git 
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
 
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
index 61ea50695..d7045c4e5 100644
--- 
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
+++ 
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, 
HashPartitioning, Partitioning, RoundRobinPartitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{BroadcastUtils, 
ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SparkPlan, 
VeloxColumnarWriteFilesExec}
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.BuildSideRelation
@@ -677,4 +677,12 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
   ): GenerateExecTransformerBase = {
     GenerateExecTransformer(generator, requiredChildOutput, outer, 
generatorOutput, child)
   }
+
+  override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = {
+    PullOutGenerateProjectHelper.pullOutPreProject(generate)
+  }
+
+  override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = {
+    PullOutGenerateProjectHelper.pullOutPostProject(generate)
+  }
 }
diff --git 
a/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
 
b/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
index 5bdfba200..b865f3104 100644
--- 
a/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
+++ 
b/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
@@ -17,25 +17,24 @@
 package io.glutenproject.execution
 
 import io.glutenproject.backendsapi.BackendsApiManager
-import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, 
ExpressionNames}
-import io.glutenproject.expression.ConverterUtils.FunctionConfig
+import io.glutenproject.execution.GenerateExecTransformer.supportsGenerate
 import io.glutenproject.extension.ValidationResult
 import io.glutenproject.metrics.{GenerateMetricsUpdater, MetricsUpdater}
-import io.glutenproject.substrait.`type`.TypeBuilder
 import io.glutenproject.substrait.SubstraitContext
-import io.glutenproject.substrait.expression.{ExpressionBuilder, 
ExpressionNode}
+import io.glutenproject.substrait.expression.ExpressionNode
 import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, 
ExtensionBuilder}
 import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
+import io.glutenproject.utils.PullOutProjectHelper
 
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{GenerateExec, ProjectExec, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructType}
+import org.apache.spark.sql.types.IntegerType
 
-import com.google.common.collect.Lists
 import com.google.protobuf.StringValue
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 case class GenerateExecTransformer(
     generator: Generator,
@@ -62,26 +61,12 @@ case class GenerateExecTransformer(
   override protected def doGeneratorValidate(
       generator: Generator,
       outer: Boolean): ValidationResult = {
-    if (outer) {
-      return ValidationResult.notOk(s"Velox backend does not support outer")
-    }
-    generator match {
-      case _: JsonTuple =>
-        ValidationResult.notOk(s"Velox backend does not support this 
json_tuple")
-      case _: ExplodeBase =>
-        ValidationResult.ok
-      case Inline(child) =>
-        child match {
-          case AttributeReference(_, ArrayType(_: StructType, _), _, _) =>
-            ValidationResult.ok
-          case _ =>
-            // TODO: Support Literal/CreateArray.
-            ValidationResult.notOk(
-              s"Velox backend does not support inline with expression " +
-                s"${child.getClass.getSimpleName}.")
-        }
-      case _ =>
-        ValidationResult.ok
+    if (!supportsGenerate(generator, outer)) {
+      ValidationResult.notOk(
+        s"Velox backend does not support this generator: 
${generator.getClass.getSimpleName}" +
+          s", outer: $outer")
+    } else {
+      ValidationResult.ok
     }
   }
 
@@ -91,30 +76,13 @@ case class GenerateExecTransformer(
       generatorNode: ExpressionNode,
       validation: Boolean): RelNode = {
     val operatorId = context.nextOperatorId(this.nodeName)
-
-    val newInput = if (!validation) {
-      applyPreProject(inputRel, context, operatorId)
-    } else {
-      // No need to validate the pre-projection. The generator output has been 
validated in
-      // doGeneratorValidate.
-      inputRel
-    }
-
-    val generateRel = RelBuilder.makeGenerateRel(
-      newInput,
+    RelBuilder.makeGenerateRel(
+      inputRel,
       generatorNode,
       requiredChildOutputNodes.asJava,
       getExtensionNode(validation),
       context,
       operatorId)
-
-    if (!validation) {
-      applyPostProject(generateRel, context, operatorId)
-    } else {
-      // No need to validate the post-projection on the native side as
-      // it only flattens the generator's output.
-      generateRel
-    }
   }
 
   private def getExtensionNode(validation: Boolean): AdvancedExtensionNode = {
@@ -141,92 +109,95 @@ case class GenerateExecTransformer(
       getExtensionNodeForValidation
     }
   }
+}
 
-  // Select child outputs and append generator input.
-  private def applyPreProject(
-      inputRel: RelNode,
-      context: SubstraitContext,
-      operatorId: Long
-  ): RelNode = {
-    val projectExpressions: Seq[ExpressionNode] =
-      child.output.indices
-        .map(ExpressionBuilder.makeSelection(_)) :+
-        ExpressionConverter
-          .replaceWithExpressionTransformer(
-            generator.asInstanceOf[UnaryExpression].child,
-            child.output)
-          .doTransform(context.registeredFunction)
+object GenerateExecTransformer {
+  def supportsGenerate(generator: Generator, outer: Boolean): Boolean = {
+    // TODO: supports outer and remove this param.
+    if (outer) {
+      false
+    } else {
+      generator match {
+        case _: Inline | _: ExplodeBase =>
+          true
+        case _ =>
+          false
+      }
+    }
+  }
+}
 
-    RelBuilder.makeProjectRel(
-      inputRel,
-      projectExpressions.asJava,
-      context,
-      operatorId,
-      child.output.size)
+object PullOutGenerateProjectHelper extends PullOutProjectHelper {
+  def pullOutPreProject(generate: GenerateExec): SparkPlan = {
+    if (GenerateExecTransformer.supportsGenerate(generate.generator, 
generate.outer)) {
+      val newGeneratorChildren = generate.generator match {
+        case _: Inline | _: ExplodeBase =>
+          val expressionMap = new mutable.HashMap[Expression, 
NamedExpression]()
+          // The new child should be either the original Attribute,
+          // or an Alias to other expressions.
+          val generatorAttr = replaceExpressionWithAttribute(
+            generate.generator.asInstanceOf[UnaryExpression].child,
+            expressionMap,
+            replaceBoundReference = true)
+          val newGeneratorChild = if (expressionMap.isEmpty) {
+            // generator.child is Attribute
+            generatorAttr.asInstanceOf[Attribute]
+          } else {
+            // generator.child is other expression, e.g 
Literal/CreateArray/CreateMap
+            expressionMap.values.head
+          }
+          Seq(newGeneratorChild)
+        case _ =>
+          // Unreachable.
+          throw new IllegalStateException(
+            s"Generator ${generate.generator.getClass.getSimpleName} is not 
supported.")
+      }
+      // Avoid using elimainateProjectList to create the project list
+      // because newGeneratorChild can be a duplicated Attribute in 
generate.child.output.
+      // The native side identifies the last field of projection as 
generator's input.
+      generate.copy(
+        generator =
+          
generate.generator.withNewChildren(newGeneratorChildren).asInstanceOf[Generator],
+        child = ProjectExec(generate.child.output ++ newGeneratorChildren, 
generate.child)
+      )
+    } else {
+      generate
+    }
   }
 
-  // There are 3 types of CollectionGenerator in spark: Explode, PosExplode 
and Inline.
-  // Adds postProject for PosExplode and Inline.
-  private def applyPostProject(
-      generateRel: RelNode,
-      context: SubstraitContext,
-      operatorId: Long): RelNode = {
-    generator match {
-      case Inline(_) =>
-        val requiredOutput = requiredChildOutputNodes.indices.map {
-          ExpressionBuilder.makeSelection(_)
-        }
-        val flattenStruct: Seq[ExpressionNode] = generatorOutput.indices.map {
-          i =>
-            val selectionNode = 
ExpressionBuilder.makeSelection(requiredOutput.size)
-            selectionNode.addNestedChildIdx(i)
-        }
-        RelBuilder.makeProjectRel(
-          generateRel,
-          (requiredOutput ++ flattenStruct).asJava,
-          context,
-          operatorId,
-          1 + requiredOutput.size // 1 stands for the inner struct field from 
array.
-        )
-      case PosExplode(posExplodeChild) =>
-        // Ordinal populated by Velox UnnestNode starts with 1.
-        // Need to substract 1 to align with Spark's output.
-        val unnestedSize = posExplodeChild.dataType match {
-          case _: MapType => 2
-          case _: ArrayType => 1
-        }
-        val subFunctionName = ConverterUtils.makeFuncName(
-          ExpressionNames.SUBTRACT,
-          Seq(LongType, LongType),
-          FunctionConfig.OPT)
-        val functionMap = context.registeredFunction
-        val addFunctionId = ExpressionBuilder.newScalarFunction(functionMap, 
subFunctionName)
-        val literalNode = ExpressionBuilder.makeLiteral(1L, LongType, false)
-        val ordinalNode = ExpressionBuilder.makeCast(
-          TypeBuilder.makeI32(false),
-          ExpressionBuilder.makeScalarFunction(
-            addFunctionId,
-            Lists.newArrayList(
-              ExpressionBuilder.makeSelection(requiredChildOutputNodes.size + 
unnestedSize),
-              literalNode),
-            ConverterUtils.getTypeNode(LongType, 
generator.elementSchema.head.nullable)
-          ),
-          true // Generated ordinal column shouldn't have null.
-        )
-        val requiredChildNodes =
-          
requiredChildOutputNodes.indices.map(ExpressionBuilder.makeSelection(_))
-        val unnestColumns = (0 until unnestedSize)
-          .map(i => ExpressionBuilder.makeSelection(i + 
requiredChildOutputNodes.size))
-        val generatorOutput: Seq[ExpressionNode] =
-          (requiredChildNodes :+ ordinalNode) ++ unnestColumns
-        RelBuilder.makeProjectRel(
-          generateRel,
-          generatorOutput.asJava,
-          context,
-          operatorId,
-          generatorOutput.size
-        )
-      case _ => generateRel
+  def pullOutPostProject(generate: GenerateExec): SparkPlan = {
+    if (GenerateExecTransformer.supportsGenerate(generate.generator, 
generate.outer)) {
+      generate.generator match {
+        case PosExplode(_) =>
+          val originalOrdinal = generate.generatorOutput.head
+          val ordinal = {
+            val subtract = Subtract(Cast(originalOrdinal, IntegerType), 
Literal(1))
+            Alias(subtract, generatePostAliasName)(
+              originalOrdinal.exprId,
+              originalOrdinal.qualifier)
+          }
+          val newGenerate =
+            generate.copy(generatorOutput = generate.generatorOutput.tail :+ 
originalOrdinal)
+          ProjectExec(
+            (generate.requiredChildOutput :+ ordinal) ++ 
generate.generatorOutput.tail,
+            newGenerate)
+        case Inline(_) =>
+          val unnestOutput = {
+            val struct = CreateStruct(generate.generatorOutput)
+            val alias = Alias(struct, generatePostAliasName)()
+            alias.toAttribute
+          }
+          val newGenerate = generate.copy(generatorOutput = Seq(unnestOutput))
+          val newOutput = generate.generatorOutput.zipWithIndex.map {
+            case (attr, i) =>
+              val getStructField = GetStructField(unnestOutput, i, 
Some(attr.name))
+              Alias(getStructField, generatePostAliasName)(attr.exprId, 
attr.qualifier)
+          }
+          ProjectExec(generate.requiredChildOutput ++ newOutput, newGenerate)
+        case _ => generate
+      }
+    } else {
+      generate
     }
   }
 }
diff --git 
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala 
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
index c81a60430..55f8ee6f5 100644
--- 
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
+++ 
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
@@ -787,8 +787,32 @@ class TestOperator extends VeloxWholeStageTransformerSuite 
{
   }
 
   test("test inline function") {
+    // Literal: func(literal)
+    runQueryAndCompare(s"""
+                          |SELECT inline(array(
+                          |  named_struct('c1', 0, 'c2', 1),
+                          |  named_struct('c1', 2, 'c2', null)));
+                          |""".stripMargin) {
+      checkOperatorMatch[GenerateExecTransformer]
+    }
+
+    // CreateArray: func(array(col))
     withTempView("t1") {
-      sql("""select * from values
+      sql("""SELECT * from values
+            |  (named_struct('c1', 0, 'c2', 1)),
+            |  (named_struct('c1', 2, 'c2', null)),
+            |  (null)
+            |as tbl(a)
+         """.stripMargin).createOrReplaceTempView("t1")
+      runQueryAndCompare(s"""
+                            |SELECT inline(array(a)) from t1;
+                            |""".stripMargin) {
+        checkOperatorMatch[GenerateExecTransformer]
+      }
+    }
+
+    withTempView("t2") {
+      sql("""SELECT * from values
             |  array(
             |    named_struct('c1', 0, 'c2', 1),
             |    null,
@@ -800,13 +824,21 @@ class TestOperator extends 
VeloxWholeStageTransformerSuite {
             |    named_struct('c1', 2, 'c2', 3)
             |  )
             |as tbl(a)
-         """.stripMargin).createOrReplaceTempView("t1")
+         """.stripMargin).createOrReplaceTempView("t2")
       runQueryAndCompare("""
-                           |SELECT inline(a) from t1;
+                           |SELECT inline(a) from t2;
                            |""".stripMargin) {
         checkOperatorMatch[GenerateExecTransformer]
       }
     }
+
+    // Fallback for array(struct(...), null) literal.
+    runQueryAndCompare(s"""
+                          |SELECT inline(array(
+                          |  named_struct('c1', 0, 'c2', 1),
+                          |  named_struct('c1', 2, 'c2', null),
+                          |  null));
+                          |""".stripMargin)(_)
   }
 
   test("test array functions") {
diff --git 
a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala
 
b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala
index 557681558..52e122c2b 100644
--- 
a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala
+++ 
b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala
@@ -136,5 +136,6 @@ class VeloxLiteralSuite extends 
VeloxWholeStageTransformerSuite {
 
   test("Literal Fallback") {
     validateFallbackResult("SELECT struct(cast(null as struct<a: string>))")
+    validateFallbackResult("SELECT array(struct(1, 'a'), null)")
   }
 }
diff --git a/cpp/velox/substrait/SubstraitToVeloxExpr.cc 
b/cpp/velox/substrait/SubstraitToVeloxExpr.cc
index f795f9c9e..8699907de 100644
--- a/cpp/velox/substrait/SubstraitToVeloxExpr.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxExpr.cc
@@ -49,15 +49,18 @@ MapVectorPtr makeMapVector(const VectorPtr& keyVector, 
const VectorPtr& valueVec
       valueVector);
 }
 
-RowVectorPtr makeRowVector(const std::vector<VectorPtr>& children) {
+RowVectorPtr makeRowVector(
+    const std::vector<VectorPtr>& children,
+    std::vector<std::string>&& names,
+    size_t length,
+    facebook::velox::memory::MemoryPool* pool) {
   std::vector<std::shared_ptr<const Type>> types;
   types.resize(children.size());
   for (int i = 0; i < children.size(); i++) {
     types[i] = children[i]->type();
   }
-  const size_t vectorSize = children.empty() ? 0 : children.front()->size();
-  auto rowType = ROW(std::move(types));
-  return std::make_shared<RowVector>(children[0]->pool(), rowType, 
BufferPtr(nullptr), vectorSize, children);
+  auto rowType = ROW(std::move(names), std::move(types));
+  return std::make_shared<RowVector>(pool, rowType, BufferPtr(nullptr), 
length, children);
 }
 
 ArrayVectorPtr makeEmptyArrayVector(memory::MemoryPool* pool, const TypePtr& 
elementType) {
@@ -73,7 +76,7 @@ MapVectorPtr makeEmptyMapVector(memory::MemoryPool* pool, 
const TypePtr& keyType
 }
 
 RowVectorPtr makeEmptyRowVector(memory::MemoryPool* pool) {
-  return makeRowVector({});
+  return makeRowVector({}, {}, 0, pool);
 }
 
 template <typename T>
@@ -485,13 +488,20 @@ VectorPtr SubstraitVeloxExprConverter::literalsToVector(
 }
 
 RowVectorPtr SubstraitVeloxExprConverter::literalsToRowVector(const 
::substrait::Expression::Literal& structLiteral) {
-  auto childSize = structLiteral.struct_().fields().size();
-  if (childSize == 0) {
+  if (structLiteral.has_null()) {
+    VELOX_NYI("NULL for struct type is not supported.");
+  }
+  auto numFields = structLiteral.struct_().fields().size();
+  if (numFields == 0) {
     return makeEmptyRowVector(pool_);
   }
   std::vector<VectorPtr> vectors;
-  vectors.reserve(structLiteral.struct_().fields().size());
-  for (const auto& child : structLiteral.struct_().fields()) {
+  std::vector<std::string> names;
+  vectors.reserve(numFields);
+  names.reserve(numFields);
+  for (auto i = 0; i < numFields; ++i) {
+    names.push_back("col_" + std::to_string(i));
+    const auto& child = structLiteral.struct_().fields(i);
     auto typeCase = child.literal_type_case();
     switch (typeCase) {
       case 
::substrait::Expression_Literal::LiteralTypeCase::kIntervalDayToSecond: {
@@ -530,7 +540,7 @@ RowVectorPtr 
SubstraitVeloxExprConverter::literalsToRowVector(const ::substrait:
         }
     }
   }
-  return makeRowVector(vectors);
+  return makeRowVector(vectors, std::move(names), 1, pool_);
 }
 
 core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr(
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
 
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
index 1b892ff59..759f7cfad 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, 
Partitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{FileSourceScanExec, LeafExecNode, 
SparkPlan}
+import org.apache.spark.sql.execution.{FileSourceScanExec, GenerateExec, 
LeafExecNode, SparkPlan}
 import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec}
 import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -653,4 +653,8 @@ trait SparkPlanExecApi {
       generatorOutput: Seq[Attribute],
       child: SparkPlan
   ): GenerateExecTransformerBase
+
+  def genPreProjectForGenerate(generate: GenerateExec): SparkPlan
+
+  def genPostProjectForGenerate(generate: GenerateExec): SparkPlan
 }
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala
 
b/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala
index 285734f38..f3e31346c 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala
@@ -39,6 +39,7 @@ abstract class GenerateExecTransformerBase(
     generatorOutput: Seq[Attribute],
     child: SparkPlan)
   extends UnaryTransformSupport {
+
   protected def doGeneratorValidate(generator: Generator, outer: Boolean): 
ValidationResult
 
   protected def getRelNode(
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala
 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala
index 0a39ef819..a77e063c5 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala
@@ -21,7 +21,7 @@ import io.glutenproject.utils.PullOutProjectHelper
 
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
NamedExpression, WindowExpression}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
+import org.apache.spark.sql.execution.{GenerateExec, ProjectExec, SparkPlan}
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
 import org.apache.spark.sql.execution.window.WindowExec
 
@@ -104,6 +104,9 @@ object PullOutPostProject extends Rule[SparkPlan] with 
PullOutProjectHelper {
         window.copy(windowExpression = 
newWindowExpressions.asInstanceOf[Seq[NamedExpression]])
       ProjectExec(window.child.output ++ postWindowExpressions, newWindow)
 
+    case generate: GenerateExec =>
+      
BackendsApiManager.getSparkPlanExecApiInstance.genPostProjectForGenerate(generate)
+
     case _ => plan
   }
 }
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
index 440f609de..92d6c9fab 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
@@ -16,12 +16,13 @@
  */
 package io.glutenproject.extension.columnar
 
+import io.glutenproject.backendsapi.BackendsApiManager
 import io.glutenproject.utils.PullOutProjectHelper
 
 import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete, Partial}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ExpandExec, ProjectExec, SortExec, 
SparkPlan, TakeOrderedAndProjectExec}
+import org.apache.spark.sql.execution.{ExpandExec, GenerateExec, ProjectExec, 
SortExec, SparkPlan, TakeOrderedAndProjectExec}
 import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, 
TypedAggregateExpression}
 import org.apache.spark.sql.execution.window.WindowExec
 
@@ -189,6 +190,10 @@ object PullOutPreProject extends Rule[SparkPlan] with 
PullOutProjectHelper {
         child = ProjectExec(
           eliminateProjectList(expand.child.outputSet, 
expressionMap.values.toSeq),
           expand.child))
+
+    case generate: GenerateExec =>
+      
BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForGenerate(generate)
+
     case _ => plan
   }
 }
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
index 8f3f01f95..b2591f048 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
@@ -54,6 +54,7 @@ class RewriteSparkPlanRulesManager(rewriteRules: 
Seq[Rule[SparkPlan]]) extends R
         case _: FilterExec => true
         case _: FileSourceScanExec => true
         case _: ExpandExec => true
+        case _: GenerateExec => true
         case _ => false
       }
     }
@@ -77,7 +78,7 @@ class RewriteSparkPlanRulesManager(rewriteRules: 
Seq[Rule[SparkPlan]]) extends R
           // Some rewrite rules may generate new parent plan node, we should 
use transform to
           // rewrite the original plan. For example, PullOutPreProject and 
PullOutPostProject
           // will generate post-project plan node.
-          plan.transform { case p => rule.apply(p) }
+          plan.transformUp { case p => rule.apply(p) }
       }
       (rewrittenPlan, None)
     } catch {
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala 
b/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala
index a519772fc..543a5413c 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala
@@ -30,8 +30,8 @@ trait PullOutProjectHelper {
 
   private val generatedNameIndex = new AtomicInteger(0)
 
-  protected def generatePreAliasName = 
s"_pre_${generatedNameIndex.getAndIncrement()}"
-  protected def generatePostAliasName = 
s"_post_${generatedNameIndex.getAndIncrement()}"
+  protected def generatePreAliasName: String = 
s"_pre_${generatedNameIndex.getAndIncrement()}"
+  protected def generatePostAliasName: String = 
s"_post_${generatedNameIndex.getAndIncrement()}"
 
   /**
    * The majority of Expressions only support Attribute and BoundReference 
when converting them into
@@ -57,12 +57,13 @@ trait PullOutProjectHelper {
 
   protected def replaceExpressionWithAttribute(
       expr: Expression,
-      projectExprsMap: mutable.HashMap[Expression, NamedExpression]): 
Expression =
+      projectExprsMap: mutable.HashMap[Expression, NamedExpression],
+      replaceBoundReference: Boolean = false): Expression =
     expr match {
       case alias: Alias =>
         projectExprsMap.getOrElseUpdate(alias.child.canonicalized, 
alias).toAttribute
       case attr: Attribute => attr
-      case e: BoundReference => e
+      case e: BoundReference if !replaceBoundReference => e
       case other =>
         projectExprsMap
           .getOrElseUpdate(other.canonicalized, Alias(other, 
generatePreAliasName)())


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to