This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e863166  [SPARK-35194][SQL] Refactor nested column aliasing for 
readability
e863166 is described below

commit e8631660ecf316e4333210650d1f40b5912fb11b
Author: Karen Feng <[email protected]>
AuthorDate: Fri May 28 13:18:44 2021 +0000

    [SPARK-35194][SQL] Refactor nested column aliasing for readability
    
    ### What changes were proposed in this pull request?
    
    Refactors `NestedColumnAliasing` and `GeneratorNestedColumnAliasing` for 
readability.
    
    ### Why are the changes needed?
    
    Improves readability for future maintenance.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #32301 from karenfeng/refactor-nested-column-aliasing.
    
    Authored-by: Karen Feng <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/expressions/AttributeMap.scala    |   6 +
 .../sql/catalyst/expressions/AttributeMap.scala    |   6 +
 .../catalyst/optimizer/NestedColumnAliasing.scala  | 426 ++++++++++++---------
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |   4 +-
 .../optimizer/NestedColumnAliasingSuite.scala      |   2 +-
 5 files changed, 250 insertions(+), 194 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
 
b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 42b92d4..189318a 100644
--- 
a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ 
b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -23,6 +23,10 @@ package org.apache.spark.sql.catalyst.expressions
  * of the name, or the expected nullability).
  */
 object AttributeMap {
+  def apply[A](kvs: Map[Attribute, A]): AttributeMap[A] = {
+    new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)))
+  }
+
   def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
     new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
   }
@@ -37,6 +41,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, 
A)])
 
   override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
 
+  override def getOrElse[B1 >: A](k: Attribute, default: => B1): B1 = 
get(k).getOrElse(default)
+
   override def contains(k: Attribute): Boolean = get(k).isDefined
 
   override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = 
baseMap.values.toMap + kv
diff --git 
a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
 
b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index e6b53e3..7715291 100644
--- 
a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ 
b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -23,6 +23,10 @@ package org.apache.spark.sql.catalyst.expressions
  * of the name, or the expected nullability).
  */
 object AttributeMap {
+  def apply[A](kvs: Map[Attribute, A]): AttributeMap[A] = {
+    new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)))
+  }
+
   def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
     new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
   }
@@ -37,6 +41,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, 
A)])
 
   override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
 
+  override def getOrElse[B1 >: A](k: Attribute, default: => B1): B1 = 
get(k).getOrElse(default)
+
   override def contains(k: Attribute): Boolean = get(k).isDefined
 
   override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1] 
