rdblue commented on a change in pull request #1986:
URL: https://github.com/apache/iceberg/pull/1986#discussion_r550291799



##########
File path: 
spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentSupport.scala
##########
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
CreateNamedStruct, Expression, ExtractValue, GetStructField, Literal, 
NamedExpression}
+import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan}
+import org.apache.spark.sql.types.{DataType, NullType, StructField, StructType}
+
+trait AssignmentAlignmentSupport extends CastSupport {
+
+  private case class ColumnUpdate(nameParts: Seq[String], expr: Expression)
+
+  /**
+   * Aligns assignments to match table columns.
+   * <p>
+   * This method processes and reorders given assignments so that each target 
column gets
+   * an expression it should be set to. If a column does not have a matching 
assignment,
+   * it will be set to its current value. For example, if one passes a table 
with columns c1, c2
+   * and an assignment c2 = 1, this method will return c1 = c1, c2 = 1.
+   * <p>
+   * This method also handles updates to nested columns. If there is an 
assignment to a particular
+   * nested field, this method will construct a new struct with one field 
updated
+   * preserving other fields that have not been modified. For example, if one 
passes a table with
+   * columns c1, c2 where c2 is a struct with fields n1 and n2 and an 
assignment c2.n2 = 1,
+   * this method will return c1 = c1, c2 = struct(c2.n1, 1).
+   *
+   * @param table a target table
+   * @param assignments assignments to align
+   * @return aligned assignments that match table columns
+   */
+  protected def alignAssignments(
+      table: LogicalPlan,
+      assignments: Seq[Assignment]): Seq[Assignment] = {
+
+    val columnUpdates = assignments.map(a => ColumnUpdate(getNameParts(a.key), 
a.value))
+    val outputExprs = applyUpdates(table.output, columnUpdates)
+    outputExprs.zip(table.output).map {
+      case (expr, attr) => Assignment(attr, expr)
+    }
+  }
+
+  private def applyUpdates(
+      cols: Seq[NamedExpression],
+      updates: Seq[ColumnUpdate],
+      resolver: Resolver = conf.resolver,
+      namePrefix: Seq[String] = Nil): Seq[Expression] = {
+
+    // iterate through columns at the current level and find which column 
updates match
+    cols.map { col =>
+      // find matches for this column or any of its children
+      val prefixMatchedUpdates = updates.filter(a => 
resolver(a.nameParts.head, col.name))
+      prefixMatchedUpdates match {
+        // if there is no exact match and no match for children, return the 
column as is
+        case updates if updates.isEmpty =>
+          col
+
+        // if there is an exact match, return the assigned expression
+        case Seq(update) if isExactMatch(update, col, resolver) =>
+          castIfNeeded(update.expr, col.dataType, resolver)
+
+        // if there are matches only for children
+        case updates if !hasExactMatch(updates, col, resolver) =>
+          col.dataType match {
+            case StructType(fields) =>
+              applyStructUpdates(col, prefixMatchedUpdates, fields, resolver, 
namePrefix)
+            case otherType =>
+              val colName = (namePrefix :+ col.name).mkString(".")
+              throw new AnalysisException(
+                "Updating nested fields is only supported for StructType " +
+                s"but $colName is of type $otherType"
+              )
+          }
+
+        // if there are conflicting updates, throw an exception
+        // there are two illegal scenarios:
+        // - multiple updates to the same column
+        // - updates to a top-level struct and its nested fields (e.g., a.b 
and a.b.c)
+        case updates if hasExactMatch(updates, col, resolver) =>
+          val conflictingCols = updates.map(u => (namePrefix ++ 
u.nameParts).mkString("."))
+          throw new AnalysisException(
+            "Updates are in conflict for these columns: " +
+            conflictingCols.distinct.mkString("[", ", ", "]"))
+      }
+    }
+  }
+
+  private def applyStructUpdates(
+      col: NamedExpression,
+      updates: Seq[ColumnUpdate],
+      fields: Seq[StructField],
+      resolver: Resolver,
+      namePrefix: Seq[String]): Expression = {
+
+    // build field expressions
+    val fieldExprs = fields.zipWithIndex.map { case (field, ordinal) =>
+      Alias(GetStructField(col, ordinal, Some(field.name)), field.name)()
+    }
+
+    // recursively apply this method on nested fields
+    val newUpdates = updates.map(u => u.copy(nameParts = u.nameParts.tail))
+    val updatedFieldExprs = applyUpdates(fieldExprs, newUpdates, resolver, 
namePrefix :+ col.name)
+
+    // construct a new struct with updated field expressions
+    toNamedStruct(fields, updatedFieldExprs)
+  }
+
+  private def toNamedStruct(fields: Seq[StructField], fieldExprs: 
Seq[Expression]): Expression = {
+    val namedStructExprs = fields.zip(fieldExprs).flatMap { case (field, expr) 
=>
+      Seq(Literal(field.name), expr)
+    }
+    CreateNamedStruct(namedStructExprs)
+  }
+
+  private def hasExactMatch(
+      updates: Seq[ColumnUpdate],
+      col: NamedExpression,
+      resolver: Resolver): Boolean = {
+
+    updates.exists(assignment => isExactMatch(assignment, col, resolver))
+  }
+
+  private def isExactMatch(
+      update: ColumnUpdate,
+      col: NamedExpression,
+      resolver: Resolver): Boolean = {
+
+    update.nameParts match {
+      case Seq(namePart) if resolver(namePart, col.name) => true
+      case _ => false
+    }
+  }
+
+  protected def castIfNeeded(
+      expr: Expression,
+      dataType: DataType,
+      resolver: Resolver): Expression = expr match {
+    // some types cannot be casted from NullType (e.g. StructType)
+    case Literal(value, NullType) => Literal(value, dataType)
+    case _ =>
+      (expr.dataType, dataType) match {
+        // resolve structs by name if they they have the same number of fields 
and their names match
+        // e.g., it is ok to set a struct with fields (a, b) as another struct 
with fields (b, a)
+        // it is invalid to a set a struct with fields (a, d) as another 
struct with fields (a, b)

Review comment:
       > In particular, what should happen if one sets a struct with fields (b, 
a) to a struct column with fields (a, b)? Should we match the fields by name or 
should we match them by position?
   
   SQL writes are always by position unless you have `(names...) VALUES 
(values...)`, but that's not the case for these structs. So I think think the 
right behavior for SQL is by position.
   
   Because dataframe columns don't have an obvious position, the expectation 
for users is that the writes happen by name. That's why we added `byName` 
variants of the logical plans. I think that extends to nested structs as well 
because nested structs are easy to produce by some conversion from an object. 
When converting from an object using an `Encoder`, there is no column order 
guarantee so it is reasonable to assume that columns will be written by name.
   
   Because we don't currently support dataframe merge into, I think that we 
should move forward with the behavior you've added that matches the insert 
behavior. We can add the `byName` flag later when we add a dataframe API.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to