This is an automated email from the ASF dual-hosted git repository.
taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 77e8e283b0 [GLUTEN-3839][CH] Extend nested column pruning in vanilla
spark (#7992)
77e8e283b0 is described below
commit 77e8e283b0913601d2a1cf270e674b3d3f57308b
Author: 李扬 <[email protected]>
AuthorDate: Wed Nov 27 14:13:29 2024 +0800
[GLUTEN-3839][CH] Extend nested column pruning in vanilla spark (#7992)
* support column pruning on generator
* rename config name
* fix style
* fix style
* support multiple fields
* fix failed uts
* fix building issues in spark3.2
* fix style
* fix building
* Update NormalFileWriter.cpp
* Update GlutenConfig.scala
* change as reeust
---
.../gluten/backendsapi/clickhouse/CHRuleApi.scala | 2 +-
.../gluten/extension/ExtendedColumnPruning.scala | 366 +++++++++++++++++++++
.../ExtendedGeneratorNestedColumnAliasing.scala | 126 -------
.../hive/GlutenClickHouseHiveTableSuite.scala | 128 +++++--
.../Storages/Output/NormalFileWriter.cpp | 1 +
.../scala/org/apache/gluten/GlutenConfig.scala | 10 +-
6 files changed, 480 insertions(+), 153 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index f6d2b85d9d..02ec91496e 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -63,7 +63,7 @@ object CHRuleApi {
injector.injectResolutionRule(spark => new
RewriteToDateExpresstionRule(spark))
injector.injectResolutionRule(spark => new
RewriteDateTimestampComparisonRule(spark))
injector.injectOptimizerRule(spark => new
CommonSubexpressionEliminateRule(spark))
- injector.injectOptimizerRule(spark => new
ExtendedGeneratorNestedColumnAliasing(spark))
+ injector.injectOptimizerRule(spark => new ExtendedColumnPruning(spark))
injector.injectOptimizerRule(spark =>
CHAggregateFunctionRewriteRule(spark))
injector.injectOptimizerRule(_ => CountDistinctWithoutExpand)
injector.injectOptimizerRule(_ => EqualToRewrite)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedColumnPruning.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedColumnPruning.scala
new file mode 100644
index 0000000000..f2a0a549bc
--- /dev/null
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedColumnPruning.scala
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.GlutenConfig
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
+import
org.apache.spark.sql.catalyst.optimizer.GeneratorNestedColumnAliasing.canPruneGenerator
+import org.apache.spark.sql.catalyst.optimizer.NestedColumnAliasing
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.AlwaysProcess
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+import scala.collection.mutable
+
+object ExtendedGeneratorNestedColumnAliasing {
+ def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
+ case pj @ Project(projectList, f @ Filter(condition, g: Generate))
+ if canPruneGenerator(g.generator) &&
+ GlutenConfig.getConf.enableExtendedColumnPruning &&
+ (SQLConf.get.nestedPruningOnExpressions ||
SQLConf.get.nestedSchemaPruningEnabled) =>
+ val attrToExtractValues =
+ getAttributeToExtractValues(projectList ++ g.generator.children :+
condition, Seq.empty)
+ if (attrToExtractValues.isEmpty) {
+ return None
+ }
+
+ val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput)
+ var (attrToExtractValuesOnGenerator, attrToExtractValuesNotOnGenerator) =
+ attrToExtractValues.partition {
+ case (attr, _) =>
+ attr.references.subsetOf(generatorOutputSet)
+ }
+
+ val pushedThrough = rewritePlanWithAliases(pj,
attrToExtractValuesNotOnGenerator)
+
+ // We cannot push through if the child of generator is `MapType`.
+ g.generator.children.head.dataType match {
+ case _: MapType => return Some(pushedThrough)
+ case _ =>
+ }
+
+ if (!g.generator.isInstanceOf[ExplodeBase]) {
+ return Some(pushedThrough)
+ }
+
+ // In spark3.2, we could not reuse
[[NestedColumnAliasing.getAttributeToExtractValues]]
+ // which only accepts 2 arguments. Instead we redefine it in current
file to avoid moving
+ // this rule to gluten-shims
+ attrToExtractValuesOnGenerator = getAttributeToExtractValues(
+ attrToExtractValuesOnGenerator.flatMap(_._2).toSeq,
+ Seq.empty,
+ collectNestedGetStructFields)
+
+ val nestedFieldsOnGenerator =
attrToExtractValuesOnGenerator.values.flatten.toSet
+ if (nestedFieldsOnGenerator.isEmpty) {
+ return Some(pushedThrough)
+ }
+
+ // Multiple or single nested column accessors.
+ // E.g. df.select(explode($"items").as("item")).select($"item.a",
$"item.b")
+ pushedThrough match {
+ case p2 @ Project(_, f2 @ Filter(_, g2: Generate)) =>
+ val nestedFieldsOnGeneratorSeq = nestedFieldsOnGenerator.toSeq
+ val nestedFieldToOrdinal =
nestedFieldsOnGeneratorSeq.zipWithIndex.toMap
+ val rewrittenG = g2.transformExpressions {
+ case e: ExplodeBase =>
+ val extractors =
nestedFieldsOnGeneratorSeq.map(replaceGenerator(e, _))
+ val names = extractors.map {
+ case g: GetStructField => Literal(g.extractFieldName)
+ case ga: GetArrayStructFields => Literal(ga.field.name)
+ case other =>
+ throw new IllegalStateException(
+ s"Unreasonable extractor " +
+ "after replaceGenerator: $other")
+ }
+ val zippedArray = ArraysZip(extractors, names)
+ e.withNewChildren(Seq(zippedArray))
+ }
+ // As we change the child of the generator, its output data type
must be updated.
+ val updatedGeneratorOutput = rewrittenG.generatorOutput
+ .zip(
+ rewrittenG.generator.elementSchema.map(
+ f => AttributeReference(f.name, f.dataType, f.nullable,
f.metadata)()))
+ .map {
+ case (oldAttr, newAttr) =>
+ newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
+ }
+ assert(
+ updatedGeneratorOutput.length == rewrittenG.generatorOutput.length,
+ "Updated generator output must have the same length " +
+ "with original generator output."
+ )
+ val updatedGenerate = rewrittenG.copy(generatorOutput =
updatedGeneratorOutput)
+
+ // Replace nested column accessor with generator output.
+ val attrExprIdsOnGenerator =
attrToExtractValuesOnGenerator.keys.map(_.exprId).toSet
+ val updatedFilter =
f2.withNewChildren(Seq(updatedGenerate)).transformExpressions {
+ case f: GetStructField if nestedFieldsOnGenerator.contains(f) =>
+ replaceGetStructField(
+ f,
+ updatedGenerate.output,
+ attrExprIdsOnGenerator,
+ nestedFieldToOrdinal)
+ }
+
+ val updatedProject =
p2.withNewChildren(Seq(updatedFilter)).transformExpressions {
+ case f: GetStructField if nestedFieldsOnGenerator.contains(f) =>
+ replaceGetStructField(
+ f,
+ updatedFilter.output,
+ attrExprIdsOnGenerator,
+ nestedFieldToOrdinal)
+ }
+
+ Some(updatedProject)
+ case other =>
+ throw new IllegalStateException(s"Unreasonable plan after
optimization: $other")
+ }
+ case _ =>
+ None
+ }
+
+ /**
+ * Returns two types of expressions:
+ * - Root references that are individually accessed
+ * - [[GetStructField]] or [[GetArrayStructFields]] on top of other
[[ExtractValue]]s or special
+ * expressions.
+ */
+ private def collectRootReferenceAndExtractValue(e: Expression):
Seq[Expression] = e match {
+ case _: AttributeReference => Seq(e)
+ case GetStructField(_: ExtractValue | _: AttributeReference, _, _) =>
Seq(e)
+ case GetArrayStructFields(
+ _: MapValues | _: MapKeys | _: ExtractValue | _: AttributeReference,
+ _,
+ _,
+ _,
+ _) =>
+ Seq(e)
+ case es if es.children.nonEmpty =>
es.children.flatMap(collectRootReferenceAndExtractValue)
+ case _ => Seq.empty
+ }
+
+ /**
+ * Creates a map from root [[Attribute]]s to non-redundant nested
[[ExtractValue]]s. Nested field
+ * accessors of `exclusiveAttrs` are not considered in nested fields
aliasing.
+ */
+ private def getAttributeToExtractValues(
+ exprList: Seq[Expression],
+ exclusiveAttrs: Seq[Attribute],
+ extractor: (Expression) => Seq[Expression] =
collectRootReferenceAndExtractValue)
+ : Map[Attribute, Seq[ExtractValue]] = {
+
+ val nestedFieldReferences = new mutable.ArrayBuffer[ExtractValue]()
+ val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]()
+ exprList.foreach {
+ e =>
+ extractor(e).foreach {
+ // we can not alias the attr from lambda variable whose expr id is
not available
+ case ev: ExtractValue if
ev.find(_.isInstanceOf[NamedLambdaVariable]).isEmpty =>
+ if (ev.references.size == 1) {
+ nestedFieldReferences.append(ev)
+ }
+ case ar: AttributeReference => otherRootReferences.append(ar)
+ case _ => // ignore
+ }
+ }
+ val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences)
+
+ // Remove cosmetic variations when we group extractors by their references
+ nestedFieldReferences
+ .filter(!_.references.subsetOf(exclusiveAttrSet))
+ .groupBy(_.references.head.canonicalized.asInstanceOf[Attribute])
+ .flatMap {
+ case (attr: Attribute, nestedFields: collection.Seq[ExtractValue]) =>
+ // Check if `ExtractValue` expressions contain any aggregate
functions in their tree.
+ // Those that do should not have an alias generated as it can lead
to pushing the
+ // aggregate down into a projection.
+ def containsAggregateFunction(ev: ExtractValue): Boolean =
+ ev.find(_.isInstanceOf[AggregateFunction]).isDefined
+
+ // Remove redundant [[ExtractValue]]s if they share the same parent
nest field.
+ // For example, when `a.b` and `a.b.c` are in project list, we only
need to alias `a.b`.
+ // Because `a.b` requires all of the inner fields of `b`, we cannot
prune `a.b.c`.
+ val dedupNestedFields = nestedFields
+ .filter {
+ // See [[collectExtractValue]]: we only need to deal with
[[GetArrayStructFields]] and
+ // [[GetStructField]]
+ case e @ (_: GetStructField | _: GetArrayStructFields) =>
+ val child = e.children.head
+ nestedFields.forall(f =>
child.find(_.semanticEquals(f)).isEmpty)
+ case _ => true
+ }
+ .distinct
+ // Discard [[ExtractValue]]s that contain aggregate functions.
+ .filterNot(containsAggregateFunction)
+
+ // If all nested fields of `attr` are used, we don't need to
introduce new aliases.
+ // By default, the [[ColumnPruning]] rule uses `attr` already.
+ // Note that we need to remove cosmetic variations first, so we only
count a
+ // nested field once.
+ val numUsedNestedFields = dedupNestedFields
+ .map(_.canonicalized)
+ .distinct
+ .map(nestedField => totalFieldNum(nestedField.dataType))
+ .sum
+ if (dedupNestedFields.nonEmpty && numUsedNestedFields <
totalFieldNum(attr.dataType)) {
+ Some((attr, dedupNestedFields.toSeq))
+ } else {
+ None
+ }
+ }
+ }
+
+ /**
+ * Return total number of fields of this type. This is used as a threshold
to use nested column
+ * pruning. It's okay to underestimate. If the number of reference is bigger
than this, the parent
+ * reference is used instead of nested field references.
+ */
+ private def totalFieldNum(dataType: DataType): Int = dataType match {
+ case StructType(fields) => fields.map(f => totalFieldNum(f.dataType)).sum
+ case ArrayType(elementType, _) => totalFieldNum(elementType)
+ case MapType(keyType, valueType, _) => totalFieldNum(keyType) +
totalFieldNum(valueType)
+ case _ => 1 // UDT and others
+ }
+
+ private def replaceGetStructField(
+ g: GetStructField,
+ input: Seq[Attribute],
+ attrExprIdsOnGenerator: Set[ExprId],
+ nestedFieldToOrdinal: Map[ExtractValue, Int]): Expression = {
+ val attr = input.find(a => attrExprIdsOnGenerator.contains(a.exprId))
+ attr match {
+ case Some(a) =>
+ val ordinal = nestedFieldToOrdinal(g)
+ GetStructField(a, ordinal, g.name)
+ case None => g
+ }
+ }
+
+ /** Replace the reference attribute of extractor expression with generator
input. */
+ private def replaceGenerator(generator: ExplodeBase, expr: Expression):
Expression = {
+ expr match {
+ case a: Attribute if expr.references.contains(a) =>
+ generator.child
+ case g: GetStructField =>
+ // We cannot simply do a transformUp instead because if we replace the
attribute
+ // `extractFieldName` could cause `ClassCastException` error. We need
to get the
+ // field name before replacing down the attribute/other extractor.
+ val fieldName = g.extractFieldName
+ val newChild = replaceGenerator(generator, g.child)
+ ExtractValue(newChild, Literal(fieldName), SQLConf.get.resolver)
+ case other =>
+ other.mapChildren(replaceGenerator(generator, _))
+ }
+ }
+
+ // This function collects all GetStructField*(attribute) from the passed in
expression.
+ // GetStructField* means arbitrary levels of nesting.
+ private def collectNestedGetStructFields(e: Expression): Seq[Expression] = {
+ // The helper function returns a tuple of
+ // (nested GetStructField including the current level, all other nested
GetStructField)
+ def helper(e: Expression): (Seq[Expression], Seq[Expression]) = e match {
+ case _: AttributeReference => (Seq(e), Seq.empty)
+ case gsf: GetStructField =>
+ val child_res = helper(gsf.child)
+ (child_res._1.map(p => gsf.withNewChildren(Seq(p))), child_res._2)
+ case other =>
+ val child_res = other.children.map(helper)
+ val child_res_combined = (child_res.flatMap(_._1),
child_res.flatMap(_._2))
+ (Seq.empty, child_res_combined._1 ++ child_res_combined._2)
+ }
+
+ val res = helper(e)
+ (res._1 ++ res._2).filterNot(_.isInstanceOf[Attribute])
+ }
+
+ private def rewritePlanWithAliases(
+ plan: LogicalPlan,
+ attributeToExtractValues: Map[Attribute, Seq[ExtractValue]]):
LogicalPlan = {
+ val attributeToExtractValuesAndAliases =
+ attributeToExtractValues.map {
+ case (attr, evSeq) =>
+ val evAliasSeq = evSeq.map {
+ ev =>
+ val fieldName = ev match {
+ case g: GetStructField => g.extractFieldName
+ case g: GetArrayStructFields => g.field.name
+ }
+ ev -> Alias(ev, s"_extract_$fieldName")()
+ }
+
+ attr -> evAliasSeq
+ }
+
+ val nestedFieldToAlias =
attributeToExtractValuesAndAliases.values.flatten.map {
+ case (field, alias) => field.canonicalized -> alias
+ }.toMap
+
+ // A reference attribute can have multiple aliases for nested fields.
+ val attrToAliases =
+
AttributeMap(attributeToExtractValuesAndAliases.mapValues(_.map(_._2)).toSeq)
+
+ plan match {
+ // Project(Filter(Generate))
+ case Project(projectList, f @ Filter(condition, g: Generate)) =>
+ val newProjectList =
NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias)
+ val newCondition = getNewExpression(condition, nestedFieldToAlias)
+ val newGenerator = getNewExpression(g.generator,
nestedFieldToAlias).asInstanceOf[Generator]
+
+ val tmpG = NestedColumnAliasing
+ .replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
+ .asInstanceOf[Generate]
+ val newG = Generate(
+ newGenerator,
+ tmpG.unrequiredChildIndex,
+ tmpG.outer,
+ tmpG.qualifier,
+ tmpG.generatorOutput,
+ tmpG.children.head)
+ val newF = Filter(newCondition, newG)
+ val newP = Project(newProjectList, newF)
+ newP
+ case _ => plan
+ }
+ }
+
+ private def getNewExpression(
+ expr: Expression,
+ nestedFieldToAlias: Map[Expression, Alias]): Expression = {
+ expr.transform {
+ case f: ExtractValue if nestedFieldToAlias.contains(f.canonicalized) =>
+ nestedFieldToAlias(f.canonicalized).toAttribute
+ }
+ }
+}
+
+// ExtendedColumnPruning process Project(Filter(Generate)),
+// which is ignored by vanilla spark in optimization rule: ColumnPruning
+class ExtendedColumnPruning(spark: SparkSession) extends Rule[LogicalPlan]
with Logging {
+
+ override def apply(plan: LogicalPlan): LogicalPlan =
+ plan.transformWithPruning(AlwaysProcess.fn) {
+ case ExtendedGeneratorNestedColumnAliasing(rewrittenPlan) =>
rewrittenPlan
+ case p => p
+ }
+}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedGeneratorNestedColumnAliasing.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedGeneratorNestedColumnAliasing.scala
deleted file mode 100644
index a97e625ae6..0000000000
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedGeneratorNestedColumnAliasing.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.extension
-
-import org.apache.gluten.GlutenConfig
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions._
-import
org.apache.spark.sql.catalyst.optimizer.GeneratorNestedColumnAliasing.canPruneGenerator
-import org.apache.spark.sql.catalyst.optimizer.NestedColumnAliasing
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.AlwaysProcess
-import org.apache.spark.sql.internal.SQLConf
-
-// ExtendedGeneratorNestedColumnAliasing process Project(Filter(Generate)),
-// which is ignored by vanilla spark in optimization rule: ColumnPruning
-class ExtendedGeneratorNestedColumnAliasing(spark: SparkSession)
- extends Rule[LogicalPlan]
- with Logging {
-
- override def apply(plan: LogicalPlan): LogicalPlan =
- plan.transformWithPruning(AlwaysProcess.fn) {
- case pj @ Project(projectList, f @ Filter(condition, g: Generate))
- if canPruneGenerator(g.generator) &&
- GlutenConfig.getConf.enableExtendedGeneratorNestedColumnAliasing &&
- (SQLConf.get.nestedPruningOnExpressions ||
SQLConf.get.nestedSchemaPruningEnabled) =>
- val attrToExtractValues =
NestedColumnAliasing.getAttributeToExtractValues(
- projectList ++ g.generator.children :+ condition,
- Seq.empty)
- if (attrToExtractValues.isEmpty) {
- pj
- } else {
- val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput)
- val (_, attrToExtractValuesNotOnGenerator) =
- attrToExtractValues.partition {
- case (attr, _) =>
- attr.references.subsetOf(generatorOutputSet)
- }
-
- val pushedThrough = rewritePlanWithAliases(pj,
attrToExtractValuesNotOnGenerator)
- pushedThrough
- }
- case p =>
- p
- }
-
- private def rewritePlanWithAliases(
- plan: LogicalPlan,
- attributeToExtractValues: Map[Attribute, Seq[ExtractValue]]):
LogicalPlan = {
- val attributeToExtractValuesAndAliases =
- attributeToExtractValues.map {
- case (attr, evSeq) =>
- val evAliasSeq = evSeq.map {
- ev =>
- val fieldName = ev match {
- case g: GetStructField => g.extractFieldName
- case g: GetArrayStructFields => g.field.name
- }
- ev -> Alias(ev, s"_extract_$fieldName")()
- }
-
- attr -> evAliasSeq
- }
-
- val nestedFieldToAlias =
attributeToExtractValuesAndAliases.values.flatten.map {
- case (field, alias) => field.canonicalized -> alias
- }.toMap
-
- // A reference attribute can have multiple aliases for nested fields.
- val attrToAliases =
-
AttributeMap(attributeToExtractValuesAndAliases.mapValues(_.map(_._2)).toSeq)
-
- plan match {
- // Project(Filter(Generate))
- case p @ Project(projectList, child)
- if child
- .isInstanceOf[Filter] &&
child.asInstanceOf[Filter].child.isInstanceOf[Generate] =>
- val f = child.asInstanceOf[Filter]
- val g = f.child.asInstanceOf[Generate]
-
- val newProjectList =
NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias)
- val newCondition = getNewExpression(f.condition, nestedFieldToAlias)
- val newGenerator = getNewExpression(g.generator,
nestedFieldToAlias).asInstanceOf[Generator]
-
- val tmpG = NestedColumnAliasing
- .replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
- .asInstanceOf[Generate]
- val newG = Generate(
- newGenerator,
- tmpG.unrequiredChildIndex,
- tmpG.outer,
- tmpG.qualifier,
- tmpG.generatorOutput,
- tmpG.children.head)
- val newF = Filter(newCondition, newG)
- val newP = Project(newProjectList, newF)
- newP
- case _ => plan
- }
- }
-
- private def getNewExpression(
- expr: Expression,
- nestedFieldToAlias: Map[Expression, Alias]): Expression = {
- expr.transform {
- case f: ExtractValue if nestedFieldToAlias.contains(f.canonicalized) =>
- nestedFieldToAlias(f.canonicalized).toAttribute
- }
- }
-}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
index 7e31e73040..14d6ff5364 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
@@ -28,7 +28,7 @@ import
org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{ArrayType, StructType}
import org.apache.hadoop.fs.Path
@@ -1473,30 +1473,116 @@ class GlutenClickHouseHiveTableSuite
| 'event_info', map('tab_type', '4', 'action', '12')))
""".stripMargin)
- val df =
- spark.sql("""
- | SELECT * FROM (
- | SELECT
- | game_name,
- | CASE WHEN
- | event.event_info['tab_type'] IN (5) THEN '1' ELSE
'0' END AS entrance
- | FROM aj
- | LATERAL VIEW
explode(split(nvl(event.event_info['game_name'],'0'),','))
- | as game_name
- | WHERE event.event_info['action'] IN (13)
- |) WHERE game_name = 'xxx'
+ val sql = """
+ | SELECT * FROM (
+ | SELECT
+ | game_name,
+ | CASE WHEN
+ | event.event_info['tab_type'] IN (5) THEN '1' ELSE '0'
END AS entrance
+ | FROM aj
+ | LATERAL VIEW
explode(split(nvl(event.event_info['game_name'],'0'),','))
+ | as game_name
+ | WHERE event.event_info['action'] IN (13)
+ |) WHERE game_name = 'xxx'
+ """.stripMargin
+
+ compareResultsAgainstVanillaSpark(
+ sql,
+ compareResult = true,
+ df => {
+ val scan = df.queryExecution.executedPlan.collect {
+ case scan: FileSourceScanExecTransformer => scan
+ }.head
+ val fieldType =
scan.schema.fields.head.dataType.asInstanceOf[StructType]
+ assert(fieldType.size == 1)
+ }
+ )
+
+ spark.sql("drop table if exists aj")
+ }
+
+ test("Nested column pruning for Project(Filter(Generate)) on generator") {
+ def assertFieldSizeAfterPruning(sql: String, expectSize: Int): Unit = {
+ compareResultsAgainstVanillaSpark(
+ sql,
+ compareResult = true,
+ df => {
+ val scan = df.queryExecution.executedPlan.collect {
+ case scan: FileSourceScanExecTransformer => scan
+ }.head
+
+ val fieldType =
+ scan.schema.fields.head.dataType
+ .asInstanceOf[ArrayType]
+ .elementType
+ .asInstanceOf[StructType]
+ assert(fieldType.size == expectSize)
+ }
+ )
+ }
+
+ spark.sql("drop table if exists ajog")
+ spark.sql(
+ """
+ |CREATE TABLE if not exists ajog (
+ | country STRING,
+ | events ARRAY<STRUCT<time:BIGINT, lng:BIGINT, lat:BIGINT, net:STRING,
+ | log_extra:MAP<STRING, STRING>, event_id:STRING,
event_info:MAP<STRING, STRING>>>
+ |)
+ |USING orc
""".stripMargin)
- val scan = df.queryExecution.executedPlan.collect {
- case scan: FileSourceScanExecTransformer => scan
- }.head
+ spark.sql("""
+ |INSERT INTO ajog VALUES
+ | ('USA', array(named_struct('time', 1622547800, 'lng', -122,
'lat', 37, 'net',
+ | 'wifi', 'log_extra', map('key1', 'value1'), 'event_id',
'event1',
+ | 'event_info', map('tab_type', '5', 'action', '13')))),
+ | ('Canada', array(named_struct('time', 1622547801, 'lng',
-79, 'lat', 43, 'net',
+ | '4g', 'log_extra', map('key2', 'value2'), 'event_id',
'event2',
+ | 'event_info', map('tab_type', '4', 'action', '12'))))
+ """.stripMargin)
- val schema = scan.schema
- assert(schema.size == 1)
- val fieldType = schema.fields.head.dataType.asInstanceOf[StructType]
- assert(fieldType.size == 1)
+ // Test nested column pruning on generator with single field extracted
+ val sql1 = """
+ |select
+ |case when event.event_info['tab_type'] in (5) then '1' else
'0' end as entrance
+ |from ajog
+ |lateral view explode(events) as event
+ |where event.event_info['action'] in (13)
+ """.stripMargin
+ assertFieldSizeAfterPruning(sql1, 1)
+
+ // Test nested column pruning on generator with multiple field extracted,
+ // which resolves SPARK-34956
+ val sql2 = """
+ |select event.event_id,
+ |case when event.event_info['tab_type'] in (5) then '1' else
'0' end as entrance
+ |from ajog
+ |lateral view explode(events) as event
+ |where event.event_info['action'] in (13)
+ """.stripMargin
+ assertFieldSizeAfterPruning(sql2, 2)
+
+ // Test nested column pruning with two adjacent generate operator
+ val sql3 = """
+ |SELECT
+ |abflag,
+ |event.event_info,
+ |event.log_extra
+ |FROM
+ |ajog
+ |LATERAL VIEW EXPLODE(events) AS event
+ |LATERAL VIEW EXPLODE(split(event.log_extra['key1'], ',')) AS
abflag
+ |WHERE
+ |event.event_id = 'event1'
+ |AND event.event_info['tab_type'] IS NOT NULL
+ |AND event.event_info['tab_type'] != ''
+ |AND event.log_extra['key1'] = 'value1'
+ |LIMIT 100;
+ """.stripMargin
+ assertFieldSizeAfterPruning(sql3, 3)
- spark.sql("drop table if exists aj")
+ spark.sql("drop table if exists ajog")
}
test("test hive table scan nested column pruning") {
diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
index d572b85385..e9fd4f358a 100644
--- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
+++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
@@ -18,6 +18,7 @@
#include <QueryPipeline/QueryPipeline.h>
#include <Poco/URI.h>
+#include <Common/DebugUtils.h>
namespace local_engine
{
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index 2ccdcae99b..c4c67f49a5 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -112,8 +112,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {
def enableCountDistinctWithoutExpand: Boolean =
conf.getConf(ENABLE_COUNT_DISTINCT_WITHOUT_EXPAND)
- def enableExtendedGeneratorNestedColumnAliasing: Boolean =
- conf.getConf(ENABLE_EXTENDED_GENERATOR_NESTED_COLUMN_ALIASING)
+ def enableExtendedColumnPruning: Boolean =
+ conf.getConf(ENABLE_EXTENDED_COLUMN_PRUNING)
def veloxOrcScanEnabled: Boolean =
conf.getConf(VELOX_ORC_SCAN_ENABLED)
@@ -1980,10 +1980,10 @@ object GlutenConfig {
.booleanConf
.createWithDefault(false)
- val ENABLE_EXTENDED_GENERATOR_NESTED_COLUMN_ALIASING =
- buildConf("spark.gluten.sql.extendedGeneratorNestedColumnAliasing")
+ val ENABLE_EXTENDED_COLUMN_PRUNING =
+ buildConf("spark.gluten.sql.extendedColumnPruning.enabled")
.internal()
- .doc("Do nested column aliasing for Project(Filter(Generator))")
+ .doc("Do extended nested column pruning for cases ignored by vanilla
Spark.")
.booleanConf
.createWithDefault(true)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]