This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d5fd71e43a9 [SPARK-39078][SQL] Support UPDATE commands with DEFAULT 
values
d5fd71e43a9 is described below

commit d5fd71e43a98077abc8b226763f42f62c694715e
Author: Daniel Tenedorio <daniel.tenedo...@databricks.com>
AuthorDate: Sat May 7 10:43:33 2022 +0800

    [SPARK-39078][SQL] Support UPDATE commands with DEFAULT values
    
    ## What changes were proposed in this pull request?
    
    DEFAULT columns are now supported in UPDATE commands.
    
    Example:
    
    ```
    CREATE TABLE t(
      i STRING,
      c CHAR(5),
      d VARCHAR(5),
      e STRING DEFAULT 'abc',
      f BIGINT DEFAULT 42)
      USING DELTA;
    INSERT INTO t VALUES('1', '12345', '67890', 'def', 31);
    UPDATE t SET c = DEFAULT WHERE i = '1';
    SELECT * FROM t;
    > "1", NULL, "67890", "def", 31
    UPDATE t SET e = DEFAULT WHERE i = '1';
    SELECT * FROM t;
    > "1", NULL, "67890", "abc", 31
    UPDATE t SET e = 'ghi' WHERE i = '1';
    SELECT * FROM t;
    > "1", NULL, "67890", "ghi", 31
    ```
    
    ### Why are the changes needed?
    
    This makes UPDATE commands easier to use.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    ## How was this patch tested?
    
    This PR adds new unit test coverage.
    
    Closes #36415 from dtenedor/default-updates.
    
    Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../catalyst/analysis/ResolveDefaultColumns.scala  | 202 ++++++++++++++++-----
 .../catalyst/util/ResolveDefaultColumnsUtil.scala  |   5 -
 .../spark/sql/errors/QueryCompilationErrors.scala  |  25 +++
 .../execution/command/PlanResolutionSuite.scala    |  76 +++++++-
 .../org/apache/spark/sql/sources/InsertSuite.scala |  10 +-
 5 files changed, 266 insertions(+), 52 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
index ffbe18a7dfa..9612713f593 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
+import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -33,7 +34,7 @@ import org.apache.spark.sql.types._
  * This is a rule to process DEFAULT columns in statements such as 
CREATE/REPLACE TABLE.
  *
  * Background: CREATE TABLE and ALTER TABLE invocations support setting column 
default values for
- * later operations. Following INSERT, and INSERT MERGE commands may then 
reference the value
+ * later operations. Following INSERT, UPDATE, and MERGE commands may then 
reference the value
  * using the DEFAULT keyword as needed.
  *
  * Example:
@@ -60,6 +61,8 @@ case class ResolveDefaultColumns(
       case i@InsertIntoStatement(_, _, _, project: Project, _, _)
         if !project.projectList.exists(_.isInstanceOf[Star]) =>
         resolveDefaultColumnsForInsertFromProject(i)
+      case u: UpdateTable =>
+        resolveDefaultColumnsForUpdate(u)
     }
   }
 
@@ -98,23 +101,23 @@ case class ResolveDefaultColumns(
       node = node.children(0)
     }
     val table = node.asInstanceOf[UnresolvedInlineTable]
-    val insertTableSchemaWithoutPartitionColumns: StructType =
+    val insertTableSchemaWithoutPartitionColumns: Option[StructType] =
       getInsertTableSchemaWithoutPartitionColumns(i)