=
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
index 5b12667..cd7032d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
@@ -17,71 +17,151 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import scala.collection.mutable
+
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
 /**
- * This aims to handle a nested column aliasing pattern inside the 
`ColumnPruning` optimizer rule.
- * If a project or its child references to nested fields, and not all the 
fields
- * in a nested attribute are used, we can substitute them by alias attributes; 
then a project
- * of the nested fields as aliases on the children of the child will be 
created.
+ * This aims to handle a nested column aliasing pattern inside the 
[[ColumnPruning]] optimizer rule.
+ * If:
+ * - A [[Project]] or its child references nested fields
+ * - Not all of the fields in a nested attribute are used
+ * Then:
+ * - Substitute the nested field references with alias attributes
+ * - Add grandchild [[Project]]s transforming the nested fields to aliases
+ *
+ * Example 1: Project
+ * ------------------
+ * Before:
+ * +- Project [concat_ws(s#0.a, s#0.b) AS concat_ws(s.a, s.b)#1]
+ *   +- GlobalLimit 5
+ *     +- LocalLimit 5
+ *       +- LocalRelation <empty>, [s#0]
+ * After:
+ * +- Project [concat_ws(_extract_a#2, _extract_b#3) AS concat_ws(s.a, s.b)#1]
+ *   +- GlobalLimit 5
+ *     +- LocalLimit 5
+ *       +- Project [s#0.a AS _extract_a#2, s#0.b AS _extract_b#3]
+ *         +- LocalRelation <empty>, [s#0]
+ *
+ * Example 2: Project above Filter
+ * -------------------------------
+ * Before:
+ * +- Project [s#0.a AS s.a#1]
+ *   +- Filter (length(s#0.b) > 2)
+ *     +- GlobalLimit 5
+ *       +- LocalLimit 5
+ *         +- LocalRelation <empty>, [s#0]
+ * After:
+ * +- Project [_extract_a#2 AS s.a#1]
+ *   +- Filter (length(_extract_b#3) > 2)
+ *     +- GlobalLimit 5
+ *       +- LocalLimit 5
+ *         +- Project [s#0.a AS _extract_a#2, s#0.b AS _extract_b#3]
+ *           +- LocalRelation <empty>, [s#0]
+ *
+ * Example 3: Nested fields with referenced parents
+ * ------------------------------------------------
+ * Before:
+ * +- Project [s#0.a AS s.a#1, s#0.a.a1 AS s.a.a1#2]
+ *   +- GlobalLimit 5
+ *     +- LocalLimit 5
+ *       +- LocalRelation <empty>, [s#0]
+ * After:
+ * +- Project [_extract_a#3 AS s.a#1, _extract_a#3.name AS s.a.a1#2]
+ *   +- GlobalLimit 5
+ *     +- LocalLimit 5
+ *       +- Project [s#0.a AS _extract_a#3]
+ *         +- LocalRelation <empty>, [s#0]
+ *
+ * The schema of the datasource relation will be pruned in the 
[[SchemaPruning]] optimizer rule.
  */
 object NestedColumnAliasing {
 
   def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
     /**
      * This pattern is needed to support [[Filter]] plan cases like
-     * [[Project]]->[[Filter]]->listed plan in `canProjectPushThrough` (e.g., 
[[Window]]).
-     * The reason why we don't simply add [[Filter]] in 
`canProjectPushThrough` is that
+     * [[Project]]->[[Filter]]->listed plan in [[canProjectPushThrough]] 
(e.g., [[Window]]).
+     * The reason why we don't simply add [[Filter]] in 
[[canProjectPushThrough]] is that
      * the optimizer can hit an infinite loop during the 
[[PushDownPredicates]] rule.
      */
-    case Project(projectList, Filter(condition, child))
-        if SQLConf.get.nestedSchemaPruningEnabled && 
canProjectPushThrough(child) =>
-      val exprCandidatesToPrune = projectList ++ Seq(condition) ++ 
child.expressions
-      getAliasSubMap(exprCandidatesToPrune, 
child.producedAttributes.toSeq).map {
-        case (nestedFieldToAlias, attrToAliases) =>
-          NestedColumnAliasing.replaceToAliases(plan, nestedFieldToAlias, 
attrToAliases)
-      }
+    case Project(projectList, Filter(condition, child)) if
+        SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) 
=>
+      rewritePlanIfSubsetFieldsUsed(
+        plan, projectList ++ Seq(condition) ++ child.expressions, 
child.producedAttributes.toSeq)
 
-    case Project(projectList, child)
-        if SQLConf.get.nestedSchemaPruningEnabled && 
canProjectPushThrough(child) =>
-      val exprCandidatesToPrune = projectList ++ child.expressions
-      getAliasSubMap(exprCandidatesToPrune, 
child.producedAttributes.toSeq).map {
-        case (nestedFieldToAlias, attrToAliases) =>
-          NestedColumnAliasing.replaceToAliases(plan, nestedFieldToAlias, 
attrToAliases)
-      }
+    case Project(projectList, child) if
+        SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) 
=>
+      rewritePlanIfSubsetFieldsUsed(
+        plan, projectList ++ child.expressions, child.producedAttributes.toSeq)
 
     case p if SQLConf.get.nestedSchemaPruningEnabled && canPruneOn(p) =>
-      val exprCandidatesToPrune = p.expressions
-      getAliasSubMap(exprCandidatesToPrune, p.producedAttributes.toSeq).map {
-        case (nestedFieldToAlias, attrToAliases) =>
-          NestedColumnAliasing.replaceToAliases(p, nestedFieldToAlias, 
attrToAliases)
-      }
+      rewritePlanIfSubsetFieldsUsed(
+        plan, p.expressions, p.producedAttributes.toSeq)
 
     case _ => None
   }
 
   /**
+   * Rewrites a plan with aliases if only a subset of the nested fields are 
used.
+   */
+  def rewritePlanIfSubsetFieldsUsed(
+      plan: LogicalPlan,
+      exprList: Seq[Expression],
+      exclusiveAttrs: Seq[Attribute]): Option[LogicalPlan] = {
+    val attrToExtractValues = getAttributeToExtractValues(exprList, 
exclusiveAttrs)
+    if (attrToExtractValues.isEmpty) {
+      None
+    } else {
+      Some(rewritePlanWithAliases(plan, attrToExtractValues))
+    }
+  }
+
+  /**
    * Replace nested columns to prune unused nested columns later.
    */
