This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 767a52d5db35 [SPARK-47528][SQL] Add UserDefinedType support to 
DataTypeUtils.canWrite
767a52d5db35 is described below

commit 767a52d5db354786d5ca07ddc4192d0eb8e8be80
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Sun Mar 24 01:27:52 2024 -0700

    [SPARK-47528][SQL] Add UserDefinedType support to DataTypeUtils.canWrite
    
    ### What changes were proposed in this pull request?
    
    This patch adds `UserDefinedType` handling to `DataTypeUtils.canWrite`.
    
    ### Why are the changes needed?
    
    Our customer hits an issue recently when they tries to save a DataFrame 
containing some UDTs as table (`saveAsTable`). The error looks like:
    
    ```
    - Cannot write 'xxx': struct<...> is incompatible with struct<...>
    ```
    
    The catalog strings between two sides are actually same which makes the 
customer confused.
    
    It is because `DataTypeUtils.canWrite` doesn't handle `UserDefinedType`. If 
the `UserDefinedType`'s underlying sql type is same as read side, `canWrite` 
should return true for two sides.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Write side column with `UserDefinedType` can be written into read side 
column with same sql data type.
    
    ### How was this patch tested?
    
    Unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45678 from viirya/udt_dt_write.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../spark/sql/catalyst/types/DataTypeUtils.scala   |  14 ++-
 .../types/DataTypeWriteCompatibilitySuite.scala    | 134 +++++++++++++++++++++
 2 files changed, 147 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
