This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e863166 [SPARK-35194][SQL] Refactor nested column aliasing for
readability
e863166 is described below
commit e8631660ecf316e4333210650d1f40b5912fb11b
Author: Karen Feng <[email protected]>
AuthorDate: Fri May 28 13:18:44 2021 +0000
[SPARK-35194][SQL] Refactor nested column aliasing for readability
### What changes were proposed in this pull request?
Refactors `NestedColumnAliasing` and `GeneratorNestedColumnAliasing` for
readability.
### Why are the changes needed?
Improves readability for future maintenance.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
Closes #32301 from karenfeng/refactor-nested-column-aliasing.
Authored-by: Karen Feng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/expressions/AttributeMap.scala | 6 +
.../sql/catalyst/expressions/AttributeMap.scala | 6 +
.../catalyst/optimizer/NestedColumnAliasing.scala | 426 ++++++++++++---------
.../spark/sql/catalyst/optimizer/Optimizer.scala | 4 +-
.../optimizer/NestedColumnAliasingSuite.scala | 2 +-
5 files changed, 250 insertions(+), 194 deletions(-)
diff --git
a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 42b92d4..189318a 100644
---
a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++
b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -23,6 +23,10 @@ package org.apache.spark.sql.catalyst.expressions
* of the name, or the expected nullability).
*/
object AttributeMap {
+ def apply[A](kvs: Map[Attribute, A]): AttributeMap[A] = {
+ new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)))
+ }
+
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
}
@@ -37,6 +41,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute,
A)])
override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
+ override def getOrElse[B1 >: A](k: Attribute, default: => B1): B1 =
get(k).getOrElse(default)
+
override def contains(k: Attribute): Boolean = get(k).isDefined
override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] =
baseMap.values.toMap + kv
diff --git
a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index e6b53e3..7715291 100644
---
a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++
b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -23,6 +23,10 @@ package org.apache.spark.sql.catalyst.expressions
* of the name, or the expected nullability).
*/
object AttributeMap {
+ def apply[A](kvs: Map[Attribute, A]): AttributeMap[A] = {
+ new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)))
+ }
+
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
}
@@ -37,6 +41,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute,
A)])
override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
+ override def getOrElse[B1 >: A](k: Attribute, default: => B1): B1 =
get(k).getOrElse(default)
+
override def contains(k: Attribute): Boolean = get(k).isDefined
override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1]
=
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
index 5b12667..cd7032d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
@@ -17,71 +17,151 @@
package org.apache.spark.sql.catalyst.optimizer
+import scala.collection.mutable
+
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
/**
- * This aims to handle a nested column aliasing pattern inside the
`ColumnPruning` optimizer rule.
- * If a project or its child references to nested fields, and not all the
fields
- * in a nested attribute are used, we can substitute them by alias attributes;
then a project
- * of the nested fields as aliases on the children of the child will be
created.
+ * This aims to handle a nested column aliasing pattern inside the
[[ColumnPruning]] optimizer rule.
+ * If:
+ * - A [[Project]] or its child references nested fields
+ * - Not all of the fields in a nested attribute are used
+ * Then:
+ * - Substitute the nested field references with alias attributes
+ * - Add grandchild [[Project]]s transforming the nested fields to aliases
+ *
+ * Example 1: Project
+ * ------------------
+ * Before:
+ * +- Project [concat_ws(s#0.a, s#0.b) AS concat_ws(s.a, s.b)#1]
+ * +- GlobalLimit 5
+ * +- LocalLimit 5
+ * +- LocalRelation <empty>, [s#0]
+ * After:
+ * +- Project [concat_ws(_extract_a#2, _extract_b#3) AS concat_ws(s.a, s.b)#1]
+ * +- GlobalLimit 5
+ * +- LocalLimit 5
+ * +- Project [s#0.a AS _extract_a#2, s#0.b AS _extract_b#3]
+ * +- LocalRelation <empty>, [s#0]
+ *
+ * Example 2: Project above Filter
+ * -------------------------------
+ * Before:
+ * +- Project [s#0.a AS s.a#1]
+ * +- Filter (length(s#0.b) > 2)
+ * +- GlobalLimit 5
+ * +- LocalLimit 5
+ * +- LocalRelation <empty>, [s#0]
+ * After:
+ * +- Project [_extract_a#2 AS s.a#1]
+ * +- Filter (length(_extract_b#3) > 2)
+ * +- GlobalLimit 5
+ * +- LocalLimit 5
+ * +- Project [s#0.a AS _extract_a#2, s#0.b AS _extract_b#3]
+ * +- LocalRelation <empty>, [s#0]
+ *
+ * Example 3: Nested fields with referenced parents
+ * ------------------------------------------------
+ * Before:
+ * +- Project [s#0.a AS s.a#1, s#0.a.a1 AS s.a.a1#2]
+ * +- GlobalLimit 5
+ * +- LocalLimit 5
+ * +- LocalRelation <empty>, [s#0]
+ * After:
+ * +- Project [_extract_a#3 AS s.a#1, _extract_a#3.name AS s.a.a1#2]
+ * +- GlobalLimit 5
+ * +- LocalLimit 5
+ * +- Project [s#0.a AS _extract_a#3]
+ * +- LocalRelation <empty>, [s#0]
+ *
+ * The schema of the datasource relation will be pruned in the
[[SchemaPruning]] optimizer rule.
*/
object NestedColumnAliasing {
def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
/**
* This pattern is needed to support [[Filter]] plan cases like
- * [[Project]]->[[Filter]]->listed plan in `canProjectPushThrough` (e.g.,
[[Window]]).
- * The reason why we don't simply add [[Filter]] in
`canProjectPushThrough` is that
+ * [[Project]]->[[Filter]]->listed plan in [[canProjectPushThrough]]
(e.g., [[Window]]).
+ * The reason why we don't simply add [[Filter]] in
[[canProjectPushThrough]] is that
* the optimizer can hit an infinite loop during the
[[PushDownPredicates]] rule.
*/
- case Project(projectList, Filter(condition, child))
- if SQLConf.get.nestedSchemaPruningEnabled &&
canProjectPushThrough(child) =>
- val exprCandidatesToPrune = projectList ++ Seq(condition) ++
child.expressions
- getAliasSubMap(exprCandidatesToPrune,
child.producedAttributes.toSeq).map {
- case (nestedFieldToAlias, attrToAliases) =>
- NestedColumnAliasing.replaceToAliases(plan, nestedFieldToAlias,
attrToAliases)
- }
+ case Project(projectList, Filter(condition, child)) if
+ SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child)
=>
+ rewritePlanIfSubsetFieldsUsed(
+ plan, projectList ++ Seq(condition) ++ child.expressions,
child.producedAttributes.toSeq)
- case Project(projectList, child)
- if SQLConf.get.nestedSchemaPruningEnabled &&
canProjectPushThrough(child) =>
- val exprCandidatesToPrune = projectList ++ child.expressions
- getAliasSubMap(exprCandidatesToPrune,
child.producedAttributes.toSeq).map {
- case (nestedFieldToAlias, attrToAliases) =>
- NestedColumnAliasing.replaceToAliases(plan, nestedFieldToAlias,
attrToAliases)
- }
+ case Project(projectList, child) if
+ SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child)
=>
+ rewritePlanIfSubsetFieldsUsed(
+ plan, projectList ++ child.expressions, child.producedAttributes.toSeq)
case p if SQLConf.get.nestedSchemaPruningEnabled && canPruneOn(p) =>
- val exprCandidatesToPrune = p.expressions
- getAliasSubMap(exprCandidatesToPrune, p.producedAttributes.toSeq).map {
- case (nestedFieldToAlias, attrToAliases) =>
- NestedColumnAliasing.replaceToAliases(p, nestedFieldToAlias,
attrToAliases)
- }
+ rewritePlanIfSubsetFieldsUsed(
+ plan, p.expressions, p.producedAttributes.toSeq)
case _ => None
}
/**
+ * Rewrites a plan with aliases if only a subset of the nested fields are
used.
+ */
+ def rewritePlanIfSubsetFieldsUsed(
+ plan: LogicalPlan,
+ exprList: Seq[Expression],
+ exclusiveAttrs: Seq[Attribute]): Option[LogicalPlan] = {
+ val attrToExtractValues = getAttributeToExtractValues(exprList,
exclusiveAttrs)
+ if (attrToExtractValues.isEmpty) {
+ None
+ } else {
+ Some(rewritePlanWithAliases(plan, attrToExtractValues))
+ }
+ }
+
+ /**
* Replace nested columns to prune unused nested columns later.
*/
- private def replaceToAliases(
+ def rewritePlanWithAliases(
plan: LogicalPlan,
- nestedFieldToAlias: Map[ExtractValue, Alias],
- attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = plan match {
- case Project(projectList, child) =>
- Project(
- getNewProjectList(projectList, nestedFieldToAlias),
- replaceWithAliases(child, nestedFieldToAlias, attrToAliases))
-
- // The operators reaching here was already guarded by `canPruneOn`.
- case other =>
- replaceWithAliases(other, nestedFieldToAlias, attrToAliases)
+ attributeToExtractValues: Map[Attribute, Seq[ExtractValue]]):
LogicalPlan = {
+ // Each expression can contain multiple nested fields.
+ // Note that we keep the original names to deliver to parquet in a
case-sensitive way.
+ // A new alias is created for each nested field.
+ // Implementation detail: we don't use mapValues, because it creates a
mutable view.
+ val attributeToExtractValuesAndAliases =
+ attributeToExtractValues.map { case (attr, evSeq) =>
+ val evAliasSeq = evSeq.map { ev =>
+ val fieldName = ev match {
+ case g: GetStructField => g.extractFieldName
+ case g: GetArrayStructFields => g.field.name
+ }
+ ev -> Alias(ev, s"_extract_$fieldName")()
+ }
+
+ attr -> evAliasSeq
+ }
+
+ val nestedFieldToAlias =
attributeToExtractValuesAndAliases.values.flatten.toMap
+
+ // A reference attribute can have multiple aliases for nested fields.
+ val attrToAliases =
AttributeMap(attributeToExtractValuesAndAliases.mapValues(_.map(_._2)))
+
+ plan match {
+ case Project(projectList, child) =>
+ Project(
+ getNewProjectList(projectList, nestedFieldToAlias),
+ replaceWithAliases(child, nestedFieldToAlias, attrToAliases))
+
+ // The operators reaching here are already guarded by [[canPruneOn]].
+ case other =>
+ replaceWithAliases(other, nestedFieldToAlias, attrToAliases)
+ }
}
/**
- * Return a replaced project list.
+ * Replace the [[ExtractValue]]s in a project list with aliased attributes.
*/
def getNewProjectList(
projectList: Seq[NamedExpression],
@@ -93,15 +173,15 @@ object NestedColumnAliasing {
}
/**
- * Return a plan with new children replaced with aliases, and expressions
replaced with
- * aliased attributes.
+ * Replace the grandchildren of a plan with [[Project]]s of the nested
fields as aliases,
+ * and replace the [[ExtractValue]] expressions with aliased attributes.
*/
def replaceWithAliases(
plan: LogicalPlan,
nestedFieldToAlias: Map[ExtractValue, Alias],
- attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
+ attrToAliases: AttributeMap[Seq[Alias]]): LogicalPlan = {
plan.withNewChildren(plan.children.map { plan =>
- Project(plan.output.flatMap(a => attrToAliases.getOrElse(a.exprId,
Seq(a))), plan)
+ Project(plan.output.flatMap(a => attrToAliases.getOrElse(a, Seq(a))),
plan)
}).transformExpressions {
case f: ExtractValue if nestedFieldToAlias.contains(f) =>
nestedFieldToAlias(f).toAttribute
@@ -109,7 +189,7 @@ object NestedColumnAliasing {
}
/**
- * Returns true for those operators that we can prune nested column on it.
+ * Returns true for operators on which we can prune nested columns.
*/
private def canPruneOn(plan: LogicalPlan) = plan match {
case _: Aggregate => true
@@ -118,7 +198,7 @@ object NestedColumnAliasing {
}
/**
- * Returns true for those operators that project can be pushed through.
+ * Returns true for operators through which project can be pushed.
*/
private def canProjectPushThrough(plan: LogicalPlan) = plan match {
case _: GlobalLimit => true
@@ -133,9 +213,10 @@ object NestedColumnAliasing {
}
/**
- * Return root references that are individually accessed as a whole, and
`GetStructField`s
- * or `GetArrayStructField`s which on top of other `ExtractValue`s or
special expressions.
- * Check `SelectedField` to see which expressions should be listed here.
+ * Returns two types of expressions:
+ * - Root references that are individually accessed
+ * - [[GetStructField]] or [[GetArrayStructFields]] on top of other
[[ExtractValue]]s
+ * or special expressions.
*/
private def collectRootReferenceAndExtractValue(e: Expression):
Seq[Expression] = e match {
case _: AttributeReference => Seq(e)
@@ -149,67 +230,55 @@ object NestedColumnAliasing {
}
/**
- * Return two maps in order to replace nested fields to aliases.
- *
- * If `exclusiveAttrs` is given, any nested field accessors of these
attributes
- * won't be considered in nested fields aliasing.
- *
- * 1. ExtractValue -> Alias: A new alias is created for each nested field.
- * 2. ExprId -> Seq[Alias]: A reference attribute has multiple aliases
pointing it.
+ * Creates a map from root [[Attribute]]s to non-redundant nested
[[ExtractValue]]s.
+ * Nested field accessors of `exclusiveAttrs` are not considered in nested
fields aliasing.
*/
- def getAliasSubMap(exprList: Seq[Expression], exclusiveAttrs: Seq[Attribute]
= Seq.empty)
- : Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = {
- val (nestedFieldReferences, otherRootReferences) =
- exprList.flatMap(collectRootReferenceAndExtractValue).partition {
- case _: ExtractValue => true
- case _ => false
+ def getAttributeToExtractValues(
+ exprList: Seq[Expression],
+ exclusiveAttrs: Seq[Attribute]): Map[Attribute, Seq[ExtractValue]] = {
+
+ val nestedFieldReferences = new mutable.ArrayBuffer[ExtractValue]()
+ val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]()
+ exprList.foreach { e =>
+ collectRootReferenceAndExtractValue(e).foreach {
+ case ev: ExtractValue =>
+ if (ev.references.size == 1) {
+ nestedFieldReferences.append(ev)
+ }
+ case ar: AttributeReference => otherRootReferences.append(ar)
}
-
- // Note that when we group by extractors with their references, we should
remove
- // cosmetic variations.
+ }
val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences)
- val aliasSub = nestedFieldReferences.asInstanceOf[Seq[ExtractValue]]
+
+ // Remove cosmetic variations when we group extractors by their references
+ nestedFieldReferences
.filter(!_.references.subsetOf(exclusiveAttrSet))
.groupBy(_.references.head.canonicalized.asInstanceOf[Attribute])
- .flatMap { case (attr, nestedFields: Seq[ExtractValue]) =>
- // Remove redundant `ExtractValue`s if they share the same parent nest
field.
+ .flatMap { case (attr: Attribute, nestedFields: Seq[ExtractValue]) =>
+ // Remove redundant [[ExtractValue]]s if they share the same parent
nest field.
// For example, when `a.b` and `a.b.c` are in project list, we only
need to alias `a.b`.
- // We only need to deal with two `ExtractValue`:
`GetArrayStructFields` and
- // `GetStructField`. Please refer to the method
`collectRootReferenceAndExtractValue`.
+ // Because `a.b` requires all of the inner fields of `b`, we cannot
prune `a.b.c`.
val dedupNestedFields = nestedFields.filter {
+ // See [[collectExtractValue]]: we only need to deal with
[[GetArrayStructFields]] and
+ // [[GetStructField]]
case e @ (_: GetStructField | _: GetArrayStructFields) =>
val child = e.children.head
nestedFields.forall(f => child.find(_.semanticEquals(f)).isEmpty)
case _ => true
- }
-
- // Each expression can contain multiple nested fields.
- // Note that we keep the original names to deliver to parquet in a
case-sensitive way.
- val nestedFieldToAlias = dedupNestedFields.distinct.map { f =>
- val exprId = NamedExpression.newExprId
- (f, Alias(f, s"_gen_alias_${exprId.id}")(exprId, Seq.empty, None))
- }
+ }.distinct
// If all nested fields of `attr` are used, we don't need to introduce
new aliases.
- // By default, ColumnPruning rule uses `attr` already.
+ // By default, the [[ColumnPruning]] rule uses `attr` already.
// Note that we need to remove cosmetic variations first, so we only
count a
// nested field once.
- if (nestedFieldToAlias.nonEmpty &&
- dedupNestedFields.map(_.canonicalized)
- .distinct
- .map { nestedField => totalFieldNum(nestedField.dataType) }
- .sum < totalFieldNum(attr.dataType)) {
- Some(attr.exprId -> nestedFieldToAlias)
+ val numUsedNestedFields =
dedupNestedFields.map(_.canonicalized).distinct
+ .map { nestedField => totalFieldNum(nestedField.dataType) }.sum
+ if (numUsedNestedFields < totalFieldNum(attr.dataType)) {
+ Some((attr, dedupNestedFields.toSeq))
} else {
None
}
}
-
- if (aliasSub.isEmpty) {
- None
- } else {
- Some((aliasSub.values.flatten.toMap, aliasSub.map(x => (x._1,
x._2.map(_._2)))))
- }
}
/**
@@ -227,31 +296,9 @@ object NestedColumnAliasing {
}
/**
- * This prunes unnecessary nested columns from `Generate` and optional
`Project` on top
- * of it.
+ * This prunes unnecessary nested columns from [[Generate]], or [[Project]] ->
[[Generate]]
*/
object GeneratorNestedColumnAliasing {
- // Partitions `attrToAliases` based on whether the attribute is in
Generator's output.
- private def aliasesOnGeneratorOutput(
- attrToAliases: Map[ExprId, Seq[Alias]],
- generatorOutput: Seq[Attribute]) = {
- val generatorOutputExprId = generatorOutput.map(_.exprId)
- attrToAliases.partition { k =>
- generatorOutputExprId.contains(k._1)
- }
- }
-
- // Partitions `nestedFieldToAlias` based on whether the attribute of nested
field extractor
- // is in Generator's output.
- private def nestedFieldOnGeneratorOutput(
- nestedFieldToAlias: Map[ExtractValue, Alias],
- generatorOutput: Seq[Attribute]) = {
- val generatorOutputSet = AttributeSet(generatorOutput)
- nestedFieldToAlias.partition { pair =>
- pair._1.references.subsetOf(generatorOutputSet)
- }
- }
-
def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
// Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is
enabled, we
// need to prune nested columns through Project and under Generate. The
difference is
@@ -261,103 +308,100 @@ object GeneratorNestedColumnAliasing {
SQLConf.get.nestedSchemaPruningEnabled) &&
canPruneGenerator(g.generator) =>
// On top on `Generate`, a `Project` that might have nested column
accessors.
// We try to get alias maps for both project list and generator's
children expressions.
- val exprsToPrune = projectList ++ g.generator.children
- NestedColumnAliasing.getAliasSubMap(exprsToPrune).map {
- case (nestedFieldToAlias, attrToAliases) =>
- val (nestedFieldsOnGenerator, nestedFieldsNotOnGenerator) =
- nestedFieldOnGeneratorOutput(nestedFieldToAlias,
g.qualifiedGeneratorOutput)
- val (attrToAliasesOnGenerator, attrToAliasesNotOnGenerator) =
- aliasesOnGeneratorOutput(attrToAliases, g.qualifiedGeneratorOutput)
-
- // Push nested column accessors through `Generator`.
- // Defer updating `Generate.unrequiredChildIndex` to next round of
`ColumnPruning`.
- val newChild = NestedColumnAliasing.replaceWithAliases(g,
- nestedFieldsNotOnGenerator, attrToAliasesNotOnGenerator)
- val pushedThrough = Project(NestedColumnAliasing
- .getNewProjectList(projectList, nestedFieldsNotOnGenerator),
newChild)
-
- // If the generator output is `ArrayType`, we cannot push through
the extractor.
- // It is because we don't allow field extractor on two-level array,
- // i.e., attr.field when attr is a ArrayType(ArrayType(...)).
- // Similarily, we also cannot push through if the child of generator
is `MapType`.
- g.generator.children.head.dataType match {
- case _: MapType => return Some(pushedThrough)
- case ArrayType(_: ArrayType, _) => return Some(pushedThrough)
- case _ =>
- }
-
- // Pruning on `Generator`'s output. We only process single field
case.
- // For multiple field case, we cannot directly move field extractor
into
- // the generator expression. A workaround is to re-construct array
of struct
- // from multiple fields. But it will be more complicated and may not
worth.
- // TODO(SPARK-34956): support multiple fields.
- if (nestedFieldsOnGenerator.size > 1 ||
nestedFieldsOnGenerator.isEmpty) {
- pushedThrough
- } else {
- // Only one nested column accessor.
- // E.g., df.select(explode($"items").as("item")).select($"item.a")
- pushedThrough match {
- case p @ Project(_, newG: Generate) =>
- // Replace the child expression of `ExplodeBase` generator with
- // nested column accessor.
- // E.g.,
df.select(explode($"items").as("item")).select($"item.a") =>
- // df.select(explode($"items.a").as("item.a"))
- val rewrittenG = newG.transformExpressions {
- case e: ExplodeBase =>
- val extractor =
nestedFieldsOnGenerator.head._1.transformUp {
- case _: Attribute =>
- e.child
- case g: GetStructField =>
- ExtractValue(g.child, Literal(g.extractFieldName),
SQLConf.get.resolver)
- }
- e.withNewChildren(Seq(extractor))
- }
+ val attrToExtractValues =
NestedColumnAliasing.getAttributeToExtractValues(
+ projectList ++ g.generator.children, Seq.empty)
+ if (attrToExtractValues.isEmpty) {
+ return None
+ }
+ val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput)
+ val (attrToExtractValuesOnGenerator, attrToExtractValuesNotOnGenerator) =
+ attrToExtractValues.partition { case (attr, _) =>
+ attr.references.subsetOf(generatorOutputSet) }
+
+ val pushedThrough = NestedColumnAliasing.rewritePlanWithAliases(
+ plan, attrToExtractValuesNotOnGenerator)
+
+ // If the generator output is `ArrayType`, we cannot push through the
extractor.
+ // It is because we don't allow field extractor on two-level array,
+ // i.e., attr.field when attr is a ArrayType(ArrayType(...)).
+ // Similarily, we also cannot push through if the child of generator is
`MapType`.
+ g.generator.children.head.dataType match {
+ case _: MapType => return Some(pushedThrough)
+ case ArrayType(_: ArrayType, _) => return Some(pushedThrough)
+ case _ =>
+ }
- // As we change the child of the generator, its output data
type must be updated.
- val updatedGeneratorOutput = rewrittenG.generatorOutput
- .zip(rewrittenG.generator.elementSchema.toAttributes)
- .map { case (oldAttr, newAttr) =>
- newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
- }
- assert(updatedGeneratorOutput.length ==
rewrittenG.generatorOutput.length,
- "Updated generator output must have the same length " +
- "with original generator output.")
- val updatedGenerate = rewrittenG.copy(generatorOutput =
updatedGeneratorOutput)
-
- // Replace nested column accessor with generator output.
- p.withNewChildren(Seq(updatedGenerate)).transformExpressions {
- case f: ExtractValue if nestedFieldsOnGenerator.contains(f)
=>
- updatedGenerate.output
- .find(a => attrToAliasesOnGenerator.contains(a.exprId))
- .getOrElse(f)
+ // Pruning on `Generator`'s output. We only process single field case.
+ // For multiple field case, we cannot directly move field extractor into
+ // the generator expression. A workaround is to re-construct array of
struct
+ // from multiple fields. But it will be more complicated and may not
worth.
+ // TODO(SPARK-34956): support multiple fields.
+ val nestedFieldsOnGenerator =
attrToExtractValuesOnGenerator.values.flatten.toSet
+ if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty)
{
+ Some(pushedThrough)
+ } else {
+ // Only one nested column accessor.
+ // E.g., df.select(explode($"items").as("item")).select($"item.a")
+ val nestedFieldOnGenerator = nestedFieldsOnGenerator.head
+ pushedThrough match {
+ case p @ Project(_, newG: Generate) =>
+ // Replace the child expression of `ExplodeBase` generator with
+ // nested column accessor.
+ // E.g., df.select(explode($"items").as("item")).select($"item.a")
=>
+ // df.select(explode($"items.a").as("item.a"))
+ val rewrittenG = newG.transformExpressions {
+ case e: ExplodeBase =>
+ val extractor = nestedFieldOnGenerator.transformUp {
+ case _: Attribute =>
+ e.child
+ case g: GetStructField =>
+ ExtractValue(g.child, Literal(g.extractFieldName),
SQLConf.get.resolver)
}
+ e.withNewChildren(Seq(extractor))
+ }
- case other =>
- // We should not reach here.
- throw new IllegalStateException(s"Unreasonable plan after
optimization: $other")
+ // As we change the child of the generator, its output data type
must be updated.
+ val updatedGeneratorOutput = rewrittenG.generatorOutput
+ .zip(rewrittenG.generator.elementSchema.toAttributes)
+ .map { case (oldAttr, newAttr) =>
+ newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
+ }
+ assert(updatedGeneratorOutput.length ==
rewrittenG.generatorOutput.length,
+ "Updated generator output must have the same length " +
+ "with original generator output.")
+ val updatedGenerate = rewrittenG.copy(generatorOutput =
updatedGeneratorOutput)
+
+ // Replace nested column accessor with generator output.
+ val attrExprIdsOnGenerator =
attrToExtractValuesOnGenerator.keys.map(_.exprId).toSet
+ val updatedProject =
p.withNewChildren(Seq(updatedGenerate)).transformExpressions {
+ case f: ExtractValue if nestedFieldsOnGenerator.contains(f) =>
+ updatedGenerate.output
+ .find(a => attrExprIdsOnGenerator.contains(a.exprId))
+ .getOrElse(f)
}
- }
+ Some(updatedProject)
+
+ case other =>
+ // We should not reach here.
+ throw new IllegalStateException(s"Unreasonable plan after
optimization: $other")
+ }
}
case g: Generate if SQLConf.get.nestedSchemaPruningEnabled &&
- canPruneGenerator(g.generator) =>
+ canPruneGenerator(g.generator) =>
// If any child output is required by higher projection, we cannot prune
on it even we
// only use part of nested column of it. A required child output means
it is referred
// as a whole or partially by higher projection, pruning it here will
cause unresolved
// query plan.
- NestedColumnAliasing.getAliasSubMap(
- g.generator.children, g.requiredChildOutput).map {
- case (nestedFieldToAlias, attrToAliases) =>
- // Defer updating `Generate.unrequiredChildIndex` to next round of
`ColumnPruning`.
- NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias,
attrToAliases)
- }
+ NestedColumnAliasing.rewritePlanIfSubsetFieldsUsed(
+ plan, g.generator.children, g.requiredChildOutput)
case _ =>
None
}
/**
- * This is a while-list for pruning nested fields at `Generator`.
+ * Types of [[Generator]] on which we can prune nested fields.
*/
def canPruneGenerator(g: Generator): Boolean = g match {
case _: Explode => true
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index defb3b4..028c9bd 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -752,7 +752,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
p.copy(child = g.copy(child = newChild, unrequiredChildIndex =
unrequiredIndices))
// prune unrequired nested fields from `Generate`.
- case GeneratorNestedColumnAliasing(p) => p
+ case GeneratorNestedColumnAliasing(rewrittenPlan) => rewrittenPlan
// Eliminate unneeded attributes from right side of a Left Existence Join.
case j @ Join(_, right, LeftExistence(_), _, _) =>
@@ -786,7 +786,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
// Can't prune the columns on LeafNode
case p @ Project(_, _: LeafNode) => p
- case NestedColumnAliasing(p) => p
+ case NestedColumnAliasing(rewrittenPlan) => rewrittenPlan
// for all other logical plans that inherits the output from it's children
// Project over project is handled by the first case, skip it here.
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
index a856caa..643974c 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
@@ -714,7 +714,7 @@ object NestedColumnAliasingSuite {
def collectGeneratedAliases(query: LogicalPlan): ArrayBuffer[String] = {
val aliases = ArrayBuffer[String]()
query.transformAllExpressions {
- case a @ Alias(_, name) if name.startsWith("_gen_alias_") =>
+ case a @ Alias(_, name) if name.startsWith("_extract_") =>
aliases += name
a
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]