gengliangwang commented on a change in pull request #30260:
URL: https://github.com/apache/spark/pull/30260#discussion_r526112623



##########
File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
##########
@@ -930,150 +733,254 @@ abstract class CastSuiteBase extends SparkFunSuite with 
ExpressionEvalHelper {
     }
   }
 
-  test("Throw exception on casting out-of-range value to byte type") {
-    withSQLConf(SQLConf.ANSI_ENABLED.key -> 
requiredAnsiEnabledForOverflowTestCases.toString) {
-      testIntMaxAndMin(ByteType)
-      Seq(Byte.MaxValue + 1, Byte.MinValue - 1).foreach { value =>
-        checkExceptionInExpression[ArithmeticException](cast(value, ByteType), 
"overflow")
-        checkExceptionInExpression[ArithmeticException](
-          cast(Literal(value * MICROS_PER_SECOND, TimestampType), ByteType), 
"overflow")
-        checkExceptionInExpression[ArithmeticException](
-          cast(Literal(value.toFloat, FloatType), ByteType), "overflow")
-        checkExceptionInExpression[ArithmeticException](
-          cast(Literal(value.toDouble, DoubleType), ByteType), "overflow")
-      }
+  test("ANSI mode: Throw exception on casting out-of-range value to byte 
type") {
+    testIntMaxAndMin(ByteType)
+    Seq(Byte.MaxValue + 1, Byte.MinValue - 1).foreach { value =>
+      checkExceptionInExpression[ArithmeticException](cast(value, ByteType), 
"overflow")
+      checkExceptionInExpression[ArithmeticException](
+        cast(Literal(value.toFloat, FloatType), ByteType), "overflow")
+      checkExceptionInExpression[ArithmeticException](
+        cast(Literal(value.toDouble, DoubleType), ByteType), "overflow")
+    }
 
-      Seq(Byte.MaxValue, 0.toByte, Byte.MinValue).foreach { value =>
-        checkEvaluation(cast(value, ByteType), value)
-        checkEvaluation(cast(value.toString, ByteType), value)
-        checkEvaluation(cast(Decimal(value.toString), ByteType), value)
-        checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, 
TimestampType), ByteType), value)
-        checkEvaluation(cast(Literal(value.toInt, DateType), ByteType), null)
-        checkEvaluation(cast(Literal(value.toFloat, FloatType), ByteType), 
value)
-        checkEvaluation(cast(Literal(value.toDouble, DoubleType), ByteType), 
value)
-      }
+    Seq(Byte.MaxValue, 0.toByte, Byte.MinValue).foreach { value =>
+      checkEvaluation(cast(value, ByteType), value)
+      checkEvaluation(cast(value.toString, ByteType), value)
+      checkEvaluation(cast(Decimal(value.toString), ByteType), value)
+      checkEvaluation(cast(Literal(value.toFloat, FloatType), ByteType), value)
+      checkEvaluation(cast(Literal(value.toDouble, DoubleType), ByteType), 
value)
     }
   }
 
-  test("Throw exception on casting out-of-range value to short type") {
-    withSQLConf(SQLConf.ANSI_ENABLED.key -> 
requiredAnsiEnabledForOverflowTestCases.toString) {
-      testIntMaxAndMin(ShortType)
-      Seq(Short.MaxValue + 1, Short.MinValue - 1).foreach { value =>
-        checkExceptionInExpression[ArithmeticException](cast(value, 
ShortType), "overflow")
-        checkExceptionInExpression[ArithmeticException](
-          cast(Literal(value * MICROS_PER_SECOND, TimestampType), ShortType), 
"overflow")
-        checkExceptionInExpression[ArithmeticException](
-          cast(Literal(value.toFloat, FloatType), ShortType), "overflow")
-        checkExceptionInExpression[ArithmeticException](
-          cast(Literal(value.toDouble, DoubleType), ShortType), "overflow")
-      }
+  test("ANSI mode: Throw exception on casting out-of-range value to short 
type") {
+    testIntMaxAndMin(ShortType)
+    Seq(Short.MaxValue + 1, Short.MinValue - 1).foreach { value =>
+      checkExceptionInExpression[ArithmeticException](cast(value, ShortType), 
"overflow")
+      checkExceptionInExpression[ArithmeticException](
+        cast(Literal(value.toFloat, FloatType), ShortType), "overflow")
+      checkExceptionInExpression[ArithmeticException](
+        cast(Literal(value.toDouble, DoubleType), ShortType), "overflow")
+    }
 
-      Seq(Short.MaxValue, 0.toShort, Short.MinValue).foreach { value =>
-        checkEvaluation(cast(value, ShortType), value)
-        checkEvaluation(cast(value.toString, ShortType), value)
-        checkEvaluation(cast(Decimal(value.toString), ShortType), value)
-        checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, 
TimestampType), ShortType), value)
-        checkEvaluation(cast(Literal(value.toInt, DateType), ShortType), null)
-        checkEvaluation(cast(Literal(value.toFloat, FloatType), ShortType), 
value)
-        checkEvaluation(cast(Literal(value.toDouble, DoubleType), ShortType), 
value)
-      }
+    Seq(Short.MaxValue, 0.toShort, Short.MinValue).foreach { value =>
+      checkEvaluation(cast(value, ShortType), value)
+      checkEvaluation(cast(value.toString, ShortType), value)
+      checkEvaluation(cast(Decimal(value.toString), ShortType), value)
+      checkEvaluation(cast(Literal(value.toFloat, FloatType), ShortType), 
value)
+      checkEvaluation(cast(Literal(value.toDouble, DoubleType), ShortType), 
value)
     }
   }
 