index 01fb86bf2957..cf8e903f03a3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
 import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy.{ANSI, 
STRICT}
-import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, Decimal, 
DecimalType, MapType, NullType, StructField, StructType}
+import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, Decimal, 
DecimalType, MapType, NullType, StructField, StructType, UserDefinedType}
 import org.apache.spark.sql.types.DecimalType.{forType, fromDecimal}
 
 object DataTypeUtils {
@@ -64,6 +64,8 @@ object DataTypeUtils {
    * - Both types are structs and have the same number of fields. The type and 
nullability of each
    *   field from read/write is compatible. If byName is true, the name of 
each field from
    *   read/write needs to be the same.
+   * - It is user defined type and its underlying sql type is same as the read 
type, or the read
+   *   type is user defined type and its underlying sql type is same as the 
write type.
    * - Both types are atomic and the write type can be safely cast to the read 
type.
    *
    * Extra fields in write-side structs are not allowed to avoid accidentally 
writing data that
@@ -180,6 +182,16 @@ object DataTypeUtils {
       case (w, r) if DataTypeUtils.sameType(w, r) && !w.isInstanceOf[NullType] 
=>
         true
 
+      // If write-side data type is a user-defined type, check with its 
underlying data type.
+      case (w, r) if w.isInstanceOf[UserDefinedType[_]] && 
!r.isInstanceOf[UserDefinedType[_]] =>
+        canWrite(tableName, w.asInstanceOf[UserDefinedType[_]].sqlType, r, 
byName, resolver,
+          context, storeAssignmentPolicy, addError)
+
+      // If read-side data type is a user-defined type, check with its 
underlying data type.
+      case (w, r) if r.isInstanceOf[UserDefinedType[_]] && 
!w.isInstanceOf[UserDefinedType[_]] =>
+        canWrite(tableName, w, r.asInstanceOf[UserDefinedType[_]].sqlType, 
byName, resolver,
+          context, storeAssignmentPolicy, addError)
+
       case (w, r) =>
         throw 
QueryCompilationErrors.incompatibleDataToTableCannotSafelyCastError(
           tableName, context, w.catalogString, r.catalogString
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala
index 7aaa69a0a5dd..8c9196cc33ca 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala
@@ -510,6 +510,140 @@ abstract class DataTypeWriteCompatibilityBaseSuite 
extends SparkFunSuite {
       "Should allow map of int written to map of long column")
   }
 
+  test("SPARK-47528: Check udt: underlying sql type is same") {
+    val udtType = new UserDefinedType[Any] {
+      override def sqlType: DataType = StructType(Seq(
+        StructField("col1", FloatType, nullable = false),
+        StructField("col2", FloatType, nullable = false)))
+
+      override def userClass: java.lang.Class[Any] = null
+
+      override def serialize(obj: Any): Any = null
+
+      override def deserialize(datum: Any): Any = null
+    }
+
+    val sqlType = StructType(Seq(
+      StructField("col1", FloatType, nullable = false),
+      StructField("col2", FloatType, nullable = false)))
+
+    assertAllowed(udtType, sqlType, "m",
+      "Should allow udt with same sqlType written to struct column")
+
+    assertAllowed(sqlType, udtType, "m",
+      "Should allow udt with same sqlType written to struct column")
+  }
+
+  test("SPARK-47528: Check udt: underlying sql type is same but different 
nullability") {
+    val udtType = new UserDefinedType[Any] {
+      override def sqlType: DataType = StructType(Seq(
+        StructField("col1", FloatType, nullable = false),
+        StructField("col2", FloatType, nullable = false)))
+
+      override def userClass: java.lang.Class[Any] = null
+
+      override def serialize(obj: Any): Any = null
+
+      override def deserialize(datum: Any): Any = null
+    }
+
+    val sqlType = StructType(Seq(
+      StructField("col1", FloatType, nullable = false),
+      StructField("col2", FloatType, nullable = true)))
+
+    assertAllowed(udtType, sqlType, "m",
+      "Should allow udt with same sqlType written to struct column")
+
+    val errs = new mutable.ArrayBuffer[String]()
+    checkError(
+      exception = intercept[AnalysisException] (
+        DataTypeUtils.canWrite("", sqlType, udtType, true,
+          analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg 
=> errs += errMsg)
+      ),
+      errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.NULLABLE_COLUMN",
+      parameters = Map(
+        "tableName" -> "``",
+        "colName" -> "`t`.`col2`"
+      )
+    )
+  }
+
+  test("SPARK-47528: Check udt: write underlying sql type is not same") {
+    val udtType = new UserDefinedType[Any] {
+      override def sqlType: DataType = StructType(Seq(
+        StructField("col1", FloatType, nullable = false),
+        StructField("col2", FloatType, nullable = false)))
+
+      override def userClass: java.lang.Class[Any] = null
+
+      override def serialize(obj: Any): Any = null
+
+      override def deserialize(datum: Any): Any = null
+    }
+
+    val sqlType = StructType(Seq(
+      StructField("col1", FloatType, nullable = false),
+      StructField("col2", IntegerType, nullable = false)))
+
+    if (canCast(udtType.sqlType, sqlType)) {
+      assertAllowed(udtType, sqlType, "m",
+        "Should allow udt with compatible sqlType written to struct column")
+    } else {
+      val errs = new mutable.ArrayBuffer[String]()
+      checkError(
+        exception = intercept[AnalysisException](
+          DataTypeUtils.canWrite("", udtType, sqlType, true,
+            analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, 
errMsg => errs += errMsg)
+        ),
+        errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST",
+        parameters = Map(
+          "tableName" -> "``",
+          "colName" -> "`t`.`col2`",
+          "srcType" -> "\"FLOAT\"",
+          "targetType" -> "\"INT\""
+        )
+      )
+    }
+  }
+
+  test("SPARK-47528: Check udt: read side underlying sql type is not same") {
+    val udtType = new UserDefinedType[Any] {
+      override def sqlType: DataType = StructType(Seq(
+        StructField("col1", FloatType, nullable = false),
+        StructField("col2", IntegerType, nullable = false)))
+
+      override def userClass: java.lang.Class[Any] = null
+
+      override def serialize(obj: Any): Any = null
+
+      override def deserialize(datum: Any): Any = null
+    }
+
+    val sqlType = StructType(Seq(
+      StructField("col1", FloatType, nullable = false),
+      StructField("col2", FloatType, nullable = false)))
+
+    if (canCast(sqlType, udtType.sqlType)) {
+      assertAllowed(sqlType, udtType, "m",
+        "Should allow udt with compatible sqlType written to struct column")
+    } else {
+      val errs = new mutable.ArrayBuffer[String]()
+      checkError(
+        exception = intercept[AnalysisException](
+          DataTypeUtils.canWrite("", sqlType, udtType, true,
+            analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, 
errMsg => errs += errMsg)
+        ),
+        errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST",
+        parameters = Map(
+          "tableName" -> "``",
+          "colName" -> "`t`.`col2`",
+          "srcType" -> "\"FLOAT\"",
+          "targetType" -> "\"INT\""
+        )
+      )
+    }
+  }
+
   test("Check types with multiple errors") {
     val readType = StructType(Seq(
       StructField("a", ArrayType(DoubleType, containsNull = false)),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to