dtenedor commented on code in PR #36445:
URL: https://github.com/apache/spark/pull/36445#discussion_r870680210


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala:
##########
@@ -441,17 +504,21 @@ case class ResolveDefaultColumns(
    */
   private def getSchemaForTargetTable(table: LogicalPlan): Option[StructType] 
= {
     // Check if the target table is already resolved. If so, return the 
computed schema.
-    table match {
-      case r: NamedRelation if r.schema.fields.nonEmpty => return 
Some(r.schema)
-      case SubqueryAlias(_, r: NamedRelation) if r.schema.fields.nonEmpty => 
return Some (r.schema)
-      case _ =>
+    val source: Option[LogicalPlan] = table.collectFirst {

Review Comment:
   Good question:
   1. It was necessary to change 'match' to 'collectFirst' to descend past any 
SubqueryAlias nodes that may be present (these appeared in the target of MERGE 
commands). I added a comment to mention this.
   2. It reduced the number of 'return' calls in this method by one.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala:
##########
@@ -164,13 +166,74 @@ case class ResolveDefaultColumns(
       // For each assignment in the UPDATE command's SET clause with a DEFAULT 
column reference on
       // the right-hand side, look up the corresponding expression from the 
above map.
       val newAssignments: Option[Seq[Assignment]] =
-      replaceExplicitDefaultValuesForUpdateAssignments(u.assignments, 
columnNamesToExpressions)
+      replaceExplicitDefaultValuesForUpdateAssignments(
+        u.assignments, CommandType.Update, columnNamesToExpressions)
       newAssignments.map { n =>
         u.copy(assignments = n)
       }.getOrElse(u)
     }.getOrElse(u)
   }
 
+  /**
+   * Resolves DEFAULT column references for a MERGE INTO command.
+   */
+  private def resolveDefaultColumnsForMerge(m: MergeIntoTable): LogicalPlan = {
+    val schema: StructType = 
getSchemaForTargetTable(m.targetTable).getOrElse(return m)
+    // Return a more descriptive error message if the user tries to use a 
DEFAULT column reference
+    // inside an UPDATE command's WHERE clause; this is not allowed.
+    m.mergeCondition.map { c: Expression =>
+      if (c.find(isExplicitDefaultColumn).isDefined) {
+        throw 
QueryCompilationErrors.defaultReferencesNotAllowedInMergeCondition()
+      }
+    }
+    val defaultExpressions: Seq[Expression] = schema.fields.map {
+      case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
+        analyze(analyzer, f, "MERGE")
+      case _ => Literal(null)
+    }
+    val columnNamesToExpressions: Map[String, Expression] =
+      mapStructFieldNamesToExpressions(schema, defaultExpressions)
+    val newMatchedActions: Seq[Option[MergeAction]] = m.matchedActions.map {

Review Comment:
   I looked into this and it was actually a bug -- we should be checking if 
*any* of the MATCHED or NOT MATCHED clauses had the DEFAULT value replaced 
(rather than *all* of them). I updated this logic accordingly and added a test 
case.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala:
##########
@@ -1510,6 +1510,152 @@ class PlanResolutionSuite extends AnalysisTest {
 
           case other => fail("Expect MergeIntoTable, but got:\n" + 
other.treeString)
         }
+
+        // default columns (implicit)
+        val sql6 =
+          s"""
+             |MERGE INTO $target AS target
+             |USING $source AS source
+             |ON target.i = source.i
+             |WHEN MATCHED AND (target.s='delete') THEN DELETE
+             |WHEN MATCHED AND (target.s='update')
+             |  THEN UPDATE SET target.s = DEFAULT
+             |WHEN NOT MATCHED AND (source.s='insert')
+             |  THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT)
+           """.stripMargin
+        parseAndResolve(sql6) match {
+          case m: MergeIntoTable =>
+            val source = m.sourceTable
+            val target = m.targetTable
+            val ti = target.output.find(_.name == 
"i").get.asInstanceOf[AttributeReference]
+            val si = source.output.find(_.name == 
"i").get.asInstanceOf[AttributeReference]
+            m.mergeCondition match {
+              case EqualTo(l: AttributeReference, r: AttributeReference) =>
+                assert(l.sameRef(ti) && r.sameRef(si))
+              case Literal(_, BooleanType) => // this is acceptable as a merge 
condition
+              case other => fail("unexpected merge condition " + other)
+            }
+            assert(m.matchedActions.length == 2)
+            val first = m.matchedActions(0)
+            first match {
+              case DeleteAction(Some(EqualTo(_: AttributeReference, 
StringLiteral("delete")))) =>
+              case other => fail("unexpected first matched action " + other)
+            }
+            val second = m.matchedActions(1)
+            second match {
+              case UpdateAction(Some(EqualTo(_: AttributeReference, 
StringLiteral("update"))),
+                  Seq(Assignment(
+                    _: AttributeReference, AnsiCast(Literal(null, _), 
StringType, _)))) =>
+              case other => fail("unexpected second matched action " + other)
+            }
+            assert(m.notMatchedActions.length == 1)
+            val negative = m.notMatchedActions(0)
+            negative match {
+              case InsertAction(Some(EqualTo(_: AttributeReference, 
StringLiteral("insert"))),
+              Seq(Assignment(i: AttributeReference, AnsiCast(Literal(null, _), 
IntegerType, _)),
+              Assignment(s: AttributeReference, AnsiCast(Literal(null, _), 
StringType, _)))) =>
+                assert(i.name == "i")
+                assert(s.name == "s")
+              case other => fail("unexpected not matched action " + other)
+            }
+
+          case other =>
+            fail("Expect MergeIntoTable, but got:\n" + other.treeString)
+        }
+    }
+
+
+    // default columns (explicit)
+    val mergeDefault1 =

Review Comment:
   The target table is different. This might be easier to understand if the 
test cases had comments :) Added some, hopefully it is better now.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala:
##########
@@ -164,13 +166,74 @@ case class ResolveDefaultColumns(
       // For each assignment in the UPDATE command's SET clause with a DEFAULT 
column reference on
       // the right-hand side, look up the corresponding expression from the 
above map.
       val newAssignments: Option[Seq[Assignment]] =
-      replaceExplicitDefaultValuesForUpdateAssignments(u.assignments, 
columnNamesToExpressions)
+      replaceExplicitDefaultValuesForUpdateAssignments(
+        u.assignments, CommandType.Update, columnNamesToExpressions)
       newAssignments.map { n =>
         u.copy(assignments = n)
       }.getOrElse(u)
     }.getOrElse(u)
   }
 
+  /**
+   * Resolves DEFAULT column references for a MERGE INTO command.
+   */
+  private def resolveDefaultColumnsForMerge(m: MergeIntoTable): LogicalPlan = {
+    val schema: StructType = 
getSchemaForTargetTable(m.targetTable).getOrElse(return m)
+    // Return a more descriptive error message if the user tries to use a 
DEFAULT column reference
+    // inside an UPDATE command's WHERE clause; this is not allowed.
+    m.mergeCondition.map { c: Expression =>
+      if (c.find(isExplicitDefaultColumn).isDefined) {
+        throw 
QueryCompilationErrors.defaultReferencesNotAllowedInMergeCondition()

Review Comment:
   Done, and also for the other error case with DEFAULT references 
participating in complex expressions (e.g. `DEFAULT + 1`) in the MATCHED or NOT 
MATCHED clauses.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala:
##########
@@ -1510,6 +1510,152 @@ class PlanResolutionSuite extends AnalysisTest {
 
           case other => fail("Expect MergeIntoTable, but got:\n" + 
other.treeString)
         }
+
+        // default columns (implicit)
+        val sql6 =
+          s"""
+             |MERGE INTO $target AS target
+             |USING $source AS source
+             |ON target.i = source.i
+             |WHEN MATCHED AND (target.s='delete') THEN DELETE
+             |WHEN MATCHED AND (target.s='update')
+             |  THEN UPDATE SET target.s = DEFAULT
+             |WHEN NOT MATCHED AND (source.s='insert')
+             |  THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT)
+           """.stripMargin
+        parseAndResolve(sql6) match {
+          case m: MergeIntoTable =>
+            val source = m.sourceTable
+            val target = m.targetTable
+            val ti = target.output.find(_.name == 
"i").get.asInstanceOf[AttributeReference]
+            val si = source.output.find(_.name == 
"i").get.asInstanceOf[AttributeReference]
+            m.mergeCondition match {
+              case EqualTo(l: AttributeReference, r: AttributeReference) =>
+                assert(l.sameRef(ti) && r.sameRef(si))
+              case Literal(_, BooleanType) => // this is acceptable as a merge 
condition
+              case other => fail("unexpected merge condition " + other)
+            }
+            assert(m.matchedActions.length == 2)
+            val first = m.matchedActions(0)
+            first match {
+              case DeleteAction(Some(EqualTo(_: AttributeReference, 
StringLiteral("delete")))) =>
+              case other => fail("unexpected first matched action " + other)
+            }
+            val second = m.matchedActions(1)
+            second match {
+              case UpdateAction(Some(EqualTo(_: AttributeReference, 
StringLiteral("update"))),
+                  Seq(Assignment(
+                    _: AttributeReference, AnsiCast(Literal(null, _), 
StringType, _)))) =>
+              case other => fail("unexpected second matched action " + other)
+            }
+            assert(m.notMatchedActions.length == 1)
+            val negative = m.notMatchedActions(0)
+            negative match {
+              case InsertAction(Some(EqualTo(_: AttributeReference, 
StringLiteral("insert"))),
+              Seq(Assignment(i: AttributeReference, AnsiCast(Literal(null, _), 
IntegerType, _)),
+              Assignment(s: AttributeReference, AnsiCast(Literal(null, _), 
StringType, _)))) =>
+                assert(i.name == "i")
+                assert(s.name == "s")
+              case other => fail("unexpected not matched action " + other)
+            }
+
+          case other =>
+            fail("Expect MergeIntoTable, but got:\n" + other.treeString)
+        }
+    }
+
+
+    // default columns (explicit)
+    val mergeDefault1 =
+      s"""
+         |MERGE INTO defaultvalues AS target
+         |USING v2Table1 AS source
+         |ON target.i = source.i
+         |WHEN MATCHED AND (target.s='delete') THEN DELETE
+         |WHEN MATCHED AND (target.s='update')
+         |  THEN UPDATE SET target.s = DEFAULT
+         |WHEN NOT MATCHED AND (source.s='insert')
+         |  THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT)
+           """.stripMargin
+    parseAndResolve(mergeDefault1, true) match {
+      case m: MergeIntoTable =>
+        val cond = m.mergeCondition
+        cond match {
+          case EqualTo(l: UnresolvedAttribute, r: UnresolvedAttribute) =>
+            assert(l.nameParts.last == "i")
+            assert(r.nameParts.last == "i")
+          case Literal(_, BooleanType) => // this is acceptable as a merge 
condition
+          case other => fail("unexpected merge condition " + other)
+        }
+        assert(m.matchedActions.length == 2)
+        val first = m.matchedActions(0)
+        first match {
+          case DeleteAction(Some(EqualTo(_: UnresolvedAttribute, 
StringLiteral("delete")))) =>
+          case other => fail("unexpected first matched action " + other)
+        }
+        val second = m.matchedActions(1)
+        second match {
+          case UpdateAction(Some(EqualTo(_: UnresolvedAttribute, 
StringLiteral("update"))),
+          Seq(Assignment(_: UnresolvedAttribute, Literal(42, IntegerType)))) =>
+          case other => fail("unexpected second matched action " + other)
+        }
+        assert(m.notMatchedActions.length == 1)
+        val negative = m.notMatchedActions(0)
+        negative match {
+          case InsertAction(Some(EqualTo(_: UnresolvedAttribute, 
StringLiteral("insert"))),
+          Seq(
+          Assignment(_: UnresolvedAttribute, Literal(true, BooleanType)),
+          Assignment(_: UnresolvedAttribute, Literal(42, IntegerType)))) =>
+          case other => fail("unexpected not matched action " + other)
+        }
+
+      case other =>
+        fail("Expect MergeIntoTable, but got:\n" + other.treeString)
+    }
+    val mergeDefault2 =

Review Comment:
   Added some comments, hopefully this makes more sense now :)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to