This is an automated email from the ASF dual-hosted git repository. taiyangli pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push: new 77e8e283b0 [GLUTEN-3839][CH] Extend nested column pruning in vanilla spark (#7992) 77e8e283b0 is described below commit 77e8e283b0913601d2a1cf270e674b3d3f57308b Author: 李扬 <654010...@qq.com> AuthorDate: Wed Nov 27 14:13:29 2024 +0800 [GLUTEN-3839][CH] Extend nested column pruning in vanilla spark (#7992) * support column pruning on generator * rename config name * fix style * fix style * support multiple fields * fix failed uts * fix building issues in spark3.2 * fix style * fix building * Update NormalFileWriter.cpp * Update GlutenConfig.scala * change as reeust --- .../gluten/backendsapi/clickhouse/CHRuleApi.scala | 2 +- .../gluten/extension/ExtendedColumnPruning.scala | 366 +++++++++++++++++++++ .../ExtendedGeneratorNestedColumnAliasing.scala | 126 ------- .../hive/GlutenClickHouseHiveTableSuite.scala | 128 +++++-- .../Storages/Output/NormalFileWriter.cpp | 1 + .../scala/org/apache/gluten/GlutenConfig.scala | 10 +- 6 files changed, 480 insertions(+), 153 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index f6d2b85d9d..02ec91496e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -63,7 +63,7 @@ object CHRuleApi { injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark)) injector.injectResolutionRule(spark => new RewriteDateTimestampComparisonRule(spark)) injector.injectOptimizerRule(spark => new CommonSubexpressionEliminateRule(spark)) - injector.injectOptimizerRule(spark => new ExtendedGeneratorNestedColumnAliasing(spark)) + injector.injectOptimizerRule(spark => new ExtendedColumnPruning(spark)) injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) injector.injectOptimizerRule(_ => CountDistinctWithoutExpand) injector.injectOptimizerRule(_ => EqualToRewrite) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedColumnPruning.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedColumnPruning.scala new file mode 100644 index 0000000000..f2a0a549bc --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedColumnPruning.scala @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.GlutenConfig + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction +import org.apache.spark.sql.catalyst.optimizer.GeneratorNestedColumnAliasing.canPruneGenerator +import org.apache.spark.sql.catalyst.optimizer.NestedColumnAliasing +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.AlwaysProcess +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import scala.collection.mutable + +object ExtendedGeneratorNestedColumnAliasing { + def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { + case pj @ Project(projectList, f @ Filter(condition, g: Generate)) + if canPruneGenerator(g.generator) && + GlutenConfig.getConf.enableExtendedColumnPruning && + (SQLConf.get.nestedPruningOnExpressions || SQLConf.get.nestedSchemaPruningEnabled) => + val attrToExtractValues = + getAttributeToExtractValues(projectList ++ g.generator.children :+ condition, Seq.empty) + if (attrToExtractValues.isEmpty) { + return None + } + + val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput) + var (attrToExtractValuesOnGenerator, attrToExtractValuesNotOnGenerator) = + attrToExtractValues.partition { + case (attr, _) => + attr.references.subsetOf(generatorOutputSet) + } + + val pushedThrough = rewritePlanWithAliases(pj, attrToExtractValuesNotOnGenerator) + + // We cannot push through if the child of generator is `MapType`. + g.generator.children.head.dataType match { + case _: MapType => return Some(pushedThrough) + case _ => + } + + if (!g.generator.isInstanceOf[ExplodeBase]) { + return Some(pushedThrough) + } + + // In spark3.2, we could not reuse [[NestedColumnAliasing.getAttributeToExtractValues]] + // which only accepts 2 arguments. Instead we redefine it in current file to avoid moving + // this rule to gluten-shims + attrToExtractValuesOnGenerator = getAttributeToExtractValues( + attrToExtractValuesOnGenerator.flatMap(_._2).toSeq, + Seq.empty, + collectNestedGetStructFields) + + val nestedFieldsOnGenerator = attrToExtractValuesOnGenerator.values.flatten.toSet + if (nestedFieldsOnGenerator.isEmpty) { + return Some(pushedThrough) + } + + // Multiple or single nested column accessors. + // E.g. df.select(explode($"items").as("item")).select($"item.a", $"item.b") + pushedThrough match { + case p2 @ Project(_, f2 @ Filter(_, g2: Generate)) => + val nestedFieldsOnGeneratorSeq = nestedFieldsOnGenerator.toSeq + val nestedFieldToOrdinal = nestedFieldsOnGeneratorSeq.zipWithIndex.toMap + val rewrittenG = g2.transformExpressions { + case e: ExplodeBase => + val extractors = nestedFieldsOnGeneratorSeq.map(replaceGenerator(e, _)) + val names = extractors.map { + case g: GetStructField => Literal(g.extractFieldName) + case ga: GetArrayStructFields => Literal(ga.field.name) + case other => + throw new IllegalStateException( + s"Unreasonable extractor " + + "after replaceGenerator: $other") + } + val zippedArray = ArraysZip(extractors, names) + e.withNewChildren(Seq(zippedArray)) + } + // As we change the child of the generator, its output data type must be updated. + val updatedGeneratorOutput = rewrittenG.generatorOutput + .zip( + rewrittenG.generator.elementSchema.map( + f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) + .map { + case (oldAttr, newAttr) => + newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name) + } + assert( + updatedGeneratorOutput.length == rewrittenG.generatorOutput.length, + "Updated generator output must have the same length " + + "with original generator output." + ) + val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput) + + // Replace nested column accessor with generator output. + val attrExprIdsOnGenerator = attrToExtractValuesOnGenerator.keys.map(_.exprId).toSet + val updatedFilter = f2.withNewChildren(Seq(updatedGenerate)).transformExpressions { + case f: GetStructField if nestedFieldsOnGenerator.contains(f) => + replaceGetStructField( + f, + updatedGenerate.output, + attrExprIdsOnGenerator, + nestedFieldToOrdinal) + } + + val updatedProject = p2.withNewChildren(Seq(updatedFilter)).transformExpressions { + case f: GetStructField if nestedFieldsOnGenerator.contains(f) => + replaceGetStructField( + f, + updatedFilter.output, + attrExprIdsOnGenerator, + nestedFieldToOrdinal) + } + + Some(updatedProject) + case other => + throw new IllegalStateException(s"Unreasonable plan after optimization: $other") + } + case _ => + None + } + + /** + * Returns two types of expressions: + * - Root references that are individually accessed + * - [[GetStructField]] or [[GetArrayStructFields]] on top of other [[ExtractValue]]s or special + * expressions. + */ + private def collectRootReferenceAndExtractValue(e: Expression): Seq[Expression] = e match { + case _: AttributeReference => Seq(e) + case GetStructField(_: ExtractValue | _: AttributeReference, _, _) => Seq(e) + case GetArrayStructFields( + _: MapValues | _: MapKeys | _: ExtractValue | _: AttributeReference, + _, + _, + _, + _) => + Seq(e) + case es if es.children.nonEmpty => es.children.flatMap(collectRootReferenceAndExtractValue) + case _ => Seq.empty + } + + /** + * Creates a map from root [[Attribute]]s to non-redundant nested [[ExtractValue]]s. Nested field + * accessors of `exclusiveAttrs` are not considered in nested fields aliasing. + */ + private def getAttributeToExtractValues( + exprList: Seq[Expression], + exclusiveAttrs: Seq[Attribute], + extractor: (Expression) => Seq[Expression] = collectRootReferenceAndExtractValue) + : Map[Attribute, Seq[ExtractValue]] = { + + val nestedFieldReferences = new mutable.ArrayBuffer[ExtractValue]() + val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]() + exprList.foreach { + e => + extractor(e).foreach { + // we can not alias the attr from lambda variable whose expr id is not available + case ev: ExtractValue if ev.find(_.isInstanceOf[NamedLambdaVariable]).isEmpty => + if (ev.references.size == 1) { + nestedFieldReferences.append(ev) + } + case ar: AttributeReference => otherRootReferences.append(ar) + case _ => // ignore + } + } + val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences) + + // Remove cosmetic variations when we group extractors by their references + nestedFieldReferences + .filter(!_.references.subsetOf(exclusiveAttrSet)) + .groupBy(_.references.head.canonicalized.asInstanceOf[Attribute]) + .flatMap { + case (attr: Attribute, nestedFields: collection.Seq[ExtractValue]) => + // Check if `ExtractValue` expressions contain any aggregate functions in their tree. + // Those that do should not have an alias generated as it can lead to pushing the + // aggregate down into a projection. + def containsAggregateFunction(ev: ExtractValue): Boolean = + ev.find(_.isInstanceOf[AggregateFunction]).isDefined + + // Remove redundant [[ExtractValue]]s if they share the same parent nest field. + // For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`. + // Because `a.b` requires all of the inner fields of `b`, we cannot prune `a.b.c`. + val dedupNestedFields = nestedFields + .filter { + // See [[collectExtractValue]]: we only need to deal with [[GetArrayStructFields]] and + // [[GetStructField]] + case e @ (_: GetStructField | _: GetArrayStructFields) => + val child = e.children.head + nestedFields.forall(f => child.find(_.semanticEquals(f)).isEmpty) + case _ => true + } + .distinct + // Discard [[ExtractValue]]s that contain aggregate functions. + .filterNot(containsAggregateFunction) + + // If all nested fields of `attr` are used, we don't need to introduce new aliases. + // By default, the [[ColumnPruning]] rule uses `attr` already. + // Note that we need to remove cosmetic variations first, so we only count a + // nested field once. + val numUsedNestedFields = dedupNestedFields + .map(_.canonicalized) + .distinct + .map(nestedField => totalFieldNum(nestedField.dataType)) + .sum + if (dedupNestedFields.nonEmpty && numUsedNestedFields < totalFieldNum(attr.dataType)) { + Some((attr, dedupNestedFields.toSeq)) + } else { + None + } + } + } + + /** + * Return total number of fields of this type. This is used as a threshold to use nested column + * pruning. It's okay to underestimate. If the number of reference is bigger than this, the parent + * reference is used instead of nested field references. + */ + private def totalFieldNum(dataType: DataType): Int = dataType match { + case StructType(fields) => fields.map(f => totalFieldNum(f.dataType)).sum + case ArrayType(elementType, _) => totalFieldNum(elementType) + case MapType(keyType, valueType, _) => totalFieldNum(keyType) + totalFieldNum(valueType) + case _ => 1 // UDT and others + } + + private def replaceGetStructField( + g: GetStructField, + input: Seq[Attribute], + attrExprIdsOnGenerator: Set[ExprId], + nestedFieldToOrdinal: Map[ExtractValue, Int]): Expression = { + val attr = input.find(a => attrExprIdsOnGenerator.contains(a.exprId)) + attr match { + case Some(a) => + val ordinal = nestedFieldToOrdinal(g) + GetStructField(a, ordinal, g.name) + case None => g + } + } + + /** Replace the reference attribute of extractor expression with generator input. */ + private def replaceGenerator(generator: ExplodeBase, expr: Expression): Expression = { + expr match { + case a: Attribute if expr.references.contains(a) => + generator.child + case g: GetStructField => + // We cannot simply do a transformUp instead because if we replace the attribute + // `extractFieldName` could cause `ClassCastException` error. We need to get the + // field name before replacing down the attribute/other extractor. + val fieldName = g.extractFieldName + val newChild = replaceGenerator(generator, g.child) + ExtractValue(newChild, Literal(fieldName), SQLConf.get.resolver) + case other => + other.mapChildren(replaceGenerator(generator, _)) + } + } + + // This function collects all GetStructField*(attribute) from the passed in expression. + // GetStructField* means arbitrary levels of nesting. + private def collectNestedGetStructFields(e: Expression): Seq[Expression] = { + // The helper function returns a tuple of + // (nested GetStructField including the current level, all other nested GetStructField) + def helper(e: Expression): (Seq[Expression], Seq[Expression]) = e match { + case _: AttributeReference => (Seq(e), Seq.empty) + case gsf: GetStructField => + val child_res = helper(gsf.child) + (child_res._1.map(p => gsf.withNewChildren(Seq(p))), child_res._2) + case other => + val child_res = other.children.map(helper) + val child_res_combined = (child_res.flatMap(_._1), child_res.flatMap(_._2)) + (Seq.empty, child_res_combined._1 ++ child_res_combined._2) + } + + val res = helper(e) + (res._1 ++ res._2).filterNot(_.isInstanceOf[Attribute]) + } + + private def rewritePlanWithAliases( + plan: LogicalPlan, + attributeToExtractValues: Map[Attribute, Seq[ExtractValue]]): LogicalPlan = { + val attributeToExtractValuesAndAliases = + attributeToExtractValues.map { + case (attr, evSeq) => + val evAliasSeq = evSeq.map { + ev => + val fieldName = ev match { + case g: GetStructField => g.extractFieldName + case g: GetArrayStructFields => g.field.name + } + ev -> Alias(ev, s"_extract_$fieldName")() + } + + attr -> evAliasSeq + } + + val nestedFieldToAlias = attributeToExtractValuesAndAliases.values.flatten.map { + case (field, alias) => field.canonicalized -> alias + }.toMap + + // A reference attribute can have multiple aliases for nested fields. + val attrToAliases = + AttributeMap(attributeToExtractValuesAndAliases.mapValues(_.map(_._2)).toSeq) + + plan match { + // Project(Filter(Generate)) + case Project(projectList, f @ Filter(condition, g: Generate)) => + val newProjectList = NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias) + val newCondition = getNewExpression(condition, nestedFieldToAlias) + val newGenerator = getNewExpression(g.generator, nestedFieldToAlias).asInstanceOf[Generator] + + val tmpG = NestedColumnAliasing + .replaceWithAliases(g, nestedFieldToAlias, attrToAliases) + .asInstanceOf[Generate] + val newG = Generate( + newGenerator, + tmpG.unrequiredChildIndex, + tmpG.outer, + tmpG.qualifier, + tmpG.generatorOutput, + tmpG.children.head) + val newF = Filter(newCondition, newG) + val newP = Project(newProjectList, newF) + newP + case _ => plan + } + } + + private def getNewExpression( + expr: Expression, + nestedFieldToAlias: Map[Expression, Alias]): Expression = { + expr.transform { + case f: ExtractValue if nestedFieldToAlias.contains(f.canonicalized) => + nestedFieldToAlias(f.canonicalized).toAttribute + } + } +} + +// ExtendedColumnPruning process Project(Filter(Generate)), +// which is ignored by vanilla spark in optimization rule: ColumnPruning +class ExtendedColumnPruning(spark: SparkSession) extends Rule[LogicalPlan] with Logging { + + override def apply(plan: LogicalPlan): LogicalPlan = + plan.transformWithPruning(AlwaysProcess.fn) { + case ExtendedGeneratorNestedColumnAliasing(rewrittenPlan) => rewrittenPlan + case p => p + } +} diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedGeneratorNestedColumnAliasing.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedGeneratorNestedColumnAliasing.scala deleted file mode 100644 index a97e625ae6..0000000000 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedGeneratorNestedColumnAliasing.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension - -import org.apache.gluten.GlutenConfig - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.GeneratorNestedColumnAliasing.canPruneGenerator -import org.apache.spark.sql.catalyst.optimizer.NestedColumnAliasing -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.AlwaysProcess -import org.apache.spark.sql.internal.SQLConf - -// ExtendedGeneratorNestedColumnAliasing process Project(Filter(Generate)), -// which is ignored by vanilla spark in optimization rule: ColumnPruning -class ExtendedGeneratorNestedColumnAliasing(spark: SparkSession) - extends Rule[LogicalPlan] - with Logging { - - override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformWithPruning(AlwaysProcess.fn) { - case pj @ Project(projectList, f @ Filter(condition, g: Generate)) - if canPruneGenerator(g.generator) && - GlutenConfig.getConf.enableExtendedGeneratorNestedColumnAliasing && - (SQLConf.get.nestedPruningOnExpressions || SQLConf.get.nestedSchemaPruningEnabled) => - val attrToExtractValues = NestedColumnAliasing.getAttributeToExtractValues( - projectList ++ g.generator.children :+ condition, - Seq.empty) - if (attrToExtractValues.isEmpty) { - pj - } else { - val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput) - val (_, attrToExtractValuesNotOnGenerator) = - attrToExtractValues.partition { - case (attr, _) => - attr.references.subsetOf(generatorOutputSet) - } - - val pushedThrough = rewritePlanWithAliases(pj, attrToExtractValuesNotOnGenerator) - pushedThrough - } - case p => - p - } - - private def rewritePlanWithAliases( - plan: LogicalPlan, - attributeToExtractValues: Map[Attribute, Seq[ExtractValue]]): LogicalPlan = { - val attributeToExtractValuesAndAliases = - attributeToExtractValues.map { - case (attr, evSeq) => - val evAliasSeq = evSeq.map { - ev => - val fieldName = ev match { - case g: GetStructField => g.extractFieldName - case g: GetArrayStructFields => g.field.name - } - ev -> Alias(ev, s"_extract_$fieldName")() - } - - attr -> evAliasSeq - } - - val nestedFieldToAlias = attributeToExtractValuesAndAliases.values.flatten.map { - case (field, alias) => field.canonicalized -> alias - }.toMap - - // A reference attribute can have multiple aliases for nested fields. - val attrToAliases = - AttributeMap(attributeToExtractValuesAndAliases.mapValues(_.map(_._2)).toSeq) - - plan match { - // Project(Filter(Generate)) - case p @ Project(projectList, child) - if child - .isInstanceOf[Filter] && child.asInstanceOf[Filter].child.isInstanceOf[Generate] => - val f = child.asInstanceOf[Filter] - val g = f.child.asInstanceOf[Generate] - - val newProjectList = NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias) - val newCondition = getNewExpression(f.condition, nestedFieldToAlias) - val newGenerator = getNewExpression(g.generator, nestedFieldToAlias).asInstanceOf[Generator] - - val tmpG = NestedColumnAliasing - .replaceWithAliases(g, nestedFieldToAlias, attrToAliases) - .asInstanceOf[Generate] - val newG = Generate( - newGenerator, - tmpG.unrequiredChildIndex, - tmpG.outer, - tmpG.qualifier, - tmpG.generatorOutput, - tmpG.children.head) - val newF = Filter(newCondition, newG) - val newP = Project(newProjectList, newF) - newP - case _ => plan - } - } - - private def getNewExpression( - expr: Expression, - nestedFieldToAlias: Map[Expression, Alias]): Expression = { - expr.transform { - case f: ExtractValue if nestedFieldToAlias.contains(f.canonicalized) => - nestedFieldToAlias(f.canonicalized).toAttribute - } - } -} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala index 7e31e73040..14d6ff5364 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig import org.apache.spark.sql.hive.HiveTableScanExecTransformer import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, StructType} import org.apache.hadoop.fs.Path @@ -1473,30 +1473,116 @@ class GlutenClickHouseHiveTableSuite | 'event_info', map('tab_type', '4', 'action', '12'))) """.stripMargin) - val df = - spark.sql(""" - | SELECT * FROM ( - | SELECT - | game_name, - | CASE WHEN - | event.event_info['tab_type'] IN (5) THEN '1' ELSE '0' END AS entrance - | FROM aj - | LATERAL VIEW explode(split(nvl(event.event_info['game_name'],'0'),',')) - | as game_name - | WHERE event.event_info['action'] IN (13) - |) WHERE game_name = 'xxx' + val sql = """ + | SELECT * FROM ( + | SELECT + | game_name, + | CASE WHEN + | event.event_info['tab_type'] IN (5) THEN '1' ELSE '0' END AS entrance + | FROM aj + | LATERAL VIEW explode(split(nvl(event.event_info['game_name'],'0'),',')) + | as game_name + | WHERE event.event_info['action'] IN (13) + |) WHERE game_name = 'xxx' + """.stripMargin + + compareResultsAgainstVanillaSpark( + sql, + compareResult = true, + df => { + val scan = df.queryExecution.executedPlan.collect { + case scan: FileSourceScanExecTransformer => scan + }.head + val fieldType = scan.schema.fields.head.dataType.asInstanceOf[StructType] + assert(fieldType.size == 1) + } + ) + + spark.sql("drop table if exists aj") + } + + test("Nested column pruning for Project(Filter(Generate)) on generator") { + def assertFieldSizeAfterPruning(sql: String, expectSize: Int): Unit = { + compareResultsAgainstVanillaSpark( + sql, + compareResult = true, + df => { + val scan = df.queryExecution.executedPlan.collect { + case scan: FileSourceScanExecTransformer => scan + }.head + + val fieldType = + scan.schema.fields.head.dataType + .asInstanceOf[ArrayType] + .elementType + .asInstanceOf[StructType] + assert(fieldType.size == expectSize) + } + ) + } + + spark.sql("drop table if exists ajog") + spark.sql( + """ + |CREATE TABLE if not exists ajog ( + | country STRING, + | events ARRAY<STRUCT<time:BIGINT, lng:BIGINT, lat:BIGINT, net:STRING, + | log_extra:MAP<STRING, STRING>, event_id:STRING, event_info:MAP<STRING, STRING>>> + |) + |USING orc """.stripMargin) - val scan = df.queryExecution.executedPlan.collect { - case scan: FileSourceScanExecTransformer => scan - }.head + spark.sql(""" + |INSERT INTO ajog VALUES + | ('USA', array(named_struct('time', 1622547800, 'lng', -122, 'lat', 37, 'net', + | 'wifi', 'log_extra', map('key1', 'value1'), 'event_id', 'event1', + | 'event_info', map('tab_type', '5', 'action', '13')))), + | ('Canada', array(named_struct('time', 1622547801, 'lng', -79, 'lat', 43, 'net', + | '4g', 'log_extra', map('key2', 'value2'), 'event_id', 'event2', + | 'event_info', map('tab_type', '4', 'action', '12')))) + """.stripMargin) - val schema = scan.schema - assert(schema.size == 1) - val fieldType = schema.fields.head.dataType.asInstanceOf[StructType] - assert(fieldType.size == 1) + // Test nested column pruning on generator with single field extracted + val sql1 = """ + |select + |case when event.event_info['tab_type'] in (5) then '1' else '0' end as entrance + |from ajog + |lateral view explode(events) as event + |where event.event_info['action'] in (13) + """.stripMargin + assertFieldSizeAfterPruning(sql1, 1) + + // Test nested column pruning on generator with multiple field extracted, + // which resolves SPARK-34956 + val sql2 = """ + |select event.event_id, + |case when event.event_info['tab_type'] in (5) then '1' else '0' end as entrance + |from ajog + |lateral view explode(events) as event + |where event.event_info['action'] in (13) + """.stripMargin + assertFieldSizeAfterPruning(sql2, 2) + + // Test nested column pruning with two adjacent generate operator + val sql3 = """ + |SELECT + |abflag, + |event.event_info, + |event.log_extra + |FROM + |ajog + |LATERAL VIEW EXPLODE(events) AS event + |LATERAL VIEW EXPLODE(split(event.log_extra['key1'], ',')) AS abflag + |WHERE + |event.event_id = 'event1' + |AND event.event_info['tab_type'] IS NOT NULL + |AND event.event_info['tab_type'] != '' + |AND event.log_extra['key1'] = 'value1' + |LIMIT 100; + """.stripMargin + assertFieldSizeAfterPruning(sql3, 3) - spark.sql("drop table if exists aj") + spark.sql("drop table if exists ajog") } test("test hive table scan nested column pruning") { diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp index d572b85385..e9fd4f358a 100644 --- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp +++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp @@ -18,6 +18,7 @@ #include <QueryPipeline/QueryPipeline.h> #include <Poco/URI.h> +#include <Common/DebugUtils.h> namespace local_engine { diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index 2ccdcae99b..c4c67f49a5 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -112,8 +112,8 @@ class GlutenConfig(conf: SQLConf) extends Logging { def enableCountDistinctWithoutExpand: Boolean = conf.getConf(ENABLE_COUNT_DISTINCT_WITHOUT_EXPAND) - def enableExtendedGeneratorNestedColumnAliasing: Boolean = - conf.getConf(ENABLE_EXTENDED_GENERATOR_NESTED_COLUMN_ALIASING) + def enableExtendedColumnPruning: Boolean = + conf.getConf(ENABLE_EXTENDED_COLUMN_PRUNING) def veloxOrcScanEnabled: Boolean = conf.getConf(VELOX_ORC_SCAN_ENABLED) @@ -1980,10 +1980,10 @@ object GlutenConfig { .booleanConf .createWithDefault(false) - val ENABLE_EXTENDED_GENERATOR_NESTED_COLUMN_ALIASING = - buildConf("spark.gluten.sql.extendedGeneratorNestedColumnAliasing") + val ENABLE_EXTENDED_COLUMN_PRUNING = + buildConf("spark.gluten.sql.extendedColumnPruning.enabled") .internal() - .doc("Do nested column aliasing for Project(Filter(Generator))") + .doc("Do extended nested column pruning for cases ignored by vanilla Spark.") .booleanConf .createWithDefault(true) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org For additional commands, e-mail: commits-h...@gluten.apache.org