cloud-fan commented on a change in pull request #25997: [SPARK-29326][SQL] ANSI
store assignment policy: throw exception on casting failure
URL: https://github.com/apache/spark/pull/25997#discussion_r331000307
##########
File path:
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
##########
@@ -1189,3 +1041,189 @@ class CastSuite extends SparkFunSuite with
ExpressionEvalHelper {
}
}
}
+
+/**
+ * Test suite for data type casting expression [[Cast]].
+ */
+class CastSuite extends CastSuiteBase {
+ // It is required to set SQLConf.ANSI_ENABLED as true for testing numeric
overflow.
+ override protected def requiredAnsiEnabledForOverflowTestCases: Boolean =
true
+
+ test("cast from int") {
+ checkCast(0, false)
+ checkCast(1, true)
+ checkCast(-5, true)
+ checkCast(1, 1.toByte)
+ checkCast(1, 1.toShort)
+ checkCast(1, 1)
+ checkCast(1, 1.toLong)
+ checkCast(1, 1.0f)
+ checkCast(1, 1.0)
+ checkCast(123, "123")
+
+ checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
+ checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
+ checkEvaluation(cast(123, DecimalType(3, 1)), null)
+ checkEvaluation(cast(123, DecimalType(2, 0)), null)
+ }
+
+ test("cast from long") {
+ checkCast(0L, false)
+ checkCast(1L, true)
+ checkCast(-5L, true)
+ checkCast(1L, 1.toByte)
+ checkCast(1L, 1.toShort)
+ checkCast(1L, 1)
+ checkCast(1L, 1.toLong)
+ checkCast(1L, 1.0f)
+ checkCast(1L, 1.0)
+ checkCast(123L, "123")
+
+ checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
+ checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
+ checkEvaluation(cast(123L, DecimalType(3, 1)), null)
+
+ checkEvaluation(cast(123L, DecimalType(2, 0)), null)
+ }
+
+ test("cast from int 2") {
+ checkEvaluation(cast(1, LongType), 1.toLong)
+ checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong)
+ checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong)
+
+ checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
+ checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
+ checkEvaluation(cast(123, DecimalType(3, 1)), null)
+ checkEvaluation(cast(123, DecimalType(2, 0)), null)
+ }
+
+ test("casting to fixed-precision decimals") {
+ assert(cast(123, DecimalType.USER_DEFAULT).nullable === false)
+ assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable)
+ assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable)
+ assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === false)
+
+ assert(cast(123, DecimalType(2, 1)).nullable)
+ assert(cast(10.03f, DecimalType(2, 1)).nullable)
+ assert(cast(10.03, DecimalType(2, 1)).nullable)
+ assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable)
+
+ assert(cast(123, DecimalType.IntDecimal).nullable === false)
+ assert(cast(10.03f, DecimalType.FloatDecimal).nullable)
+ assert(cast(10.03, DecimalType.DoubleDecimal).nullable)
+ assert(cast(Decimal(10.03), DecimalType(4, 2)).nullable === false)
+ assert(cast(Decimal(10.03), DecimalType(5, 3)).nullable === false)
+
+ assert(cast(Decimal(10.03), DecimalType(3, 1)).nullable)
+ assert(cast(Decimal(10.03), DecimalType(4, 1)).nullable === false)
+ assert(cast(Decimal(9.95), DecimalType(2, 1)).nullable)
+ assert(cast(Decimal(9.95), DecimalType(3, 1)).nullable === false)
+
+ assert(cast(Decimal("1003"), DecimalType(3, -1)).nullable)
+ assert(cast(Decimal("1003"), DecimalType(4, -1)).nullable === false)
+ assert(cast(Decimal("995"), DecimalType(2, -1)).nullable)
+ assert(cast(Decimal("995"), DecimalType(3, -1)).nullable === false)
+
+ assert(cast(true, DecimalType.SYSTEM_DEFAULT).nullable === false)
+ assert(cast(true, DecimalType(1, 1)).nullable)
+
+
+ checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03))
+ checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03))
+ checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(cast(10.03, DecimalType(1, 0)), null)
+ checkEvaluation(cast(10.03, DecimalType(2, 1)), null)
+ checkEvaluation(cast(10.03, DecimalType(3, 2)), null)
+ checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null)
+
+ checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05))
+ checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05))
+ checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1))
+ checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(cast(10.05, DecimalType(1, 0)), null)
+ checkEvaluation(cast(10.05, DecimalType(2, 1)), null)
+ checkEvaluation(cast(10.05, DecimalType(3, 2)), null)
+ checkEvaluation(cast(Decimal(10.05), DecimalType(3, 1)), Decimal(10.1))
+ checkEvaluation(cast(Decimal(10.05), DecimalType(3, 2)), null)
+
+ checkEvaluation(cast(9.95, DecimalType(3, 2)), Decimal(9.95))
+ checkEvaluation(cast(9.95, DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(cast(9.95, DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(cast(9.95, DecimalType(2, 1)), null)
+ checkEvaluation(cast(9.95, DecimalType(1, 0)), null)
+ checkEvaluation(cast(Decimal(9.95), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(cast(Decimal(9.95), DecimalType(1, 0)), null)
+
+ checkEvaluation(cast(-9.95, DecimalType(3, 2)), Decimal(-9.95))
+ checkEvaluation(cast(-9.95, DecimalType(3, 1)), Decimal(-10.0))
+ checkEvaluation(cast(-9.95, DecimalType(2, 0)), Decimal(-10))
+ checkEvaluation(cast(-9.95, DecimalType(2, 1)), null)
+ checkEvaluation(cast(-9.95, DecimalType(1, 0)), null)
+ checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0))
+ checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null)
+
+ checkEvaluation(cast(Decimal("1003"), DecimalType.SYSTEM_DEFAULT),
Decimal(1003))
+ checkEvaluation(cast(Decimal("1003"), DecimalType(4, 0)), Decimal(1003))
+ checkEvaluation(cast(Decimal("1003"), DecimalType(3, -1)), Decimal(1000))
+ checkEvaluation(cast(Decimal("1003"), DecimalType(2, -2)), Decimal(1000))
+ checkEvaluation(cast(Decimal("1003"), DecimalType(1, -2)), null)
+ checkEvaluation(cast(Decimal("1003"), DecimalType(2, -1)), null)
+ checkEvaluation(cast(Decimal("1003"), DecimalType(3, 0)), null)
+
+ checkEvaluation(cast(Decimal("995"), DecimalType(3, 0)), Decimal(995))
+ checkEvaluation(cast(Decimal("995"), DecimalType(3, -1)), Decimal(1000))
+ checkEvaluation(cast(Decimal("995"), DecimalType(2, -2)), Decimal(1000))
+ checkEvaluation(cast(Decimal("995"), DecimalType(2, -1)), null)
+ checkEvaluation(cast(Decimal("995"), DecimalType(1, -2)), null)
+
+ checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null)
+ checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null)
+ checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null)
+ checkEvaluation(cast(1.0f / 0.0f, DecimalType.SYSTEM_DEFAULT), null)
+
+ checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null)
+ checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null)
+ checkEvaluation(cast(Float.NaN, DecimalType(2, 1)), null)
+ checkEvaluation(cast(1.0f / 0.0f, DecimalType(2, 1)), null)
+
+ checkEvaluation(cast(true, DecimalType(2, 1)), Decimal(1))
+ checkEvaluation(cast(true, DecimalType(1, 1)), null)
+ }
+
+ test("SPARK-28470: Cast should honor nullOnOverflow property") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ checkEvaluation(Cast(Literal("134.12"), DecimalType(3, 2)), null)
+ checkEvaluation(
+ Cast(Literal(Timestamp.valueOf("2019-07-25 22:04:36")), DecimalType(3,
2)), null)
+ checkEvaluation(Cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)),
null)
+ checkEvaluation(Cast(Literal(134.12), DecimalType(3, 2)), null)
+ }
+ }
+}
+
+/**
+ * Test suite for data type casting expression [[AnsiCast]].
+ */
+class AnsiCastSuite extends CastSuiteBase {
+ // It is not required to set SQLConf.ANSI_ENABLED as true for testing
numeric overflow.
+ override protected def requiredAnsiEnabledForOverflowTestCases: Boolean =
false
+
+ override def checkEvaluation(
Review comment:
shall we override `cast` instead of `checkEvaluation` and
`checkExceptionInExpression`?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]