-  test("Throw exception on casting out-of-range value to int type") {
-    withSQLConf(SQLConf.ANSI_ENABLED.key -> 
requiredAnsiEnabledForOverflowTestCases.toString) {
-      testIntMaxAndMin(IntegerType)
-      testLongMaxAndMin(IntegerType)
+  test("ANSI mode: Throw exception on casting out-of-range value to int type") 
{
+    testIntMaxAndMin(IntegerType)
+    testLongMaxAndMin(IntegerType)
 
-      Seq(Int.MaxValue, 0, Int.MinValue).foreach { value =>
-        checkEvaluation(cast(value, IntegerType), value)
-        checkEvaluation(cast(value.toString, IntegerType), value)
-        checkEvaluation(cast(Decimal(value.toString), IntegerType), value)
-        checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, 
TimestampType), IntegerType), value)
-        checkEvaluation(cast(Literal(value * 1.0, DoubleType), IntegerType), 
value)
-      }
-      checkEvaluation(cast(Int.MaxValue + 0.9D, IntegerType), Int.MaxValue)
-      checkEvaluation(cast(Int.MinValue - 0.9D, IntegerType), Int.MinValue)
+    Seq(Int.MaxValue, 0, Int.MinValue).foreach { value =>
+      checkEvaluation(cast(value, IntegerType), value)
+      checkEvaluation(cast(value.toString, IntegerType), value)
+      checkEvaluation(cast(Decimal(value.toString), IntegerType), value)
+      checkEvaluation(cast(Literal(value * 1.0, DoubleType), IntegerType), 
value)
     }
+    checkEvaluation(cast(Int.MaxValue + 0.9D, IntegerType), Int.MaxValue)
+    checkEvaluation(cast(Int.MinValue - 0.9D, IntegerType), Int.MinValue)
   }
 
