gengliangwang commented on a change in pull request #25581: [SPARK-28495][SQL]
Introduce ANSI store assignment policy for table insertion
URL: https://github.com/apache/spark/pull/25581#discussion_r317950237
##########
File path:
sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala
##########
@@ -22,20 +22,296 @@ import scala.collection.mutable
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
-class DataTypeWriteCompatibilitySuite extends SparkFunSuite {
- private val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType,
LongType, FloatType,
- DoubleType, DateType, TimestampType, StringType, BinaryType)
+class StrictDataTypeWriteCompatibilitySuite extends
DataTypeWriteCompatibilityBaseSuite {
+ override protected def storeAssignmentPolicy:
SQLConf.StoreAssignmentPolicy.Value =
+ StoreAssignmentPolicy.STRICT
- private val point2 = StructType(Seq(
+ test("Check atomic types: write allowed only when casting is safe") {
+ atomicTypes.foreach { w =>
+ atomicTypes.foreach { r =>
+ if (Cast.canUpCast(w, r)) {
+ assertAllowed(w, r, "t", s"Should allow writing $w to $r because
cast is safe")
+
+ } else {
+ assertSingleError(w, r, "t",
+ s"Should not allow writing $w to $r because cast is not safe") {
err =>
+ assert(err.contains("'t'"), "Should include the field name
context")
+ assert(err.contains("Cannot safely cast"), "Should identify unsafe
cast")
+ assert(err.contains(s"$w"), "Should include write type")
+ assert(err.contains(s"$r"), "Should include read type")
+ }
+ }
+ }
+ }
+ }
+
+ test("Check struct types: unsafe casts are not allowed") {
+ assertNumErrors(widerPoint2, point2, "t",
+ "Should fail because types require unsafe casts", 2) { errs =>
+
+ assert(errs(0).contains("'t.x'"), "Should include the nested field name
context")
+ assert(errs(0).contains("Cannot safely cast"))
+
+ assert(errs(1).contains("'t.y'"), "Should include the nested field name
context")
+ assert(errs(1).contains("Cannot safely cast"))
+ }
+ }
+
+ test("Check array types: unsafe casts are not allowed") {
+ val arrayOfLong = ArrayType(LongType)
+ val arrayOfInt = ArrayType(IntegerType)
+
+ assertSingleError(arrayOfLong, arrayOfInt, "arr",
+ "Should not allow array of longs to array of ints") { err =>
+ assert(err.contains("'arr.element'"),
+ "Should identify problem with named array's element type")
+ assert(err.contains("Cannot safely cast"))
+ }
+ }
+
+ test("Check map value types: casting Long to Integer is not allowed") {
+ val mapOfLong = MapType(StringType, LongType)
+ val mapOfInt = MapType(StringType, IntegerType)
+
+ assertSingleError(mapOfLong, mapOfInt, "m",
+ "Should not allow map of longs to map of ints") { err =>
+ assert(err.contains("'m.value'"), "Should identify problem with named
map's value type")
+ assert(err.contains("Cannot safely cast"))
+ }
+ }
+
+ test("Check map key types: unsafe casts are not allowed") {
+ val mapKeyLong = MapType(LongType, StringType)
+ val mapKeyInt = MapType(IntegerType, StringType)
+
+ assertSingleError(mapKeyLong, mapKeyInt, "m",
+ "Should not allow map of long keys to map of int keys") { err =>
+ assert(err.contains("'m.key'"), "Should identify problem with named
map's key type")
+ assert(err.contains("Cannot safely cast"))
+ }
+ }
+
+ test("Check types with multiple errors") {
+ val readType = StructType(Seq(
+ StructField("a", ArrayType(DoubleType, containsNull = false)),
+ StructField("arr_of_structs", ArrayType(point2, containsNull = false)),
+ StructField("bad_nested_type", ArrayType(StringType)),
+ StructField("m", MapType(LongType, FloatType, valueContainsNull =
false)),
+ StructField("map_of_structs", MapType(StringType, point3,
valueContainsNull = false)),
+ StructField("x", IntegerType, nullable = false),
+ StructField("missing1", StringType, nullable = false),
+ StructField("missing2", StringType)
+ ))
+
+ val missingMiddleField = StructType(Seq(
+ StructField("x", FloatType, nullable = false),
+ StructField("z", FloatType, nullable = false)))
+
+ val writeType = StructType(Seq(
+ StructField("a", ArrayType(StringType)),
+ StructField("arr_of_structs", ArrayType(point3)),
+ StructField("bad_nested_type", point3),
+ StructField("m", MapType(DoubleType, DoubleType)),
+ StructField("map_of_structs", MapType(StringType, missingMiddleField)),
+ StructField("y", LongType)
+ ))
+
+ assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14)
{ errs =>
+ assert(errs(0).contains("'top.a.element'"), "Should identify bad type")
+ assert(errs(0).contains("Cannot safely cast"))
+ assert(errs(0).contains("StringType to DoubleType"))
+
+ assert(errs(1).contains("'top.a'"), "Should identify bad type")
+ assert(errs(1).contains("Cannot write nullable elements to array of
non-nulls"))
+
+ assert(errs(2).contains("'top.arr_of_structs.element'"), "Should
identify bad type")
+ assert(errs(2).contains("'z'"), "Should identify bad field")
+ assert(errs(2).contains("Cannot write extra fields to struct"))
+
+ assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad
type")
+ assert(errs(3).contains("Cannot write nullable elements to array of
non-nulls"))
+
+ assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad
type")
+ assert(errs(4).contains("is incompatible with"))
+
+ assert(errs(5).contains("'top.m.key'"), "Should identify bad type")
+ assert(errs(5).contains("Cannot safely cast"))
+ assert(errs(5).contains("DoubleType to LongType"))
+
+ assert(errs(6).contains("'top.m.value'"), "Should identify bad type")
+ assert(errs(6).contains("Cannot safely cast"))
+ assert(errs(6).contains("DoubleType to FloatType"))
+
+ assert(errs(7).contains("'top.m'"), "Should identify bad type")
+ assert(errs(7).contains("Cannot write nullable values to map of
non-nulls"))
+
+ assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify
bad type")
+ assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name
mismatch")
+ assert(errs(8).contains("field name does not match"), "Should identify
name problem")
+
+ assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify
bad type")
+ assert(errs(9).contains("'z'"), "Should identify missing field")
+ assert(errs(9).contains("missing fields"), "Should detect missing field")
+
+ assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad
type")
+ assert(errs(10).contains("Cannot write nullable values to map of
non-nulls"))
+
+ assert(errs(11).contains("'top.x'"), "Should identify bad type")
+ assert(errs(11).contains("Cannot safely cast"))
+ assert(errs(11).contains("LongType to IntegerType"))
+
+ assert(errs(12).contains("'top'"), "Should identify bad type")
+ assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name
mismatch")
+ assert(errs(12).contains("field name does not match"), "Should identify
name problem")
+
+ assert(errs(13).contains("'top'"), "Should identify bad type")
+ assert(errs(13).contains("'missing1'"), "Should identify missing field")
+ assert(errs(13).contains("missing fields"), "Should detect missing
field")
+ }
+ }
+}
+
+class ANSIDataTypeWriteCompatibilitySuite extends
DataTypeWriteCompatibilityBaseSuite {
+ override protected def storeAssignmentPolicy:
SQLConf.StoreAssignmentPolicy.Value =
+ StoreAssignmentPolicy.ANSI
+
+ test("Check atomic types: write allowed only when casting is safe") {
+ atomicTypes.foreach { w =>
+ atomicTypes.foreach { r =>
+ if ((w.isInstanceOf[NumericType] && r.isInstanceOf[NumericType]) ||
Review comment:
I thought we were using the method `canANSIStoreAssign` to test the method
itself. But it also makes sense to just use `Cast.canANSIStoreAssign(w, r)`, we
need to guarantee that the method is used accurately.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]