cloud-fan commented on code in PR #43843:
URL: https://github.com/apache/spark/pull/43843#discussion_r1400570698


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala:
##########
@@ -444,6 +448,145 @@ case class UnresolvedStar(target: Option[Seq[String]]) 
extends Star with Unevalu
   override def toString: String = target.map(_.mkString("", ".", 
".")).getOrElse("") + "*"
 }
 
+/**
+ * Represents some of the input attributes to a given relational operator, for 
example in
+ * "SELECT * EXCEPT(a) FROM ...".
+ *
+ * @param target an optional name that should be the target of the expansion. 
If omitted all
+ *              targets' columns are produced. This can only be a table name. 
This
+ *              is a list of identifiers that is the path of the expansion.
+ *
+ * @param excepts a list of names that should be excluded from the expansion.
+ *
+ */
+case class UnresolvedStarExcept(target: Option[Seq[String]], excepts: 
Seq[Seq[String]])
+  extends UnresolvedStarBase(target) {
+
+  /**
+   * We expand the * EXCEPT by the following three steps:
+   * 1. use the original .expand() to get top-level column list or struct 
expansion
+   * 2. resolve excepts (with respect to the Seq[NamedExpression] returned 
from (1))
+   * 3. filter the expanded columns with the resolved except list. recursively 
apply filtering in
+   *    case of nested columns in the except list (in order to rewrite structs)
+   */
+  override def expand(input: LogicalPlan, resolver: Resolver): 
Seq[NamedExpression] = {
+    // Use the UnresolvedStarBase expand method to get a seq of 
NamedExpressions corresponding to
+    // the star expansion. This will yield a list of top-level columns from 
the logical plan's
+    // output, or in the case of struct expansion (e.g. target=`x` for SELECT 
x.*) it will give
+    // a seq of NamedExpressions corresponding to struct fields.
+    val expandedCols = super.expand(input, resolver)
+
+    // resolve except list with respect to the expandedCols
+    val resolvedExcepts = excepts.map { exceptParts =>
+      AttributeSeq(expandedCols.map(_.toAttribute)).resolve(exceptParts, 
resolver).getOrElse {
+        val orderedCandidates = 
StringUtils.orderSuggestedIdentifiersBySimilarity(
+          UnresolvedAttribute(exceptParts).name, expandedCols.map(a => 
a.qualifier :+ a.name))
+        // if target is defined and expandedCols does not include any 
Attributes, it must be struct
+        // expansion; give message suggesting to use unqualified names of 
nested fields.
+        throw QueryCompilationErrors
+          .unresolvedColumnError(UnresolvedAttribute(exceptParts).name, 
orderedCandidates)
+      }
+    }
+
+    // Convert each resolved except into a pair of (col: Attribute, 
nestedColumn) representing the
+    // top level column in expandedCols that we must 'filter' based on 
nestedColumn.
+    @scala.annotation.tailrec
+    def getRootColumn(expr: Expression, nestedColumn: Seq[String] = Nil)
+      : (NamedExpression, Seq[String]) = expr match {
+      case GetStructField(fieldExpr, _, Some(fieldName)) =>
+        getRootColumn(fieldExpr, fieldName +: nestedColumn)
+      case e: NamedExpression => e -> nestedColumn
+      case other: ExtractValue => throw new AnalysisException(
+        errorClass = "EXCEPT_NESTED_COLUMN_INVALID_TYPE",
+        messageParameters = Map("columnName" -> other.sql, "dataType" -> 
other.dataType.toString))
+    }
+    // An exceptPair represents the column in expandedCols along with the path 
of a nestedColumn
+    // that should be except-ed. Consider two examples:
+    // 1. excepting the entire col1 = (col1, Seq())
+    // 2. excepting a nested field in col2, col2.a.b = (col2, Seq(a, b))
+    // INVARIANT: we rely on the structure of the resolved except being an 
Alias of GetStructField
+    // in the case of nested columns.
+    val exceptPairs = resolvedExcepts.map {
+      case Alias(exceptExpr, name) => getRootColumn(exceptExpr)
+      case except: NamedExpression => except -> Seq.empty
+    }
+
+    // Filter columns which correspond to ones listed in the except list and 
return a new list of
+    // columns which exclude the columns in the except list. The 'filtering' 
manifests as either
+    // dropping the column from the list of columns we return, or rewriting 
the projected column in
+    // order to remove excepts that refer to nested columns. For example, with 
the example above:
+    // excepts = Seq(
+    //   (col1, Seq()),    => filter col1 from the output
+    //   (col2, Seq(a, b)) => rewrite col2 in the output so that it doesn't 
include the nested field
+    // )                      corresponding to col2.a.b
+    //
+    // This occurs in two steps:
+    // 1. group the excepts by the column they refer to (groupedExcepts)
+    // 2. filter/rewrite input columns based on four cases:
+    //    a. column doesn't match any groupedExcepts => column unchanged
+    //    b. column exists in groupedExcepts and:
+    //       i.   none of remainingExcepts are empty => recursively apply 
filterColumns over the
+    //            struct fields in order to rewrite the struct
+    //       ii.  a remainingExcept is empty, but there are multiple 
remainingExcepts => we must
+    //            have duplicate/overlapping excepts - throw an error
+    //       iii. [otherwise] remainingExcept is exactly Seq(Seq()) => this is 
the base 'filtering'
+    //            case. we omit the column from the output (this is a column 
we would like to
+    //            except). NOTE: this case isn't explicitly listed in the 
`collect` below since we
+    //            'collect' columns which match the cases above and omit ones 
that fall into this
+    //            remaining case.
+    def filterColumns(columns: Seq[NamedExpression], excepts: 
Seq[(NamedExpression, Seq[String])])
+      : Seq[NamedExpression] = {
+      // group the except pairs by the column they refer to. NOTE: no groupMap 
until scala 2.13
+      val groupedExcepts: AttributeMap[Seq[Seq[String]]] =
+        AttributeMap(excepts.groupBy(_._1.toAttribute).mapValues(v => 
v.map(_._2)))
+
+      // map input columns while searching for the except entry corresponding 
to the current column
+      columns.map(col => col -> groupedExcepts.get(col.toAttribute)).collect {
+        // pass through columns that don't match anything in groupedExcepts
+        case (col, None) => col
+        // found a match but nestedExcepts has remaining excepts - recurse to 
rewrite the struct
+        case (col, Some(nestedExcepts)) if !nestedExcepts.exists(_.isEmpty) =>

Review Comment:
   ```suggestion
           case (col, Some(nestedExcepts)) if nestedExcepts.forall(_.nonEmpty) 
=>
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to