gengliangwang commented on a change in pull request #25581: [SPARK-28495][SQL] 
Introduce ANSI store assignment policy for table insertion
URL: https://github.com/apache/spark/pull/25581#discussion_r317950237
 
 

 ##########
 File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala
 ##########
 @@ -22,20 +22,296 @@ import scala.collection.mutable
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
 
-class DataTypeWriteCompatibilitySuite extends SparkFunSuite {
-  private val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, 
LongType, FloatType,
-    DoubleType, DateType, TimestampType, StringType, BinaryType)
+class StrictDataTypeWriteCompatibilitySuite extends 
DataTypeWriteCompatibilityBaseSuite {
+  override protected def storeAssignmentPolicy: 
SQLConf.StoreAssignmentPolicy.Value =
+    StoreAssignmentPolicy.STRICT
 
-  private val point2 = StructType(Seq(
+  test("Check atomic types: write allowed only when casting is safe") {
+    atomicTypes.foreach { w =>
+      atomicTypes.foreach { r =>
+        if (Cast.canUpCast(w, r)) {
+          assertAllowed(w, r, "t", s"Should allow writing $w to $r because 
cast is safe")
+
+        } else {
+          assertSingleError(w, r, "t",
+            s"Should not allow writing $w to $r because cast is not safe") { 
err =>
+            assert(err.contains("'t'"), "Should include the field name 
context")
+            assert(err.contains("Cannot safely cast"), "Should identify unsafe 
cast")
+            assert(err.contains(s"$w"), "Should include write type")
+            assert(err.contains(s"$r"), "Should include read type")
+          }
+        }
+      }
+    }
+  }
+
+  test("Check struct types: unsafe casts are not allowed") {
+    assertNumErrors(widerPoint2, point2, "t",
+      "Should fail because types require unsafe casts", 2) { errs =>
+
+      assert(errs(0).contains("'t.x'"), "Should include the nested field name 
context")
+      assert(errs(0).contains("Cannot safely cast"))
+
+      assert(errs(1).contains("'t.y'"), "Should include the nested field name 
context")
+      assert(errs(1).contains("Cannot safely cast"))
+    }
+  }
+
+  test("Check array types: unsafe casts are not allowed") {
+    val arrayOfLong = ArrayType(LongType)
+    val arrayOfInt = ArrayType(IntegerType)
+
+    assertSingleError(arrayOfLong, arrayOfInt, "arr",
+      "Should not allow array of longs to array of ints") { err =>
+      assert(err.contains("'arr.element'"),
+        "Should identify problem with named array's element type")
+      assert(err.contains("Cannot safely cast"))
+    }
+  }
+
+  test("Check map value types: casting Long to Integer is not allowed") {
+    val mapOfLong = MapType(StringType, LongType)
+    val mapOfInt = MapType(StringType, IntegerType)
+
+    assertSingleError(mapOfLong, mapOfInt, "m",
+      "Should not allow map of longs to map of ints") { err =>
+      assert(err.contains("'m.value'"), "Should identify problem with named 
map's value type")
+      assert(err.contains("Cannot safely cast"))
+    }
+  }
+
+  test("Check map key types: unsafe casts are not allowed") {
+    val mapKeyLong = MapType(LongType, StringType)
+    val mapKeyInt = MapType(IntegerType, StringType)
+
+    assertSingleError(mapKeyLong, mapKeyInt, "m",
+      "Should not allow map of long keys to map of int keys") { err =>
+      assert(err.contains("'m.key'"), "Should identify problem with named 
map's key type")
+      assert(err.contains("Cannot safely cast"))
+    }
+  }
+
+  test("Check types with multiple errors") {
+    val readType = StructType(Seq(
+      StructField("a", ArrayType(DoubleType, containsNull = false)),
+      StructField("arr_of_structs", ArrayType(point2, containsNull = false)),
+      StructField("bad_nested_type", ArrayType(StringType)),
+      StructField("m", MapType(LongType, FloatType, valueContainsNull = 
false)),
+      StructField("map_of_structs", MapType(StringType, point3, 
valueContainsNull = false)),
+      StructField("x", IntegerType, nullable = false),
+      StructField("missing1", StringType, nullable = false),
+      StructField("missing2", StringType)
+    ))
+
+    val missingMiddleField = StructType(Seq(
+      StructField("x", FloatType, nullable = false),
+      StructField("z", FloatType, nullable = false)))
+
+    val writeType = StructType(Seq(
+      StructField("a", ArrayType(StringType)),
+      StructField("arr_of_structs", ArrayType(point3)),
+      StructField("bad_nested_type", point3),
+      StructField("m", MapType(DoubleType, DoubleType)),
+      StructField("map_of_structs", MapType(StringType, missingMiddleField)),
+      StructField("y", LongType)
+    ))
+
+    assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) 
{ errs =>
+      assert(errs(0).contains("'top.a.element'"), "Should identify bad type")
+      assert(errs(0).contains("Cannot safely cast"))
+      assert(errs(0).contains("StringType to DoubleType"))
+
+      assert(errs(1).contains("'top.a'"), "Should identify bad type")
+      assert(errs(1).contains("Cannot write nullable elements to array of 
non-nulls"))
+
+      assert(errs(2).contains("'top.arr_of_structs.element'"), "Should 
identify bad type")
+      assert(errs(2).contains("'z'"), "Should identify bad field")
+      assert(errs(2).contains("Cannot write extra fields to struct"))
+
+      assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad 
type")
+      assert(errs(3).contains("Cannot write nullable elements to array of 
non-nulls"))
+
+      assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad 
type")
+      assert(errs(4).contains("is incompatible with"))
+
+      assert(errs(5).contains("'top.m.key'"), "Should identify bad type")
+      assert(errs(5).contains("Cannot safely cast"))
+      assert(errs(5).contains("DoubleType to LongType"))
+
+      assert(errs(6).contains("'top.m.value'"), "Should identify bad type")
+      assert(errs(6).contains("Cannot safely cast"))
+      assert(errs(6).contains("DoubleType to FloatType"))
+
+      assert(errs(7).contains("'top.m'"), "Should identify bad type")
+      assert(errs(7).contains("Cannot write nullable values to map of 
non-nulls"))
+
+      assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify 
bad type")
+      assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name 
mismatch")
+      assert(errs(8).contains("field name does not match"), "Should identify 
name problem")
+
+      assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify 
bad type")
+      assert(errs(9).contains("'z'"), "Should identify missing field")
+      assert(errs(9).contains("missing fields"), "Should detect missing field")
+
+      assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad 
type")
+      assert(errs(10).contains("Cannot write nullable values to map of 
non-nulls"))
+
+      assert(errs(11).contains("'top.x'"), "Should identify bad type")
+      assert(errs(11).contains("Cannot safely cast"))
+      assert(errs(11).contains("LongType to IntegerType"))
+
+      assert(errs(12).contains("'top'"), "Should identify bad type")
+      assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name 
mismatch")
+      assert(errs(12).contains("field name does not match"), "Should identify 
name problem")
+
+      assert(errs(13).contains("'top'"), "Should identify bad type")
+      assert(errs(13).contains("'missing1'"), "Should identify missing field")
+      assert(errs(13).contains("missing fields"), "Should detect missing 
field")
+    }
+  }
+}
+
+class ANSIDataTypeWriteCompatibilitySuite extends 
DataTypeWriteCompatibilityBaseSuite {
+  override protected def storeAssignmentPolicy: 
SQLConf.StoreAssignmentPolicy.Value =
+    StoreAssignmentPolicy.ANSI
+
+  test("Check atomic types: write allowed only when casting is safe") {
+    atomicTypes.foreach { w =>
+      atomicTypes.foreach { r =>
+        if ((w.isInstanceOf[NumericType] && r.isInstanceOf[NumericType]) ||
 
 Review comment:
   I thought we were using the method `canANSIStoreAssign` to test the method 
itself. But it also makes sense to just use `Cast.canANSIStoreAssign(w, r)`, we 
need to guarantee that the method is used accurately.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to