This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 09fb739877f [SPARK-43313][SQL] Adding missing column DEFAULT values 
for MERGE INSERT actions
09fb739877f is described below

commit 09fb739877f035c5bd11f9e9716595c6975b5c49
Author: Daniel Tenedorio <[email protected]>
AuthorDate: Thu May 4 13:18:15 2023 -0700

    [SPARK-43313][SQL] Adding missing column DEFAULT values for MERGE INSERT 
actions
    
    ### What changes were proposed in this pull request?
    
    This PR updates column DEFAULT assignment to add missing values for MERGE 
INSERT actions. This brings the behavior to parity with non-MERGE INSERT 
commands.
    
    * It also adds a small convenience feature where if the provided default 
value is a literal of a wider type than the target column, but the literal 
value fits within the narrower type, just coerce it for convenience. For 
example, `CREATE TABLE t (col INT DEFAULT 42L)` returns an error before this PR 
because `42L` has a long integer type which is wider than `col`, but after this 
PR we just coerce it to `42` since the value fits within the short integer 
range.
    * We also add the `SupportsCustomSchemaWrite` interface which tables may 
implement to exclude certain pseudocolumns from consideration when resolving 
column DEFAULT values.
    
    ### Why are the changes needed?
    
    These changes make column DEFAULT values more usable in more types of 
situations.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, see above.
    
    ### How was this patch tested?
    
    This PR adds new unit test coverage.
    
    Closes #40996 from dtenedor/merge-actions.
    
    Authored-by: Daniel Tenedorio <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
    (cherry picked from commit 3a0e6bde2aaa11e1165f4fde040ff02e1743795e)
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../connector/write/SupportsCustomSchemaWrite.java |  38 ++++++
 .../catalyst/analysis/ResolveDefaultColumns.scala  |  56 +++++++-
 .../catalyst/util/ResolveDefaultColumnsUtil.scala  |  32 ++++-
 .../analysis/ResolveDefaultColumnsSuite.scala      | 151 ++++++++++++++++++++-
 .../org/apache/spark/sql/sources/InsertSuite.scala |   4 +-
 5 files changed, 265 insertions(+), 16 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsCustomSchemaWrite.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsCustomSchemaWrite.java
new file mode 100644
index 00000000000..9435625a1c4
--- /dev/null
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsCustomSchemaWrite.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.write;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * Trait for tables that support custom schemas for write operations including 
INSERT INTO commands
+ * whose target table columns have explicit or implicit default values.
+ *
+ * @since 3.4.1
+ */
+@Evolving
+public interface SupportsCustomSchemaWrite {
+    /**
+     * Represents a table with a custom schema to use for resolving DEFAULT 
column references when
+     * inserting into the table. For example, this can be useful for excluding 
hidden pseudocolumns.
+     *
+     * @return the new schema to use for this process.
+     */
+    StructType customSchemaForInserts();
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
index d028d4ff596..13e9866645a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
@@ -18,13 +18,14 @@
 package org.apache.spark.sql.catalyst.analysis
 
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.catalyst.catalog.UnresolvedCatalogRelation
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
-import org.apache.spark.sql.connector.catalog.CatalogV2Util
+import org.apache.spark.sql.connector.write.SupportsCustomSchemaWrite
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.internal.SQLConf
@@ -214,8 +215,11 @@ case class ResolveDefaultColumns(
         throw 
QueryCompilationErrors.defaultReferencesNotAllowedInMergeCondition()
       }
     }
+    val columnsWithDefaults = ArrayBuffer.empty[String]
     val defaultExpressions: Seq[Expression] = schema.fields.map {
-      case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => 
analyze(f, "MERGE")
+      case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
+        columnsWithDefaults.append(normalizeFieldName(f.name))
+        analyze(f, "MERGE")
       case _ => Literal(null)
     }
     val columnNamesToExpressions: Map[String, Expression] =
