This is an automated email from the ASF dual-hosted git repository.
yuanzhou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new b1fe7a33d [VL] Support posexplode function and code refactoring on
GenerateExecTransformer (#4901)
b1fe7a33d is described below
commit b1fe7a33d8d3822150abd56441115d8f8cab21fc
Author: Rong Ma <[email protected]>
AuthorDate: Wed Mar 13 13:12:40 2024 +0800
[VL] Support posexplode function and code refactoring on
GenerateExecTransformer (#4901)
This patch contains 2 major changes:
Support posexplode for Velox backend.
Separate GenerateExecTransformer for Velox/CH backends. For Velox backend,
the RelNode creation requires pre/projection for different generators. For CH
backend, the RelNode creation is kept clean.
---
.../backendsapi/clickhouse/CHMetricsApi.scala | 16 --
.../clickhouse/CHSparkPlanExecApi.scala | 10 +
.../execution/CHGenerateExecTransformer.scala | 91 ++++++++
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 8 +-
.../metrics/GlutenClickHouseTPCHMetricsSuite.scala | 2 +-
.../backendsapi/velox/MetricsApiImpl.scala | 7 -
.../backendsapi/velox/SparkPlanExecApiImpl.scala | 21 +-
.../backendsapi/velox/ValidatorApiImpl.scala | 26 +--
.../execution/GenerateExecTransformer.scala | 232 +++++++++++++++++++++
.../io/glutenproject/execution/TestOperator.scala | 91 ++++++--
cpp/velox/substrait/SubstraitToVeloxPlan.cc | 10 +-
.../substrait/SubstraitToVeloxPlanValidator.cc | 1 -
.../io/glutenproject/backendsapi/MetricsApi.scala | 4 -
.../backendsapi/SparkPlanExecApi.scala | 8 +
.../glutenproject/backendsapi/ValidatorApi.scala | 7 +-
.../execution/GenerateExecTransformer.scala | 224 --------------------
.../execution/GenerateExecTransformerBase.scala | 97 +++++++++
.../extension/columnar/MiscColumnarRules.scala | 2 +-
.../extension/columnar/TransformHintRule.scala | 2 +-
19 files changed, 543 insertions(+), 316 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala
index 3d47063bf..cda406872 100644
---
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala
+++
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala
@@ -356,21 +356,6 @@ class CHMetricsApi extends MetricsApi with Logging with
LogLevelUtil {
throw new UnsupportedOperationException(
s"NestedLoopJoinTransformer metrics update is not supported in CH
backend")
}
- override def genGenerateTransformerMetrics(sparkContext: SparkContext):
Map[String, SQLMetric] =
- Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of
output rows"),
- "outputVectors" -> SQLMetrics.createMetric(sparkContext, "number of
output vectors"),
- "outputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of
output bytes"),
- "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input
rows"),
- "inputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of
input bytes"),
- "extraTime" -> SQLMetrics.createTimingMetric(sparkContext, "extra
operators time"),
- "inputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of
waiting for data"),
- "outputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of
waiting for output"),
- "totalTime" -> SQLMetrics.createTimingMetric(sparkContext, "total time")
- )
-
- override def genGenerateTransformerMetricsUpdater(
- metrics: Map[String, SQLMetric]): MetricsUpdater = new
GenerateMetricsUpdater(metrics)
def genWriteFilesTransformerMetrics(sparkContext: SparkContext): Map[String,
SQLMetric] = {
throw new UnsupportedOperationException(
@@ -381,5 +366,4 @@ class CHMetricsApi extends MetricsApi with Logging with
LogLevelUtil {
throw new UnsupportedOperationException(
s"WriteFilesTransformer metrics update is not supported in CH backend")
}
-
}
diff --git
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index f32116728..17a9e8d67 100644
---
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -735,4 +735,14 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
case _ => super.postProcessPushDownFilter(extraFilters, sparkExecNode)
}
}
+
+ override def genGenerateTransformer(
+ generator: Generator,
+ requiredChildOutput: Seq[Attribute],
+ outer: Boolean,
+ generatorOutput: Seq[Attribute],
+ child: SparkPlan
+ ): GenerateExecTransformerBase = {
+ CHGenerateExecTransformer(generator, requiredChildOutput, outer,
generatorOutput, child)
+ }
}
diff --git
a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHGenerateExecTransformer.scala
b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHGenerateExecTransformer.scala
new file mode 100644
index 000000000..b4d84141e
--- /dev/null
+++
b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHGenerateExecTransformer.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.execution
+
+import io.glutenproject.extension.ValidationResult
+import io.glutenproject.metrics.{GenerateMetricsUpdater, MetricsUpdater}
+import io.glutenproject.substrait.SubstraitContext
+import io.glutenproject.substrait.expression.ExpressionNode
+import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+import scala.collection.JavaConverters._
+
+// Transformer for GeneratorExec, which Applies a [[Generator]] to a stream of
input rows.
+// For clickhouse backend, it will transform Spark explode lateral view to CH
array join.
+case class CHGenerateExecTransformer(
+ generator: Generator,
+ requiredChildOutput: Seq[Attribute],
+ outer: Boolean,
+ generatorOutput: Seq[Attribute],
+ child: SparkPlan)
+ extends GenerateExecTransformerBase(
+ generator,
+ requiredChildOutput,
+ outer,
+ generatorOutput,
+ child) {
+
+ @transient
+ override lazy val metrics =
+ Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of
output rows"),
+ "outputVectors" -> SQLMetrics.createMetric(sparkContext, "number of
output vectors"),
+ "outputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of
output bytes"),
+ "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input
rows"),
+ "inputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of
input bytes"),
+ "extraTime" -> SQLMetrics.createTimingMetric(sparkContext, "extra
operators time"),
+ "inputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of
waiting for data"),
+ "outputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of
waiting for output"),
+ "totalTime" -> SQLMetrics.createTimingMetric(sparkContext, "total time")
+ )
+
+ override def metricsUpdater(): MetricsUpdater = new
GenerateMetricsUpdater(metrics)
+ override protected def withNewChildInternal(newChild: SparkPlan):
CHGenerateExecTransformer =
+ copy(generator, requiredChildOutput, outer, generatorOutput, newChild)
+
+ override protected def doGeneratorValidate(
+ generator: Generator,
+ outer: Boolean): ValidationResult =
+ ValidationResult.ok
+
+ override protected def getRelNode(
+ context: SubstraitContext,
+ inputRel: RelNode,
+ generatorNode: ExpressionNode,
+ validation: Boolean): RelNode = {
+ if (!validation) {
+ RelBuilder.makeGenerateRel(
+ inputRel,
+ generatorNode,
+ requiredChildOutputNodes.asJava,
+ context,
+ context.nextOperatorId(this.nodeName))
+ } else {
+ RelBuilder.makeGenerateRel(
+ inputRel,
+ generatorNode,
+ requiredChildOutputNodes.asJava,
+ getExtensionNodeForValidation,
+ context,
+ context.nextOperatorId(this.nodeName))
+ }
+ }
+}
diff --git
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 92ba4cb13..83945dbf2 100644
---
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -1278,14 +1278,14 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
val sql = """
| select id from test_1767 lateral view
| posexplode(split(data['k'], ',')) tx as a, b""".stripMargin
- runQueryAndCompare(sql)(checkOperatorMatch[GenerateExecTransformer])
+ runQueryAndCompare(sql)(checkOperatorMatch[CHGenerateExecTransformer])
spark.sql("drop table test_1767")
}
test("test posexplode issue:
https://github.com/oap-project/gluten/issues/2492") {
val sql = "select posexplode(split(n_comment, ' ')) from nation where
n_comment is null"
- runQueryAndCompare(sql)(checkOperatorMatch[GenerateExecTransformer])
+ runQueryAndCompare(sql)(checkOperatorMatch[CHGenerateExecTransformer])
}
test("test posexplode issue:
https://github.com/oap-project/gluten/issues/2454") {
@@ -1297,7 +1297,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
)
for (sql <- sqls) {
- runQueryAndCompare(sql)(checkOperatorMatch[GenerateExecTransformer])
+ runQueryAndCompare(sql)(checkOperatorMatch[CHGenerateExecTransformer])
}
}
@@ -1306,7 +1306,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
spark.sql("insert into test_3124 values (31, null, 'm'), (32, 'a,b,c',
'f')")
val sql = "select id, flag from test_3124 lateral view explode(split(name,
',')) as flag"
- runQueryAndCompare(sql)(checkOperatorMatch[GenerateExecTransformer])
+ runQueryAndCompare(sql)(checkOperatorMatch[CHGenerateExecTransformer])
spark.sql("drop table test_3124")
}
diff --git
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
index 6ed720ca6..49846280e 100644
---
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
+++
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
@@ -85,7 +85,7 @@ class GlutenClickHouseTPCHMetricsSuite extends
GlutenClickHouseTPCHAbstractSuite
runQueryAndCompare(sql) {
df =>
val plans = df.queryExecution.executedPlan.collect {
- case generate: GenerateExecTransformer => generate
+ case generate: CHGenerateExecTransformer => generate
}
assert(plans.size == 1)
assert(plans.head.metrics("numInputRows").value == 25)
diff --git
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/MetricsApiImpl.scala
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/MetricsApiImpl.scala
index 71c2642ff..9b49b9282 100644
---
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/MetricsApiImpl.scala
+++
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/MetricsApiImpl.scala
@@ -534,11 +534,4 @@ class MetricsApiImpl extends MetricsApi with Logging {
override def genNestedLoopJoinTransformerMetricsUpdater(
metrics: Map[String, SQLMetric]): MetricsUpdater = new
NestedLoopJoinMetricsUpdater(metrics)
-
- override def genGenerateTransformerMetrics(sparkContext: SparkContext):
Map[String, SQLMetric] = {
- Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of
output rows"))
- }
-
- override def genGenerateTransformerMetricsUpdater(
- metrics: Map[String, SQLMetric]): MetricsUpdater = new
GenerateMetricsUpdater(metrics)
}
diff --git
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
index 2ab61f271..503463fcb 100644
---
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
+++
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
@@ -36,7 +36,7 @@ import
org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast,
CreateNamedStruct, ElementAt, Expression, ExpressionInfo, GetArrayItem,
GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression,
NaNvl, Round, StringSplit, StringTrim}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast,
CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator,
GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash,
NamedExpression, NaNvl, PosExplode, Round, StringSplit, StringTrim}
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
HLLAdapter}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
@@ -136,6 +136,15 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, Seq(child), expr)
}
+ /** Transform inline to Substrait. */
+ override def genPosExplodeTransformer(
+ substraitExprName: String,
+ child: ExpressionTransformer,
+ original: PosExplode,
+ attrSeq: Seq[Attribute]): ExpressionTransformer = {
+ GenericExpressionTransformer(substraitExprName, Seq(child), attrSeq.head)
+ }
+
/** Transform inline to Substrait. */
override def genInlineTransformer(
substraitExprName: String,
@@ -623,4 +632,14 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
throw new IllegalStateException(s"Unsupported fs: $other")
}
}
+
+ override def genGenerateTransformer(
+ generator: Generator,
+ requiredChildOutput: Seq[Attribute],
+ outer: Boolean,
+ generatorOutput: Seq[Attribute],
+ child: SparkPlan
+ ): GenerateExecTransformerBase = {
+ GenerateExecTransformer(generator, requiredChildOutput, outer,
generatorOutput, child)
+ }
}
diff --git
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/ValidatorApiImpl.scala
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/ValidatorApiImpl.scala
index b8e22becc..a2394e6c9 100644
---
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/ValidatorApiImpl.scala
+++
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/ValidatorApiImpl.scala
@@ -17,12 +17,11 @@
package io.glutenproject.backendsapi.velox
import io.glutenproject.backendsapi.ValidatorApi
-import io.glutenproject.extension.ValidationResult
import io.glutenproject.substrait.plan.PlanNode
import io.glutenproject.validate.NativePlanValidationInfo
import io.glutenproject.vectorized.NativePlanEvaluator
-import org.apache.spark.sql.catalyst.expressions.{CreateMap, Explode,
Expression, Generator, JsonTuple, PosExplode}
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types._
@@ -80,27 +79,4 @@ class ValidatorApiImpl extends ValidatorApi {
child: SparkPlan): Option[String] = {
doSchemaValidate(child.schema)
}
-
- override def doGeneratorValidate(generator: Generator, outer: Boolean):
ValidationResult = {
- if (outer) {
- return ValidationResult.notOk(s"Velox backend does not support outer")
- }
- generator match {
- case _: JsonTuple =>
- ValidationResult.notOk(s"Velox backend does not support this
json_tuple")
- case _: PosExplode =>
- // TODO(yuan): support posexplode and remove this check
- ValidationResult.notOk(s"Velox backend does not support this
posexplode")
- case explode: Explode =>
- explode.child match {
- case _: CreateMap =>
- // explode(MAP(col1, col2))
- ValidationResult.notOk(s"Velox backend does not support MAP
datatype")
- case _ =>
- ValidationResult.ok
- }
- case _ =>
- ValidationResult.ok
- }
- }
}
diff --git
a/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
b/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
new file mode 100644
index 000000000..5bdfba200
--- /dev/null
+++
b/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.execution
+
+import io.glutenproject.backendsapi.BackendsApiManager
+import io.glutenproject.expression.{ConverterUtils, ExpressionConverter,
ExpressionNames}
+import io.glutenproject.expression.ConverterUtils.FunctionConfig
+import io.glutenproject.extension.ValidationResult
+import io.glutenproject.metrics.{GenerateMetricsUpdater, MetricsUpdater}
+import io.glutenproject.substrait.`type`.TypeBuilder
+import io.glutenproject.substrait.SubstraitContext
+import io.glutenproject.substrait.expression.{ExpressionBuilder,
ExpressionNode}
+import io.glutenproject.substrait.extensions.{AdvancedExtensionNode,
ExtensionBuilder}
+import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructType}
+
+import com.google.common.collect.Lists
+import com.google.protobuf.StringValue
+
+import scala.collection.JavaConverters._
+
+case class GenerateExecTransformer(
+ generator: Generator,
+ requiredChildOutput: Seq[Attribute],
+ outer: Boolean,
+ generatorOutput: Seq[Attribute],
+ child: SparkPlan)
+ extends GenerateExecTransformerBase(
+ generator,
+ requiredChildOutput,
+ outer,
+ generatorOutput,
+ child) {
+
+ @transient
+ override lazy val metrics =
+ Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of
output rows"))
+
+ override def metricsUpdater(): MetricsUpdater = new
GenerateMetricsUpdater(metrics)
+
+ override protected def withNewChildInternal(newChild: SparkPlan):
GenerateExecTransformer =
+ copy(generator, requiredChildOutput, outer, generatorOutput, newChild)
+
+ override protected def doGeneratorValidate(
+ generator: Generator,
+ outer: Boolean): ValidationResult = {
+ if (outer) {
+ return ValidationResult.notOk(s"Velox backend does not support outer")
+ }
+ generator match {
+ case _: JsonTuple =>
+ ValidationResult.notOk(s"Velox backend does not support this
json_tuple")
+ case _: ExplodeBase =>
+ ValidationResult.ok
+ case Inline(child) =>
+ child match {
+ case AttributeReference(_, ArrayType(_: StructType, _), _, _) =>
+ ValidationResult.ok
+ case _ =>
+ // TODO: Support Literal/CreateArray.
+ ValidationResult.notOk(
+ s"Velox backend does not support inline with expression " +
+ s"${child.getClass.getSimpleName}.")
+ }
+ case _ =>
+ ValidationResult.ok
+ }
+ }
+
+ override protected def getRelNode(
+ context: SubstraitContext,
+ inputRel: RelNode,
+ generatorNode: ExpressionNode,
+ validation: Boolean): RelNode = {
+ val operatorId = context.nextOperatorId(this.nodeName)
+
+ val newInput = if (!validation) {
+ applyPreProject(inputRel, context, operatorId)
+ } else {
+ // No need to validate the pre-projection. The generator output has been
validated in
+ // doGeneratorValidate.
+ inputRel
+ }
+
+ val generateRel = RelBuilder.makeGenerateRel(
+ newInput,
+ generatorNode,
+ requiredChildOutputNodes.asJava,
+ getExtensionNode(validation),
+ context,
+ operatorId)
+
+ if (!validation) {
+ applyPostProject(generateRel, context, operatorId)
+ } else {
+ // No need to validate the post-projection on the native side as
+ // it only flattens the generator's output.
+ generateRel
+ }
+ }
+
+ private def getExtensionNode(validation: Boolean): AdvancedExtensionNode = {
+ if (!validation) {
+ // Start with "GenerateParameters:"
+ val parametersStr = new StringBuffer("GenerateParameters:")
+ // isPosExplode: 1 for PosExplode, 0 for others.
+ val isPosExplode = if (generator.isInstanceOf[PosExplode]) {
+ "1"
+ } else {
+ "0"
+ }
+ parametersStr
+ .append("isPosExplode=")
+ .append(isPosExplode)
+ .append("\n")
+ val message = StringValue
+ .newBuilder()
+ .setValue(parametersStr.toString)
+ .build()
+ val optimization =
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
+ ExtensionBuilder.makeAdvancedExtension(optimization, null)
+ } else {
+ getExtensionNodeForValidation
+ }
+ }
+
+ // Select child outputs and append generator input.
+ private def applyPreProject(
+ inputRel: RelNode,
+ context: SubstraitContext,
+ operatorId: Long
+ ): RelNode = {
+ val projectExpressions: Seq[ExpressionNode] =
+ child.output.indices
+ .map(ExpressionBuilder.makeSelection(_)) :+
+ ExpressionConverter
+ .replaceWithExpressionTransformer(
+ generator.asInstanceOf[UnaryExpression].child,
+ child.output)
+ .doTransform(context.registeredFunction)
+
+ RelBuilder.makeProjectRel(
+ inputRel,
+ projectExpressions.asJava,
+ context,
+ operatorId,
+ child.output.size)
+ }
+
+ // There are 3 types of CollectionGenerator in spark: Explode, PosExplode
and Inline.
+ // Adds postProject for PosExplode and Inline.
+ private def applyPostProject(
+ generateRel: RelNode,
+ context: SubstraitContext,
+ operatorId: Long): RelNode = {
+ generator match {
+ case Inline(_) =>
+ val requiredOutput = requiredChildOutputNodes.indices.map {
+ ExpressionBuilder.makeSelection(_)
+ }
+ val flattenStruct: Seq[ExpressionNode] = generatorOutput.indices.map {
+ i =>
+ val selectionNode =
ExpressionBuilder.makeSelection(requiredOutput.size)
+ selectionNode.addNestedChildIdx(i)
+ }
+ RelBuilder.makeProjectRel(
+ generateRel,
+ (requiredOutput ++ flattenStruct).asJava,
+ context,
+ operatorId,
+ 1 + requiredOutput.size // 1 stands for the inner struct field from
array.
+ )
+ case PosExplode(posExplodeChild) =>
+ // Ordinal populated by Velox UnnestNode starts with 1.
+ // Need to substract 1 to align with Spark's output.
+ val unnestedSize = posExplodeChild.dataType match {
+ case _: MapType => 2
+ case _: ArrayType => 1
+ }
+ val subFunctionName = ConverterUtils.makeFuncName(
+ ExpressionNames.SUBTRACT,
+ Seq(LongType, LongType),
+ FunctionConfig.OPT)
+ val functionMap = context.registeredFunction
+ val addFunctionId = ExpressionBuilder.newScalarFunction(functionMap,
subFunctionName)
+ val literalNode = ExpressionBuilder.makeLiteral(1L, LongType, false)
+ val ordinalNode = ExpressionBuilder.makeCast(
+ TypeBuilder.makeI32(false),
+ ExpressionBuilder.makeScalarFunction(
+ addFunctionId,
+ Lists.newArrayList(
+ ExpressionBuilder.makeSelection(requiredChildOutputNodes.size +
unnestedSize),
+ literalNode),
+ ConverterUtils.getTypeNode(LongType,
generator.elementSchema.head.nullable)
+ ),
+ true // Generated ordinal column shouldn't have null.
+ )
+ val requiredChildNodes =
+
requiredChildOutputNodes.indices.map(ExpressionBuilder.makeSelection(_))
+ val unnestColumns = (0 until unnestedSize)
+ .map(i => ExpressionBuilder.makeSelection(i +
requiredChildOutputNodes.size))
+ val generatorOutput: Seq[ExpressionNode] =
+ (requiredChildNodes :+ ordinalNode) ++ unnestColumns
+ RelBuilder.makeProjectRel(
+ generateRel,
+ generatorOutput.asJava,
+ context,
+ operatorId,
+ generatorOutput.size
+ )
+ case _ => generateRel
+ }
+ }
+}
diff --git
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
index 961a0c3de..486a3c057 100644
---
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
+++
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
@@ -714,32 +714,79 @@ class TestOperator extends
VeloxWholeStageTransformerSuite {
}
}
- test("test explode function") {
- runQueryAndCompare("""
- |SELECT explode(array(1, 2, 3));
- |""".stripMargin) {
- checkOperatorMatch[GenerateExecTransformer]
- }
- runQueryAndCompare("""
- |SELECT explode(map(1, 'a', 2, 'b'));
- |""".stripMargin) {
- checkOperatorMatch[GenerateExecTransformer]
- }
- runQueryAndCompare(
- """
- |SELECT explode(array(map(1, 'a', 2, 'b'), map(3, 'c', 4, 'd'), map(5,
'e', 6, 'f')));
- |""".stripMargin) {
- checkOperatorMatch[GenerateExecTransformer]
- }
- runQueryAndCompare("""
- |SELECT explode(map(1, array(1, 2), 2, array(3, 4)));
- |""".stripMargin) {
- checkOperatorMatch[GenerateExecTransformer]
+ test("test explode/posexplode function") {
+ Seq("explode", "posexplode").foreach {
+ func =>
+ // Literal: func(literal)
+ runQueryAndCompare(s"""
+ |SELECT $func(array(1, 2, 3));
+ |""".stripMargin) {
+ checkOperatorMatch[GenerateExecTransformer]
+ }
+ runQueryAndCompare(s"""
+ |SELECT $func(map(1, 'a', 2, 'b'));
+ |""".stripMargin) {
+ checkOperatorMatch[GenerateExecTransformer]
+ }
+ runQueryAndCompare(
+ s"""
+ |SELECT $func(array(map(1, 'a', 2, 'b'), map(3, 'c', 4, 'd'),
map(5, 'e', 6, 'f')));
+ |""".stripMargin) {
+ checkOperatorMatch[GenerateExecTransformer]
+ }
+ runQueryAndCompare(s"""
+ |SELECT $func(map(1, array(1, 2), 2, array(3,
4)));
+ |""".stripMargin) {
+ checkOperatorMatch[GenerateExecTransformer]
+ }
+
+ // CreateArray/CreateMap: func(array(col)), func(map(k, v))
+ withTempView("t1") {
+ sql("""select * from values (1), (2), (3), (4)
+ |as tbl(a)
+ """.stripMargin).createOrReplaceTempView("t1")
+ runQueryAndCompare(s"""
+ |SELECT $func(array(a)) from t1;
+ |""".stripMargin) {
+ checkOperatorMatch[GenerateExecTransformer]
+ }
+ sql("""select * from values (1, 'a'), (2, 'b'), (3, null), (4, null)
+ |as tbl(a, b)
+ """.stripMargin).createOrReplaceTempView("t1")
+ runQueryAndCompare(s"""
+ |SELECT $func(map(a, b)) from t1;
+ |""".stripMargin) {
+ checkOperatorMatch[GenerateExecTransformer]
+ }
+ }
+
+ // AttributeReference: func(col)
+ withTempView("t2") {
+ sql("""select * from values
+ | array(1, 2, 3),
+ | array(4, null)
+ |as tbl(a)
+ """.stripMargin).createOrReplaceTempView("t2")
+ runQueryAndCompare(s"""
+ |SELECT $func(a) from t2;
+ |""".stripMargin) {
+ checkOperatorMatch[GenerateExecTransformer]
+ }
+ sql("""select * from values
+ | map(1, 'a', 2, 'b', 3, null),
+ | map(4, null)
+ |as tbl(a)
+ """.stripMargin).createOrReplaceTempView("t2")
+ runQueryAndCompare(s"""
+ |SELECT $func(a) from t2;
+ |""".stripMargin) {
+ checkOperatorMatch[GenerateExecTransformer]
+ }
+ }
}
}
test("test inline function") {
-
withTempView("t1") {
sql("""select * from values
| array(
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index d9ebe2902..b8ca25a43 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -766,10 +766,14 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
}
}
- auto node = std::make_shared<core::UnnestNode>(
- nextPlanNodeId(), replicated, unnest, std::move(unnestNames),
std::nullopt, childNode);
+ std::optional<std::string> ordinalityName = std::nullopt;
+ if (generateRel.has_advanced_extension() &&
+
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(),
"isPosExplode=")) {
+ ordinalityName = std::make_optional<std::string>("pos");
+ }
- return node;
+ return std::make_shared<core::UnnestNode>(
+ nextPlanNodeId(), replicated, unnest, std::move(unnestNames),
ordinalityName, childNode);
}
const core::WindowNode::Frame createWindowFrame(
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
index b15b18872..0e81ba91a 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
@@ -67,7 +67,6 @@ static const std::unordered_set<std::string> kBlackList = {
"repeat",
"trunc",
"sequence",
- "posexplode",
"arrays_overlap",
"approx_percentile",
"get_array_struct_fields"};
diff --git
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/MetricsApi.scala
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/MetricsApi.scala
index c0448af78..46f5c82b9 100644
--- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/MetricsApi.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/MetricsApi.scala
@@ -111,10 +111,6 @@ trait MetricsApi extends Serializable {
def genNestedLoopJoinTransformerMetricsUpdater(metrics: Map[String,
SQLMetric]): MetricsUpdater
- def genGenerateTransformerMetrics(sparkContext: SparkContext): Map[String,
SQLMetric]
-
- def genGenerateTransformerMetricsUpdater(metrics: Map[String, SQLMetric]):
MetricsUpdater
-
def genColumnarInMemoryTableMetrics(sparkContext: SparkContext): Map[String,
SQLMetric] =
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of
output rows"))
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
index 4379745cc..8eec0e38d 100644
---
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
@@ -650,4 +650,12 @@ trait SparkPlanExecApi {
s"${sparkExecNode.getClass.toString} is not supported.")
}
}
+
+ def genGenerateTransformer(
+ generator: Generator,
+ requiredChildOutput: Seq[Attribute],
+ outer: Boolean,
+ generatorOutput: Seq[Attribute],
+ child: SparkPlan
+ ): GenerateExecTransformerBase
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/ValidatorApi.scala
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/ValidatorApi.scala
index 7a11268d3..04d80e42e 100644
--- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/ValidatorApi.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/ValidatorApi.scala
@@ -16,11 +16,10 @@
*/
package io.glutenproject.backendsapi
-import io.glutenproject.extension.ValidationResult
import io.glutenproject.substrait.plan.PlanNode
import io.glutenproject.validate.NativePlanValidationInfo
-import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.DataType
@@ -64,8 +63,4 @@ trait ValidatorApi {
def doColumnarShuffleExchangeExecValidate(
outputPartitioning: Partitioning,
child: SparkPlan): Option[String]
-
- /** Validate against Generator expression. */
- def doGeneratorValidate(generator: Generator, outer: Boolean):
ValidationResult =
- ValidationResult.ok
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
b/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
deleted file mode 100644
index 29e13bd7b..000000000
---
a/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala
+++ /dev/null
@@ -1,224 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package io.glutenproject.execution
-
-import io.glutenproject.backendsapi.BackendsApiManager
-import io.glutenproject.exception.GlutenException
-import io.glutenproject.expression.{ConverterUtils, ExpressionConverter}
-import io.glutenproject.extension.ValidationResult
-import io.glutenproject.metrics.MetricsUpdater
-import io.glutenproject.substrait.`type`.TypeBuilder
-import io.glutenproject.substrait.SubstraitContext
-import io.glutenproject.substrait.expression.{ExpressionBuilder,
ExpressionNode}
-import io.glutenproject.substrait.extensions.ExtensionBuilder
-import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkPlan
-
-import java.util.{ArrayList => JArrayList, List => JList}
-
-import scala.collection.JavaConverters._
-
-// Transformer for GeneratorExec, which Applies a [[Generator]] to a stream of
input rows.
-// For clickhouse backend, it will transform Spark explode lateral view to CH
array join.
-case class GenerateExecTransformer(
- generator: Generator,
- requiredChildOutput: Seq[Attribute],
- outer: Boolean,
- generatorOutput: Seq[Attribute],
- child: SparkPlan)
- extends UnaryTransformSupport {
-
- @transient
- override lazy val metrics =
-
BackendsApiManager.getMetricsApiInstance.genGenerateTransformerMetrics(sparkContext)
-
- override def output: Seq[Attribute] = requiredChildOutput ++ generatorOutput
-
- override def producedAttributes: AttributeSet = AttributeSet(generatorOutput)
-
- override protected def withNewChildInternal(newChild: SparkPlan):
GenerateExecTransformer =
- copy(generator, requiredChildOutput, outer, generatorOutput, newChild)
-
- override protected def doValidateInternal(): ValidationResult = {
- val validationResult =
-
BackendsApiManager.getValidatorApiInstance.doGeneratorValidate(generator, outer)
- if (!validationResult.isValid) {
- return validationResult
- }
- val context = new SubstraitContext
- val args = context.registeredFunction
-
- val operatorId = context.nextOperatorId(this.nodeName)
- val generatorExpr =
- ExpressionConverter.replaceWithExpressionTransformer(generator,
child.output)
- val generatorNode = generatorExpr.doTransform(args)
- val childOutputNodes = new java.util.ArrayList[ExpressionNode]
- for (target <- requiredChildOutput) {
- val found = child.output.zipWithIndex.filter(_._1.name == target.name)
- if (found.nonEmpty) {
- val exprNode = ExpressionBuilder.makeSelection(found.head._2)
- childOutputNodes.add(exprNode)
- } else {
- throw new GlutenException(s"Can't found column ${target.name} in child
output")
- }
- }
-
- val relNode =
- getRelNode(
- context,
- operatorId,
- child.output,
- null,
- generatorNode,
- childOutputNodes,
- validation = true)
-
- doNativeValidation(context, relNode)
- }
-
- override def doTransform(context: SubstraitContext): TransformContext = {
- val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
- val args = context.registeredFunction
- val operatorId = context.nextOperatorId(this.nodeName)
- val generatorExpr =
- ExpressionConverter.replaceWithExpressionTransformer(generator,
child.output)
- val generatorNode = generatorExpr.doTransform(args)
- val requiredChildOutputNodes = new JArrayList[ExpressionNode]
- for (target <- requiredChildOutput) {
- val found = child.output.zipWithIndex.filter(_._1.name == target.name)
- if (found.nonEmpty) {
- val exprNode = ExpressionBuilder.makeSelection(found.head._2)
- requiredChildOutputNodes.add(exprNode)
- } else {
- throw new GlutenException(s"Can't found column ${target.name} in child
output")
- }
- }
-
- val inputRel = childCtx.root
- val projRel =
- if (BackendsApiManager.getSettings.insertPostProjectForGenerate()) {
- // need to insert one projection node for velox backend
- val projectExpressions = new JArrayList[ExpressionNode]()
- val childOutputNodes = child.output.indices
- .map(i =>
ExpressionBuilder.makeSelection(i).asInstanceOf[ExpressionNode])
- .asJava
- projectExpressions.addAll(childOutputNodes)
- val projectExprNode = ExpressionConverter
- .replaceWithExpressionTransformer(
- generator.asInstanceOf[UnaryExpression].child,
- child.output)
- .doTransform(args)
-
- projectExpressions.add(projectExprNode)
-
- RelBuilder.makeProjectRel(
- inputRel,
- projectExpressions,
- context,
- operatorId,
- childOutputNodes.size)
- } else {
- inputRel
- }
-
- val relNode = getRelNode(
- context,
- operatorId,
- child.output,
- projRel,
- generatorNode,
- requiredChildOutputNodes,
- validation = false)
-
- TransformContext(child.output, output, relNode)
- }
-
- def getRelNode(
- context: SubstraitContext,
- operatorId: Long,
- inputAttributes: Seq[Attribute],
- input: RelNode,
- generatorNode: ExpressionNode,
- childOutput: JList[ExpressionNode],
- validation: Boolean): RelNode = {
- val generateRel = if (!validation) {
- RelBuilder.makeGenerateRel(input, generatorNode, childOutput, context,
operatorId)
- } else {
- // Use a extension node to send the input types through Substrait plan
for validation.
- val inputTypeNodeList =
- inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType,
attr.nullable)).asJava
- val extensionNode = ExtensionBuilder.makeAdvancedExtension(
- BackendsApiManager.getTransformerApiInstance.packPBMessage(
- TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
- RelBuilder.makeGenerateRel(
- input,
- generatorNode,
- childOutput,
- extensionNode,
- context,
- operatorId)
- }
- applyPostProjectOnGenerator(generateRel, context, operatorId, childOutput,
validation)
- }
-
- // There are 3 types of CollectionGenerator in spark: Explode, PosExplode
and Inline.
- // Only Inline needs the post projection.
- private def applyPostProjectOnGenerator(
- generateRel: RelNode,
- context: SubstraitContext,
- operatorId: Long,
- childOutput: JList[ExpressionNode],
- validation: Boolean): RelNode = {
- generator match {
- case Inline(inlineChild) =>
- inlineChild match {
- case _: AttributeReference =>
- case _ =>
- throw new UnsupportedOperationException("Child of Inline is not
AttributeReference.")
- }
- val requiredOutput = (0 until childOutput.size).map {
- ExpressionBuilder.makeSelection(_)
- }
- val flattenStruct: Seq[ExpressionNode] = generatorOutput.indices.map {
- i =>
- val selectionNode =
ExpressionBuilder.makeSelection(requiredOutput.size)
- selectionNode.addNestedChildIdx(i)
- }
- val postProjectRel = RelBuilder.makeProjectRel(
- generateRel,
- (requiredOutput ++ flattenStruct).asJava,
- context,
- operatorId,
- 1 + requiredOutput.size // 1 stands for the inner struct field from
array.
- )
- if (validation) {
- // No need to validate the project rel on the native side as
- // it only flattens the generator's output.
- generateRel
- } else {
- postProjectRel
- }
- case _ => generateRel
- }
- }
-
- override def metricsUpdater(): MetricsUpdater = {
-
BackendsApiManager.getMetricsApiInstance.genGenerateTransformerMetricsUpdater(metrics)
- }
-}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala
b/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala
new file mode 100644
index 000000000..285734f38
--- /dev/null
+++
b/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.execution
+
+import io.glutenproject.backendsapi.BackendsApiManager
+import io.glutenproject.exception.GlutenException
+import io.glutenproject.expression.{ConverterUtils, ExpressionConverter}
+import io.glutenproject.extension.ValidationResult
+import io.glutenproject.substrait.`type`.TypeBuilder
+import io.glutenproject.substrait.SubstraitContext
+import io.glutenproject.substrait.expression.{ExpressionBuilder,
ExpressionNode}
+import io.glutenproject.substrait.extensions.{AdvancedExtensionNode,
ExtensionBuilder}
+import io.glutenproject.substrait.rel.RelNode
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+
+import scala.collection.JavaConverters._
+
+// Transformer for GeneratorExec, which Applies a [[Generator]] to a stream of
input rows.
+abstract class GenerateExecTransformerBase(
+ generator: Generator,
+ requiredChildOutput: Seq[Attribute],
+ outer: Boolean,
+ generatorOutput: Seq[Attribute],
+ child: SparkPlan)
+ extends UnaryTransformSupport {
+ protected def doGeneratorValidate(generator: Generator, outer: Boolean):
ValidationResult
+
+ protected def getRelNode(
+ context: SubstraitContext,
+ inputRel: RelNode,
+ generatorNode: ExpressionNode,
+ validation: Boolean): RelNode
+
+ protected lazy val requiredChildOutputNodes: Seq[ExpressionNode] = {
+ requiredChildOutput.map {
+ target =>
+ val childIndex = child.output.zipWithIndex
+ .collectFirst {
+ case (attr, i) if attr.name == target.name => i
+ }
+ .getOrElse(
+ throw new GlutenException(s"Can't found column ${target.name} in
child output"))
+ ExpressionBuilder.makeSelection(childIndex)
+ }
+ }
+
+ override def output: Seq[Attribute] = requiredChildOutput ++ generatorOutput
+
+ override def producedAttributes: AttributeSet = AttributeSet(generatorOutput)
+
+ override protected def doValidateInternal(): ValidationResult = {
+ val validationResult = doGeneratorValidate(generator, outer)
+ if (!validationResult.isValid) {
+ return validationResult
+ }
+ val context = new SubstraitContext
+ val relNode =
+ getRelNode(context, null, getGeneratorNode(context), validation = true)
+ doNativeValidation(context, relNode)
+ }
+
+ override def doTransform(context: SubstraitContext): TransformContext = {
+ val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
+ val relNode = getRelNode(context, childCtx.root,
getGeneratorNode(context), validation = false)
+ TransformContext(child.output, output, relNode)
+ }
+
+ protected def getExtensionNodeForValidation: AdvancedExtensionNode = {
+ // Use a extension node to send the input types through Substrait plan for
validation.
+ val inputTypeNodeList =
+ child.output.map(attr => ConverterUtils.getTypeNode(attr.dataType,
attr.nullable)).asJava
+ ExtensionBuilder.makeAdvancedExtension(
+ BackendsApiManager.getTransformerApiInstance.packPBMessage(
+ TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
+ }
+
+ private def getGeneratorNode(context: SubstraitContext): ExpressionNode =
+ ExpressionConverter
+ .replaceWithExpressionTransformer(generator, child.output)
+ .doTransform(context.registeredFunction)
+}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
index b368b158f..ec8e1d43e 100644
---
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
@@ -366,7 +366,7 @@ object MiscColumnarRules {
case plan: GenerateExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
val child = replaceWithTransformerPlan(plan.child)
- GenerateExecTransformer(
+
BackendsApiManager.getSparkPlanExecApiInstance.genGenerateTransformer(
plan.generator,
plan.requiredChildOutput,
plan.outer,
diff --git
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala
index 27a250566..5827ae618 100644
---
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala
@@ -704,7 +704,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] {
plan,
"columnar generate is not enabled in GenerateExec")
} else {
- val transformer = GenerateExecTransformer(
+ val transformer =
BackendsApiManager.getSparkPlanExecApiInstance.genGenerateTransformer(
plan.generator,
plan.requiredChildOutput,
plan.outer,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]