This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 77e8e283b0 [GLUTEN-3839][CH] Extend nested column pruning in vanilla 
spark (#7992)
77e8e283b0 is described below

commit 77e8e283b0913601d2a1cf270e674b3d3f57308b
Author: 李扬 <654010...@qq.com>
AuthorDate: Wed Nov 27 14:13:29 2024 +0800

    [GLUTEN-3839][CH] Extend nested column pruning in vanilla spark (#7992)
    
    * support column pruning on generator
    
    * rename config name
    
    * fix style
    
    * fix style
    
    * support multiple fields
    
    * fix failed uts
    
    * fix building issues in spark3.2
    
    * fix style
    
    * fix building
    
    * Update NormalFileWriter.cpp
    
    * Update GlutenConfig.scala
    
    * change as reeust
---
 .../gluten/backendsapi/clickhouse/CHRuleApi.scala  |   2 +-
 .../gluten/extension/ExtendedColumnPruning.scala   | 366 +++++++++++++++++++++
 .../ExtendedGeneratorNestedColumnAliasing.scala    | 126 -------
 .../hive/GlutenClickHouseHiveTableSuite.scala      | 128 +++++--
 .../Storages/Output/NormalFileWriter.cpp           |   1 +
 .../scala/org/apache/gluten/GlutenConfig.scala     |  10 +-
 6 files changed, 480 insertions(+), 153 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index f6d2b85d9d..02ec91496e 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -63,7 +63,7 @@ object CHRuleApi {
     injector.injectResolutionRule(spark => new 
RewriteToDateExpresstionRule(spark))
     injector.injectResolutionRule(spark => new 
RewriteDateTimestampComparisonRule(spark))
     injector.injectOptimizerRule(spark => new 
CommonSubexpressionEliminateRule(spark))
-    injector.injectOptimizerRule(spark => new 
ExtendedGeneratorNestedColumnAliasing(spark))
+    injector.injectOptimizerRule(spark => new ExtendedColumnPruning(spark))
     injector.injectOptimizerRule(spark => 
CHAggregateFunctionRewriteRule(spark))
     injector.injectOptimizerRule(_ => CountDistinctWithoutExpand)
     injector.injectOptimizerRule(_ => EqualToRewrite)
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedColumnPruning.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedColumnPruning.scala
new file mode 100644
index 0000000000..f2a0a549bc
--- /dev/null
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedColumnPruning.scala
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.GlutenConfig
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
+import 
org.apache.spark.sql.catalyst.optimizer.GeneratorNestedColumnAliasing.canPruneGenerator
+import org.apache.spark.sql.catalyst.optimizer.NestedColumnAliasing
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.AlwaysProcess
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+import scala.collection.mutable
+
+object ExtendedGeneratorNestedColumnAliasing {
+  def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
+    case pj @ Project(projectList, f @ Filter(condition, g: Generate))
+        if canPruneGenerator(g.generator) &&
+          GlutenConfig.getConf.enableExtendedColumnPruning &&
+          (SQLConf.get.nestedPruningOnExpressions || 
SQLConf.get.nestedSchemaPruningEnabled) =>
+      val attrToExtractValues =
+        getAttributeToExtractValues(projectList ++ g.generator.children :+ 
condition, Seq.empty)
+      if (attrToExtractValues.isEmpty) {
+        return None
+      }
+
+      val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput)
+      var (attrToExtractValuesOnGenerator, attrToExtractValuesNotOnGenerator) =
+        attrToExtractValues.partition {
+          case (attr, _) =>
+            attr.references.subsetOf(generatorOutputSet)
+        }
+
+      val pushedThrough = rewritePlanWithAliases(pj, 
attrToExtractValuesNotOnGenerator)
+
+      // We cannot push through if the child of generator is `MapType`.
+      g.generator.children.head.dataType match {
+        case _: MapType => return Some(pushedThrough)
+        case _ =>
+      }
+
+      if (!g.generator.isInstanceOf[ExplodeBase]) {
+        return Some(pushedThrough)
+      }
+
+      // In spark3.2, we could not reuse 
[[NestedColumnAliasing.getAttributeToExtractValues]]
+      // which only accepts 2 arguments. Instead we redefine it in current 
file to avoid moving
+      // this rule to gluten-shims
+      attrToExtractValuesOnGenerator = getAttributeToExtractValues(
+        attrToExtractValuesOnGenerator.flatMap(_._2).toSeq,
+        Seq.empty,
+        collectNestedGetStructFields)
+
+      val nestedFieldsOnGenerator = 
attrToExtractValuesOnGenerator.values.flatten.toSet
+      if (nestedFieldsOnGenerator.isEmpty) {
+        return Some(pushedThrough)
+      }
+
+      // Multiple or single nested column accessors.
+      // E.g. df.select(explode($"items").as("item")).select($"item.a", 
$"item.b")
+      pushedThrough match {
+        case p2 @ Project(_, f2 @ Filter(_, g2: Generate)) =>
+          val nestedFieldsOnGeneratorSeq = nestedFieldsOnGenerator.toSeq
+          val nestedFieldToOrdinal = 
nestedFieldsOnGeneratorSeq.zipWithIndex.toMap
+          val rewrittenG = g2.transformExpressions {
+            case e: ExplodeBase =>
+              val extractors = 
nestedFieldsOnGeneratorSeq.map(replaceGenerator(e, _))
+              val names = extractors.map {
+                case g: GetStructField => Literal(g.extractFieldName)
+                case ga: GetArrayStructFields => Literal(ga.field.name)
+                case other =>
+                  throw new IllegalStateException(
+                    s"Unreasonable extractor " +
+                      "after replaceGenerator: $other")
+              }
+              val zippedArray = ArraysZip(extractors, names)
+              e.withNewChildren(Seq(zippedArray))
+          }
+          // As we change the child of the generator, its output data type 
must be updated.
+          val updatedGeneratorOutput = rewrittenG.generatorOutput
+            .zip(
+              rewrittenG.generator.elementSchema.map(
+                f => AttributeReference(f.name, f.dataType, f.nullable, 
f.metadata)()))
+            .map {
+              case (oldAttr, newAttr) =>
+                newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
+            }
+          assert(
+            updatedGeneratorOutput.length == rewrittenG.generatorOutput.length,
+            "Updated generator output must have the same length " +
+              "with original generator output."
+          )
+          val updatedGenerate = rewrittenG.copy(generatorOutput = 
updatedGeneratorOutput)
+
+          // Replace nested column accessor with generator output.
+          val attrExprIdsOnGenerator = 
attrToExtractValuesOnGenerator.keys.map(_.exprId).toSet
+          val updatedFilter = 
f2.withNewChildren(Seq(updatedGenerate)).transformExpressions {
+            case f: GetStructField if nestedFieldsOnGenerator.contains(f) =>
+              replaceGetStructField(
+                f,
+                updatedGenerate.output,
+                attrExprIdsOnGenerator,
+                nestedFieldToOrdinal)
+          }
+
+          val updatedProject = 
p2.withNewChildren(Seq(updatedFilter)).transformExpressions {
+            case f: GetStructField if nestedFieldsOnGenerator.contains(f) =>
+              replaceGetStructField(
+                f,
+                updatedFilter.output,
+                attrExprIdsOnGenerator,
+                nestedFieldToOrdinal)
+          }
+
+          Some(updatedProject)
+        case other =>
+          throw new IllegalStateException(s"Unreasonable plan after 
optimization: $other")
+      }
+    case _ =>
+      None
+  }
+
+  /**
+   * Returns two types of expressions:
+   *   - Root references that are individually accessed
+   *   - [[GetStructField]] or [[GetArrayStructFields]] on top of other 
[[ExtractValue]]s or special
+   *     expressions.
+   */
+  private def collectRootReferenceAndExtractValue(e: Expression): 
Seq[Expression] = e match {
+    case _: AttributeReference => Seq(e)
+    case GetStructField(_: ExtractValue | _: AttributeReference, _, _) => 
Seq(e)
+    case GetArrayStructFields(
+          _: MapValues | _: MapKeys | _: ExtractValue | _: AttributeReference,
+          _,
+          _,
+          _,
+          _) =>
+      Seq(e)
+    case es if es.children.nonEmpty => 
es.children.flatMap(collectRootReferenceAndExtractValue)
+    case _ => Seq.empty
+  }
+
+  /**
+   * Creates a map from root [[Attribute]]s to non-redundant nested 
[[ExtractValue]]s. Nested field
+   * accessors of `exclusiveAttrs` are not considered in nested fields 
aliasing.
+   */
+  private def getAttributeToExtractValues(
+      exprList: Seq[Expression],
+      exclusiveAttrs: Seq[Attribute],
+      extractor: (Expression) => Seq[Expression] = 
collectRootReferenceAndExtractValue)
+      : Map[Attribute, Seq[ExtractValue]] = {
+
+    val nestedFieldReferences = new mutable.ArrayBuffer[ExtractValue]()
+    val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]()
+    exprList.foreach {
+      e =>
+        extractor(e).foreach {
+          // we can not alias the attr from lambda variable whose expr id is 
not available
+          case ev: ExtractValue if 
ev.find(_.isInstanceOf[NamedLambdaVariable]).isEmpty =>
+            if (ev.references.size == 1) {
+              nestedFieldReferences.append(ev)
+            }
+          case ar: AttributeReference => otherRootReferences.append(ar)
+          case _ => // ignore
+        }
+    }
+    val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences)
+
+    // Remove cosmetic variations when we group extractors by their references
+    nestedFieldReferences
+      .filter(!_.references.subsetOf(exclusiveAttrSet))
+      .groupBy(_.references.head.canonicalized.asInstanceOf[Attribute])
+      .flatMap {
+        case (attr: Attribute, nestedFields: collection.Seq[ExtractValue]) =>
+          // Check if `ExtractValue` expressions contain any aggregate 
functions in their tree.
+          // Those that do should not have an alias generated as it can lead 
to pushing the
+          // aggregate down into a projection.
+          def containsAggregateFunction(ev: ExtractValue): Boolean =
+            ev.find(_.isInstanceOf[AggregateFunction]).isDefined
+
+          // Remove redundant [[ExtractValue]]s if they share the same parent 
nest field.
+          // For example, when `a.b` and `a.b.c` are in project list, we only 
need to alias `a.b`.
+          // Because `a.b` requires all of the inner fields of `b`, we cannot 
prune `a.b.c`.
+          val dedupNestedFields = nestedFields
+            .filter {
+              // See [[collectExtractValue]]: we only need to deal with 
[[GetArrayStructFields]] and
+              // [[GetStructField]]
+              case e @ (_: GetStructField | _: GetArrayStructFields) =>
+                val child = e.children.head
+                nestedFields.forall(f => 
child.find(_.semanticEquals(f)).isEmpty)
+              case _ => true
+            }
+            .distinct
+            // Discard [[ExtractValue]]s that contain aggregate functions.
+            .filterNot(containsAggregateFunction)
+
+          // If all nested fields of `attr` are used, we don't need to 
introduce new aliases.
+          // By default, the [[ColumnPruning]] rule uses `attr` already.
+          // Note that we need to remove cosmetic variations first, so we only 
count a
+          // nested field once.
+          val numUsedNestedFields = dedupNestedFields
+            .map(_.canonicalized)
+            .distinct
+            .map(nestedField => totalFieldNum(nestedField.dataType))
+            .sum
+          if (dedupNestedFields.nonEmpty && numUsedNestedFields < 
totalFieldNum(attr.dataType)) {
+            Some((attr, dedupNestedFields.toSeq))
+          } else {
+            None
+          }
+      }
+  }
+
+  /**
+   * Return total number of fields of this type. This is used as a threshold 
to use nested column
+   * pruning. It's okay to underestimate. If the number of reference is bigger 
than this, the parent
+   * reference is used instead of nested field references.
+   */
+  private def totalFieldNum(dataType: DataType): Int = dataType match {
+    case StructType(fields) => fields.map(f => totalFieldNum(f.dataType)).sum
+    case ArrayType(elementType, _) => totalFieldNum(elementType)
+    case MapType(keyType, valueType, _) => totalFieldNum(keyType) + 
totalFieldNum(valueType)
+    case _ => 1 // UDT and others
+  }
+
+  private def replaceGetStructField(
+      g: GetStructField,
+      input: Seq[Attribute],
+      attrExprIdsOnGenerator: Set[ExprId],
+      nestedFieldToOrdinal: Map[ExtractValue, Int]): Expression = {
+    val attr = input.find(a => attrExprIdsOnGenerator.contains(a.exprId))
+    attr match {
+      case Some(a) =>
+        val ordinal = nestedFieldToOrdinal(g)
+        GetStructField(a, ordinal, g.name)
+      case None => g
+    }
+  }
+
+  /** Replace the reference attribute of extractor expression with generator 
input. */
+  private def replaceGenerator(generator: ExplodeBase, expr: Expression): 
Expression = {
+    expr match {
+      case a: Attribute if expr.references.contains(a) =>
+        generator.child
+      case g: GetStructField =>
+        // We cannot simply do a transformUp instead because if we replace the 
attribute
+        // `extractFieldName` could cause `ClassCastException` error. We need 
to get the
+        // field name before replacing down the attribute/other extractor.
+        val fieldName = g.extractFieldName
+        val newChild = replaceGenerator(generator, g.child)
+        ExtractValue(newChild, Literal(fieldName), SQLConf.get.resolver)
+      case other =>
+        other.mapChildren(replaceGenerator(generator, _))
+    }
+  }
+
+  // This function collects all GetStructField*(attribute) from the passed in 
expression.
+  // GetStructField* means arbitrary levels of nesting.
+  private def collectNestedGetStructFields(e: Expression): Seq[Expression] = {
+    // The helper function returns a tuple of
+    // (nested GetStructField including the current level, all other nested 
GetStructField)
+    def helper(e: Expression): (Seq[Expression], Seq[Expression]) = e match {
+      case _: AttributeReference => (Seq(e), Seq.empty)
+      case gsf: GetStructField =>
+        val child_res = helper(gsf.child)
+        (child_res._1.map(p => gsf.withNewChildren(Seq(p))), child_res._2)
+      case other =>
+        val child_res = other.children.map(helper)
+        val child_res_combined = (child_res.flatMap(_._1), 
child_res.flatMap(_._2))
+        (Seq.empty, child_res_combined._1 ++ child_res_combined._2)
+    }
+
+    val res = helper(e)
+    (res._1 ++ res._2).filterNot(_.isInstanceOf[Attribute])
+  }
+
+  private def rewritePlanWithAliases(
+      plan: LogicalPlan,
+      attributeToExtractValues: Map[Attribute, Seq[ExtractValue]]): 
LogicalPlan = {
+    val attributeToExtractValuesAndAliases =
+      attributeToExtractValues.map {
+        case (attr, evSeq) =>
+          val evAliasSeq = evSeq.map {
+            ev =>
+              val fieldName = ev match {
+                case g: GetStructField => g.extractFieldName
+                case g: GetArrayStructFields => g.field.name
+              }
+              ev -> Alias(ev, s"_extract_$fieldName")()
+          }
+
+          attr -> evAliasSeq
+      }
+
+    val nestedFieldToAlias = 
attributeToExtractValuesAndAliases.values.flatten.map {
+      case (field, alias) => field.canonicalized -> alias
+    }.toMap
+
+    // A reference attribute can have multiple aliases for nested fields.
+    val attrToAliases =
+      
AttributeMap(attributeToExtractValuesAndAliases.mapValues(_.map(_._2)).toSeq)
+
+    plan match {
+      // Project(Filter(Generate))
+      case Project(projectList, f @ Filter(condition, g: Generate)) =>
+        val newProjectList = 
NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias)
+        val newCondition = getNewExpression(condition, nestedFieldToAlias)
+        val newGenerator = getNewExpression(g.generator, 
nestedFieldToAlias).asInstanceOf[Generator]
+
+        val tmpG = NestedColumnAliasing
+          .replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
+          .asInstanceOf[Generate]
+        val newG = Generate(
+          newGenerator,
+          tmpG.unrequiredChildIndex,
+          tmpG.outer,
+          tmpG.qualifier,
+          tmpG.generatorOutput,
+          tmpG.children.head)
+        val newF = Filter(newCondition, newG)
+        val newP = Project(newProjectList, newF)
+        newP
+      case _ => plan
+    }
+  }
+
+  private def getNewExpression(
+      expr: Expression,
+      nestedFieldToAlias: Map[Expression, Alias]): Expression = {
+    expr.transform {
+      case f: ExtractValue if nestedFieldToAlias.contains(f.canonicalized) =>
+        nestedFieldToAlias(f.canonicalized).toAttribute
+    }
+  }
+}
+
+// ExtendedColumnPruning process Project(Filter(Generate)),
+// which is ignored by vanilla spark in optimization rule: ColumnPruning
+class ExtendedColumnPruning(spark: SparkSession) extends Rule[LogicalPlan] 
with Logging {
+
+  override def apply(plan: LogicalPlan): LogicalPlan =
+    plan.transformWithPruning(AlwaysProcess.fn) {
+      case ExtendedGeneratorNestedColumnAliasing(rewrittenPlan) => 
rewrittenPlan
+      case p => p
+    }
+}
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedGeneratorNestedColumnAliasing.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedGeneratorNestedColumnAliasing.scala
deleted file mode 100644
index a97e625ae6..0000000000
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedGeneratorNestedColumnAliasing.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.extension
-
-import org.apache.gluten.GlutenConfig
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.optimizer.GeneratorNestedColumnAliasing.canPruneGenerator
-import org.apache.spark.sql.catalyst.optimizer.NestedColumnAliasing
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.AlwaysProcess
-import org.apache.spark.sql.internal.SQLConf
-
-// ExtendedGeneratorNestedColumnAliasing process Project(Filter(Generate)),
-// which is ignored by vanilla spark in optimization rule: ColumnPruning
-class ExtendedGeneratorNestedColumnAliasing(spark: SparkSession)
-  extends Rule[LogicalPlan]
-  with Logging {
-
-  override def apply(plan: LogicalPlan): LogicalPlan =
-    plan.transformWithPruning(AlwaysProcess.fn) {
-      case pj @ Project(projectList, f @ Filter(condition, g: Generate))
-          if canPruneGenerator(g.generator) &&
-            GlutenConfig.getConf.enableExtendedGeneratorNestedColumnAliasing &&
-            (SQLConf.get.nestedPruningOnExpressions || 
SQLConf.get.nestedSchemaPruningEnabled) =>
-        val attrToExtractValues = 
NestedColumnAliasing.getAttributeToExtractValues(
-          projectList ++ g.generator.children :+ condition,
-          Seq.empty)
-        if (attrToExtractValues.isEmpty) {
-          pj
-        } else {
-          val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput)
-          val (_, attrToExtractValuesNotOnGenerator) =
-            attrToExtractValues.partition {
-              case (attr, _) =>
-                attr.references.subsetOf(generatorOutputSet)
-            }
-
-          val pushedThrough = rewritePlanWithAliases(pj, 
attrToExtractValuesNotOnGenerator)
-          pushedThrough
-        }
-      case p =>
-        p
-    }
-
-  private def rewritePlanWithAliases(
-      plan: LogicalPlan,
-      attributeToExtractValues: Map[Attribute, Seq[ExtractValue]]): 
LogicalPlan = {
-    val attributeToExtractValuesAndAliases =
-      attributeToExtractValues.map {
-        case (attr, evSeq) =>
-          val evAliasSeq = evSeq.map {
-            ev =>
-              val fieldName = ev match {
-                case g: GetStructField => g.extractFieldName
-                case g: GetArrayStructFields => g.field.name
-              }
-              ev -> Alias(ev, s"_extract_$fieldName")()
-          }
-
-          attr -> evAliasSeq
-      }
-
-    val nestedFieldToAlias = 
attributeToExtractValuesAndAliases.values.flatten.map {
-      case (field, alias) => field.canonicalized -> alias
-    }.toMap
-
-    // A reference attribute can have multiple aliases for nested fields.
-    val attrToAliases =
-      
AttributeMap(attributeToExtractValuesAndAliases.mapValues(_.map(_._2)).toSeq)
-
-    plan match {
-      // Project(Filter(Generate))
-      case p @ Project(projectList, child)
-          if child
-            .isInstanceOf[Filter] && 
child.asInstanceOf[Filter].child.isInstanceOf[Generate] =>
-        val f = child.asInstanceOf[Filter]
-        val g = f.child.asInstanceOf[Generate]
-
-        val newProjectList = 
NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias)
-        val newCondition = getNewExpression(f.condition, nestedFieldToAlias)
-        val newGenerator = getNewExpression(g.generator, 
nestedFieldToAlias).asInstanceOf[Generator]
-
-        val tmpG = NestedColumnAliasing
-          .replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
-          .asInstanceOf[Generate]
-        val newG = Generate(
-          newGenerator,
-          tmpG.unrequiredChildIndex,
-          tmpG.outer,
-          tmpG.qualifier,
-          tmpG.generatorOutput,
-          tmpG.children.head)
-        val newF = Filter(newCondition, newG)
-        val newP = Project(newProjectList, newF)
-        newP
-      case _ => plan
-    }
-  }
-
-  private def getNewExpression(
-      expr: Expression,
-      nestedFieldToAlias: Map[Expression, Alias]): Expression = {
-    expr.transform {
-      case f: ExtractValue if nestedFieldToAlias.contains(f.canonicalized) =>
-        nestedFieldToAlias(f.canonicalized).toAttribute
-    }
-  }
-}
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
index 7e31e73040..14d6ff5364 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
@@ -28,7 +28,7 @@ import 
org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import 
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
 import org.apache.spark.sql.hive.HiveTableScanExecTransformer
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{ArrayType, StructType}
 
 import org.apache.hadoop.fs.Path
 
