This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 26af020 chore: Add more cast tests and improve test framework (#351)
26af020 is described below
commit 26af0202ddbc9780022f870466b8499c86f50b62
Author: Andy Grove <[email protected]>
AuthorDate: Tue Apr 30 13:23:19 2024 -0600
chore: Add more cast tests and improve test framework (#351)
---
.../scala/org/apache/comet/CometCastSuite.scala | 611 +++++++++++++++++++--
.../scala/org/apache/spark/sql/CometTestBase.scala | 7 +-
2 files changed, 575 insertions(+), 43 deletions(-)
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 669a855..500b8f8 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -24,6 +24,7 @@ import java.io.File
import scala.util.Random
import org.apache.spark.sql.{CometTestBase, DataFrame, SaveMode}
+import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
@@ -43,65 +44,438 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
private val datePattern = "0123456789/" + whitespaceChars
private val timestampPattern = "0123456789/:T" + whitespaceChars
- ignore("cast long to short") {
- castTest(generateLongs, DataTypes.ShortType)
+ test("all valid cast combinations covered") {
+ val names = testNames
+
+ def assertTestsExist(fromTypes: Seq[DataType], toTypes: Seq[DataType]):
Unit = {
+ for (fromType <- fromTypes) {
+ for (toType <- toTypes) {
+ val expectedTestName = s"cast $fromType to $toType"
+ val testExists = names.contains(expectedTestName)
+ if (Cast.canCast(fromType, toType)) {
+ if (fromType == toType) {
+ if (testExists) {
+ fail(s"Found redundant test for no-op cast: $expectedTestName")
+ }
+ } else if (!testExists) {
+ fail(s"Missing test: $expectedTestName")
+ }
+ } else if (testExists) {
+ fail(s"Found test for cast that Spark does not support:
$expectedTestName")
+ }
+ }
+ }
+ }
+
+ // make sure we have tests for all combinations of our supported types
+ val supportedTypes =
+ Seq(
+ DataTypes.BooleanType,
+ DataTypes.ByteType,
+ DataTypes.ShortType,
+ DataTypes.IntegerType,
+ DataTypes.LongType,
+ DataTypes.FloatType,
+ DataTypes.DoubleType,
+ DataTypes.createDecimalType(10, 2),
+ DataTypes.StringType,
+ DataTypes.DateType,
+ DataTypes.TimestampType)
+ // TODO add DataTypes.TimestampNTZType for Spark 3.4 and later
+ assertTestsExist(supportedTypes, supportedTypes)
+ }
+
+ // CAST from BooleanType
+
+ ignore("cast BooleanType to ByteType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateBools(), DataTypes.ByteType)
+ }
+
+ ignore("cast BooleanType to ShortType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateBools(), DataTypes.ShortType)
+ }
+
+ test("cast BooleanType to IntegerType") {
+ castTest(generateBools(), DataTypes.IntegerType)
+ }
+
+ test("cast BooleanType to LongType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateBools(), DataTypes.LongType)
+ }
+
+ test("cast BooleanType to FloatType") {
+ castTest(generateBools(), DataTypes.FloatType)
+ }
+
+ test("cast BooleanType to DoubleType") {
+ castTest(generateBools(), DataTypes.DoubleType)
+ }
+
+ ignore("cast BooleanType to DecimalType(10,2)") {
+ // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] -1117686336
cannot be represented as Decimal(10, 2)
+ castTest(generateBools(), DataTypes.createDecimalType(10, 2))
+ }
+
+ test("cast BooleanType to StringType") {
+ castTest(generateBools(), DataTypes.StringType)
+ }
+
+ ignore("cast BooleanType to TimestampType") {
+ // Arrow error: Cast error: Casting from Boolean to Timestamp(Microsecond,
Some("UTC")) not supported
+ castTest(generateBools(), DataTypes.TimestampType)
+ }
+
+ // CAST from ByteType
+
+ test("cast ByteType to BooleanType") {
+ castTest(generateBytes(), DataTypes.BooleanType)
+ }
+
+ ignore("cast ByteType to ShortType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateBytes(), DataTypes.ShortType)
+ }
+
+ test("cast ByteType to IntegerType") {
+ castTest(generateBytes(), DataTypes.IntegerType)
+ }
+
+ test("cast ByteType to LongType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateBytes(), DataTypes.LongType)
+ }
+
+ test("cast ByteType to FloatType") {
+ castTest(generateBytes(), DataTypes.FloatType)
+ }
+
+ test("cast ByteType to DoubleType") {
+ castTest(generateBytes(), DataTypes.DoubleType)
+ }
+
+ ignore("cast ByteType to DecimalType(10,2)") {
+ // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] -1117686336
cannot be represented as Decimal(10, 2)
+ castTest(generateBytes(), DataTypes.createDecimalType(10, 2))
+ }
+
+ test("cast ByteType to StringType") {
+ castTest(generateBytes(), DataTypes.StringType)
+ }
+
+ ignore("cast ByteType to TimestampType") {
+ // input: -1, expected: 1969-12-31 15:59:59.0, actual: 1969-12-31
15:59:59.999999
+ castTest(generateBytes(), DataTypes.TimestampType)
+ }
+
+ // CAST from ShortType
+
+ test("cast ShortType to BooleanType") {
+ castTest(generateShorts(), DataTypes.BooleanType)
+ }
+
+ ignore("cast ShortType to ByteType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateShorts(), DataTypes.ByteType)
+ }
+
+ test("cast ShortType to IntegerType") {
+ castTest(generateShorts(), DataTypes.IntegerType)
+ }
+
+ test("cast ShortType to LongType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateShorts(), DataTypes.LongType)
+ }
+
+ test("cast ShortType to FloatType") {
+ castTest(generateShorts(), DataTypes.FloatType)
+ }
+
+ test("cast ShortType to DoubleType") {
+ castTest(generateShorts(), DataTypes.DoubleType)
+ }
+
+ ignore("cast ShortType to DecimalType(10,2)") {
+ // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] -1117686336
cannot be represented as Decimal(10, 2)
+ castTest(generateShorts(), DataTypes.createDecimalType(10, 2))
+ }
+
+ test("cast ShortType to StringType") {
+ castTest(generateShorts(), DataTypes.StringType)
+ }
+
+ ignore("cast ShortType to TimestampType") {
+ // input: -1003, expected: 1969-12-31 15:43:17.0, actual: 1969-12-31
15:59:59.998997
+ castTest(generateShorts(), DataTypes.TimestampType)
+ }
+
+ // CAST from integer
+
+ test("cast IntegerType to BooleanType") {
+ castTest(generateInts(), DataTypes.BooleanType)
+ }
+
+ ignore("cast IntegerType to ByteType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateInts(), DataTypes.ByteType)
+ }
+
+ ignore("cast IntegerType to ShortType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateInts(), DataTypes.ShortType)
+ }
+
+ test("cast IntegerType to LongType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateInts(), DataTypes.LongType)
+ }
+
+ test("cast IntegerType to FloatType") {
+ castTest(generateInts(), DataTypes.FloatType)
+ }
+
+ test("cast IntegerType to DoubleType") {
+ castTest(generateInts(), DataTypes.DoubleType)
+ }
+
+ ignore("cast IntegerType to DecimalType(10,2)") {
+ // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] -1117686336
cannot be represented as Decimal(10, 2)
+ castTest(generateInts(), DataTypes.createDecimalType(10, 2))
+ }
+
+ test("cast IntegerType to StringType") {
+ castTest(generateInts(), DataTypes.StringType)
+ }
+
+ ignore("cast IntegerType to TimestampType") {
+ // inputL -1000479329, expected: 1938-04-19 01:04:31.0, actual: 1969-12-31
15:43:19.520671
+ castTest(generateInts(), DataTypes.TimestampType)
+ }
+
+ // CAST from LongType
+
+ test("cast LongType to BooleanType") {
+ castTest(generateLongs(), DataTypes.BooleanType)
+ }
+
+ ignore("cast LongType to ByteType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateLongs(), DataTypes.ByteType)
+ }
+
+ ignore("cast LongType to ShortType") {
+ // https://github.com/apache/datafusion-comet/issues/311
+ castTest(generateLongs(), DataTypes.ShortType)
+ }
+
+ ignore("cast LongType to IntegerType") {
+ castTest(generateLongs(), DataTypes.IntegerType)
+ }
+
+ test("cast LongType to FloatType") {
+ castTest(generateLongs(), DataTypes.FloatType)
+ }
+
+ test("cast LongType to DoubleType") {
+ castTest(generateLongs(), DataTypes.DoubleType)
+ }
+
+ ignore("cast LongType to DecimalType(10,2)") {
+ // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] -1117686336
cannot be represented as Decimal(10, 2)
+ castTest(generateLongs(), DataTypes.createDecimalType(10, 2))
+ }
+
+ test("cast LongType to StringType") {
+ castTest(generateLongs(), DataTypes.StringType)
+ }
+
+ ignore("cast LongType to TimestampType") {
+ // java.lang.ArithmeticException: long overflow
+ castTest(generateLongs(), DataTypes.TimestampType)
+ }
+
+ // CAST from FloatType
+
+ ignore("cast FloatType to BooleanType") {
+ castTest(generateFloats(), DataTypes.BooleanType)
+ }
+
+ ignore("cast FloatType to ByteType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateFloats(), DataTypes.ByteType)
+ }
+
+ ignore("cast FloatType to ShortType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateFloats(), DataTypes.ShortType)
}
- ignore("cast float to bool") {
- castTest(generateFloats, DataTypes.BooleanType)
+ ignore("cast FloatType to IntegerType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateFloats(), DataTypes.IntegerType)
}
- ignore("cast float to int") {
- castTest(generateFloats, DataTypes.IntegerType)
+ ignore("cast FloatType to LongType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateFloats(), DataTypes.LongType)
}
- ignore("cast float to string") {
- castTest(generateFloats, DataTypes.StringType)
+ ignore("cast FloatType to DoubleType") {
+ // fails due to incompatible sort order for 0.0 and -0.0
+ castTest(generateFloats(), DataTypes.DoubleType)
}
- test("cast string to bool") {
+ ignore("cast FloatType to DecimalType(10,2)") {
+ // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE]
+ castTest(generateFloats(), DataTypes.createDecimalType(10, 2))
+ }
+
+ ignore("cast FloatType to StringType") {
+ // https://github.com/apache/datafusion-comet/issues/312
+ castTest(generateFloats(), DataTypes.StringType)
+ }
+
+ ignore("cast FloatType to TimestampType") {
+ // https://github.com/apache/datafusion-comet/issues/312
+ castTest(generateFloats(), DataTypes.TimestampType)
+ }
+
+ // CAST from DoubleType
+
+ ignore("cast DoubleType to BooleanType") {
+ // fails due to incompatible sort order for 0.0 and -0.0
+ castTest(generateDoubles(), DataTypes.BooleanType)
+ }
+
+ ignore("cast DoubleType to ByteType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateDoubles(), DataTypes.ByteType)
+ }
+
+ ignore("cast DoubleType to ShortType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateDoubles(), DataTypes.ShortType)
+ }
+
+ ignore("cast DoubleType to IntegerType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateDoubles(), DataTypes.IntegerType)
+ }
+
+ ignore("cast DoubleType to LongType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateDoubles(), DataTypes.LongType)
+ }
+
+ ignore("cast DoubleType to FloatType") {
+ castTest(generateDoubles(), DataTypes.FloatType)
+ }
+
+ ignore("cast DoubleType to DecimalType(10,2)") {
+ // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE]
+ castTest(generateDoubles(), DataTypes.createDecimalType(10, 2))
+ }
+
+ ignore("cast DoubleType to StringType") {
+ // https://github.com/apache/datafusion-comet/issues/312
+ castTest(generateDoubles(), DataTypes.StringType)
+ }
+
+ ignore("cast DoubleType to TimestampType") {
+ castTest(generateDoubles(), DataTypes.TimestampType)
+ }
+
+ // CAST from DecimalType(10,2)
+
+ ignore("cast DecimalType(10,2) to BooleanType") {
+ // Arrow error: Cast error: Casting from Decimal128(38, 18) to Boolean not
supported
+ castTest(generateDecimals(), DataTypes.BooleanType)
+ }
+
+ ignore("cast DecimalType(10,2) to ByteType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateDecimals(), DataTypes.ByteType)
+ }
+
+ ignore("cast DecimalType(10,2) to ShortType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateDecimals(), DataTypes.ShortType)
+ }
+
+ ignore("cast DecimalType(10,2) to IntegerType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateDecimals(), DataTypes.IntegerType)
+ }
+
+ ignore("cast DecimalType(10,2) to LongType") {
+ // https://github.com/apache/datafusion-comet/issues/350
+ castTest(generateDecimals(), DataTypes.LongType)
+ }
+
+ test("cast DecimalType(10,2) to FloatType") {
+ castTest(generateDecimals(), DataTypes.FloatType)
+ }
+
+ test("cast DecimalType(10,2) to DoubleType") {
+ castTest(generateDecimals(), DataTypes.DoubleType)
+ }
+
+ ignore("cast DecimalType(10,2) to StringType") {
+ castTest(generateDecimals(), DataTypes.StringType)
+ }
+
+ ignore("cast DecimalType(10,2) to TimestampType") {
+ castTest(generateDecimals(), DataTypes.TimestampType)
+ }
+
+ // CAST from StringType
+
+ test("cast StringType to BooleanType") {
val testValues =
(Seq("TRUE", "True", "true", "FALSE", "False", "false", "1", "0", "",
null) ++
generateStrings("truefalseTRUEFALSEyesno10" + whitespaceChars,
8)).toDF("a")
castTest(testValues, DataTypes.BooleanType)
}
- ignore("cast string to byte") {
+ ignore("cast StringType to ByteType") {
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ByteType)
}
- ignore("cast string to short") {
+ ignore("cast StringType to ShortType") {
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ShortType)
}
- ignore("cast string to int") {
+ ignore("cast StringType to IntegerType") {
castTest(generateStrings(numericPattern, 8).toDF("a"),
DataTypes.IntegerType)
}
- ignore("cast string to long") {
+ ignore("cast StringType to LongType") {
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType)
}
- ignore("cast string to float") {
+ ignore("cast StringType to FloatType") {
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.FloatType)
}
- ignore("cast string to double") {
+ ignore("cast StringType to DoubleType") {
castTest(generateStrings(numericPattern, 8).toDF("a"),
DataTypes.DoubleType)
}
- ignore("cast string to decimal") {
+ ignore("cast StringType to DecimalType(10,2)") {
val values = generateStrings(numericPattern, 8).toDF("a")
castTest(values, DataTypes.createDecimalType(10, 2))
castTest(values, DataTypes.createDecimalType(10, 0))
castTest(values, DataTypes.createDecimalType(10, -2))
}
- ignore("cast string to date") {
+ ignore("cast StringType to DateType") {
castTest(generateStrings(datePattern, 8).toDF("a"), DataTypes.DateType)
}
- test("cast string to timestamp disabled by default") {
+ test("cast StringType to TimestampType disabled by default") {
val values = Seq("2020-01-01T12:34:56.123456", "T2").toDF("a")
castFallbackTest(
values.toDF("a"),
@@ -109,21 +483,163 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
"spark.comet.cast.stringToTimestamp is disabled")
}
- ignore("cast string to timestamp") {
+ ignore("cast StringType to TimestampType") {
withSQLConf((CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key, "true")) {
val values = Seq("2020-01-01T12:34:56.123456", "T2") ++
generateStrings(timestampPattern, 8)
castTest(values.toDF("a"), DataTypes.TimestampType)
}
}
+ // CAST from date
+
+ ignore("cast DateType to BooleanType") {
+ // TODO: implement
+ }
+
+ ignore("cast DateType to ByteType") {
+ // TODO: implement
+ }
+
+ ignore("cast DateType to ShortType") {
+ // TODO: implement
+ }
+
+ ignore("cast DateType to IntegerType") {
+ // TODO: implement
+ }
+
+ ignore("cast DateType to LongType") {
+ // TODO: implement
+ }
+
+ ignore("cast DateType to FloatType") {
+ // TODO: implement
+ }
+
+ ignore("cast DateType to DoubleType") {
+ // TODO: implement
+ }
+
+ ignore("cast DateType to DecimalType(10,2)") {
+ // TODO: implement
+ }
+
+ ignore("cast DateType to StringType") {
+ // TODO: implement
+ }
+
+ ignore("cast DateType to TimestampType") {
+ // TODO: implement
+ }
+
+ // CAST from TimestampType
+
+ ignore("cast TimestampType to BooleanType") {
+ // TODO: implement
+ }
+
+ ignore("cast TimestampType to ByteType") {
+ // TODO: implement
+ }
+
+ ignore("cast TimestampType to ShortType") {
+ // TODO: implement
+ }
+
+ ignore("cast TimestampType to IntegerType") {
+ // TODO: implement
+ }
+
+ ignore("cast TimestampType to LongType") {
+ // TODO: implement
+ }
+
+ ignore("cast TimestampType to FloatType") {
+ // TODO: implement
+ }
+
+ ignore("cast TimestampType to DoubleType") {
+ // TODO: implement
+ }
+
+ ignore("cast TimestampType to DecimalType(10,2)") {
+ // TODO: implement
+ }
+
+ ignore("cast TimestampType to StringType") {
+ // TODO: implement
+ }
+
+ ignore("cast TimestampType to DateType") {
+ // TODO: implement
+ }
+
private def generateFloats(): DataFrame = {
val r = new Random(0)
- Range(0, dataSize).map(_ => r.nextFloat()).toDF("a")
+ val values = Seq(
+ Float.MaxValue,
+ Float.MinPositiveValue,
+ Float.MinValue,
+ Float.NaN,
+ Float.PositiveInfinity,
+ Float.NegativeInfinity,
+ 0.0f,
+ -0.0f) ++
+ Range(0, dataSize).map(_ => r.nextFloat())
+ withNulls(values).toDF("a")
+ }
+
+ private def generateDoubles(): DataFrame = {
+ val r = new Random(0)
+ val values = Seq(
+ Double.MaxValue,
+ Double.MinPositiveValue,
+ Double.MinValue,
+ Double.NaN,
+ Double.PositiveInfinity,
+ Double.NegativeInfinity,
+ 0.0d,
+ -0.0d) ++
+ Range(0, dataSize).map(_ => r.nextDouble())
+ withNulls(values).toDF("a")
+ }
+
+ private def generateBools(): DataFrame = {
+ withNulls(Seq(true, false)).toDF("a")
+ }
+
+ private def generateBytes(): DataFrame = {
+ val r = new Random(0)
+ val values = Seq(Byte.MinValue, Byte.MaxValue) ++
+ Range(0, dataSize).map(_ => r.nextInt().toByte)
+ withNulls(values).toDF("a")
+ }
+
+ private def generateShorts(): DataFrame = {
+ val r = new Random(0)
+ val values = Seq(Short.MinValue, Short.MaxValue) ++
+ Range(0, dataSize).map(_ => r.nextInt().toShort)
+ withNulls(values).toDF("a")
+ }
+
+ private def generateInts(): DataFrame = {
+ val r = new Random(0)
+ val values = Seq(Int.MinValue, Int.MaxValue) ++
+ Range(0, dataSize).map(_ => r.nextInt())
+ withNulls(values).toDF("a")
}
private def generateLongs(): DataFrame = {
val r = new Random(0)
- Range(0, dataSize).map(_ => r.nextLong()).toDF("a")
+ val values = Seq(Long.MinValue, Long.MaxValue) ++
+ Range(0, dataSize).map(_ => r.nextLong())
+ withNulls(values).toDF("a")
+ }
+
+ private def generateDecimals(): DataFrame = {
+ // TODO improve this
+ val values = Seq(BigDecimal("123456.789"), BigDecimal("-123456.789"),
BigDecimal("0.0"))
+ withNulls(values).toDF("a")
}
private def generateString(r: Random, chars: String, maxLen: Int): String = {
@@ -131,11 +647,16 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
Range(0, len).map(_ => chars.charAt(r.nextInt(chars.length))).mkString
}
+ // TODO return DataFrame for consistency with other generators and include
null values
private def generateStrings(chars: String, maxLen: Int): Seq[String] = {
val r = new Random(0)
Range(0, dataSize).map(_ => generateString(r, chars, maxLen))
}
+ private def withNulls[T](values: Seq[T]): Seq[Option[T]] = {
+ values.map(v => Some(v)) ++ Seq(None)
+ }
+
private def castFallbackTest(
input: DataFrame,
toType: DataType,
@@ -176,28 +697,38 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
// cast() should throw exception on invalid inputs when ansi mode is
enabled
val df = data.withColumn("converted", col("a").cast(toType))
- val (expected, actual) = checkSparkThrows(df)
-
- if (CometSparkSessionExtensions.isSpark34Plus) {
- // We have to workaround
https://github.com/apache/datafusion-comet/issues/293 here by
- // removing the "Execution error: " error message prefix that is
added by DataFusion
- val cometMessage = actual.getMessage
- .substring("Execution error: ".length)
-
- assert(expected.getMessage == cometMessage)
- } else {
- // Spark 3.2 and 3.3 have a different error message format so we
can't do a direct
- // comparison between Spark and Comet.
- // Spark message is in format `invalid input syntax for type TYPE:
VALUE`
- // Comet message is in format `The value 'VALUE' of the type
FROM_TYPE cannot be cast to TO_TYPE`
- // We just check that the comet message contains the same invalid
value as the Spark message
- val sparkInvalidValue =
- expected.getMessage.substring(expected.getMessage.indexOf(':') + 2)
- assert(actual.getMessage.contains(sparkInvalidValue))
+ checkSparkMaybeThrows(df) match {
+ case (None, None) =>
+ // neither system threw an exception
+ case (None, Some(e)) =>
+ // Spark succeeded but Comet failed
+ throw e
+ case (Some(e), None) =>
+ // Spark failed but Comet succeeded
+ fail(s"Comet should have failed with ${e.getCause.getMessage}")
+ case (Some(sparkException), Some(cometException)) =>
+ // both systems threw an exception so we make sure they are the
same
+ val sparkMessage = sparkException.getCause.getMessage
+ // We have to workaround
https://github.com/apache/datafusion-comet/issues/293 here by
+ // removing the "Execution error: " error message prefix that is
added by DataFusion
+ val cometMessage = cometException.getCause.getMessage
+ .replace("Execution error: ", "")
+ if (CometSparkSessionExtensions.isSpark34Plus) {
+ assert(cometMessage == sparkMessage)
+ } else {
+ // Spark 3.2 and 3.3 have a different error message format so we
can't do a direct
+ // comparison between Spark and Comet.
+ // Spark message is in format `invalid input syntax for type
TYPE: VALUE`
+ // Comet message is in format `The value 'VALUE' of the type
FROM_TYPE cannot be cast to TO_TYPE`
+ // We just check that the comet message contains the same
invalid value as the Spark message
+ val sparkInvalidValue =
sparkMessage.substring(sparkMessage.indexOf(':') + 2)
+ assert(cometMessage.contains(sparkInvalidValue))
+ }
}
// try_cast() should always return null for invalid inputs
- val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t")
+ val df2 =
+ spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by
a")
checkSparkAnswer(df2)
}
}
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 27428b8..8fda136 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -230,15 +230,16 @@ abstract class CometTestBase
checkAnswerWithTol(dfComet, expected, absTol: Double)
}
- protected def checkSparkThrows(df: => DataFrame): (Throwable, Throwable) = {
+ protected def checkSparkMaybeThrows(
+ df: => DataFrame): (Option[Throwable], Option[Throwable]) = {
var expected: Option[Throwable] = None
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val dfSpark = Dataset.ofRows(spark, df.logicalPlan)
expected = Try(dfSpark.collect()).failed.toOption
}
val dfComet = Dataset.ofRows(spark, df.logicalPlan)
- val actual = Try(dfComet.collect()).failed.get
- (expected.get.getCause, actual.getCause)
+ val actual = Try(dfComet.collect()).failed.toOption
+ (expected, actual)
}
protected def checkSparkAnswerAndCompareExplainPlan(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]