sunxiaoguang commented on code in PR #49453:
URL: https://github.com/apache/spark/pull/49453#discussion_r1917621813
##########
connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala:
##########
@@ -241,6 +241,56 @@ class MySQLIntegrationSuite extends
DockerJDBCIntegrationV2Suite with V2JDBCTest
assert(rows10(0).getString(0) === "amy")
assert(rows10(1).getString(0) === "alex")
}
+
+ test("SPARK-50793: MySQL JDBC Connector failed to cast some types") {
+ val tableName = catalogName + ".test_cast_function"
+ withTable(tableName) {
+ val stringValue = "0"
+ val stringLiteral = "'0'"
+ val longValue = 0L
+ val binaryValue = Array[Byte](0x30)
+ val binaryLiteral = "x'30'"
+ val doubleValue = 0.0
+ val doubleLiteral = "0.0"
+ // CREATE table to use types defined in Spark SQL
+ sql(s"""CREATE TABLE $tableName (
+ string_col STRING,
+ long_col LONG,
+ binary_col BINARY,
+ double_col DOUBLE
+ )""")
+ sql(
+ s"INSERT INTO $tableName VALUES($stringLiteral, $longValue,
$binaryLiteral, $doubleValue)")
+
+ def testCast(castType: String, sourceCol: String, targetCol: String,
+ sourceValue: Any, targetValue: Any): Unit = {
+ val sql =
+ s"""SELECT $sourceCol, CAST($sourceCol AS $castType) FROM $tableName
+ |WHERE CAST($sourceCol AS $castType) = $targetCol""".stripMargin
+ val df = spark.sql(sql)
+ checkFilterPushed(df)
+ val rows = df.collect()
Review Comment:
After taking a look at the checkAnswer implementation, it is using `==` to
compare types other than those types need special handling. This means the
check may skip actual types, so let's double check if things like this is
acceptable for tests related to type cast
```scala
val i = 1
val s = 1.toShort
val l = 1L
println(i == s)
println(i == l)
println(s ==l)
The output for these lines of code is:
true
true
true
```
FYI: This is the implementation of compare which checkAnswer finally uses
```scala
def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match {
case (null, null) => true
case (null, _) => false
case (_, null) => false
case (a: Array[_], b: Array[_]) =>
a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)}
case (a: Map[_, _], b: Map[_, _]) =>
a.size == b.size && a.keys.forall { aKey =>
b.keys.find(bKey => compare(aKey, bKey)).exists(bKey =>
compare(a(aKey), b(bKey)))
}
case (a: Iterable[_], b: Iterable[_]) =>
a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)}
case (a: Product, b: Product) =>
compare(a.productIterator.toSeq, b.productIterator.toSeq)
case (a: Row, b: Row) =>
compare(a.toSeq, b.toSeq)
// 0.0 == -0.0, turn float/double to bits before comparison, to
distinguish 0.0 and -0.0.
case (a: Double, b: Double) =>
java.lang.Double.doubleToRawLongBits(a) ==
java.lang.Double.doubleToRawLongBits(b)
case (a: Float, b: Float) =>
java.lang.Float.floatToRawIntBits(a) ==
java.lang.Float.floatToRawIntBits(b)
case (a, b) => a == b
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]