This is an automated email from the ASF dual-hosted git repository.
marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 59ff500b5 [CORE] Pullout pre/post project for generate (#4952)
59ff500b5 is described below
commit 59ff500b572ddccf5e563f599fc2d7b0dd1fa8d0
Author: Rong Ma <[email protected]>
AuthorDate: Mon Mar 25 17:07:03 2024 +0800
[CORE] Pullout pre/post project for generate (#4952)
---
.../clickhouse/CHSparkPlanExecApi.scala | 4 +
.../backendsapi/velox/SparkPlanExecApiImpl.scala | 10 +-
.../execution/GenerateExecTransformer.scala | 227 +++++++++------------
.../io/glutenproject/execution/TestOperator.scala | 38 +++-
.../execution/VeloxLiteralSuite.scala | 1 +
cpp/velox/substrait/SubstraitToVeloxExpr.cc | 30 ++-
.../backendsapi/SparkPlanExecApi.scala | 6 +-
.../execution/GenerateExecTransformerBase.scala | 1 +
.../extension/columnar/PullOutPostProject.scala | 5 +-
.../extension/columnar/PullOutPreProject.scala | 7 +-
.../columnar/RewriteSparkPlanRulesManager.scala | 3 +-
.../glutenproject/utils/PullOutProjectHelper.scala | 9 +-
12 files changed, 191 insertions(+), 150 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 4b6ee1909..781884ad5 100644
---
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -766,4 +766,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
): GenerateExecTransformerBase = {
CHGenerateExecTransformer(generator, requiredChildOutput, outer,
generatorOutput, child)
}
+
+ override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan =
generate
+
+ override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan =
generate
}
diff --git
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
index 61ea50695..d7045c4e5 100644
---
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
+++
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode,
HashPartitioning, Partitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{BroadcastUtils,
ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SparkPlan,
VeloxColumnarWriteFilesExec}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.BuildSideRelation
@@ -677,4 +677,12 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
): GenerateExecTransformerBase = {
GenerateExecTransformer(generator, requiredChildOutput, outer,
generatorOutput, child)
}
+
+ override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = {
+ PullOutGenerateProjectHelper.pullOutPreProject(generate)
+ }
+
+ override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = {
+ PullOutGenerateProjectHelper.pullOutPostProject(generate)
+ }
}
diff --git
a/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
b/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
index 5bdfba200..b865f3104 100644
---
a/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
+++
b/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
@@ -17,25 +17,24 @@
package io.glutenproject.execution
import io.glutenproject.backendsapi.BackendsApiManager
-import io.glutenproject.expression.{ConverterUtils, ExpressionConverter,
ExpressionNames}
-import io.glutenproject.expression.ConverterUtils.FunctionConfig
+import io.glutenproject.execution.GenerateExecTransformer.supportsGenerate
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.{GenerateMetricsUpdater, MetricsUpdater}
-import io.glutenproject.substrait.`type`.TypeBuilder
import io.glutenproject.substrait.SubstraitContext
-import io.glutenproject.substrait.expression.{ExpressionBuilder,
ExpressionNode}
+import io.glutenproject.substrait.expression.ExpressionNode
import io.glutenproject.substrait.extensions.{AdvancedExtensionNode,
ExtensionBuilder}
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
+import io.glutenproject.utils.PullOutProjectHelper
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{GenerateExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructType}
+import org.apache.spark.sql.types.IntegerType
-import com.google.common.collect.Lists
import com.google.protobuf.StringValue
import scala.collection.JavaConverters._
+import scala.collection.mutable
case class GenerateExecTransformer(
generator: Generator,
@@ -62,26 +61,12 @@ case class GenerateExecTransformer(
override protected def doGeneratorValidate(
generator: Generator,
outer: Boolean): ValidationResult = {
- if (outer) {
- return ValidationResult.notOk(s"Velox backend does not support outer")
- }
- generator match {
- case _: JsonTuple =>
- ValidationResult.notOk(s"Velox backend does not support this
json_tuple")
- case _: ExplodeBase =>
- ValidationResult.ok
- case Inline(child) =>
- child match {
- case AttributeReference(_, ArrayType(_: StructType, _), _, _) =>
- ValidationResult.ok
- case _ =>
- // TODO: Support Literal/CreateArray.
- ValidationResult.notOk(
- s"Velox backend does not support inline with expression " +
- s"${child.getClass.getSimpleName}.")
- }
- case _ =>
- ValidationResult.ok
+ if (!supportsGenerate(generator, outer)) {
+ ValidationResult.notOk(
+ s"Velox backend does not support this generator:
${generator.getClass.getSimpleName}" +
+ s", outer: $outer")
+ } else {
+ ValidationResult.ok
}
}
@@ -91,30 +76,13 @@ case class GenerateExecTransformer(
generatorNode: ExpressionNode,
validation: Boolean): RelNode = {
val operatorId = context.nextOperatorId(this.nodeName)
-
- val newInput = if (!validation) {
- applyPreProject(inputRel, context, operatorId)
- } else {
- // No need to validate the pre-projection. The generator output has been
validated in
- // doGeneratorValidate.
- inputRel
- }
-
- val generateRel = RelBuilder.makeGenerateRel(
- newInput,
+ RelBuilder.makeGenerateRel(
+ inputRel,
generatorNode,
requiredChildOutputNodes.asJava,
getExtensionNode(validation),
context,
operatorId)
-
- if (!validation) {
- applyPostProject(generateRel, context, operatorId)
- } else {
- // No need to validate the post-projection on the native side as
- // it only flattens the generator's output.
- generateRel
- }
}
private def getExtensionNode(validation: Boolean): AdvancedExtensionNode = {
@@ -141,92 +109,95 @@ case class GenerateExecTransformer(
getExtensionNodeForValidation
}
}
+}
- // Select child outputs and append generator input.
- private def applyPreProject(
- inputRel: RelNode,
- context: SubstraitContext,
- operatorId: Long
- ): RelNode = {
- val projectExpressions: Seq[ExpressionNode] =
- child.output.indices
- .map(ExpressionBuilder.makeSelection(_)) :+
- ExpressionConverter
- .replaceWithExpressionTransformer(
- generator.asInstanceOf[UnaryExpression].child,
- child.output)
- .doTransform(context.registeredFunction)
+object GenerateExecTransformer {
+ def supportsGenerate(generator: Generator, outer: Boolean): Boolean = {
+ // TODO: supports outer and remove this param.
+ if (outer) {
+ false
+ } else {
+ generator match {
+ case _: Inline | _: ExplodeBase =>
+ true
+ case _ =>
+ false
+ }
+ }
+ }
+}
- RelBuilder.makeProjectRel(
- inputRel,
- projectExpressions.asJava,
- context,
- operatorId,
- child.output.size)
+object PullOutGenerateProjectHelper extends PullOutProjectHelper {
+ def pullOutPreProject(generate: GenerateExec): SparkPlan = {
+ if (GenerateExecTransformer.supportsGenerate(generate.generator,
generate.outer)) {
+ val newGeneratorChildren = generate.generator match {
+ case _: Inline | _: ExplodeBase =>
+ val expressionMap = new mutable.HashMap[Expression,
NamedExpression]()
+ // The new child should be either the original Attribute,
+ // or an Alias to other expressions.
+ val generatorAttr = replaceExpressionWithAttribute(
+ generate.generator.asInstanceOf[UnaryExpression].child,
+ expressionMap,
+ replaceBoundReference = true)
+ val newGeneratorChild = if (expressionMap.isEmpty) {
+ // generator.child is Attribute
+ generatorAttr.asInstanceOf[Attribute]
+ } else {
+ // generator.child is other expression, e.g
Literal/CreateArray/CreateMap
+ expressionMap.values.head
+ }
+ Seq(newGeneratorChild)
+ case _ =>
+ // Unreachable.
+ throw new IllegalStateException(
+ s"Generator ${generate.generator.getClass.getSimpleName} is not
supported.")
+ }
+ // Avoid using elimainateProjectList to create the project list
+ // because newGeneratorChild can be a duplicated Attribute in
generate.child.output.
+ // The native side identifies the last field of projection as
generator's input.
+ generate.copy(
+ generator =
+
generate.generator.withNewChildren(newGeneratorChildren).asInstanceOf[Generator],
+ child = ProjectExec(generate.child.output ++ newGeneratorChildren,
generate.child)
+ )
+ } else {
+ generate
+ }
}
- // There are 3 types of CollectionGenerator in spark: Explode, PosExplode
and Inline.
- // Adds postProject for PosExplode and Inline.
- private def applyPostProject(
- generateRel: RelNode,
- context: SubstraitContext,
- operatorId: Long): RelNode = {
- generator match {
- case Inline(_) =>
- val requiredOutput = requiredChildOutputNodes.indices.map {
- ExpressionBuilder.makeSelection(_)
- }
- val flattenStruct: Seq[ExpressionNode] = generatorOutput.indices.map {
- i =>
- val selectionNode =
ExpressionBuilder.makeSelection(requiredOutput.size)
- selectionNode.addNestedChildIdx(i)
- }
- RelBuilder.makeProjectRel(
- generateRel,
- (requiredOutput ++ flattenStruct).asJava,
- context,
- operatorId,
- 1 + requiredOutput.size // 1 stands for the inner struct field from
array.
- )
- case PosExplode(posExplodeChild) =>
- // Ordinal populated by Velox UnnestNode starts with 1.
- // Need to substract 1 to align with Spark's output.
- val unnestedSize = posExplodeChild.dataType match {
- case _: MapType => 2
- case _: ArrayType => 1
- }
- val subFunctionName = ConverterUtils.makeFuncName(
- ExpressionNames.SUBTRACT,
- Seq(LongType, LongType),
- FunctionConfig.OPT)
- val functionMap = context.registeredFunction
- val addFunctionId = ExpressionBuilder.newScalarFunction(functionMap,
subFunctionName)
- val literalNode = ExpressionBuilder.makeLiteral(1L, LongType, false)
- val ordinalNode = ExpressionBuilder.makeCast(
- TypeBuilder.makeI32(false),
- ExpressionBuilder.makeScalarFunction(
- addFunctionId,
- Lists.newArrayList(
- ExpressionBuilder.makeSelection(requiredChildOutputNodes.size +
unnestedSize),
- literalNode),
- ConverterUtils.getTypeNode(LongType,
generator.elementSchema.head.nullable)
- ),
- true // Generated ordinal column shouldn't have null.
- )
- val requiredChildNodes =
-
requiredChildOutputNodes.indices.map(ExpressionBuilder.makeSelection(_))
- val unnestColumns = (0 until unnestedSize)
- .map(i => ExpressionBuilder.makeSelection(i +
requiredChildOutputNodes.size))
- val generatorOutput: Seq[ExpressionNode] =
- (requiredChildNodes :+ ordinalNode) ++ unnestColumns
- RelBuilder.makeProjectRel(
- generateRel,
- generatorOutput.asJava,
- context,
- operatorId,
- generatorOutput.size
- )
- case _ => generateRel
+ def pullOutPostProject(generate: GenerateExec): SparkPlan = {
+ if (GenerateExecTransformer.supportsGenerate(generate.generator,
generate.outer)) {
+ generate.generator match {
+ case PosExplode(_) =>
+ val originalOrdinal = generate.generatorOutput.head
+ val ordinal = {
+ val subtract = Subtract(Cast(originalOrdinal, IntegerType),
Literal(1))
+ Alias(subtract, generatePostAliasName)(
+ originalOrdinal.exprId,
+ originalOrdinal.qualifier)
+ }
+ val newGenerate =
+ generate.copy(generatorOutput = generate.generatorOutput.tail :+
originalOrdinal)
+ ProjectExec(
+ (generate.requiredChildOutput :+ ordinal) ++
generate.generatorOutput.tail,
+ newGenerate)
+ case Inline(_) =>
+ val unnestOutput = {
+ val struct = CreateStruct(generate.generatorOutput)
+ val alias = Alias(struct, generatePostAliasName)()
+ alias.toAttribute
+ }
+ val newGenerate = generate.copy(generatorOutput = Seq(unnestOutput))
+ val newOutput = generate.generatorOutput.zipWithIndex.map {
+ case (attr, i) =>
+ val getStructField = GetStructField(unnestOutput, i,
Some(attr.name))
+ Alias(getStructField, generatePostAliasName)(attr.exprId,
attr.qualifier)
+ }
+ ProjectExec(generate.requiredChildOutput ++ newOutput, newGenerate)
+ case _ => generate
+ }
+ } else {
+ generate
}
}
}
diff --git
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
index c81a60430..55f8ee6f5 100644
---
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
+++
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
@@ -787,8 +787,32 @@ class TestOperator extends VeloxWholeStageTransformerSuite
{
}
test("test inline function") {
+ // Literal: func(literal)
+ runQueryAndCompare(s"""
+ |SELECT inline(array(
+ | named_struct('c1', 0, 'c2', 1),
+ | named_struct('c1', 2, 'c2', null)));
+ |""".stripMargin) {
+ checkOperatorMatch[GenerateExecTransformer]
+ }
+
+ // CreateArray: func(array(col))
withTempView("t1") {
- sql("""select * from values
+ sql("""SELECT * from values
+ | (named_struct('c1', 0, 'c2', 1)),
+ | (named_struct('c1', 2, 'c2', null)),
+ | (null)
+ |as tbl(a)
+ """.stripMargin).createOrReplaceTempView("t1")
+ runQueryAndCompare(s"""
+ |SELECT inline(array(a)) from t1;
+ |""".stripMargin) {
+ checkOperatorMatch[GenerateExecTransformer]
+ }
+ }
+
+ withTempView("t2") {
+ sql("""SELECT * from values
| array(
| named_struct('c1', 0, 'c2', 1),
| null,
@@ -800,13 +824,21 @@ class TestOperator extends
VeloxWholeStageTransformerSuite {
| named_struct('c1', 2, 'c2', 3)
| )
|as tbl(a)
- """.stripMargin).createOrReplaceTempView("t1")
+ """.stripMargin).createOrReplaceTempView("t2")
runQueryAndCompare("""
- |SELECT inline(a) from t1;
+ |SELECT inline(a) from t2;
|""".stripMargin) {
checkOperatorMatch[GenerateExecTransformer]
}
}
+
+ // Fallback for array(struct(...), null) literal.
+ runQueryAndCompare(s"""
+ |SELECT inline(array(
+ | named_struct('c1', 0, 'c2', 1),
+ | named_struct('c1', 2, 'c2', null),
+ | null));
+ |""".stripMargin)(_)
}
test("test array functions") {
diff --git
a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala
b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala
index 557681558..52e122c2b 100644
---
a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala
+++
b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala
@@ -136,5 +136,6 @@ class VeloxLiteralSuite extends
VeloxWholeStageTransformerSuite {
test("Literal Fallback") {
validateFallbackResult("SELECT struct(cast(null as struct<a: string>))")
+ validateFallbackResult("SELECT array(struct(1, 'a'), null)")
}
}
diff --git a/cpp/velox/substrait/SubstraitToVeloxExpr.cc
b/cpp/velox/substrait/SubstraitToVeloxExpr.cc
index f795f9c9e..8699907de 100644
--- a/cpp/velox/substrait/SubstraitToVeloxExpr.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxExpr.cc
@@ -49,15 +49,18 @@ MapVectorPtr makeMapVector(const VectorPtr& keyVector,
const VectorPtr& valueVec
valueVector);
}
-RowVectorPtr makeRowVector(const std::vector<VectorPtr>& children) {
+RowVectorPtr makeRowVector(
+ const std::vector<VectorPtr>& children,
+ std::vector<std::string>&& names,
+ size_t length,
+ facebook::velox::memory::MemoryPool* pool) {
std::vector<std::shared_ptr<const Type>> types;
types.resize(children.size());
for (int i = 0; i < children.size(); i++) {
types[i] = children[i]->type();
}
- const size_t vectorSize = children.empty() ? 0 : children.front()->size();
- auto rowType = ROW(std::move(types));
- return std::make_shared<RowVector>(children[0]->pool(), rowType,
BufferPtr(nullptr), vectorSize, children);
+ auto rowType = ROW(std::move(names), std::move(types));
+ return std::make_shared<RowVector>(pool, rowType, BufferPtr(nullptr),
length, children);
}
ArrayVectorPtr makeEmptyArrayVector(memory::MemoryPool* pool, const TypePtr&
elementType) {
@@ -73,7 +76,7 @@ MapVectorPtr makeEmptyMapVector(memory::MemoryPool* pool,
const TypePtr& keyType
}
RowVectorPtr makeEmptyRowVector(memory::MemoryPool* pool) {
- return makeRowVector({});
+ return makeRowVector({}, {}, 0, pool);
}
template <typename T>
@@ -485,13 +488,20 @@ VectorPtr SubstraitVeloxExprConverter::literalsToVector(
}
RowVectorPtr SubstraitVeloxExprConverter::literalsToRowVector(const
::substrait::Expression::Literal& structLiteral) {
- auto childSize = structLiteral.struct_().fields().size();
- if (childSize == 0) {
+ if (structLiteral.has_null()) {
+ VELOX_NYI("NULL for struct type is not supported.");
+ }
+ auto numFields = structLiteral.struct_().fields().size();
+ if (numFields == 0) {
return makeEmptyRowVector(pool_);
}
std::vector<VectorPtr> vectors;
- vectors.reserve(structLiteral.struct_().fields().size());
- for (const auto& child : structLiteral.struct_().fields()) {
+ std::vector<std::string> names;
+ vectors.reserve(numFields);
+ names.reserve(numFields);
+ for (auto i = 0; i < numFields; ++i) {
+ names.push_back("col_" + std::to_string(i));
+ const auto& child = structLiteral.struct_().fields(i);
auto typeCase = child.literal_type_case();
switch (typeCase) {
case
::substrait::Expression_Literal::LiteralTypeCase::kIntervalDayToSecond: {
@@ -530,7 +540,7 @@ RowVectorPtr
SubstraitVeloxExprConverter::literalsToRowVector(const ::substrait:
}
}
}
- return makeRowVector(vectors);
+ return makeRowVector(vectors, std::move(names), 1, pool_);
}
core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr(
diff --git
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
index 1b892ff59..759f7cfad 100644
---
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode,
Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{FileSourceScanExec, LeafExecNode,
SparkPlan}
+import org.apache.spark.sql.execution.{FileSourceScanExec, GenerateExec,
LeafExecNode, SparkPlan}
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec}
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -653,4 +653,8 @@ trait SparkPlanExecApi {
generatorOutput: Seq[Attribute],
child: SparkPlan
): GenerateExecTransformerBase
+
+ def genPreProjectForGenerate(generate: GenerateExec): SparkPlan
+
+ def genPostProjectForGenerate(generate: GenerateExec): SparkPlan
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala
b/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala
index 285734f38..f3e31346c 100644
---
a/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala
@@ -39,6 +39,7 @@ abstract class GenerateExecTransformerBase(
generatorOutput: Seq[Attribute],
child: SparkPlan)
extends UnaryTransformSupport {
+
protected def doGeneratorValidate(generator: Generator, outer: Boolean):
ValidationResult
protected def getRelNode(
diff --git
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala
index 0a39ef819..a77e063c5 100644
---
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala
@@ -21,7 +21,7 @@ import io.glutenproject.utils.PullOutProjectHelper
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
NamedExpression, WindowExpression}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
+import org.apache.spark.sql.execution.{GenerateExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.window.WindowExec
@@ -104,6 +104,9 @@ object PullOutPostProject extends Rule[SparkPlan] with
PullOutProjectHelper {
window.copy(windowExpression =
newWindowExpressions.asInstanceOf[Seq[NamedExpression]])
ProjectExec(window.child.output ++ postWindowExpressions, newWindow)
+ case generate: GenerateExec =>
+
BackendsApiManager.getSparkPlanExecApiInstance.genPostProjectForGenerate(generate)
+
case _ => plan
}
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
index 440f609de..92d6c9fab 100644
---
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
@@ -16,12 +16,13 @@
*/
package io.glutenproject.extension.columnar
+import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.utils.PullOutProjectHelper
import org.apache.spark.sql.catalyst.expressions._
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Complete, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ExpandExec, ProjectExec, SortExec,
SparkPlan, TakeOrderedAndProjectExec}
+import org.apache.spark.sql.execution.{ExpandExec, GenerateExec, ProjectExec,
SortExec, SparkPlan, TakeOrderedAndProjectExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
TypedAggregateExpression}
import org.apache.spark.sql.execution.window.WindowExec
@@ -189,6 +190,10 @@ object PullOutPreProject extends Rule[SparkPlan] with
PullOutProjectHelper {
child = ProjectExec(
eliminateProjectList(expand.child.outputSet,
expressionMap.values.toSeq),
expand.child))
+
+ case generate: GenerateExec =>
+
BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForGenerate(generate)
+
case _ => plan
}
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
index 8f3f01f95..b2591f048 100644
---
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
@@ -54,6 +54,7 @@ class RewriteSparkPlanRulesManager(rewriteRules:
Seq[Rule[SparkPlan]]) extends R
case _: FilterExec => true
case _: FileSourceScanExec => true
case _: ExpandExec => true
+ case _: GenerateExec => true
case _ => false
}
}
@@ -77,7 +78,7 @@ class RewriteSparkPlanRulesManager(rewriteRules:
Seq[Rule[SparkPlan]]) extends R
// Some rewrite rules may generate new parent plan node, we should
use transform to
// rewrite the original plan. For example, PullOutPreProject and
PullOutPostProject
// will generate post-project plan node.
- plan.transform { case p => rule.apply(p) }
+ plan.transformUp { case p => rule.apply(p) }
}
(rewrittenPlan, None)
} catch {
diff --git
a/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala
b/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala
index a519772fc..543a5413c 100644
---
a/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala
@@ -30,8 +30,8 @@ trait PullOutProjectHelper {
private val generatedNameIndex = new AtomicInteger(0)
- protected def generatePreAliasName =
s"_pre_${generatedNameIndex.getAndIncrement()}"
- protected def generatePostAliasName =
s"_post_${generatedNameIndex.getAndIncrement()}"
+ protected def generatePreAliasName: String =
s"_pre_${generatedNameIndex.getAndIncrement()}"
+ protected def generatePostAliasName: String =
s"_post_${generatedNameIndex.getAndIncrement()}"
/**
* The majority of Expressions only support Attribute and BoundReference
when converting them into
@@ -57,12 +57,13 @@ trait PullOutProjectHelper {
protected def replaceExpressionWithAttribute(
expr: Expression,
- projectExprsMap: mutable.HashMap[Expression, NamedExpression]):
Expression =
+ projectExprsMap: mutable.HashMap[Expression, NamedExpression],
+ replaceBoundReference: Boolean = false): Expression =
expr match {
case alias: Alias =>
projectExprsMap.getOrElseUpdate(alias.child.canonicalized,
alias).toAttribute
case attr: Attribute => attr
- case e: BoundReference => e
+ case e: BoundReference if !replaceBoundReference => e
case other =>
projectExprsMap
.getOrElseUpdate(other.canonicalized, Alias(other,
generatePreAliasName)())
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]