@@ -1473,30 +1473,116 @@ class GlutenClickHouseHiveTableSuite
                 |    'event_info', map('tab_type', '4', 'action', '12')))
        """.stripMargin)
 
-    val df =
-      spark.sql("""
-                  | SELECT * FROM (
-                  |  SELECT
-                  |    game_name,
-                  |    CASE WHEN
-                  |       event.event_info['tab_type'] IN (5) THEN '1' ELSE 
'0' END AS entrance
-                  |  FROM aj
-                  |  LATERAL VIEW 
explode(split(nvl(event.event_info['game_name'],'0'),','))
-                  |    as game_name
-                  |  WHERE event.event_info['action'] IN (13)
-                  |) WHERE game_name = 'xxx'
+    val sql = """
+                | SELECT * FROM (
+                |  SELECT
+                |    game_name,
+                |    CASE WHEN
+                |       event.event_info['tab_type'] IN (5) THEN '1' ELSE '0' 
END AS entrance
+                |  FROM aj
+                |  LATERAL VIEW 
explode(split(nvl(event.event_info['game_name'],'0'),','))
+                |    as game_name
+                |  WHERE event.event_info['action'] IN (13)
+                |) WHERE game_name = 'xxx'
+      """.stripMargin
+
+    compareResultsAgainstVanillaSpark(
+      sql,
+      compareResult = true,
+      df => {
+        val scan = df.queryExecution.executedPlan.collect {
+          case scan: FileSourceScanExecTransformer => scan
+        }.head
+        val fieldType = 
scan.schema.fields.head.dataType.asInstanceOf[StructType]
+        assert(fieldType.size == 1)
+      }
+    )
+
+    spark.sql("drop table if exists aj")
+  }
+
+  test("Nested column pruning for Project(Filter(Generate)) on generator") {
+    def assertFieldSizeAfterPruning(sql: String, expectSize: Int): Unit = {
+      compareResultsAgainstVanillaSpark(
+        sql,
+        compareResult = true,
+        df => {
+          val scan = df.queryExecution.executedPlan.collect {
+            case scan: FileSourceScanExecTransformer => scan
+          }.head
+
+          val fieldType =
+            scan.schema.fields.head.dataType
+              .asInstanceOf[ArrayType]
+              .elementType
+              .asInstanceOf[StructType]
+          assert(fieldType.size == expectSize)
+        }
+      )
+    }
+
+    spark.sql("drop table if exists ajog")
+    spark.sql(
+      """
+        |CREATE TABLE if not exists ajog (
+        |  country STRING,
+        |  events ARRAY<STRUCT<time:BIGINT, lng:BIGINT, lat:BIGINT, net:STRING,
+        |     log_extra:MAP<STRING, STRING>, event_id:STRING, 
event_info:MAP<STRING, STRING>>>
+        |)
+        |USING orc
       """.stripMargin)
 
-    val scan = df.queryExecution.executedPlan.collect {
-      case scan: FileSourceScanExecTransformer => scan
-    }.head
+    spark.sql("""
+                |INSERT INTO ajog VALUES
+                |  ('USA', array(named_struct('time', 1622547800, 'lng', -122, 
'lat', 37, 'net',
+                |    'wifi', 'log_extra', map('key1', 'value1'), 'event_id', 
'event1',
+                |    'event_info', map('tab_type', '5', 'action', '13')))),
+                |  ('Canada', array(named_struct('time', 1622547801, 'lng', 
-79, 'lat', 43, 'net',
+                |    '4g', 'log_extra', map('key2', 'value2'), 'event_id', 
'event2',
+                |    'event_info', map('tab_type', '4', 'action', '12'))))
+       """.stripMargin)
 
-    val schema = scan.schema
-    assert(schema.size == 1)
-    val fieldType = schema.fields.head.dataType.asInstanceOf[StructType]
-    assert(fieldType.size == 1)
+    // Test nested column pruning on generator with single field extracted
+    val sql1 = """
+                 |select
+                 |case when event.event_info['tab_type'] in (5) then '1' else 
'0' end as entrance
+                 |from ajog
+                 |lateral view explode(events)  as event
+                 |where  event.event_info['action'] in (13)
+      """.stripMargin
+    assertFieldSizeAfterPruning(sql1, 1)
+
+    // Test nested column pruning on generator with multiple field extracted,
+    // which resolves SPARK-34956
+    val sql2 = """
+                 |select event.event_id,
+                 |case when event.event_info['tab_type'] in (5) then '1' else 
'0' end as entrance
+                 |from ajog
+                 |lateral view explode(events)  as event
+                 |where  event.event_info['action'] in (13)
+      """.stripMargin
+    assertFieldSizeAfterPruning(sql2, 2)
+
+    // Test nested column pruning with two adjacent generate operator
+    val sql3 = """
+                 |SELECT
+                 |abflag,
+                 |event.event_info,
+                 |event.log_extra
+                 |FROM
+                 |ajog
+                 |LATERAL VIEW EXPLODE(events) AS event
+                 |LATERAL VIEW EXPLODE(split(event.log_extra['key1'], ',')) AS 
abflag
+                 |WHERE
+                 |event.event_id = 'event1'
+                 |AND event.event_info['tab_type'] IS NOT NULL
+                 |AND event.event_info['tab_type'] != ''
+                 |AND event.log_extra['key1'] = 'value1'
+                 |LIMIT 100;
+      """.stripMargin
+    assertFieldSizeAfterPruning(sql3, 3)
 
-    spark.sql("drop table if exists aj")
+    spark.sql("drop table if exists ajog")
   }
 
   test("test hive table scan nested column pruning") {
diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp 
b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
index d572b85385..e9fd4f358a 100644
--- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
+++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp
@@ -18,6 +18,7 @@
 
 #include <QueryPipeline/QueryPipeline.h>
 #include <Poco/URI.h>
+#include <Common/DebugUtils.h>
 
 namespace local_engine
 {
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala 
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index 2ccdcae99b..c4c67f49a5 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -112,8 +112,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {
   def enableCountDistinctWithoutExpand: Boolean =
     conf.getConf(ENABLE_COUNT_DISTINCT_WITHOUT_EXPAND)
 
-  def enableExtendedGeneratorNestedColumnAliasing: Boolean =
-    conf.getConf(ENABLE_EXTENDED_GENERATOR_NESTED_COLUMN_ALIASING)
+  def enableExtendedColumnPruning: Boolean =
+    conf.getConf(ENABLE_EXTENDED_COLUMN_PRUNING)
 
   def veloxOrcScanEnabled: Boolean =
     conf.getConf(VELOX_ORC_SCAN_ENABLED)
@@ -1980,10 +1980,10 @@ object GlutenConfig {
       .booleanConf
       .createWithDefault(false)
 
-  val ENABLE_EXTENDED_GENERATOR_NESTED_COLUMN_ALIASING =
-    buildConf("spark.gluten.sql.extendedGeneratorNestedColumnAliasing")
+  val ENABLE_EXTENDED_COLUMN_PRUNING =
+    buildConf("spark.gluten.sql.extendedColumnPruning.enabled")
       .internal()
-      .doc("Do nested column aliasing for Project(Filter(Generator))")
+      .doc("Do extended nested column pruning for cases ignored by vanilla 
Spark.")
       .booleanConf
       .createWithDefault(true)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org
For additional commands, e-mail: commits-h...@gluten.apache.org

Reply via email to