-        .getOrElse(return i)
-    val regenerated: InsertIntoStatement =
-      regenerateUserSpecifiedCols(i, insertTableSchemaWithoutPartitionColumns)
-    val expanded: UnresolvedInlineTable =
-      addMissingDefaultValuesForInsertFromInlineTable(
-        table, insertTableSchemaWithoutPartitionColumns)
-    val replaced: LogicalPlan =
-      replaceExplicitDefaultValuesForInputOfInsertInto(
-        analyzer, insertTableSchemaWithoutPartitionColumns, expanded)
-        .getOrElse(return i)
-    node = replaced
-    for (child <- children.reverse) {
-      node = child.withNewChildren(Seq(node))
-    }
-    regenerated.copy(query = node)
+    insertTableSchemaWithoutPartitionColumns.map { schema: StructType =>
+      val regenerated: InsertIntoStatement =
+        regenerateUserSpecifiedCols(i, schema)
+      val expanded: UnresolvedInlineTable =
+        addMissingDefaultValuesForInsertFromInlineTable(table, schema)
+      val replaced: Option[LogicalPlan] =
+        replaceExplicitDefaultValuesForInputOfInsertInto(analyzer, schema, 
expanded)
+      replaced.map { r: LogicalPlan =>
+        node = r
+        for (child <- children.reverse) {
+          node = child.withNewChildren(Seq(node))
+        }
+        regenerated.copy(query = node)
+      }.getOrElse(i)
+    }.getOrElse(i)
   }
 
   /**
@@ -122,20 +125,50 @@ case class ResolveDefaultColumns(
    * projection.
    */
   private def resolveDefaultColumnsForInsertFromProject(i: 
InsertIntoStatement): LogicalPlan = {
-    val insertTableSchemaWithoutPartitionColumns: StructType =
+    val insertTableSchemaWithoutPartitionColumns: Option[StructType] =
       getInsertTableSchemaWithoutPartitionColumns(i)
-        .getOrElse(return i)
-    val regenerated: InsertIntoStatement =
-      regenerateUserSpecifiedCols(i, insertTableSchemaWithoutPartitionColumns)
-    val project: Project = i.query.asInstanceOf[Project]
-    val expanded: Project =
-      addMissingDefaultValuesForInsertFromProject(
-        project, insertTableSchemaWithoutPartitionColumns)
-    val replaced: LogicalPlan =
-      replaceExplicitDefaultValuesForInputOfInsertInto(
-        analyzer, insertTableSchemaWithoutPartitionColumns, expanded)
-        .getOrElse(return i)
-    regenerated.copy(query = replaced)
+    insertTableSchemaWithoutPartitionColumns.map { schema =>
+      val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i, 
schema)
+      val project: Project = i.query.asInstanceOf[Project]
+      val expanded: Project =
+        addMissingDefaultValuesForInsertFromProject(project, schema)
+      val replaced: Option[LogicalPlan] =
+        replaceExplicitDefaultValuesForInputOfInsertInto(analyzer, schema, 
expanded)
+      replaced.map { r =>
+        regenerated.copy(query = r)
+      }.getOrElse(i)
+    }.getOrElse(i)
+  }
+
+  /**
+   * Resolves DEFAULT column references for an UPDATE command.
+   */
+  private def resolveDefaultColumnsForUpdate(u: UpdateTable): LogicalPlan = {
+    // Return a more descriptive error message if the user tries to use a 
DEFAULT column reference
+    // inside an UPDATE command's WHERE clause; this is not allowed.
+    u.condition.foreach { c: Expression =>
+      if (c.find(isExplicitDefaultColumn).isDefined) {
+        throw 
QueryCompilationErrors.defaultReferencesNotAllowedInUpdateWhereClause()
+      }
+    }
+    val schemaForTargetTable: Option[StructType] = 
getSchemaForTargetTable(u.table)
+    schemaForTargetTable.map { schema =>
+      val defaultExpressions: Seq[Expression] = schema.fields.map {
+        case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
+          analyze(analyzer, f, "UPDATE")
+        case _ => Literal(null)
+      }
+      // Create a map from each column name in the target table to its DEFAULT 
expression.
+      val columnNamesToExpressions: Map[String, Expression] =
+        mapStructFieldNamesToExpressions(schema, defaultExpressions)
+      // For each assignment in the UPDATE command's SET clause with a DEFAULT 
column reference on
+      // the right-hand side, look up the corresponding expression from the 
above map.
+      val newAssignments: Option[Seq[Assignment]] =
+      replaceExplicitDefaultValuesForUpdateAssignments(u.assignments, 
columnNamesToExpressions)
+      newAssignments.map { n =>
+        u.copy(assignments = n)
+      }.getOrElse(u)
+    }.getOrElse(u)
   }
 
   /**
@@ -260,7 +293,7 @@ case class ResolveDefaultColumns(
           expr = row(i)
           defaultExpr = if (i < defaultExpressions.size) defaultExpressions(i) 
else Literal(null)
         } yield replaceExplicitDefaultReferenceInExpression(
-          expr, defaultExpr, DEFAULTS_IN_EXPRESSIONS_ERROR, false).map { e =>
+          expr, defaultExpr, CommandType.Insert, addAlias = false).map { e =>
           replaced = true
           e
         }.getOrElse(expr)
@@ -286,7 +319,7 @@ case class ResolveDefaultColumns(
         projectExpr = project.projectList(i)
         defaultExpr = if (i < defaultExpressions.size) defaultExpressions(i) 
else Literal(null)
       } yield replaceExplicitDefaultReferenceInExpression(
-        projectExpr, defaultExpr, DEFAULTS_IN_EXPRESSIONS_ERROR, true).map { e 
=>
+        projectExpr, defaultExpr, CommandType.Insert, addAlias = true).map { e 
=>
         replaced = true
         e.asInstanceOf[NamedExpression]
       }.getOrElse(projectExpr)
@@ -298,19 +331,26 @@ case class ResolveDefaultColumns(
     }
   }
 
+  /**
+   * Represents a type of command we are currently processing.
+   */
+  private object CommandType extends Enumeration {
+    val Insert, Update = Value
+  }
+
   /**
    * Checks if a given input expression is an unresolved "DEFAULT" attribute 
reference.
    *
    * @param input the input expression to examine.
    * @param defaultExpr the default to return if [[input]] is an unresolved 
"DEFAULT" reference.
-   * @param complexDefaultError error if [[input]] is a complex expression 
with "DEFAULT" inside.
+   * @param isInsert the type of command we are currently processing.
    * @param addAlias if true, wraps the result with an alias of the original 
default column name.
    * @return [[defaultExpr]] if [[input]] is an unresolved "DEFAULT" attribute 
reference.
    */
   private def replaceExplicitDefaultReferenceInExpression(
       input: Expression,
       defaultExpr: Expression,
-      complexDefaultError: String,
+      command: CommandType.Value,
       addAlias: Boolean): Option[Expression] = {
     input match {
       case a@Alias(u: UnresolvedAttribute, _)
@@ -325,7 +365,14 @@ case class ResolveDefaultColumns(
         }
       case expr@_
         if expr.find(isExplicitDefaultColumn).isDefined =>
-        throw new AnalysisException(complexDefaultError)
+        command match {
+          case CommandType.Insert =>
+            throw QueryCompilationErrors
+              
.defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList()
+          case CommandType.Update =>
+            throw QueryCompilationErrors
+              
.defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause()
+        }
       case _ =>
         None
     }
@@ -344,16 +391,10 @@ case class ResolveDefaultColumns(
     if (userSpecifiedCols.isEmpty) {
       return Some(schema)
     }
-    def normalize(str: String) = {
-      if (SQLConf.get.caseSensitiveAnalysis) str else str.toLowerCase()
-    }
-    val colNamesToFields: Map[String, StructField] =
-      schema.fields.map {
-        field: StructField => normalize(field.name) -> field
-      }.toMap
+    val colNamesToFields: Map[String, StructField] = 
mapStructFieldNamesToFields(schema)
     val userSpecifiedFields: Seq[StructField] =
       userSpecifiedCols.map {
-        name: String => colNamesToFields.getOrElse(normalize(name), return 
None)
+        name: String => colNamesToFields.getOrElse(normalizeFieldName(name), 
return None)
       }
     val userSpecifiedColNames: Set[String] = userSpecifiedCols.toSet
     val nonUserSpecifiedFields: Seq[StructField] =
@@ -364,10 +405,47 @@ case class ResolveDefaultColumns(
       getStructFieldsForDefaultExpressions(nonUserSpecifiedFields)))
   }
 