@@ -228,7 +232,8 @@ case class ResolveDefaultColumns(
       }.getOrElse(action)
     }
     val newNotMatchedActions: Seq[MergeAction] = m.notMatchedActions.map { 
action: MergeAction =>
-      replaceExplicitDefaultValuesInMergeAction(action, 
columnNamesToExpressions).map { r =>
+      val expanded = addMissingDefaultValuesForMergeAction(action, m, 
columnsWithDefaults.toSeq)
+      replaceExplicitDefaultValuesInMergeAction(expanded, 
columnNamesToExpressions).map { r =>
         replaced = true
         r
       }.getOrElse(action)
@@ -249,6 +254,38 @@ case class ResolveDefaultColumns(
     }
   }
 
+  /** Adds a new expressions to a merge action to generate missing default 
column values. */
+  def addMissingDefaultValuesForMergeAction(
+      action: MergeAction,
+      m: MergeIntoTable,
+      columnNamesWithDefaults: Seq[String]): MergeAction = {
+    action match {
+      case i: InsertAction =>
+        val targetColumns: Set[String] = i.assignments.map(_.key).flatMap { 
expr =>
+          expr match {
+            case a: AttributeReference => Seq(normalizeFieldName(a.name))
+            case u: UnresolvedAttribute => 
Seq(u.nameParts.map(normalizeFieldName).mkString("."))
+            case _ => Seq()
+          }
+        }.toSet
+        val targetTable: String = m.targetTable match {
+          case SubqueryAlias(id, _) => id.name
+          case d: DataSourceV2Relation => d.name
+        }
+        val missingColumnNamesWithDefaults = columnNamesWithDefaults.filter { 
name =>
+          !targetColumns.contains(normalizeFieldName(name)) &&
+            !targetColumns.contains(
+              
s"${normalizeFieldName(targetTable)}.${normalizeFieldName(name)}")
+        }
+        val newAssignments: Seq[Assignment] = 
missingColumnNamesWithDefaults.map { key =>
+          Assignment(UnresolvedAttribute(key), 
UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME))
+        }
+        i.copy(assignments = i.assignments ++ newAssignments)
+      case _ =>
+        action
+    }
+  }
+
   /**
    * Replaces unresolved DEFAULT column references with corresponding values 
in one action of a
    * MERGE INTO command.
@@ -547,9 +584,14 @@ case class ResolveDefaultColumns(
         name: String => colNamesToFields.getOrElse(normalizeFieldName(name), 
return None)
       }
     val userSpecifiedColNames: Set[String] = userSpecifiedCols.toSet
+      .map(normalizeFieldName)
     val nonUserSpecifiedFields: Seq[StructField] =
       schema.fields.filter {
-        field => !userSpecifiedColNames.contains(field.name)
+        field => !userSpecifiedColNames.contains(
+          normalizeFieldName(
+            field.name
+          )
+        )
       }
     Some(StructType(userSpecifiedFields ++
       getStructFieldsForDefaultExpressions(nonUserSpecifiedFields)))
@@ -589,8 +631,10 @@ case class ResolveDefaultColumns(
     resolved.collectFirst {
       case r: UnresolvedCatalogRelation =>
         r.tableMeta.schema
-      case d: DataSourceV2Relation if !d.skipSchemaResolution && 
!d.isStreaming =>
-        CatalogV2Util.v2ColumnsToStructType(d.table.columns())
+      case DataSourceV2Relation(table: SupportsCustomSchemaWrite, _, _, _, _) 
=>
+        table.customSchemaForInserts
+      case r: NamedRelation if !r.skipSchemaResolution =>
+        r.schema
       case v: View if v.isTempViewStoringAnalyzedPlan =>
         v.schema
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
index d0287cc602b..8c7e2ad4f1d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.util
 
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark.SparkThrowable
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis._
@@ -218,10 +219,33 @@ object ResolveDefaultColumns {
     } else if (Cast.canUpCast(analyzed.dataType, dataType)) {
       Cast(analyzed, dataType)
     } else {
-      throw new AnalysisException(
-        s"Failed to execute $statementType command because the destination 
table column " +
-          s"$colName has a DEFAULT value with type $dataType, but the " +
-          s"statement provided a value of incompatible type 
${analyzed.dataType}")
+      // If the provided default value is a literal of a wider type than the 
target column, but the
+      // literal value fits within the narrower type, just coerce it for 
convenience. Exclude
+      // boolean/array/struct/map types from consideration for this type 
coercion to avoid
+      // surprising behavior like interpreting "false" as integer zero.
+      val result = if (analyzed.isInstanceOf[Literal] &&
+        !Seq(dataType, analyzed.dataType).exists(_ match {
+          case _: BooleanType | _: ArrayType | _: StructType | _: MapType => 
true
+          case _ => false
+        })) {
+        try {
+          val casted = Cast(analyzed, dataType, evalMode = EvalMode.TRY).eval()
+          if (casted != null) {
+            Some(Literal(casted, dataType))
+          } else {
+            None
+          }
+        } catch {
+          case _: SparkThrowable | _: RuntimeException =>
+            None
+        }
+      } else None
+      result.getOrElse {
+        throw new AnalysisException(
+          s"Failed to execute $statementType command because the destination 
table column " +
+            s"$colName has a DEFAULT value with type $dataType, but the " +
+            s"statement provided a value of incompatible type 
${analyzed.dataType}")
+      }
     }
   }
   /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala
index ea0dc9d603d..ba52ac995b7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala
@@ -17,11 +17,16 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
-import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.{AnalysisException, QueryTest}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
+import org.apache.spark.sql.connector.catalog.{Table, TableCapability}
+import org.apache.spark.sql.connector.write.SupportsCustomSchemaWrite
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{StructField, StructType, TimestampType}
+import org.apache.spark.sql.types.{MetadataBuilder, StringType, StructField, 
StructType, TimestampType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession {
   val rule = ResolveDefaultColumns(null)
@@ -93,4 +98,142 @@ class ResolveDefaultColumnsSuite extends QueryTest with 
SharedSparkSession {
       }
     }
   }
+
+  test("SPARK-43313: Column default values with implicit coercion from 
provided values") {
+    withDatabase("demos") {
+      sql("create database demos")
+      withTable("demos.test_ts") {
+        // If the provided default value is a literal of a wider type than the 
target column, but
+        // the literal value fits within the narrower type, just coerce it for 
convenience.
+        sql(
+          """create table demos.test_ts (
+            |a int default 42L,
+            |b timestamp_ntz default '2022-01-02',
+            |c date default '2022-01-03',
+            |f float default 0D
+            |) using parquet""".stripMargin)
+        sql("insert into demos.test_ts(a) values (default)")
+        checkAnswer(spark.table("demos.test_ts"),
+          sql("select 42, timestamp_ntz'2022-01-02', date'2022-01-03', 0f"))
+        // If the provided default value is a literal of a different type than 
the target column
+        // such that no coercion is possible, throw an error.
+        Seq(
+          "create table demos.test_ts_other (a int default 'abc') using 
parquet",
+          "create table demos.test_ts_other (a timestamp default '2022-01-02') 
using parquet",
+          "create table demos.test_ts_other (a boolean default 'true') using 
parquet",
+          "create table demos.test_ts_other (a int default true) using parquet"
+        ).foreach { command =>
+          assert(intercept[AnalysisException](sql(command))
+            .getMessage.contains("statement provided a value of incompatible 
type"))
+        }
+      }
+    }
+  }
+
+  /**
+   * This is a new relation type that defines the 'customSchemaForInserts' 
method.
+   * Its implementation drops the last table column as it represents an 
internal pseudocolumn.
+   */
+  case class TableWithCustomInsertSchema(output: Seq[Attribute], 
numMetadataColumns: Int)
+    extends Table with SupportsCustomSchemaWrite {
+    override def name: String = "t"
+    override def schema: StructType = StructType.fromAttributes(output)
+    override def capabilities(): java.util.Set[TableCapability] =
+      new java.util.HashSet[TableCapability]()
+    override def customSchemaForInserts: StructType =
+      StructType(schema.fields.dropRight(numMetadataColumns))
+  }
+
+  /** Helper method to generate a DSV2 relation using the above table type. */
+  private def relationWithCustomInsertSchema(
+      output: Seq[AttributeReference], numMetadataColumns: Int): 
DataSourceV2Relation = {
+    DataSourceV2Relation(
+      TableWithCustomInsertSchema(output, numMetadataColumns),
+      output,
+      catalog = None,
+      identifier = None,
+      options = CaseInsensitiveStringMap.empty)
+  }
+
+  test("SPARK-43313: Add missing default values for MERGE INSERT actions") {
+    val testRelation = SubqueryAlias(
+      "testRelation",
+      relationWithCustomInsertSchema(Seq(
+        AttributeReference(
+          "a",
+          StringType,
+          true,
+          new MetadataBuilder()
+            .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'a'")
+            .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'a'")
+            .build())(),
+        AttributeReference(
+          "b",
+          StringType,
+          true,
+          new MetadataBuilder()
+            .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'b'")
+            .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'b'")
+            .build())(),
+        AttributeReference(
+          "c",
+          StringType,
+          true,
+          new MetadataBuilder()
+            .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'c'")
+            .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'c'")
+            .build())(),
+        AttributeReference(
+          "pseudocolumn",
+          StringType,
+          true,
+          new MetadataBuilder()
+            .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'pseudocolumn'")
+            .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'pseudocolumn'")
+            .build())()),
+        numMetadataColumns = 1))
+    val testRelation2 =
+      SubqueryAlias(
+        "testRelation2",
+        relationWithCustomInsertSchema(Seq(
+          AttributeReference("d", StringType)(),
+          AttributeReference("e", StringType)(),
+          AttributeReference("f", StringType)()),
+        numMetadataColumns = 0))
+    val mergePlan = MergeIntoTable(
+      targetTable = testRelation,
+      sourceTable = testRelation2,
+      mergeCondition = EqualTo(testRelation.output.head, 
testRelation2.output.head),
+      matchedActions = Seq(DeleteAction(None)),
+      notMatchedActions = Seq(
+        InsertAction(
+          condition = None,
+          assignments = Seq(
+            Assignment(
+              key = UnresolvedAttribute("a"),
+              value = UnresolvedAttribute("DEFAULT")),
+            Assignment(
+              key = UnresolvedAttribute(Seq("testRelation", "b")),
+              value = Literal("xyz"))))),
+      notMatchedBySourceActions = Seq(DeleteAction(None)))
+    // Run the 'addMissingDefaultValuesForMergeAction' method of the 
'ResolveDefaultColumns' rule
+    // on an MERGE INSERT action with two assignments, one to the target 
table's column 'a' and
+    // another to the target table's column 'b'.
+    val columnNamesWithDefaults = Seq("a", "b", "c")
+    val actualMergeAction =
+      rule.apply(mergePlan).asInstanceOf[MergeIntoTable].notMatchedActions.head
+    val expectedMergeAction =
+      InsertAction(
+        condition = None,
+        assignments = Seq(
+          Assignment(key = UnresolvedAttribute("a"), value = Literal("a")),
+          Assignment(key = UnresolvedAttribute(Seq("testRelation", "b")), 
value = Literal("xyz")),
+          Assignment(key = UnresolvedAttribute("c"), value = Literal("c"))))
+    assert(expectedMergeAction == actualMergeAction)
+    // Run the same method on another MERGE DELETE action. There is no change 
because this method
+    // only operates on MERGE INSERT actions.
+    assert(rule.addMissingDefaultValuesForMergeAction(
+      mergePlan.matchedActions.head, mergePlan, columnNamesWithDefaults) ==
+      mergePlan.matchedActions.head)
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 2207661478d..7d86a60815c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -1044,8 +1044,8 @@ class InsertSuite extends DataSourceTest with 
SharedSparkSession {
   test("SPARK-38336 INSERT INTO statements with tables with default columns: 
negative tests") {
     object Errors {
       val COMMON_SUBSTRING = " has a DEFAULT value"
-      val COLUMN_DEFAULT_NOT_FOUND = "`default` cannot be resolved."
       val BAD_SUBQUERY = "subquery expressions are not allowed in DEFAULT 
values"
+      val TARGET_TABLE_NOT_FOUND = "The table or view `t` cannot be found"
     }
     // The default value fails to analyze.
     withTable("t") {
@@ -1096,7 +1096,7 @@ class InsertSuite extends DataSourceTest with 
SharedSparkSession {
     withTable("t") {
       assert(intercept[AnalysisException] {
         sql("insert into t values(false, default)")
-      }.getMessage.contains(Errors.COLUMN_DEFAULT_NOT_FOUND))
+      }.getMessage.contains(Errors.TARGET_TABLE_NOT_FOUND))
     }
     // The default value parses but the type is not coercible.
     withTable("t") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to