gengliangwang commented on a change in pull request #30260:
URL: https://github.com/apache/spark/pull/30260#discussion_r526112623
##########
File path:
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
##########
@@ -930,150 +733,254 @@ abstract class CastSuiteBase extends SparkFunSuite with
ExpressionEvalHelper {
}
}
- test("Throw exception on casting out-of-range value to byte type") {
- withSQLConf(SQLConf.ANSI_ENABLED.key ->
requiredAnsiEnabledForOverflowTestCases.toString) {
- testIntMaxAndMin(ByteType)
- Seq(Byte.MaxValue + 1, Byte.MinValue - 1).foreach { value =>
- checkExceptionInExpression[ArithmeticException](cast(value, ByteType),
"overflow")
- checkExceptionInExpression[ArithmeticException](
- cast(Literal(value * MICROS_PER_SECOND, TimestampType), ByteType),
"overflow")
- checkExceptionInExpression[ArithmeticException](
- cast(Literal(value.toFloat, FloatType), ByteType), "overflow")
- checkExceptionInExpression[ArithmeticException](
- cast(Literal(value.toDouble, DoubleType), ByteType), "overflow")
- }
+ test("ANSI mode: Throw exception on casting out-of-range value to byte
type") {
+ testIntMaxAndMin(ByteType)
+ Seq(Byte.MaxValue + 1, Byte.MinValue - 1).foreach { value =>
+ checkExceptionInExpression[ArithmeticException](cast(value, ByteType),
"overflow")
+ checkExceptionInExpression[ArithmeticException](
+ cast(Literal(value.toFloat, FloatType), ByteType), "overflow")
+ checkExceptionInExpression[ArithmeticException](
+ cast(Literal(value.toDouble, DoubleType), ByteType), "overflow")
+ }
- Seq(Byte.MaxValue, 0.toByte, Byte.MinValue).foreach { value =>
- checkEvaluation(cast(value, ByteType), value)
- checkEvaluation(cast(value.toString, ByteType), value)
- checkEvaluation(cast(Decimal(value.toString), ByteType), value)
- checkEvaluation(cast(Literal(value * MICROS_PER_SECOND,
TimestampType), ByteType), value)
- checkEvaluation(cast(Literal(value.toInt, DateType), ByteType), null)
- checkEvaluation(cast(Literal(value.toFloat, FloatType), ByteType),
value)
- checkEvaluation(cast(Literal(value.toDouble, DoubleType), ByteType),
value)
- }
+ Seq(Byte.MaxValue, 0.toByte, Byte.MinValue).foreach { value =>
+ checkEvaluation(cast(value, ByteType), value)
+ checkEvaluation(cast(value.toString, ByteType), value)
+ checkEvaluation(cast(Decimal(value.toString), ByteType), value)
+ checkEvaluation(cast(Literal(value.toFloat, FloatType), ByteType), value)
+ checkEvaluation(cast(Literal(value.toDouble, DoubleType), ByteType),
value)
}
}
- test("Throw exception on casting out-of-range value to short type") {
- withSQLConf(SQLConf.ANSI_ENABLED.key ->
requiredAnsiEnabledForOverflowTestCases.toString) {
- testIntMaxAndMin(ShortType)
- Seq(Short.MaxValue + 1, Short.MinValue - 1).foreach { value =>
- checkExceptionInExpression[ArithmeticException](cast(value,
ShortType), "overflow")
- checkExceptionInExpression[ArithmeticException](
- cast(Literal(value * MICROS_PER_SECOND, TimestampType), ShortType),
"overflow")
- checkExceptionInExpression[ArithmeticException](
- cast(Literal(value.toFloat, FloatType), ShortType), "overflow")
- checkExceptionInExpression[ArithmeticException](
- cast(Literal(value.toDouble, DoubleType), ShortType), "overflow")
- }
+ test("ANSI mode: Throw exception on casting out-of-range value to short
type") {
+ testIntMaxAndMin(ShortType)
+ Seq(Short.MaxValue + 1, Short.MinValue - 1).foreach { value =>
+ checkExceptionInExpression[ArithmeticException](cast(value, ShortType),
"overflow")
+ checkExceptionInExpression[ArithmeticException](
+ cast(Literal(value.toFloat, FloatType), ShortType), "overflow")
+ checkExceptionInExpression[ArithmeticException](
+ cast(Literal(value.toDouble, DoubleType), ShortType), "overflow")
+ }
- Seq(Short.MaxValue, 0.toShort, Short.MinValue).foreach { value =>
- checkEvaluation(cast(value, ShortType), value)
- checkEvaluation(cast(value.toString, ShortType), value)
- checkEvaluation(cast(Decimal(value.toString), ShortType), value)
- checkEvaluation(cast(Literal(value * MICROS_PER_SECOND,
TimestampType), ShortType), value)
- checkEvaluation(cast(Literal(value.toInt, DateType), ShortType), null)
- checkEvaluation(cast(Literal(value.toFloat, FloatType), ShortType),
value)
- checkEvaluation(cast(Literal(value.toDouble, DoubleType), ShortType),
value)
- }
+ Seq(Short.MaxValue, 0.toShort, Short.MinValue).foreach { value =>
+ checkEvaluation(cast(value, ShortType), value)
+ checkEvaluation(cast(value.toString, ShortType), value)
+ checkEvaluation(cast(Decimal(value.toString), ShortType), value)
+ checkEvaluation(cast(Literal(value.toFloat, FloatType), ShortType),
value)
+ checkEvaluation(cast(Literal(value.toDouble, DoubleType), ShortType),
value)
}
}
- test("Throw exception on casting out-of-range value to int type") {
- withSQLConf(SQLConf.ANSI_ENABLED.key ->
requiredAnsiEnabledForOverflowTestCases.toString) {
- testIntMaxAndMin(IntegerType)
- testLongMaxAndMin(IntegerType)
+ test("ANSI mode: Throw exception on casting out-of-range value to int type")
{
+ testIntMaxAndMin(IntegerType)
+ testLongMaxAndMin(IntegerType)
- Seq(Int.MaxValue, 0, Int.MinValue).foreach { value =>
- checkEvaluation(cast(value, IntegerType), value)
- checkEvaluation(cast(value.toString, IntegerType), value)
- checkEvaluation(cast(Decimal(value.toString), IntegerType), value)
- checkEvaluation(cast(Literal(value * MICROS_PER_SECOND,
TimestampType), IntegerType), value)
- checkEvaluation(cast(Literal(value * 1.0, DoubleType), IntegerType),
value)
- }
- checkEvaluation(cast(Int.MaxValue + 0.9D, IntegerType), Int.MaxValue)
- checkEvaluation(cast(Int.MinValue - 0.9D, IntegerType), Int.MinValue)
+ Seq(Int.MaxValue, 0, Int.MinValue).foreach { value =>
+ checkEvaluation(cast(value, IntegerType), value)
+ checkEvaluation(cast(value.toString, IntegerType), value)
+ checkEvaluation(cast(Decimal(value.toString), IntegerType), value)
+ checkEvaluation(cast(Literal(value * 1.0, DoubleType), IntegerType),
value)
}
+ checkEvaluation(cast(Int.MaxValue + 0.9D, IntegerType), Int.MaxValue)
+ checkEvaluation(cast(Int.MinValue - 0.9D, IntegerType), Int.MinValue)
}
- test("Throw exception on casting out-of-range value to long type") {
- withSQLConf(SQLConf.ANSI_ENABLED.key ->
requiredAnsiEnabledForOverflowTestCases.toString) {
- testLongMaxAndMin(LongType)
+ test("ANSI mode: Throw exception on casting out-of-range value to long
type") {
+ testLongMaxAndMin(LongType)
- Seq(Long.MaxValue, 0, Long.MinValue).foreach { value =>
- checkEvaluation(cast(value, LongType), value)
- checkEvaluation(cast(value.toString, LongType), value)
- checkEvaluation(cast(Decimal(value.toString), LongType), value)
- checkEvaluation(cast(Literal(value, TimestampType), LongType),
- Math.floorDiv(value, MICROS_PER_SECOND))
- }
- checkEvaluation(cast(Long.MaxValue + 0.9F, LongType), Long.MaxValue)
- checkEvaluation(cast(Long.MinValue - 0.9F, LongType), Long.MinValue)
- checkEvaluation(cast(Long.MaxValue + 0.9D, LongType), Long.MaxValue)
- checkEvaluation(cast(Long.MinValue - 0.9D, LongType), Long.MinValue)
+ Seq(Long.MaxValue, 0, Long.MinValue).foreach { value =>
+ checkEvaluation(cast(value, LongType), value)
+ checkEvaluation(cast(value.toString, LongType), value)
+ checkEvaluation(cast(Decimal(value.toString), LongType), value)
}
+ checkEvaluation(cast(Long.MaxValue + 0.9F, LongType), Long.MaxValue)
+ checkEvaluation(cast(Long.MinValue - 0.9F, LongType), Long.MinValue)
+ checkEvaluation(cast(Long.MaxValue + 0.9D, LongType), Long.MaxValue)
+ checkEvaluation(cast(Long.MinValue - 0.9D, LongType), Long.MinValue)
}
-}
-/**
- * Test suite for data type casting expression [[Cast]].
- */
-class CastSuite extends CastSuiteBase {
- // It is required to set SQLConf.ANSI_ENABLED as true for testing numeric
overflow.
- override protected def requiredAnsiEnabledForOverflowTestCases: Boolean =
true
+ test("ANSI mode: Throw exception on casting out-of-range value to decimal
type") {
+ checkExceptionInExpression[ArithmeticException](
+ cast(Literal("134.12"), DecimalType(3, 2)), "cannot be represented")
+ checkExceptionInExpression[ArithmeticException](
+ cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)), "cannot be
represented")
+ checkExceptionInExpression[ArithmeticException](
+ cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented")
+ }
- override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] =
None): CastBase = {
- v match {
- case lit: Expression => Cast(lit, targetType, timeZoneId)
- case _ => Cast(Literal(v), targetType, timeZoneId)
+ test("ANSI mode: disallow type conversions between Numeric types and
Timestamp type") {
+ import DataTypeTestUtils.numericTypes
+ checkInvalidCastFromNumericType(TimestampType)
+ val timestampLiteral = Literal(1L, TimestampType)
+ numericTypes.foreach { numericType =>
+ assert(cast(timestampLiteral,
numericType).checkInputDataTypes().isFailure)
}
}
- test("cast from int") {
- checkCast(0, false)
- checkCast(1, true)
- checkCast(-5, true)
- checkCast(1, 1.toByte)
- checkCast(1, 1.toShort)
- checkCast(1, 1)
- checkCast(1, 1.toLong)
- checkCast(1, 1.0f)
- checkCast(1, 1.0)
- checkCast(123, "123")
-
- checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
- checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
- checkEvaluation(cast(123, DecimalType(3, 1)), null)
- checkEvaluation(cast(123, DecimalType(2, 0)), null)
+ test("ANSI mode: disallow type conversions between Numeric types and Date
type") {
+ import DataTypeTestUtils.numericTypes
+ checkInvalidCastFromNumericType(DateType)
+ val dateLiteral = Literal(1, DateType)
+ numericTypes.foreach { numericType =>
+ assert(cast(dateLiteral, numericType).checkInputDataTypes().isFailure)
+ }
}
- test("cast from long") {
- checkCast(0L, false)
- checkCast(1L, true)
- checkCast(-5L, true)
- checkCast(1L, 1.toByte)
- checkCast(1L, 1.toShort)
- checkCast(1L, 1)
- checkCast(1L, 1.toLong)
- checkCast(1L, 1.0f)
- checkCast(1L, 1.0)
- checkCast(123L, "123")
+ test("ANSI mode: disallow type conversions between Numeric types and Binary
type") {
+ import DataTypeTestUtils.numericTypes
+ checkInvalidCastFromNumericType(BinaryType)
+ val binaryLiteral = Literal(new Array[Byte](1.toByte), BinaryType)
+ numericTypes.foreach { numericType =>
+ assert(cast(binaryLiteral, numericType).checkInputDataTypes().isFailure)
+ }
+ }
- checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
- checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
- checkEvaluation(cast(123L, DecimalType(3, 1)), null)
+ test("ANSI mode: disallow type conversions between Datatime types and
Boolean types") {
+ val timestampLiteral = Literal(1L, TimestampType)
+ assert(cast(timestampLiteral, BooleanType).checkInputDataTypes().isFailure)
+ val dateLiteral = Literal(1, DateType)
+ assert(cast(dateLiteral, BooleanType).checkInputDataTypes().isFailure)
- checkEvaluation(cast(123L, DecimalType(2, 0)), null)
+ val booleanLiteral = Literal(true, BooleanType)
+ assert(cast(booleanLiteral, TimestampType).checkInputDataTypes().isFailure)
+ assert(cast(booleanLiteral, DateType).checkInputDataTypes().isFailure)
}
- test("cast from int 2") {
- checkEvaluation(cast(1, LongType), 1.toLong)
-
- checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong)
- checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong)
+ test("ANSI mode: disallow casting complex types as String type") {
+ assert(cast(Literal.create(Array(1, 2, 3, 4, 5)),
StringType).checkInputDataTypes().isFailure)
+ assert(cast(Literal.create(Map(1 -> "a")),
StringType).checkInputDataTypes().isFailure)
+ assert(cast(Literal.create((1, "a", 0.1)),
StringType).checkInputDataTypes().isFailure)
+ }
- checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
+ test("cast from invalid string to numeric should throw
NumberFormatException") {
+ // cast to IntegerType
+ Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType =>
+ val array = Literal.create(Seq("123", "true", "f", null),
+ ArrayType(StringType, containsNull = true))
+ checkExceptionInExpression[NumberFormatException](
+ cast(array, ArrayType(dataType, containsNull = true)),
+ "invalid input syntax for type numeric: true")
+ checkExceptionInExpression[NumberFormatException](
+ cast("string", dataType), "invalid input syntax for type numeric:
string")
+ checkExceptionInExpression[NumberFormatException](
+ cast("123-string", dataType), "invalid input syntax for type numeric:
123-string")
+ checkExceptionInExpression[NumberFormatException](
+ cast("2020-07-19", dataType), "invalid input syntax for type numeric:
2020-07-19")
+ checkExceptionInExpression[NumberFormatException](
+ cast("1.23", dataType), "invalid input syntax for type numeric: 1.23")
+ }
+
+ Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType =>
+ checkExceptionInExpression[NumberFormatException](
+ cast("string", dataType), "invalid input syntax for type numeric:
string")
+ checkExceptionInExpression[NumberFormatException](
+ cast("123.000.00", dataType), "invalid input syntax for type numeric:
123.000.00")
+ checkExceptionInExpression[NumberFormatException](
+ cast("abc.com", dataType), "invalid input syntax for type numeric:
abc.com")
+ }
+ }
+
+ test("Fast fail for cast string type to decimal type in ansi mode") {
+ checkEvaluation(cast("12345678901234567890123456789012345678",
DecimalType(38, 0)),
+ Decimal("12345678901234567890123456789012345678"))
+ checkExceptionInExpression[ArithmeticException](
+ cast("123456789012345678901234567890123456789", DecimalType(38, 0)),
+ "out of decimal type range")
+ checkExceptionInExpression[ArithmeticException](
+ cast("12345678901234567890123456789012345678", DecimalType(38, 1)),
+ "cannot be represented as Decimal(38, 1)")
+
+ checkEvaluation(cast("0.00000000000000000000000000000000000001",
DecimalType(38, 0)),
+ Decimal("0"))
+ checkEvaluation(cast("0.00000000000000000000000000000000000000000001",
DecimalType(38, 0)),
+ Decimal("0"))
+ checkEvaluation(cast("0.00000000000000000000000000000000000001",
DecimalType(38, 18)),
+ Decimal("0E-18"))
+ checkEvaluation(cast("6E-120", DecimalType(38, 0)),
+ Decimal("0"))
+
+ checkEvaluation(cast("6E+37", DecimalType(38, 0)),
+ Decimal("60000000000000000000000000000000000000"))
+ checkExceptionInExpression[ArithmeticException](
+ cast("6E+38", DecimalType(38, 0)),
+ "out of decimal type range")
+ checkExceptionInExpression[ArithmeticException](
+ cast("6E+37", DecimalType(38, 1)),
+ "cannot be represented as Decimal(38, 1)")
+
+ checkExceptionInExpression[NumberFormatException](
+ cast("abcd", DecimalType(38, 1)),
+ "invalid input syntax for type numeric")
+ }
+}
+
+/**
+ * Test suite for data type casting expression [[Cast]].
+ */
+class CastSuite extends CastSuiteBase {
+
+ override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] =
None): CastBase = {
+ v match {
+ case lit: Expression => Cast(lit, targetType, timeZoneId)
+ case _ => Cast(Literal(v), targetType, timeZoneId)
+ }
+ }
+
+ test("null cast II") {
+ import DataTypeTestUtils._
+
+ checkNullCast(DateType, BooleanType)
+ checkNullCast(TimestampType, BooleanType)
+ checkNullCast(BooleanType, TimestampType)
+ numericTypes.foreach(dt => checkNullCast(dt, TimestampType))
+ numericTypes.foreach(dt => checkNullCast(TimestampType, dt))
+ numericTypes.foreach(dt => checkNullCast(DateType, dt))
+ }
+
+ test("cast from int") {
+ checkCast(0, false)
+ checkCast(1, true)
+ checkCast(-5, true)
+ checkCast(1, 1.toByte)
+ checkCast(1, 1.toShort)
Review comment:
Good catch. I have updated the code.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]