cloud-fan commented on code in PR #43843:
URL: https://github.com/apache/spark/pull/43843#discussion_r1400043814
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala:
##########
@@ -444,6 +447,156 @@ case class UnresolvedStar(target: Option[Seq[String]])
extends Star with Unevalu
override def toString: String = target.map(_.mkString("", ".",
".")).getOrElse("") + "*"
}
+/**
+ * Represents some of the input attributes to a given relational operator, for
example in
+ * "SELECT * EXCEPT(a) FROM ...".
+ *
+ * @param target an optional name that should be the target of the expansion.
If omitted all
+ * targets' columns are produced. This can only be a table name.
This
+ * is a list of identifiers that is the path of the expansion.
+ *
+ * @param excepts a list of names that should be excluded from the expansion.
+ *
+ */
+case class UnresolvedStarExcept(target: Option[Seq[String]], excepts:
Seq[Seq[String]])
+ extends UnresolvedStarBase(target) {
+
+ /**
+ * We expand the * EXCEPT by the following three steps:
+ * 1. use the original .expand() to get top-level column list or struct
expansion
+ * 2. resolve excepts (with respect to the Seq[NamedExpression] returned
from (1))
+ * 3. filter the expanded columns with the resolved except list. recursively
apply filtering in
+ * case of nested columns in the except list (in order to rewrite structs)
+ */
+ override def expand(input: LogicalPlan, resolver: Resolver):
Seq[NamedExpression] = {
+ // Use the UnresolvedStarBase expand method to get a seq of
NamedExpressions corresponding to
+ // the star expansion. This will yield a list of top-level columns from
the logical plan's
+ // output, or in the case of struct expansion (e.g. target=`x` for SELECT
x.*) it will give
+ // a seq of NamedExpressions corresponding to struct fields.
+ val expandedCols = super.expand(input, resolver)
+
+ // resolve except list with respect to the expandedCols
+ val resolvedExcepts = excepts.map { exceptParts =>
+ AttributeSeq(expandedCols.map(_.toAttribute)).resolve(exceptParts,
resolver).getOrElse {
+ val orderedCandidates =
StringUtils.orderSuggestedIdentifiersBySimilarity(
+ UnresolvedAttribute(exceptParts).name, expandedCols.map(a =>
a.qualifier :+ a.name))
+ // if target is defined and expandedCols does not include any
Attributes, it must be struct
+ // expansion; give message suggesting to use unqualified names of
nested fields.
+ if (target.isDefined &&
!expandedCols.exists(_.isInstanceOf[Attribute])) {
+ throw new AnalysisException(
+ errorClass = "EXCEPT_UNRESOLVED_COLUMN_IN_STRUCT_EXPANSION",
+ messageParameters = Map(
+ "objectName" -> UnresolvedAttribute(exceptParts).sql,
+ "objectList" -> orderedCandidates.mkString(", ")))
+ } else {
+ throw QueryCompilationErrors
+ .unresolvedColumnError(UnresolvedAttribute(exceptParts).name,
orderedCandidates)
+ }
+ }
+ }
+
+ // Convert each resolved except into a pair of (col: Attribute,
nestedColumn) representing the
+ // top level column in expandedCols that we must 'filter' based on
nestedColumn.
+ @scala.annotation.tailrec
+ def getRootStruct(expr: Expression, nestedColumn: Seq[String] = Nil)
+ : (NamedExpression, Seq[String]) = expr match {
+ case GetStructField(fieldExpr, _, Some(fieldName)) =>
+ getRootStruct(fieldExpr, fieldName +: nestedColumn)
+ case e: NamedExpression => e -> nestedColumn
+ case other: ExtractValue => throw new AnalysisException(
+ errorClass = "EXCEPT_NESTED_COLUMN_INVALID_TYPE",
+ messageParameters = Map("columnName" -> other.sql, "dataType" ->
other.dataType.toString))
+ }
+ // An exceptPair represents the column in expandedCols along with the path
of a nestedColumn
+ // that should be except-ed. Consider two examples:
+ // 1. excepting the entire col1 = (col1, Seq())
+ // 2. excepting a nested field in col2, col2.a.b = (col2, Seq(a, b))
+ // INVARIANT: we rely on the structure of the resolved except being an
Alias of GetStructField
+ // in the case of nested columns.
+ val exceptPairs = resolvedExcepts.map {
+ case Alias(exceptExpr, name) => getRootStruct(exceptExpr)
+ case except: NamedExpression => except -> Seq.empty
+ }
+
+ // Filter columns which correspond to ones listed in the except list and
return a new list of
+ // columns which exclude the columns in the except list. The 'filtering'
manifests as either
+ // dropping the column from the list of columns we return, or rewriting
the projected column in
+ // order to remove excepts that refer to nested columns. For example, with
the example above:
+ // excepts = Seq(
+ // (col1, Seq()), => filter col1 from the output
+ // (col2, Seq(a, b)) => rewrite col2 in the output so that it doesn't
include the nested field
+ // ) corresponding to col2.a.b
+ //
+ // This occurs in two steps:
+ // 1. group the excepts by the column they refer to (groupedExcepts)
+ // 2. filter/rewrite input columns based on four cases:
+ // a. column doesn't match any groupedExcepts => column unchanged
+ // b. column exists in groupedExcepts and:
+ // i. none of remainingExcepts are empty => recursively apply
filterColumns over the
+ // struct fields in order to rewrite the struct
+ // ii. a remainingExcept is empty, but there are multiple
remainingExcepts => we must
+ // have duplicate/overlapping excepts - throw an error
+ // iii. [otherwise] remainingExcept is exactly Seq(Seq()) => this is
the base 'filtering'
+ // case. we omit the column from the output (this is a column
we would like to
+ // except). NOTE: this case isn't explicitly listed in the
`collect` below since we
+ // 'collect' columns which match the cases above and omit ones
that fall into this
+ // remaining case.
+ def filterColumns(columns: Seq[NamedExpression], excepts:
Seq[(NamedExpression, Seq[String])])
+ : Seq[NamedExpression] = {
+ // group the except pairs by the column they refer to. NOTE: no groupMap
until scala 2.13
+ val groupedExcepts: AttributeMap[Seq[Seq[String]]] =
+ AttributeMap(excepts.groupBy(_._1.toAttribute).mapValues(v =>
v.map(_._2)))
+
+ // map input columns while searching for the except entry corresponding
to the current column
+ columns.map(col => col -> groupedExcepts.get(col.toAttribute)).collect {
+ // pass through columns that don't match anything in groupedExcepts
+ case (col, None) => col
+ // found a match but nestedExcepts has remaining excepts - recurse to
rewrite the struct
+ case (col, Some(nestedExcepts)) if !nestedExcepts.exists(_.isEmpty) =>
+ val fields = col.dataType match {
+ case s: StructType => s.fields
+ // we shouldn't be here since we throw the same error above (in
getRootStruct), but
+ // nonetheless just throw the same error
Review Comment:
this means a bug if we reach here, we can throw
`SparkException.internalError`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]