aokolnychyi commented on code in PR #55518:
URL: https://github.com/apache/spark/pull/55518#discussion_r3423916257


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala:
##########
@@ -154,30 +193,114 @@ object RewriteUpdateTable extends RewriteRowLevelCommand 
{
       cond: Expression): WriteDelta = {
 
     val operation = operationTable.operation.asInstanceOf[SupportsDelta]
+    val supportsColumnUpdate = operation.supportsColumnUpdates()
 
-    // resolve all needed attrs (e.g. row ID and any required metadata attrs)
-    val rowAttrs = relation.output
     val rowIdAttrs = resolveRowIdAttrs(relation, operation)
     val metadataAttrs = resolveRequiredMetadataAttrs(relation, operation)
 
-    // construct a read relation and include all required metadata columns
+    val connectorDataAttrs = if (supportsColumnUpdate) {
+      resolveRequiredDataAttrs(relation, operation)
+    } else Nil
+
     val readRelation = buildRelationWithAttrs(relation, operationTable, 
metadataAttrs, rowIdAttrs)
 
+    // Connector-declared attrs not being assigned are passed through so 
ColumnPruning
+    // keeps them in the scan and the connector receives their current values.
+    val assignedAttrs = if (supportsColumnUpdate) 
computeAssignedAttrs(assignments)
+                        else relation.output
+    val connectorExtraAttrs: Seq[AttributeReference] = if 
(connectorDataAttrs.nonEmpty) {
+      val assignedAttrSet = AttributeSet(assignedAttrs)
+      connectorDataAttrs.filterNot(assignedAttrSet.contains)
+    } else Nil
+
     // build a plan for updated records that match the condition
     val matchedRowsPlan = Filter(cond, readRelation)
     val rowDeltaPlan = if (operation.representUpdateAsDeleteAndInsert) {
       buildDeletesAndInserts(matchedRowsPlan, assignments, rowIdAttrs)
+    } else if (supportsColumnUpdate) {
+      buildColumnUpdateProjection(
+        matchedRowsPlan, assignments, rowIdAttrs, metadataAttrs, 
connectorExtraAttrs)
     } else {
       buildWriteDeltaUpdateProjection(matchedRowsPlan, assignments, rowIdAttrs)
     }
 
+    val effectiveRowAttrs = if (supportsColumnUpdate && 
connectorDataAttrs.nonEmpty) {
+      connectorDataAttrs
+    } else if (supportsColumnUpdate) {
+      assignedAttrs
+    } else {
+      relation.output
+    }
+
     // build a plan to write the row delta to the table
     val writeRelation = relation.copy(table = operationTable)
-    val projections = buildWriteDeltaProjections(rowDeltaPlan, rowAttrs, 
rowIdAttrs, metadataAttrs)
+    val projections = buildWriteDeltaProjections(
+      rowDeltaPlan, effectiveRowAttrs, rowIdAttrs, metadataAttrs)
     val groupFilterCond = if (groupFilterEnabled) Some(cond) else None
     WriteDelta(writeRelation, cond, rowDeltaPlan, relation, projections, 
groupFilterCond)
   }
 
+  /**
+   * Builds the WriteDelta projection for the column update path. The 
resulting Project
+   * references only the columns needed for the write, so ColumnPruning 
narrows the scan
+   * to match.
+   */
+  private def buildColumnUpdateProjection(
+      plan: LogicalPlan,
+      assignments: Seq[Assignment],
+      rowIdAttrs: Seq[Attribute],
+      metadataAttrs: Seq[Attribute],
+      connectorExtraAttrs: Seq[AttributeReference] = Nil): LogicalPlan = {
+
+    val assignedValues = assignments.collect {
+      case Assignment(key: Attribute, value) if !isIdentityAssignment(key, 
value) =>
+        Alias(value, key.name)()
+    }
+
+    val connectorExtraAttrSet = AttributeSet(connectorExtraAttrs)
+    val connectorPassThroughValues = plan.output.filter { a =>
+      connectorExtraAttrSet.contains(a) && 
!MetadataAttribute.isValid(a.metadata)
+    }
+
+    val metadataAttrSet = AttributeSet(metadataAttrs)
+    val metadataValues = plan.output.filter(metadataAttrSet.contains).map { 
attr =>
+      if (MetadataAttribute.isPreservedOnUpdate(attr)) {
+        attr
+      } else {
+        Alias(Literal(null, attr.dataType), attr.name)(explicitMetadata = 
Some(attr.metadata))
+      }
+    }
+
+    val rowIdAttrSet = AttributeSet(rowIdAttrs)
+    val rowIdValues = plan.output.filter(rowIdAttrSet.contains)
+
+    val originalRowIdValues = buildOriginalRowIdValues(rowIdAttrs, assignments)
+    val operationType = Alias(Literal(UPDATE_OPERATION), OPERATION_COLUMN)()
+
+    Project(
+      Seq(operationType) ++ assignedValues ++ connectorPassThroughValues ++
+        metadataValues ++ rowIdValues ++ originalRowIdValues,
+      plan)
+  }
+
+  // Returns the table attributes that are genuinely updated (non-identity) in 
this UPDATE.
+  private def computeAssignedAttrs(assignments: Seq[Assignment]): 
Seq[AttributeReference] = {
+    assignments.collect {
+      case Assignment(key: AttributeReference, value) if 
!isIdentityAssignment(key, value) => key
+    }
+  }
+
+  private def isIdentityAssignment(key: Attribute, value: Expression): Boolean 
= {

Review Comment:
   Does this handle nested columns?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to