allisonwang-db commented on a change in pull request #32301:
URL: https://github.com/apache/spark/pull/32301#discussion_r618834219
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
##########
@@ -30,54 +30,61 @@ import org.apache.spark.sql.types._
*/
object NestedColumnAliasing {
- def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
+ def unapply(plan: LogicalPlan): Option[Map[Attribute, Seq[ExtractValue]]] =
plan match {
Review comment:
It's a bit unnatural to return an option of map for
getAttributeToExtractValues. How about using the unapply method to match the
pattern and extract the information needed to build the map:
```scala
def unapply(plan: LogicalPlan): Option[(Seq[Expression], Seq[Attribute])] =
plan match {
case Project(projectList, g: Generate) if
(SQLConf.get.nestedPruningOnExpressions ||
SQLConf.get.nestedSchemaPruningEnabled) &&
canPruneGenerator(g.generator) =>
Some((projectList ++ g.generator.children, g.qualifiedGeneratorOutput))
case ...
}
```
Then build the mapping in the optimizer rule.
```scala
case p @ NestedColumnAliasing((exprs, excludedAttrs)) =>
...
```
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
##########
@@ -133,82 +140,84 @@ object NestedColumnAliasing {
}
/**
- * Return root references that are individually accessed as a whole, and
`GetStructField`s
- * or `GetArrayStructField`s which on top of other `ExtractValue`s or
special expressions.
- * Check `SelectedField` to see which expressions should be listed here.
+ * Check [[SelectedField]] to see which expressions should be listed here.
*/
- private def collectRootReferenceAndExtractValue(e: Expression):
Seq[Expression] = e match {
- case _: AttributeReference => Seq(e)
- case GetStructField(_: ExtractValue | _: AttributeReference, _, _) =>
Seq(e)
+ private def isSelectedField(e: Expression): Boolean = e match {
+ case GetStructField(_: ExtractValue | _: AttributeReference, _, _) => true
case GetArrayStructFields(_: MapValues |
_: MapKeys |
_: ExtractValue |
- _: AttributeReference, _, _, _, _) => Seq(e)
- case es if es.children.nonEmpty =>
es.children.flatMap(collectRootReferenceAndExtractValue)
+ _: AttributeReference, _, _, _, _) => true
+ case _ => false
+ }
+
+ /**
+ * Return root references that are individually accessed.
+ */
+ private def collectAttributeReference(e: Expression):
Seq[AttributeReference] = e match {
+ case a: AttributeReference => Seq(a)
+ case g if isSelectedField(g) => Seq.empty
+ case es if es.children.nonEmpty =>
es.children.flatMap(collectAttributeReference)
case _ => Seq.empty
}
/**
- * Return two maps in order to replace nested fields to aliases.
- *
- * If `exclusiveAttrs` is given, any nested field accessors of these
attributes
- * won't be considered in nested fields aliasing.
- *
- * 1. ExtractValue -> Alias: A new alias is created for each nested field.
- * 2. ExprId -> Seq[Alias]: A reference attribute has multiple aliases
pointing it.
+ * Return [[GetStructField]] or [[GetArrayStructFields]] on top of other
[[ExtractValue]]s
+ * or special expressions.
*/
- def getAliasSubMap(exprList: Seq[Expression], exclusiveAttrs: Seq[Attribute]
= Seq.empty)
- : Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = {
- val (nestedFieldReferences, otherRootReferences) =
- exprList.flatMap(collectRootReferenceAndExtractValue).partition {
- case _: ExtractValue => true
- case _ => false
- }
+ private def collectExtractValue(e: Expression): Seq[ExtractValue] = e match {
+ case g if isSelectedField(g) => Seq(g.asInstanceOf[ExtractValue])
+ case es if es.children.nonEmpty => es.children.flatMap(collectExtractValue)
+ case _ => Seq.empty
+ }
- // Note that when we group by extractors with their references, we should
remove
- // cosmetic variations.
+ /**
+ * Creates a map from root [[Attribute]]s to non-redundant nested
[[ExtractValue]]s in the
+ * case that only a subset of the nested fields are used.
+ * Nested field accessors of `exclusiveAttrs` are not considered in nested
fields aliasing.
+ */
+ def getAttributeToExtractValues(
+ exprList: Seq[Expression],
+ exclusiveAttrs: Seq[Attribute]): Option[Map[Attribute,
Seq[ExtractValue]]] = {
+
+ val nestedFieldReferences = exprList.flatMap(collectExtractValue)
+ val otherRootReferences = exprList.flatMap(collectAttributeReference)
val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences)
- val aliasSub = nestedFieldReferences.asInstanceOf[Seq[ExtractValue]]
+
+ // Remove cosmetic variations when we group extractors by their references
+ val attributeToExtractValues = nestedFieldReferences
.filter(!_.references.subsetOf(exclusiveAttrSet))
.groupBy(_.references.head.canonicalized.asInstanceOf[Attribute])
- .flatMap { case (attr, nestedFields: Seq[ExtractValue]) =>
- // Remove redundant `ExtractValue`s if they share the same parent nest
field.
+ .flatMap { case (attr: Attribute, nestedFields: Seq[ExtractValue]) =>
+ // Remove redundant [[ExtractValue]]s if they share the same parent
nest field.
// For example, when `a.b` and `a.b.c` are in project list, we only
need to alias `a.b`.
- // We only need to deal with two `ExtractValue`:
`GetArrayStructFields` and
- // `GetStructField`. Please refer to the method
`collectRootReferenceAndExtractValue`.
val dedupNestedFields = nestedFields.filter {
- case e @ (_: GetStructField | _: GetArrayStructFields) =>
- val child = e.children.head
+ // See [[collectExtractValue]]: we only need to deal with
[[GetArrayStructFields]] and
+ // [[GetStructField]]
+ case GetStructField(child, _, _) =>
Review comment:
This seems to be the same as the original logic?
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
##########
@@ -30,54 +30,61 @@ import org.apache.spark.sql.types._
*/
object NestedColumnAliasing {
- def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
+ def unapply(plan: LogicalPlan): Option[Map[Attribute, Seq[ExtractValue]]] =
plan match {
/**
* This pattern is needed to support [[Filter]] plan cases like
* [[Project]]->[[Filter]]->listed plan in `canProjectPushThrough` (e.g.,
[[Window]]).
* The reason why we don't simply add [[Filter]] in
`canProjectPushThrough` is that
* the optimizer can hit an infinite loop during the
[[PushDownPredicates]] rule.
*/
- case Project(projectList, Filter(condition, child))
- if SQLConf.get.nestedSchemaPruningEnabled &&
canProjectPushThrough(child) =>
- val exprCandidatesToPrune = projectList ++ Seq(condition) ++
child.expressions
- getAliasSubMap(exprCandidatesToPrune,
child.producedAttributes.toSeq).map {
- case (nestedFieldToAlias, attrToAliases) =>
- NestedColumnAliasing.replaceToAliases(plan, nestedFieldToAlias,
attrToAliases)
- }
+ case Project(projectList, Filter(condition, child)) if
+ SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child)
=>
+ getAttributeToExtractValues(
+ projectList ++ Seq(condition) ++ child.expressions,
child.producedAttributes.toSeq)
- case Project(projectList, child)
- if SQLConf.get.nestedSchemaPruningEnabled &&
canProjectPushThrough(child) =>
- val exprCandidatesToPrune = projectList ++ child.expressions
- getAliasSubMap(exprCandidatesToPrune,
child.producedAttributes.toSeq).map {
- case (nestedFieldToAlias, attrToAliases) =>
- NestedColumnAliasing.replaceToAliases(plan, nestedFieldToAlias,
attrToAliases)
- }
+ case Project(projectList, child) if
+ SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child)
=>
+ getAttributeToExtractValues(
+ projectList ++ child.expressions, child.producedAttributes.toSeq)
case p if SQLConf.get.nestedSchemaPruningEnabled && canPruneOn(p) =>
- val exprCandidatesToPrune = p.expressions
- getAliasSubMap(exprCandidatesToPrune, p.producedAttributes.toSeq).map {
- case (nestedFieldToAlias, attrToAliases) =>
- NestedColumnAliasing.replaceToAliases(p, nestedFieldToAlias,
attrToAliases)
- }
+ getAttributeToExtractValues(
+ p.expressions, p.producedAttributes.toSeq)
case _ => None
}
/**
* Replace nested columns to prune unused nested columns later.
*/
- private def replaceToAliases(
- plan: LogicalPlan,
- nestedFieldToAlias: Map[ExtractValue, Alias],
- attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = plan match {
- case Project(projectList, child) =>
- Project(
- getNewProjectList(projectList, nestedFieldToAlias),
- replaceWithAliases(child, nestedFieldToAlias, attrToAliases))
+ def replacePlanWithAliases(
+ plan: LogicalPlan,
Review comment:
nit: 4 space indentation
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
##########
@@ -227,11 +236,11 @@ object NestedColumnAliasing {
}
/**
- * This prunes unnecessary nested columns from `Generate` and optional
`Project` on top
+ * This prunes unnecessary nested columns from [[Generate]] and optional
[[Project]] on top
* of it.
*/
object GeneratorNestedColumnAliasing {
- def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
+ def unapply(plan: LogicalPlan): Option[Map[Attribute, Seq[ExtractValue]]] =
plan match {
Review comment:
ditto
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]