-  test("Throw exception on casting out-of-range value to long type") {
-    withSQLConf(SQLConf.ANSI_ENABLED.key -> 
requiredAnsiEnabledForOverflowTestCases.toString) {
-      testLongMaxAndMin(LongType)
+  test("ANSI mode: Throw exception on casting out-of-range value to long 
type") {
+    testLongMaxAndMin(LongType)
 
-      Seq(Long.MaxValue, 0, Long.MinValue).foreach { value =>
-        checkEvaluation(cast(value, LongType), value)
-        checkEvaluation(cast(value.toString, LongType), value)
-        checkEvaluation(cast(Decimal(value.toString), LongType), value)
-        checkEvaluation(cast(Literal(value, TimestampType), LongType),
-          Math.floorDiv(value, MICROS_PER_SECOND))
-      }
-      checkEvaluation(cast(Long.MaxValue + 0.9F, LongType), Long.MaxValue)
-      checkEvaluation(cast(Long.MinValue - 0.9F, LongType), Long.MinValue)
-      checkEvaluation(cast(Long.MaxValue + 0.9D, LongType), Long.MaxValue)
-      checkEvaluation(cast(Long.MinValue - 0.9D, LongType), Long.MinValue)
+    Seq(Long.MaxValue, 0, Long.MinValue).foreach { value =>
+      checkEvaluation(cast(value, LongType), value)
+      checkEvaluation(cast(value.toString, LongType), value)
+      checkEvaluation(cast(Decimal(value.toString), LongType), value)
     }
+    checkEvaluation(cast(Long.MaxValue + 0.9F, LongType), Long.MaxValue)
+    checkEvaluation(cast(Long.MinValue - 0.9F, LongType), Long.MinValue)
+    checkEvaluation(cast(Long.MaxValue + 0.9D, LongType), Long.MaxValue)
+    checkEvaluation(cast(Long.MinValue - 0.9D, LongType), Long.MinValue)
   }
-}
 
-/**
- * Test suite for data type casting expression [[Cast]].
- */
-class CastSuite extends CastSuiteBase {
-  // It is required to set SQLConf.ANSI_ENABLED as true for testing numeric 
overflow.
-  override protected def requiredAnsiEnabledForOverflowTestCases: Boolean = 
true
+  test("ANSI mode: Throw exception on casting out-of-range value to decimal 
type") {
+    checkExceptionInExpression[ArithmeticException](
+      cast(Literal("134.12"), DecimalType(3, 2)), "cannot be represented")
+    checkExceptionInExpression[ArithmeticException](
+      cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)), "cannot be 
represented")
+    checkExceptionInExpression[ArithmeticException](
+      cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented")
+  }
 
-  override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = 
None): CastBase = {
-    v match {
-      case lit: Expression => Cast(lit, targetType, timeZoneId)
-      case _ => Cast(Literal(v), targetType, timeZoneId)
+  test("ANSI mode: disallow type conversions between Numeric types and 
Timestamp type") {
+    import DataTypeTestUtils.numericTypes
+    checkInvalidCastFromNumericType(TimestampType)
+    val timestampLiteral = Literal(1L, TimestampType)
+    numericTypes.foreach { numericType =>
+      assert(cast(timestampLiteral, 
numericType).checkInputDataTypes().isFailure)
     }
   }
 
-  test("cast from int") {
-    checkCast(0, false)
-    checkCast(1, true)
-    checkCast(-5, true)
-    checkCast(1, 1.toByte)
-    checkCast(1, 1.toShort)
-    checkCast(1, 1)
-    checkCast(1, 1.toLong)
-    checkCast(1, 1.0f)
-    checkCast(1, 1.0)
-    checkCast(123, "123")
-
-    checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
-    checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
-    checkEvaluation(cast(123, DecimalType(3, 1)), null)
-    checkEvaluation(cast(123, DecimalType(2, 0)), null)
+  test("ANSI mode: disallow type conversions between Numeric types and Date 
type") {
+    import DataTypeTestUtils.numericTypes
+    checkInvalidCastFromNumericType(DateType)
+    val dateLiteral = Literal(1, DateType)
+    numericTypes.foreach { numericType =>
+      assert(cast(dateLiteral, numericType).checkInputDataTypes().isFailure)
+    }
   }
 
-  test("cast from long") {
-    checkCast(0L, false)
-    checkCast(1L, true)
-    checkCast(-5L, true)
-    checkCast(1L, 1.toByte)
-    checkCast(1L, 1.toShort)
-    checkCast(1L, 1)
-    checkCast(1L, 1.toLong)
-    checkCast(1L, 1.0f)
-    checkCast(1L, 1.0)
-    checkCast(123L, "123")
+  test("ANSI mode: disallow type conversions between Numeric types and Binary 
type") {
+    import DataTypeTestUtils.numericTypes
+    checkInvalidCastFromNumericType(BinaryType)
+    val binaryLiteral = Literal(new Array[Byte](1.toByte), BinaryType)
+    numericTypes.foreach { numericType =>
+      assert(cast(binaryLiteral, numericType).checkInputDataTypes().isFailure)
+    }
+  }
 
-    checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
-    checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
-    checkEvaluation(cast(123L, DecimalType(3, 1)), null)
+  test("ANSI mode: disallow type conversions between Datatime types and 
Boolean types") {
+    val timestampLiteral = Literal(1L, TimestampType)
+    assert(cast(timestampLiteral, BooleanType).checkInputDataTypes().isFailure)
+    val dateLiteral = Literal(1, DateType)
+    assert(cast(dateLiteral, BooleanType).checkInputDataTypes().isFailure)
 
-    checkEvaluation(cast(123L, DecimalType(2, 0)), null)
+    val booleanLiteral = Literal(true, BooleanType)
+    assert(cast(booleanLiteral, TimestampType).checkInputDataTypes().isFailure)
+    assert(cast(booleanLiteral, DateType).checkInputDataTypes().isFailure)
   }
 
-  test("cast from int 2") {
-    checkEvaluation(cast(1, LongType), 1.toLong)
-
-    checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong)
-    checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong)
+  test("ANSI mode: disallow casting complex types as String type") {
+    assert(cast(Literal.create(Array(1, 2, 3, 4, 5)), 
StringType).checkInputDataTypes().isFailure)
+    assert(cast(Literal.create(Map(1 -> "a")), 
StringType).checkInputDataTypes().isFailure)
+    assert(cast(Literal.create((1, "a", 0.1)), 
StringType).checkInputDataTypes().isFailure)
+  }
 
-    checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
+  test("cast from invalid string to numeric should throw 
NumberFormatException") {
+    // cast to IntegerType
+    Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType =>
+      val array = Literal.create(Seq("123", "true", "f", null),
+        ArrayType(StringType, containsNull = true))
+      checkExceptionInExpression[NumberFormatException](
+        cast(array, ArrayType(dataType, containsNull = true)),
+        "invalid input syntax for type numeric: true")
+      checkExceptionInExpression[NumberFormatException](
+        cast("string", dataType), "invalid input syntax for type numeric: 
string")
+      checkExceptionInExpression[NumberFormatException](
+        cast("123-string", dataType), "invalid input syntax for type numeric: 
123-string")
+      checkExceptionInExpression[NumberFormatException](
+        cast("2020-07-19", dataType), "invalid input syntax for type numeric: 
2020-07-19")
+      checkExceptionInExpression[NumberFormatException](
+        cast("1.23", dataType), "invalid input syntax for type numeric: 1.23")
+    }
+
+    Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType =>
+      checkExceptionInExpression[NumberFormatException](
+        cast("string", dataType), "invalid input syntax for type numeric: 
string")
+      checkExceptionInExpression[NumberFormatException](
+        cast("123.000.00", dataType), "invalid input syntax for type numeric: 
123.000.00")
+      checkExceptionInExpression[NumberFormatException](
+        cast("abc.com", dataType), "invalid input syntax for type numeric: 
abc.com")
+    }
+  }
+
+  test("Fast fail for cast string type to decimal type in ansi mode") {
+    checkEvaluation(cast("12345678901234567890123456789012345678", 
DecimalType(38, 0)),
+      Decimal("12345678901234567890123456789012345678"))
+    checkExceptionInExpression[ArithmeticException](
+      cast("123456789012345678901234567890123456789", DecimalType(38, 0)),
+      "out of decimal type range")
+    checkExceptionInExpression[ArithmeticException](
+      cast("12345678901234567890123456789012345678", DecimalType(38, 1)),
+      "cannot be represented as Decimal(38, 1)")
+
+    checkEvaluation(cast("0.00000000000000000000000000000000000001", 
DecimalType(38, 0)),
+      Decimal("0"))
+    checkEvaluation(cast("0.00000000000000000000000000000000000000000001", 
DecimalType(38, 0)),
+      Decimal("0"))
+    checkEvaluation(cast("0.00000000000000000000000000000000000001", 
DecimalType(38, 18)),
+      Decimal("0E-18"))
+    checkEvaluation(cast("6E-120", DecimalType(38, 0)),
+      Decimal("0"))
+
+    checkEvaluation(cast("6E+37", DecimalType(38, 0)),
+      Decimal("60000000000000000000000000000000000000"))
+    checkExceptionInExpression[ArithmeticException](
+      cast("6E+38", DecimalType(38, 0)),
+      "out of decimal type range")
+    checkExceptionInExpression[ArithmeticException](
+      cast("6E+37", DecimalType(38, 1)),
+      "cannot be represented as Decimal(38, 1)")
+
+    checkExceptionInExpression[NumberFormatException](
+      cast("abcd", DecimalType(38, 1)),
+      "invalid input syntax for type numeric")
+  }
+}
+
+/**
+ * Test suite for data type casting expression [[Cast]].
+ */
+class CastSuite extends CastSuiteBase {
+
+  override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = 
None): CastBase = {
+    v match {
+      case lit: Expression => Cast(lit, targetType, timeZoneId)
+      case _ => Cast(Literal(v), targetType, timeZoneId)
+    }
+  }
+
+  test("null cast II") {
+    import DataTypeTestUtils._
+
+    checkNullCast(DateType, BooleanType)
+    checkNullCast(TimestampType, BooleanType)
+    checkNullCast(BooleanType, TimestampType)
+    numericTypes.foreach(dt => checkNullCast(dt, TimestampType))
+    numericTypes.foreach(dt => checkNullCast(TimestampType, dt))
+    numericTypes.foreach(dt => checkNullCast(DateType, dt))
+  }
+
+  test("cast from int") {
+    checkCast(0, false)
+    checkCast(1, true)
+    checkCast(-5, true)
+    checkCast(1, 1.toByte)
+    checkCast(1, 1.toShort)

Review comment:
       Good catch. I have updated the code.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to