+  /**
+   * Normalizes a schema field name suitable for use in map lookups.
+   */
+  private def normalizeFieldName(str: String): String = {
+    if (SQLConf.get.caseSensitiveAnalysis) {
+      str
+    } else {
+      str.toLowerCase()
+    }
+  }
+
+  /**
+   * Returns a map of the names of fields in a schema to the fields themselves.
+   */
+  private def mapStructFieldNamesToFields(schema: StructType): Map[String, 
StructField] = {
+    schema.fields.map {
+      field: StructField => normalizeFieldName(field.name) -> field
+    }.toMap
+  }
+
+  /**
+   * Returns a map of the names of fields in a schema to corresponding 
expressions.
+   */
+  private def mapStructFieldNamesToExpressions(
+      schema: StructType,
+      expressions: Seq[Expression]): Map[String, Expression] = {
+    val namesToFields: Map[String, StructField] = 
mapStructFieldNamesToFields(schema)
+    val namesAndExpressions: Seq[(String, Expression)] = 
namesToFields.keys.toSeq.zip(expressions)
+    namesAndExpressions.toMap
+  }
+
   /**
    * Returns the schema for the target table of a DML command, looking into 
the catalog if needed.
    */
   private def getSchemaForTargetTable(table: LogicalPlan): Option[StructType] 
