dtenedor commented on code in PR #36445:
URL: https://github.com/apache/spark/pull/36445#discussion_r870680210
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala:
##########
@@ -441,17 +504,21 @@ case class ResolveDefaultColumns(
*/
private def getSchemaForTargetTable(table: LogicalPlan): Option[StructType]
= {
// Check if the target table is already resolved. If so, return the
computed schema.
- table match {
- case r: NamedRelation if r.schema.fields.nonEmpty => return
Some(r.schema)
- case SubqueryAlias(_, r: NamedRelation) if r.schema.fields.nonEmpty =>
return Some (r.schema)
- case _ =>
+ val source: Option[LogicalPlan] = table.collectFirst {
Review Comment:
Good question:
1. It was necessary to change 'match' to 'collectFirst' to descend past any
SubqueryAlias nodes that may be present (these appeared in the target of MERGE
commands). I added a comment to mention this.
2. It reduced the number of 'return' calls in this method by one.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala:
##########
@@ -164,13 +166,74 @@ case class ResolveDefaultColumns(
// For each assignment in the UPDATE command's SET clause with a DEFAULT
column reference on
// the right-hand side, look up the corresponding expression from the
above map.
val newAssignments: Option[Seq[Assignment]] =
- replaceExplicitDefaultValuesForUpdateAssignments(u.assignments,
columnNamesToExpressions)
+ replaceExplicitDefaultValuesForUpdateAssignments(
+ u.assignments, CommandType.Update, columnNamesToExpressions)
newAssignments.map { n =>
u.copy(assignments = n)
}.getOrElse(u)
}.getOrElse(u)
}
+ /**
+ * Resolves DEFAULT column references for a MERGE INTO command.
+ */
+ private def resolveDefaultColumnsForMerge(m: MergeIntoTable): LogicalPlan = {
+ val schema: StructType =
getSchemaForTargetTable(m.targetTable).getOrElse(return m)
+ // Return a more descriptive error message if the user tries to use a
DEFAULT column reference
+ // inside an UPDATE command's WHERE clause; this is not allowed.
+ m.mergeCondition.map { c: Expression =>
+ if (c.find(isExplicitDefaultColumn).isDefined) {
+ throw
QueryCompilationErrors.defaultReferencesNotAllowedInMergeCondition()
+ }
+ }
+ val defaultExpressions: Seq[Expression] = schema.fields.map {
+ case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
+ analyze(analyzer, f, "MERGE")
+ case _ => Literal(null)
+ }
+ val columnNamesToExpressions: Map[String, Expression] =
+ mapStructFieldNamesToExpressions(schema, defaultExpressions)
+ val newMatchedActions: Seq[Option[MergeAction]] = m.matchedActions.map {
Review Comment:
I looked into this and it was actually a bug -- we should be checking if
*any* of the MATCHED or NOT MATCHED clauses had the DEFAULT value replaced
(rather than *all* of them). I updated this logic accordingly and added a test
case.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala:
##########
@@ -1510,6 +1510,152 @@ class PlanResolutionSuite extends AnalysisTest {
case other => fail("Expect MergeIntoTable, but got:\n" +
other.treeString)
}
+
+ // default columns (implicit)
+ val sql6 =
+ s"""
+ |MERGE INTO $target AS target
+ |USING $source AS source
+ |ON target.i = source.i
+ |WHEN MATCHED AND (target.s='delete') THEN DELETE
+ |WHEN MATCHED AND (target.s='update')
+ | THEN UPDATE SET target.s = DEFAULT
+ |WHEN NOT MATCHED AND (source.s='insert')
+ | THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT)
+ """.stripMargin
+ parseAndResolve(sql6) match {
+ case m: MergeIntoTable =>
+ val source = m.sourceTable
+ val target = m.targetTable
+ val ti = target.output.find(_.name ==
"i").get.asInstanceOf[AttributeReference]
+ val si = source.output.find(_.name ==
"i").get.asInstanceOf[AttributeReference]
+ m.mergeCondition match {
+ case EqualTo(l: AttributeReference, r: AttributeReference) =>
+ assert(l.sameRef(ti) && r.sameRef(si))
+ case Literal(_, BooleanType) => // this is acceptable as a merge
condition
+ case other => fail("unexpected merge condition " + other)
+ }
+ assert(m.matchedActions.length == 2)
+ val first = m.matchedActions(0)
+ first match {
+ case DeleteAction(Some(EqualTo(_: AttributeReference,
StringLiteral("delete")))) =>
+ case other => fail("unexpected first matched action " + other)
+ }
+ val second = m.matchedActions(1)
+ second match {
+ case UpdateAction(Some(EqualTo(_: AttributeReference,
StringLiteral("update"))),
+ Seq(Assignment(
+ _: AttributeReference, AnsiCast(Literal(null, _),
StringType, _)))) =>
+ case other => fail("unexpected second matched action " + other)
+ }
+ assert(m.notMatchedActions.length == 1)
+ val negative = m.notMatchedActions(0)
+ negative match {
+ case InsertAction(Some(EqualTo(_: AttributeReference,
StringLiteral("insert"))),
+ Seq(Assignment(i: AttributeReference, AnsiCast(Literal(null, _),
IntegerType, _)),
+ Assignment(s: AttributeReference, AnsiCast(Literal(null, _),
StringType, _)))) =>
+ assert(i.name == "i")
+ assert(s.name == "s")
+ case other => fail("unexpected not matched action " + other)
+ }
+
+ case other =>
+ fail("Expect MergeIntoTable, but got:\n" + other.treeString)
+ }
+ }
+
+
+ // default columns (explicit)
+ val mergeDefault1 =
Review Comment:
The target table is different. This might be easier to understand if the
test cases had comments :) Added some, hopefully it is better now.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala:
##########
@@ -164,13 +166,74 @@ case class ResolveDefaultColumns(
// For each assignment in the UPDATE command's SET clause with a DEFAULT
column reference on
// the right-hand side, look up the corresponding expression from the
above map.
val newAssignments: Option[Seq[Assignment]] =
- replaceExplicitDefaultValuesForUpdateAssignments(u.assignments,
columnNamesToExpressions)
+ replaceExplicitDefaultValuesForUpdateAssignments(
+ u.assignments, CommandType.Update, columnNamesToExpressions)
newAssignments.map { n =>
u.copy(assignments = n)
}.getOrElse(u)
}.getOrElse(u)
}
+ /**
+ * Resolves DEFAULT column references for a MERGE INTO command.
+ */
+ private def resolveDefaultColumnsForMerge(m: MergeIntoTable): LogicalPlan = {
+ val schema: StructType =
getSchemaForTargetTable(m.targetTable).getOrElse(return m)
+ // Return a more descriptive error message if the user tries to use a
DEFAULT column reference
+ // inside an UPDATE command's WHERE clause; this is not allowed.
+ m.mergeCondition.map { c: Expression =>
+ if (c.find(isExplicitDefaultColumn).isDefined) {
+ throw
QueryCompilationErrors.defaultReferencesNotAllowedInMergeCondition()
Review Comment:
Done, and also for the other error case with DEFAULT references
participating in complex expressions (e.g. `DEFAULT + 1`) in the MATCHED or NOT
MATCHED clauses.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala:
##########
@@ -1510,6 +1510,152 @@ class PlanResolutionSuite extends AnalysisTest {
case other => fail("Expect MergeIntoTable, but got:\n" +
other.treeString)
}
+
+ // default columns (implicit)
+ val sql6 =
+ s"""
+ |MERGE INTO $target AS target
+ |USING $source AS source
+ |ON target.i = source.i
+ |WHEN MATCHED AND (target.s='delete') THEN DELETE
+ |WHEN MATCHED AND (target.s='update')
+ | THEN UPDATE SET target.s = DEFAULT
+ |WHEN NOT MATCHED AND (source.s='insert')
+ | THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT)
+ """.stripMargin
+ parseAndResolve(sql6) match {
+ case m: MergeIntoTable =>
+ val source = m.sourceTable
+ val target = m.targetTable
+ val ti = target.output.find(_.name ==
"i").get.asInstanceOf[AttributeReference]
+ val si = source.output.find(_.name ==
"i").get.asInstanceOf[AttributeReference]
+ m.mergeCondition match {
+ case EqualTo(l: AttributeReference, r: AttributeReference) =>
+ assert(l.sameRef(ti) && r.sameRef(si))
+ case Literal(_, BooleanType) => // this is acceptable as a merge
condition
+ case other => fail("unexpected merge condition " + other)
+ }
+ assert(m.matchedActions.length == 2)
+ val first = m.matchedActions(0)
+ first match {
+ case DeleteAction(Some(EqualTo(_: AttributeReference,
StringLiteral("delete")))) =>
+ case other => fail("unexpected first matched action " + other)
+ }
+ val second = m.matchedActions(1)
+ second match {
+ case UpdateAction(Some(EqualTo(_: AttributeReference,
StringLiteral("update"))),
+ Seq(Assignment(
+ _: AttributeReference, AnsiCast(Literal(null, _),
StringType, _)))) =>
+ case other => fail("unexpected second matched action " + other)
+ }
+ assert(m.notMatchedActions.length == 1)
+ val negative = m.notMatchedActions(0)
+ negative match {
+ case InsertAction(Some(EqualTo(_: AttributeReference,
StringLiteral("insert"))),
+ Seq(Assignment(i: AttributeReference, AnsiCast(Literal(null, _),
IntegerType, _)),
+ Assignment(s: AttributeReference, AnsiCast(Literal(null, _),
StringType, _)))) =>
+ assert(i.name == "i")
+ assert(s.name == "s")
+ case other => fail("unexpected not matched action " + other)
+ }
+
+ case other =>
+ fail("Expect MergeIntoTable, but got:\n" + other.treeString)
+ }
+ }
+
+
+ // default columns (explicit)
+ val mergeDefault1 =
+ s"""
+ |MERGE INTO defaultvalues AS target
+ |USING v2Table1 AS source
+ |ON target.i = source.i
+ |WHEN MATCHED AND (target.s='delete') THEN DELETE
+ |WHEN MATCHED AND (target.s='update')
+ | THEN UPDATE SET target.s = DEFAULT
+ |WHEN NOT MATCHED AND (source.s='insert')
+ | THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT)
+ """.stripMargin
+ parseAndResolve(mergeDefault1, true) match {
+ case m: MergeIntoTable =>
+ val cond = m.mergeCondition
+ cond match {
+ case EqualTo(l: UnresolvedAttribute, r: UnresolvedAttribute) =>
+ assert(l.nameParts.last == "i")
+ assert(r.nameParts.last == "i")
+ case Literal(_, BooleanType) => // this is acceptable as a merge
condition
+ case other => fail("unexpected merge condition " + other)
+ }
+ assert(m.matchedActions.length == 2)
+ val first = m.matchedActions(0)
+ first match {
+ case DeleteAction(Some(EqualTo(_: UnresolvedAttribute,
StringLiteral("delete")))) =>
+ case other => fail("unexpected first matched action " + other)
+ }
+ val second = m.matchedActions(1)
+ second match {
+ case UpdateAction(Some(EqualTo(_: UnresolvedAttribute,
StringLiteral("update"))),
+ Seq(Assignment(_: UnresolvedAttribute, Literal(42, IntegerType)))) =>
+ case other => fail("unexpected second matched action " + other)
+ }
+ assert(m.notMatchedActions.length == 1)
+ val negative = m.notMatchedActions(0)
+ negative match {
+ case InsertAction(Some(EqualTo(_: UnresolvedAttribute,
StringLiteral("insert"))),
+ Seq(
+ Assignment(_: UnresolvedAttribute, Literal(true, BooleanType)),
+ Assignment(_: UnresolvedAttribute, Literal(42, IntegerType)))) =>
+ case other => fail("unexpected not matched action " + other)
+ }
+
+ case other =>
+ fail("Expect MergeIntoTable, but got:\n" + other.treeString)
+ }
+ val mergeDefault2 =
Review Comment:
Added some comments, hopefully this makes more sense now :)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]