-  private def replaceToAliases(
+  def rewritePlanWithAliases(
       plan: LogicalPlan,
-      nestedFieldToAlias: Map[ExtractValue, Alias],
-      attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = plan match {
-    case Project(projectList, child) =>
-      Project(
-        getNewProjectList(projectList, nestedFieldToAlias),
-        replaceWithAliases(child, nestedFieldToAlias, attrToAliases))
-
-    // The operators reaching here was already guarded by `canPruneOn`.
-    case other =>
-      replaceWithAliases(other, nestedFieldToAlias, attrToAliases)
+      attributeToExtractValues: Map[Attribute, Seq[ExtractValue]]): 
LogicalPlan = {
+    // Each expression can contain multiple nested fields.
+    // Note that we keep the original names to deliver to parquet in a 
case-sensitive way.
+    // A new alias is created for each nested field.
+    // Implementation detail: we don't use mapValues, because it creates a 
mutable view.
+    val attributeToExtractValuesAndAliases =
+      attributeToExtractValues.map { case (attr, evSeq) =>
+        val evAliasSeq = evSeq.map { ev =>
+          val fieldName = ev match {
+            case g: GetStructField => g.extractFieldName
+            case g: GetArrayStructFields => g.field.name
+          }
+          ev -> Alias(ev, s"_extract_$fieldName")()
+        }
+
+        attr -> evAliasSeq
+      }
+
+    val nestedFieldToAlias = 
attributeToExtractValuesAndAliases.values.flatten.toMap
+
+    // A reference attribute can have multiple aliases for nested fields.
+    val attrToAliases = 
AttributeMap(attributeToExtractValuesAndAliases.mapValues(_.map(_._2)))
+
+    plan match {
+      case Project(projectList, child) =>
+        Project(
+          getNewProjectList(projectList, nestedFieldToAlias),
+          replaceWithAliases(child, nestedFieldToAlias, attrToAliases))
+
+      // The operators reaching here are already guarded by [[canPruneOn]].
+      case other =>
+        replaceWithAliases(other, nestedFieldToAlias, attrToAliases)
+    }
   }
 
   /**
-   * Return a replaced project list.
+   * Replace the [[ExtractValue]]s in a project list with aliased attributes.
    */
   def getNewProjectList(
       projectList: Seq[NamedExpression],
@@ -93,15 +173,15 @@ object NestedColumnAliasing {
   }
 
   /**
-   * Return a plan with new children replaced with aliases, and expressions 
replaced with
-   * aliased attributes.
+   * Replace the grandchildren of a plan with [[Project]]s of the nested 
fields as aliases,
+   * and replace the [[ExtractValue]] expressions with aliased attributes.
    */
   def replaceWithAliases(
       plan: LogicalPlan,
       nestedFieldToAlias: Map[ExtractValue, Alias],
-      attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
+      attrToAliases: AttributeMap[Seq[Alias]]): LogicalPlan = {
     plan.withNewChildren(plan.children.map { plan =>
-      Project(plan.output.flatMap(a => attrToAliases.getOrElse(a.exprId, 
Seq(a))), plan)
+      Project(plan.output.flatMap(a => attrToAliases.getOrElse(a, Seq(a))), 
plan)
     }).transformExpressions {
       case f: ExtractValue if nestedFieldToAlias.contains(f) =>
         nestedFieldToAlias(f).toAttribute
@@ -109,7 +189,7 @@ object NestedColumnAliasing {
   }
 
   /**
-   * Returns true for those operators that we can prune nested column on it.
+   * Returns true for operators on which we can prune nested columns.
    */
   private def canPruneOn(plan: LogicalPlan) = plan match {
     case _: Aggregate => true
@@ -118,7 +198,7 @@ object NestedColumnAliasing {
   }
 
   /**
-   * Returns true for those operators that project can be pushed through.
+   * Returns true for operators through which project can be pushed.
    */
   private def canProjectPushThrough(plan: LogicalPlan) = plan match {
     case _: GlobalLimit => true
@@ -133,9 +213,10 @@ object NestedColumnAliasing {
   }
 
   /**
-   * Return root references that are individually accessed as a whole, and 
`GetStructField`s
-   * or `GetArrayStructField`s which on top of other `ExtractValue`s or 
special expressions.
-   * Check `SelectedField` to see which expressions should be listed here.
+   * Returns two types of expressions:
+   * - Root references that are individually accessed
+   * - [[GetStructField]] or [[GetArrayStructFields]] on top of other 
[[ExtractValue]]s
+   *   or special expressions.
    */
   private def collectRootReferenceAndExtractValue(e: Expression): 
Seq[Expression] = e match {
     case _: AttributeReference => Seq(e)
@@ -149,67 +230,55 @@ object NestedColumnAliasing {
   }
 
   /**
-   * Return two maps in order to replace nested fields to aliases.
-   *
-   * If `exclusiveAttrs` is given, any nested field accessors of these 
attributes
-   * won't be considered in nested fields aliasing.
-   *
-   * 1. ExtractValue -> Alias: A new alias is created for each nested field.
-   * 2. ExprId -> Seq[Alias]: A reference attribute has multiple aliases 
pointing it.
+   * Creates a map from root [[Attribute]]s to non-redundant nested 
[[ExtractValue]]s.
+   * Nested field accessors of `exclusiveAttrs` are not considered in nested 
fields aliasing.
    */
-  def getAliasSubMap(exprList: Seq[Expression], exclusiveAttrs: Seq[Attribute] 
= Seq.empty)
-    : Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = {
-    val (nestedFieldReferences, otherRootReferences) =
-      exprList.flatMap(collectRootReferenceAndExtractValue).partition {
-        case _: ExtractValue => true
-        case _ => false
+  def getAttributeToExtractValues(
+      exprList: Seq[Expression],
+      exclusiveAttrs: Seq[Attribute]): Map[Attribute, Seq[ExtractValue]] = {
+
+    val nestedFieldReferences = new mutable.ArrayBuffer[ExtractValue]()
+    val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]()
+    exprList.foreach { e =>
+      collectRootReferenceAndExtractValue(e).foreach {
+        case ev: ExtractValue =>
+          if (ev.references.size == 1) {
+            nestedFieldReferences.append(ev)
+          }
+        case ar: AttributeReference => otherRootReferences.append(ar)
       }
-
-    // Note that when we group by extractors with their references, we should 
remove
-    // cosmetic variations.
+    }
     val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences)
-    val aliasSub = nestedFieldReferences.asInstanceOf[Seq[ExtractValue]]
+
+    // Remove cosmetic variations when we group extractors by their references
+    nestedFieldReferences
       .filter(!_.references.subsetOf(exclusiveAttrSet))
       .groupBy(_.references.head.canonicalized.asInstanceOf[Attribute])
-      .flatMap { case (attr, nestedFields: Seq[ExtractValue]) =>
-        // Remove redundant `ExtractValue`s if they share the same parent nest 
field.
+      .flatMap { case (attr: Attribute, nestedFields: Seq[ExtractValue]) =>
+        // Remove redundant [[ExtractValue]]s if they share the same parent 
nest field.
         // For example, when `a.b` and `a.b.c` are in project list, we only 
need to alias `a.b`.
-        // We only need to deal with two `ExtractValue`: 
`GetArrayStructFields` and
-        // `GetStructField`. Please refer to the method 
`collectRootReferenceAndExtractValue`.
+        // Because `a.b` requires all of the inner fields of `b`, we cannot 
prune `a.b.c`.
         val dedupNestedFields = nestedFields.filter {
+          // See [[collectExtractValue]]: we only need to deal with 
[[GetArrayStructFields]] and
+          // [[GetStructField]]
           case e @ (_: GetStructField | _: GetArrayStructFields) =>
             val child = e.children.head
             nestedFields.forall(f => child.find(_.semanticEquals(f)).isEmpty)
           case _ => true
-        }
-
-        // Each expression can contain multiple nested fields.
-        // Note that we keep the original names to deliver to parquet in a 
case-sensitive way.
-        val nestedFieldToAlias = dedupNestedFields.distinct.map { f =>
-          val exprId = NamedExpression.newExprId
-          (f, Alias(f, s"_gen_alias_${exprId.id}")(exprId, Seq.empty, None))
-        }
+        }.distinct
 
         // If all nested fields of `attr` are used, we don't need to introduce 
new aliases.
-        // By default, ColumnPruning rule uses `attr` already.
+        // By default, the [[ColumnPruning]] rule uses `attr` already.
         // Note that we need to remove cosmetic variations first, so we only 
count a
         // nested field once.
-        if (nestedFieldToAlias.nonEmpty &&
-            dedupNestedFields.map(_.canonicalized)
-              .distinct
-              .map { nestedField => totalFieldNum(nestedField.dataType) }
-              .sum < totalFieldNum(attr.dataType)) {
-          Some(attr.exprId -> nestedFieldToAlias)
+        val numUsedNestedFields = 
dedupNestedFields.map(_.canonicalized).distinct
+          .map { nestedField => totalFieldNum(nestedField.dataType) }.sum
+        if (numUsedNestedFields < totalFieldNum(attr.dataType)) {
+          Some((attr, dedupNestedFields.toSeq))
         } else {
           None
         }
       }
-
-    if (aliasSub.isEmpty) {
-      None
-    } else {
-      Some((aliasSub.values.flatten.toMap, aliasSub.map(x => (x._1, 
x._2.map(_._2)))))
-    }
   }
 
   /**
@@ -227,31 +296,9 @@ object NestedColumnAliasing {
 }
 
 /**
- * This prunes unnecessary nested columns from `Generate` and optional 
`Project` on top
- * of it.
+ * This prunes unnecessary nested columns from [[Generate]], or [[Project]] -> 
[[Generate]]
  */
 object GeneratorNestedColumnAliasing {
-  // Partitions `attrToAliases` based on whether the attribute is in 
Generator's output.
-  private def aliasesOnGeneratorOutput(
-      attrToAliases: Map[ExprId, Seq[Alias]],
-      generatorOutput: Seq[Attribute]) = {
-    val generatorOutputExprId = generatorOutput.map(_.exprId)
-    attrToAliases.partition { k =>
-      generatorOutputExprId.contains(k._1)
-    }
-  }
-
-  // Partitions `nestedFieldToAlias` based on whether the attribute of nested 
field extractor
-  // is in Generator's output.
-  private def nestedFieldOnGeneratorOutput(
-      nestedFieldToAlias: Map[ExtractValue, Alias],
-      generatorOutput: Seq[Attribute]) = {
-    val generatorOutputSet = AttributeSet(generatorOutput)
-    nestedFieldToAlias.partition { pair =>
-      pair._1.references.subsetOf(generatorOutputSet)
-    }
-  }
-
   def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
     // Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is 
enabled, we
     // need to prune nested columns through Project and under Generate. The 
difference is
@@ -261,103 +308,100 @@ object GeneratorNestedColumnAliasing {
         SQLConf.get.nestedSchemaPruningEnabled) && 
canPruneGenerator(g.generator) =>
       // On top on `Generate`, a `Project` that might have nested column 
accessors.
       // We try to get alias maps for both project list and generator's 
children expressions.
-      val exprsToPrune = projectList ++ g.generator.children
-      NestedColumnAliasing.getAliasSubMap(exprsToPrune).map {
-        case (nestedFieldToAlias, attrToAliases) =>
-          val (nestedFieldsOnGenerator, nestedFieldsNotOnGenerator) =
-            nestedFieldOnGeneratorOutput(nestedFieldToAlias, 
g.qualifiedGeneratorOutput)
-          val (attrToAliasesOnGenerator, attrToAliasesNotOnGenerator) =
-            aliasesOnGeneratorOutput(attrToAliases, g.qualifiedGeneratorOutput)
-
-          // Push nested column accessors through `Generator`.
-          // Defer updating `Generate.unrequiredChildIndex` to next round of 
`ColumnPruning`.
-          val newChild = NestedColumnAliasing.replaceWithAliases(g,
-            nestedFieldsNotOnGenerator, attrToAliasesNotOnGenerator)
-          val pushedThrough = Project(NestedColumnAliasing
-            .getNewProjectList(projectList, nestedFieldsNotOnGenerator), 
newChild)
-
-          // If the generator output is `ArrayType`, we cannot push through 
the extractor.
-          // It is because we don't allow field extractor on two-level array,
-          // i.e., attr.field when attr is a ArrayType(ArrayType(...)).
-          // Similarily, we also cannot push through if the child of generator 
is `MapType`.
-          g.generator.children.head.dataType match {
-            case _: MapType => return Some(pushedThrough)
-            case ArrayType(_: ArrayType, _) => return Some(pushedThrough)
-            case _ =>
-          }
-
-          // Pruning on `Generator`'s output. We only process single field 
case.
-          // For multiple field case, we cannot directly move field extractor 
into
-          // the generator expression. A workaround is to re-construct array 
of struct
-          // from multiple fields. But it will be more complicated and may not 
worth.
-          // TODO(SPARK-34956): support multiple fields.
-          if (nestedFieldsOnGenerator.size > 1 || 
nestedFieldsOnGenerator.isEmpty) {
-            pushedThrough
-          } else {
-            // Only one nested column accessor.
-            // E.g., df.select(explode($"items").as("item")).select($"item.a")
-            pushedThrough match {
-              case p @ Project(_, newG: Generate) =>
-                // Replace the child expression of `ExplodeBase` generator with
-                // nested column accessor.
-                // E.g., 
df.select(explode($"items").as("item")).select($"item.a") =>
-                //       df.select(explode($"items.a").as("item.a"))
-                val rewrittenG = newG.transformExpressions {
-                  case e: ExplodeBase =>
-                    val extractor = 
nestedFieldsOnGenerator.head._1.transformUp {
-                      case _: Attribute =>
-                        e.child
-                      case g: GetStructField =>
-                        ExtractValue(g.child, Literal(g.extractFieldName), 
SQLConf.get.resolver)
-                    }
-                    e.withNewChildren(Seq(extractor))
-                }
+      val attrToExtractValues = 
NestedColumnAliasing.getAttributeToExtractValues(
+        projectList ++ g.generator.children, Seq.empty)
+      if (attrToExtractValues.isEmpty) {
+        return None
+      }
+      val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput)
+      val (attrToExtractValuesOnGenerator, attrToExtractValuesNotOnGenerator) =
+        attrToExtractValues.partition { case (attr, _) =>
+          attr.references.subsetOf(generatorOutputSet) }
+
+      val pushedThrough = NestedColumnAliasing.rewritePlanWithAliases(
+        plan, attrToExtractValuesNotOnGenerator)
+
+      // If the generator output is `ArrayType`, we cannot push through the 
extractor.
+      // It is because we don't allow field extractor on two-level array,
+      // i.e., attr.field when attr is a ArrayType(ArrayType(...)).
+      // Similarily, we also cannot push through if the child of generator is 
`MapType`.
+      g.generator.children.head.dataType match {
+        case _: MapType => return Some(pushedThrough)
+        case ArrayType(_: ArrayType, _) => return Some(pushedThrough)
+        case _ =>
+      }
 
-                // As we change the child of the generator, its output data 
type must be updated.
-                val updatedGeneratorOutput = rewrittenG.generatorOutput
-                  .zip(rewrittenG.generator.elementSchema.toAttributes)
-                  .map { case (oldAttr, newAttr) =>
-                  newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
-                }
-                assert(updatedGeneratorOutput.length == 
rewrittenG.generatorOutput.length,
-                  "Updated generator output must have the same length " +
-                    "with original generator output.")
-                val updatedGenerate = rewrittenG.copy(generatorOutput = 
updatedGeneratorOutput)
-
-                // Replace nested column accessor with generator output.
-                p.withNewChildren(Seq(updatedGenerate)).transformExpressions {
-                  case f: ExtractValue if nestedFieldsOnGenerator.contains(f) 
=>
-                    updatedGenerate.output
-                      .find(a => attrToAliasesOnGenerator.contains(a.exprId))
-                      .getOrElse(f)
+      // Pruning on `Generator`'s output. We only process single field case.
+      // For multiple field case, we cannot directly move field extractor into
+      // the generator expression. A workaround is to re-construct array of 
struct
+      // from multiple fields. But it will be more complicated and may not 
worth.
+      // TODO(SPARK-34956): support multiple fields.
+      val nestedFieldsOnGenerator = 
attrToExtractValuesOnGenerator.values.flatten.toSet
+      if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty) 
{
+        Some(pushedThrough)
+      } else {
+        // Only one nested column accessor.
+        // E.g., df.select(explode($"items").as("item")).select($"item.a")
+        val nestedFieldOnGenerator = nestedFieldsOnGenerator.head
+        pushedThrough match {
+          case p @ Project(_, newG: Generate) =>
+            // Replace the child expression of `ExplodeBase` generator with
+            // nested column accessor.
+            // E.g., df.select(explode($"items").as("item")).select($"item.a") 
=>
+            //       df.select(explode($"items.a").as("item.a"))
+            val rewrittenG = newG.transformExpressions {
+              case e: ExplodeBase =>
+                val extractor = nestedFieldOnGenerator.transformUp {
+                  case _: Attribute =>
+                    e.child
+                  case g: GetStructField =>
+                    ExtractValue(g.child, Literal(g.extractFieldName), 
SQLConf.get.resolver)
                 }
+                e.withNewChildren(Seq(extractor))
+            }
 
-              case other =>
-                // We should not reach here.
-                throw new IllegalStateException(s"Unreasonable plan after 
optimization: $other")
+            // As we change the child of the generator, its output data type 
must be updated.
+            val updatedGeneratorOutput = rewrittenG.generatorOutput
+              .zip(rewrittenG.generator.elementSchema.toAttributes)
+              .map { case (oldAttr, newAttr) =>
+                newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
+              }
+            assert(updatedGeneratorOutput.length == 
rewrittenG.generatorOutput.length,
+              "Updated generator output must have the same length " +
+                "with original generator output.")
+            val updatedGenerate = rewrittenG.copy(generatorOutput = 
updatedGeneratorOutput)
+
+            // Replace nested column accessor with generator output.
+            val attrExprIdsOnGenerator = 
attrToExtractValuesOnGenerator.keys.map(_.exprId).toSet
+            val updatedProject = 
p.withNewChildren(Seq(updatedGenerate)).transformExpressions {
+              case f: ExtractValue if nestedFieldsOnGenerator.contains(f) =>
+                updatedGenerate.output
+                  .find(a => attrExprIdsOnGenerator.contains(a.exprId))
+                  .getOrElse(f)
             }
-          }
+            Some(updatedProject)
+
+          case other =>
+            // We should not reach here.
+            throw new IllegalStateException(s"Unreasonable plan after 
optimization: $other")
+        }
       }
 
     case g: Generate if SQLConf.get.nestedSchemaPruningEnabled &&
-        canPruneGenerator(g.generator) =>
+      canPruneGenerator(g.generator) =>
       // If any child output is required by higher projection, we cannot prune 
on it even we
       // only use part of nested column of it. A required child output means 
it is referred
       // as a whole or partially by higher projection, pruning it here will 
cause unresolved
       // query plan.
-      NestedColumnAliasing.getAliasSubMap(
-        g.generator.children, g.requiredChildOutput).map {
-        case (nestedFieldToAlias, attrToAliases) =>
-          // Defer updating `Generate.unrequiredChildIndex` to next round of 
`ColumnPruning`.
-          NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, 
attrToAliases)
-      }
+      NestedColumnAliasing.rewritePlanIfSubsetFieldsUsed(
+        plan, g.generator.children, g.requiredChildOutput)
 
     case _ =>
       None
   }
 
   /**
-   * This is a while-list for pruning nested fields at `Generator`.
+   * Types of [[Generator]] on which we can prune nested fields.
    */
   def canPruneGenerator(g: Generator): Boolean = g match {
     case _: Explode => true
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index defb3b4..028c9bd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -752,7 +752,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
       p.copy(child = g.copy(child = newChild, unrequiredChildIndex = 
unrequiredIndices))
 
     // prune unrequired nested fields from `Generate`.
-    case GeneratorNestedColumnAliasing(p) => p
+    case GeneratorNestedColumnAliasing(rewrittenPlan) => rewrittenPlan
 
     // Eliminate unneeded attributes from right side of a Left Existence Join.
     case j @ Join(_, right, LeftExistence(_), _, _) =>
@@ -786,7 +786,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
     // Can't prune the columns on LeafNode
     case p @ Project(_, _: LeafNode) => p
 
-    case NestedColumnAliasing(p) => p
+    case NestedColumnAliasing(rewrittenPlan) => rewrittenPlan
 
     // for all other logical plans that inherits the output from it's children
     // Project over project is handled by the first case, skip it here.
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
index a856caa..643974c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
@@ -714,7 +714,7 @@ object NestedColumnAliasingSuite {
   def collectGeneratedAliases(query: LogicalPlan): ArrayBuffer[String] = {
     val aliases = ArrayBuffer[String]()
     query.transformAllExpressions {
-      case a @ Alias(_, name) if name.startsWith("_gen_alias_") =>
+      case a @ Alias(_, name) if name.startsWith("_extract_") =>
         aliases += name
         a
     }

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to