= {
+    // Check if the target table is already resolved. If so, return the 
computed schema.
+    table match {
+      case r: NamedRelation if r.schema.fields.nonEmpty => return 
Some(r.schema)
+      case SubqueryAlias(_, r: NamedRelation) if r.schema.fields.nonEmpty => 
return Some (r.schema)
+      case _ =>
+    }
     // Lookup the relation from the catalog by name. This either succeeds or 
returns some "not
     // found" error. In the latter cases, return out of this rule without 
changing anything and let
     // the analyzer return a proper error message elsewhere.
@@ -389,4 +467,42 @@ case class ResolveDefaultColumns(
       case _ => None
     }
   }
+
+  /**
+   * Replaces unresolved DEFAULT column references with corresponding values 
in a series of
+   * assignments in an UPDATE command.
+   */
+  private def replaceExplicitDefaultValuesForUpdateAssignments(
+      assignments: Seq[Assignment],
+      columnNamesToExpressions: Map[String, Expression]): 
Option[Seq[Assignment]] = {
+    var replaced = false
+    val newAssignments: Seq[Assignment] =
+      for (assignment <- assignments) yield {
+        val destColName = assignment.key match {
+          case a: AttributeReference => a.name
+          case u: UnresolvedAttribute => u.nameParts.last
+          case _ => ""
+        }
+        val adjusted: String = normalizeFieldName(destColName)
+        val lookup: Option[Expression] = columnNamesToExpressions.get(adjusted)
+        val newValue: Expression = lookup.map { defaultExpr =>
+          val updated: Option[Expression] =
+            replaceExplicitDefaultReferenceInExpression(
+              assignment.value,
+              defaultExpr,
+              CommandType.Update,
+              addAlias = false)
+          updated.map { e =>
+            replaced = true
+            e
+          }.getOrElse(assignment.value)
+        }.getOrElse(assignment.value)
+        assignment.copy(value = newValue)
+      }
+    if (replaced) {
+      Some(newAssignments)
+    } else {
+      None
+    }
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
index b7db9be114f..6e23ba46396 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
@@ -55,11 +55,6 @@ object ResolveDefaultColumns {
   // Name of attributes representing explicit references to the value stored 
in the above
   // CURRENT_DEFAULT_COLUMN_METADATA.
   val CURRENT_DEFAULT_COLUMN_NAME = "DEFAULT"
-  // Return a more descriptive error message if the user tries to nest the 
DEFAULT column reference
-  // inside some other expression, such as DEFAULT + 1 (this is not allowed).
-  val DEFAULTS_IN_EXPRESSIONS_ERROR = "Failed to execute INSERT INTO command 
because the " +
-    "VALUES list contains a DEFAULT column reference as part of another 
expression; this is " +
-    "not allowed"
 
   /**
    * Finds "current default" expressions in CREATE/REPLACE TABLE columns and 
constant-folds them.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 151c43cd92d..3b167eeb417 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -2407,4 +2407,29 @@ object QueryCompilationErrors extends QueryErrorsBase {
   def noSuchFunctionError(database: String, funcInfo: String): Throwable = {
     new AnalysisException(s"$database does not support function: $funcInfo")
   }
+
+  // Return a more descriptive error message if the user tries to nest a 
DEFAULT column reference
+  // inside some other expression (such as DEFAULT + 1) in an INSERT INTO 
command's VALUES list;
+  // this is not allowed.
+  def defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList(): 
Throwable = {
+    new AnalysisException(
+      "Failed to execute INSERT INTO command because the VALUES list contains 
a DEFAULT column " +
+        "reference as part of another expression; this is not allowed")
+  }
+
+  // Return a descriptive error message in the presence of INSERT INTO 
commands with explicit
+  // DEFAULT column references and explicit column lists, since this is not 
implemented yet.
+  def defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause(): 
Throwable = {
+    new AnalysisException(
+      "Failed to execute UPDATE command because the SET list contains a 
DEFAULT column reference " +
+        "as part of another expression; this is not allowed")
+  }
+
+  // Return a more descriptive error message if the user tries to use a 
DEFAULT column reference
+  // inside an UPDATE command's WHERE clause; this is not allowed.
+  def defaultReferencesNotAllowedInUpdateWhereClause(): Throwable = {
+    new AnalysisException(
+      "Failed to execute UPDATE command because the WHERE clause contains a 
DEFAULT column " +
+        "reference; this is not allowed")
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
index 63d2b9d47f1..84b900f4cd7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
@@ -33,14 +33,16 @@ import 
org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
 import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, 
AnalysisOnlyCommand, AppendData, Assignment, CreateTable, CreateTableAsSelect, 
DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, 
LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, 
SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, 
UnsetTableProperties, UpdateAction, UpdateTable}
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
 import org.apache.spark.sql.connector.FakeV2Provider
 import org.apache.spark.sql.connector.catalog.{CatalogManager, 
CatalogNotFoundException, Identifier, SupportsDelete, Table, TableCapability, 
TableCatalog, V1Table}
 import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.datasources.{CreateTable => 
CreateTableV1}
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
 import org.apache.spark.sql.sources.SimpleScanSource
-import org.apache.spark.sql.types.{BooleanType, CharType, DoubleType, 
IntegerType, LongType, StringType, StructField, StructType}
+import org.apache.spark.sql.types.{BooleanType, CharType, DoubleType, 
IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType}
 
 class PlanResolutionSuite extends AnalysisTest {
   import CatalystSqlParser._
@@ -83,6 +85,22 @@ class PlanResolutionSuite extends AnalysisTest {
     t
   }
 
+  private val defaultValues: Table = {
+    val t = mock(classOf[Table])
+    when(t.schema()).thenReturn(
+      new StructType()
+        .add("i", BooleanType, true,
+          new MetadataBuilder()
+            
.putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "true")
+            
.putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, 
"true").build())
+        .add("s", IntegerType, true,
+          new MetadataBuilder()
+            
.putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "42")
+            
.putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, 
"42").build()))
+    when(t.partitioning()).thenReturn(Array.empty[Transform])
+    t
+  }
+
   private val v1Table: V1Table = {
     val t = mock(classOf[CatalogTable])
     when(t.schema).thenReturn(new StructType()
@@ -118,6 +136,7 @@ class PlanResolutionSuite extends AnalysisTest {
         case "tab1" => table1
         case "tab2" => table2
         case "charvarchar" => charVarcharTable
+        case "defaultvalues" => defaultValues
         case name => throw new NoSuchTableException(name)
       }
     })
@@ -974,11 +993,19 @@ class PlanResolutionSuite extends AnalysisTest {
            |SET t.age=32
            |WHERE t.name IN (SELECT s.name FROM s)
          """.stripMargin
+      val sql5 = s"UPDATE $tblName SET name=DEFAULT, age=DEFAULT"
+      // Note: 'i' and 's' are the names of the columns in 'tblName'.
+      val sql6 = s"UPDATE $tblName SET i=DEFAULT, s=DEFAULT"
+      val sql7 = s"UPDATE defaultvalues SET i=DEFAULT, s=DEFAULT"
+      val sql8 = s"UPDATE $tblName SET name='Robert', age=32 WHERE p=DEFAULT"
 
       val parsed1 = parseAndResolve(sql1)
       val parsed2 = parseAndResolve(sql2)
       val parsed3 = parseAndResolve(sql3)
       val parsed4 = parseAndResolve(sql4)
+      val parsed5 = parseAndResolve(sql5)
+      val parsed6 = parseAndResolve(sql6)
+      val parsed7 = parseAndResolve(sql7, true)
 
       parsed1 match {
         case UpdateTable(
@@ -1035,6 +1062,53 @@ class PlanResolutionSuite extends AnalysisTest {
 
         case _ => fail("Expect UpdateTable, but got:\n" + parsed4.treeString)
       }
+
+      parsed5 match {
+        case UpdateTable(
+          AsDataSourceV2Relation(_),
+          Seq(
+            Assignment(name: UnresolvedAttribute, 
UnresolvedAttribute(Seq("DEFAULT"))),
+            Assignment(age: UnresolvedAttribute, 
UnresolvedAttribute(Seq("DEFAULT")))),
+          None) =>
+          assert(name.name == "name")
+          assert(age.name == "age")
+
+        case _ => fail("Expect UpdateTable, but got:\n" + parsed5.treeString)
+      }
+
+      parsed6 match {
+        case UpdateTable(
+          AsDataSourceV2Relation(_),
+          Seq(
+            // Note that when resolving DEFAULT column references, the 
analyzer will insert literal
+            // NULL values if the corresponding table does not define an 
explicit default value for
+            // that column. This is intended.
+            Assignment(i: AttributeReference, AnsiCast(Literal(null, _), 
IntegerType, _)),
+            Assignment(s: AttributeReference, AnsiCast(Literal(null, _), 
StringType, _))),
+          None) =>
+          assert(i.name == "i")
+          assert(s.name == "s")
+
+        case _ => fail("Expect UpdateTable, but got:\n" + parsed6.treeString)
+      }
+
+      parsed7 match {
+        case UpdateTable(
+          _,
+          Seq(
+            Assignment(i: AttributeReference, Literal(true, BooleanType)),
+            Assignment(s: AttributeReference, Literal(42, IntegerType))),
+          None) =>
+          assert(i.name == "i")
+          assert(s.name == "s")
+
+        case _ => fail("Expect UpdateTable, but got:\n" + parsed7.treeString)
+      }
+
+      assert(intercept[AnalysisException] {
+        parseAndResolve(sql8)
+      }.getMessage.contains(
+        
QueryCompilationErrors.defaultReferencesNotAllowedInUpdateWhereClause().getMessage))
     }
 
     val sql1 = "UPDATE non_existing SET id=1"
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index b312e65154a..a2237b377cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, 
CatalogTable, CatalogTableType}
 import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
+import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.datasources.DataSourceUtils
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
@@ -1058,14 +1058,18 @@ class InsertSuite extends DataSourceTest with 
SharedSparkSession {
       sql("create table t(i boolean, s bigint default 42) using parquet")
       assert(intercept[AnalysisException] {
         sql("insert into t values(false, default + 1)")
-      
}.getMessage.contains(ResolveDefaultColumns.DEFAULTS_IN_EXPRESSIONS_ERROR))
+      }.getMessage.contains(
+        
QueryCompilationErrors.defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList()
+          .getMessage))
     }
     // Explicit default values may not participate in complex expressions in 
the SELECT query.
     withTable("t") {
       sql("create table t(i boolean, s bigint default 42) using parquet")
       assert(intercept[AnalysisException] {
         sql("insert into t select false, default + 1")
-      
}.getMessage.contains(ResolveDefaultColumns.DEFAULTS_IN_EXPRESSIONS_ERROR))
+      }.getMessage.contains(
+        
QueryCompilationErrors.defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList()
+          .getMessage))
     }
     // Explicit default values have a reasonable error path if the table is 
not found.
     withTable("t") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to