This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 767a52d5db35 [SPARK-47528][SQL] Add UserDefinedType support to DataTypeUtils.canWrite 767a52d5db35 is described below commit 767a52d5db354786d5ca07ddc4192d0eb8e8be80 Author: Liang-Chi Hsieh <vii...@gmail.com> AuthorDate: Sun Mar 24 01:27:52 2024 -0700 [SPARK-47528][SQL] Add UserDefinedType support to DataTypeUtils.canWrite ### What changes were proposed in this pull request? This patch adds `UserDefinedType` handling to `DataTypeUtils.canWrite`. ### Why are the changes needed? Our customer hits an issue recently when they tries to save a DataFrame containing some UDTs as table (`saveAsTable`). The error looks like: ``` - Cannot write 'xxx': struct<...> is incompatible with struct<...> ``` The catalog strings between two sides are actually same which makes the customer confused. It is because `DataTypeUtils.canWrite` doesn't handle `UserDefinedType`. If the `UserDefinedType`'s underlying sql type is same as read side, `canWrite` should return true for two sides. ### Does this PR introduce _any_ user-facing change? Yes. Write side column with `UserDefinedType` can be written into read side column with same sql data type. ### How was this patch tested? Unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #45678 from viirya/udt_dt_write. Authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../spark/sql/catalyst/types/DataTypeUtils.scala | 14 ++- .../types/DataTypeWriteCompatibilitySuite.scala | 134 +++++++++++++++++++++ 2 files changed, 147 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala index 01fb86bf2957..cf8e903f03a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy.{ANSI, STRICT} -import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, Decimal, DecimalType, MapType, NullType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, Decimal, DecimalType, MapType, NullType, StructField, StructType, UserDefinedType} import org.apache.spark.sql.types.DecimalType.{forType, fromDecimal} object DataTypeUtils { @@ -64,6 +64,8 @@ object DataTypeUtils { * - Both types are structs and have the same number of fields. The type and nullability of each * field from read/write is compatible. If byName is true, the name of each field from * read/write needs to be the same. + * - It is user defined type and its underlying sql type is same as the read type, or the read + * type is user defined type and its underlying sql type is same as the write type. * - Both types are atomic and the write type can be safely cast to the read type. * * Extra fields in write-side structs are not allowed to avoid accidentally writing data that @@ -180,6 +182,16 @@ object DataTypeUtils { case (w, r) if DataTypeUtils.sameType(w, r) && !w.isInstanceOf[NullType] => true + // If write-side data type is a user-defined type, check with its underlying data type. + case (w, r) if w.isInstanceOf[UserDefinedType[_]] && !r.isInstanceOf[UserDefinedType[_]] => + canWrite(tableName, w.asInstanceOf[UserDefinedType[_]].sqlType, r, byName, resolver, + context, storeAssignmentPolicy, addError) + + // If read-side data type is a user-defined type, check with its underlying data type. + case (w, r) if r.isInstanceOf[UserDefinedType[_]] && !w.isInstanceOf[UserDefinedType[_]] => + canWrite(tableName, w, r.asInstanceOf[UserDefinedType[_]].sqlType, byName, resolver, + context, storeAssignmentPolicy, addError) + case (w, r) => throw QueryCompilationErrors.incompatibleDataToTableCannotSafelyCastError( tableName, context, w.catalogString, r.catalogString diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index 7aaa69a0a5dd..8c9196cc33ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -510,6 +510,140 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { "Should allow map of int written to map of long column") } + test("SPARK-47528: Check udt: underlying sql type is same") { + val udtType = new UserDefinedType[Any] { + override def sqlType: DataType = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", FloatType, nullable = false))) + + override def userClass: java.lang.Class[Any] = null + + override def serialize(obj: Any): Any = null + + override def deserialize(datum: Any): Any = null + } + + val sqlType = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", FloatType, nullable = false))) + + assertAllowed(udtType, sqlType, "m", + "Should allow udt with same sqlType written to struct column") + + assertAllowed(sqlType, udtType, "m", + "Should allow udt with same sqlType written to struct column") + } + + test("SPARK-47528: Check udt: underlying sql type is same but different nullability") { + val udtType = new UserDefinedType[Any] { + override def sqlType: DataType = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", FloatType, nullable = false))) + + override def userClass: java.lang.Class[Any] = null + + override def serialize(obj: Any): Any = null + + override def deserialize(datum: Any): Any = null + } + + val sqlType = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", FloatType, nullable = true))) + + assertAllowed(udtType, sqlType, "m", + "Should allow udt with same sqlType written to struct column") + + val errs = new mutable.ArrayBuffer[String]() + checkError( + exception = intercept[AnalysisException] ( + DataTypeUtils.canWrite("", sqlType, udtType, true, + analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) + ), + errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.NULLABLE_COLUMN", + parameters = Map( + "tableName" -> "``", + "colName" -> "`t`.`col2`" + ) + ) + } + + test("SPARK-47528: Check udt: write underlying sql type is not same") { + val udtType = new UserDefinedType[Any] { + override def sqlType: DataType = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", FloatType, nullable = false))) + + override def userClass: java.lang.Class[Any] = null + + override def serialize(obj: Any): Any = null + + override def deserialize(datum: Any): Any = null + } + + val sqlType = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", IntegerType, nullable = false))) + + if (canCast(udtType.sqlType, sqlType)) { + assertAllowed(udtType, sqlType, "m", + "Should allow udt with compatible sqlType written to struct column") + } else { + val errs = new mutable.ArrayBuffer[String]() + checkError( + exception = intercept[AnalysisException]( + DataTypeUtils.canWrite("", udtType, sqlType, true, + analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) + ), + errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + parameters = Map( + "tableName" -> "``", + "colName" -> "`t`.`col2`", + "srcType" -> "\"FLOAT\"", + "targetType" -> "\"INT\"" + ) + ) + } + } + + test("SPARK-47528: Check udt: read side underlying sql type is not same") { + val udtType = new UserDefinedType[Any] { + override def sqlType: DataType = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", IntegerType, nullable = false))) + + override def userClass: java.lang.Class[Any] = null + + override def serialize(obj: Any): Any = null + + override def deserialize(datum: Any): Any = null + } + + val sqlType = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", FloatType, nullable = false))) + + if (canCast(sqlType, udtType.sqlType)) { + assertAllowed(sqlType, udtType, "m", + "Should allow udt with compatible sqlType written to struct column") + } else { + val errs = new mutable.ArrayBuffer[String]() + checkError( + exception = intercept[AnalysisException]( + DataTypeUtils.canWrite("", sqlType, udtType, true, + analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) + ), + errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + parameters = Map( + "tableName" -> "``", + "colName" -> "`t`.`col2`", + "srcType" -> "\"FLOAT\"", + "targetType" -> "\"INT\"" + ) + ) + } + } + test("Check types with multiple errors") { val readType = StructType(Seq( StructField("a", ArrayType(DoubleType, containsNull = false)), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org