This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new b39ed882 feat: Implement Spark-compatible CAST between integer types
(#340)
b39ed882 is described below
commit b39ed8823a6bf5ce7b624918c37cec40afcf7b36
Author: గణేష్ <[email protected]>
AuthorDate: Sat May 4 03:37:47 2024 +0530
feat: Implement Spark-compatible CAST between integer types (#340)
* handled cast for long to short
* handled cast for all overflow cases
* ran make format
* added check for overflow exception for 3.4 below.
* added comments to on why we do overflow check.
added a check before we fetch the sparkInvalidValue
* -1 instead of 0, -1 indicates the provided character is not present
* ran mvn spotless:apply
* check for presence of ':' and have asserts accordingly
* reusing exising test functions
* added one more check in assert when ':' is not present
* redo the compare logic as per andy's suggestions.
---------
Co-authored-by: ganesh.maddula <[email protected]>
---
core/src/errors.rs | 9 ++
core/src/execution/datafusion/expressions/cast.rs | 98 ++++++++++++++++++++++
.../scala/org/apache/comet/CometCastSuite.scala | 37 +++++---
3 files changed, 131 insertions(+), 13 deletions(-)
diff --git a/core/src/errors.rs b/core/src/errors.rs
index f02bd196..a06c613a 100644
--- a/core/src/errors.rs
+++ b/core/src/errors.rs
@@ -72,6 +72,15 @@ pub enum CometError {
to_type: String,
},
+ #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\"
cannot be cast to \"{to_type}\" \
+ due to an overflow. Use `try_cast` to tolerate overflow and return
NULL instead. If necessary \
+ set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ CastOverFlow {
+ value: String,
+ from_type: String,
+ to_type: String,
+ },
+
#[error(transparent)]
Arrow {
#[from]
diff --git a/core/src/execution/datafusion/expressions/cast.rs
b/core/src/execution/datafusion/expressions/cast.rs
index 45859c5f..a6e3adac 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -176,6 +176,62 @@ macro_rules! cast_float_to_string {
}};
}
+macro_rules! cast_int_to_int_macro {
+ (
+ $array: expr,
+ $eval_mode:expr,
+ $from_arrow_primitive_type: ty,
+ $to_arrow_primitive_type: ty,
+ $from_data_type: expr,
+ $to_native_type: ty,
+ $spark_from_data_type_name: expr,
+ $spark_to_data_type_name: expr
+ ) => {{
+ let cast_array = $array
+ .as_any()
+ .downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
+ .unwrap();
+ let spark_int_literal_suffix = match $from_data_type {
+ &DataType::Int64 => "L",
+ &DataType::Int16 => "S",
+ &DataType::Int8 => "T",
+ _ => "",
+ };
+
+ let output_array = match $eval_mode {
+ EvalMode::Legacy => cast_array
+ .iter()
+ .map(|value| match value {
+ Some(value) => {
+ Ok::<Option<$to_native_type>, CometError>(Some(value
as $to_native_type))
+ }
+ _ => Ok(None),
+ })
+ .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>,
_>>(),
+ _ => cast_array
+ .iter()
+ .map(|value| match value {
+ Some(value) => {
+ let res = <$to_native_type>::try_from(value);
+ if res.is_err() {
+ Err(CometError::CastOverFlow {
+ value: value.to_string() +
spark_int_literal_suffix,
+ from_type:
$spark_from_data_type_name.to_string(),
+ to_type: $spark_to_data_type_name.to_string(),
+ })
+ } else {
+ Ok::<Option<$to_native_type>,
CometError>(Some(res.unwrap()))
+ }
+ }
+ _ => Ok(None),
+ })
+ .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>,
_>>(),
+ }?;
+ let result: CometResult<ArrayRef> = Ok(Arc::new(output_array) as
ArrayRef);
+ result
+ }};
+}
+
impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
@@ -218,6 +274,16 @@ impl Cast {
(DataType::Utf8, DataType::Timestamp(_, _)) => {
Self::cast_string_to_timestamp(&array, to_type,
self.eval_mode)?
}
+ (DataType::Int64, DataType::Int32)
+ | (DataType::Int64, DataType::Int16)
+ | (DataType::Int64, DataType::Int8)
+ | (DataType::Int32, DataType::Int16)
+ | (DataType::Int32, DataType::Int8)
+ | (DataType::Int16, DataType::Int8)
+ if self.eval_mode != EvalMode::Try =>
+ {
+ Self::spark_cast_int_to_int(&array, self.eval_mode, from_type,
to_type)?
+ }
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64,
@@ -349,6 +415,38 @@ impl Cast {
cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
}
+ fn spark_cast_int_to_int(
+ array: &dyn Array,
+ eval_mode: EvalMode,
+ from_type: &DataType,
+ to_type: &DataType,
+ ) -> CometResult<ArrayRef> {
+ match (from_type, to_type) {
+ (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
+ array, eval_mode, Int64Type, Int32Type, from_type, i32,
"BIGINT", "INT"
+ ),
+ (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
+ array, eval_mode, Int64Type, Int16Type, from_type, i16,
"BIGINT", "SMALLINT"
+ ),
+ (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
+ array, eval_mode, Int64Type, Int8Type, from_type, i8,
"BIGINT", "TINYINT"
+ ),
+ (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
+ array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT",
"SMALLINT"
+ ),
+ (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
+ array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT",
"TINYINT"
+ ),
+ (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
+ array, eval_mode, Int16Type, Int8Type, from_type, i8,
"SMALLINT", "TINYINT"
+ ),
+ _ => unreachable!(
+ "{}",
+ format!("invalid integer type {to_type} in cast from
{from_type}")
+ ),
+ }
+ }
+
fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 54b13679..483301e0 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -166,7 +166,7 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateShorts(), DataTypes.BooleanType)
}
- ignore("cast ShortType to ByteType") {
+ test("cast ShortType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateShorts(), DataTypes.ByteType)
}
@@ -210,12 +210,12 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateInts(), DataTypes.BooleanType)
}
- ignore("cast IntegerType to ByteType") {
+ test("cast IntegerType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateInts(), DataTypes.ByteType)
}
- ignore("cast IntegerType to ShortType") {
+ test("cast IntegerType to ShortType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateInts(), DataTypes.ShortType)
}
@@ -256,17 +256,17 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateLongs(), DataTypes.BooleanType)
}
- ignore("cast LongType to ByteType") {
+ test("cast LongType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.ByteType)
}
- ignore("cast LongType to ShortType") {
+ test("cast LongType to ShortType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.ShortType)
}
- ignore("cast LongType to IntegerType") {
+ test("cast LongType to IntegerType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.IntegerType)
}
@@ -921,15 +921,26 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
val cometMessage = cometException.getCause.getMessage
.replace("Execution error: ", "")
if (CometSparkSessionExtensions.isSpark34Plus) {
+ // for Spark 3.4 we expect to reproduce the error message exactly
assert(cometMessage == sparkMessage)
+ } else if (CometSparkSessionExtensions.isSpark33Plus) {
+ // for Spark 3.3 we just need to strip the prefix from the Comet
message
+ // before comparing
+ val cometMessageModified = cometMessage
+ .replace("[CAST_INVALID_INPUT] ", "")
+ .replace("[CAST_OVERFLOW] ", "")
+ assert(cometMessageModified == sparkMessage)
} else {
- // Spark 3.2 and 3.3 have a different error message format so we
can't do a direct
- // comparison between Spark and Comet.
- // Spark message is in format `invalid input syntax for type
TYPE: VALUE`
- // Comet message is in format `The value 'VALUE' of the type
FROM_TYPE cannot be cast to TO_TYPE`
- // We just check that the comet message contains the same
invalid value as the Spark message
- val sparkInvalidValue =
sparkMessage.substring(sparkMessage.indexOf(':') + 2)
- assert(cometMessage.contains(sparkInvalidValue))
+ // for Spark 3.2 we just make sure we are seeing a similar type
of error
+ if (sparkMessage.contains("causes overflow")) {
+ assert(cometMessage.contains("due to an overflow"))
+ } else {
+ // assume that this is an invalid input message in the form:
+ // `invalid input syntax for type numeric:
-9223372036854775809`
+ // we just check that the Comet message contains the same
literal value
+ val sparkInvalidValue =
sparkMessage.substring(sparkMessage.indexOf(':') + 2)
+ assert(cometMessage.contains(sparkInvalidValue))
+ }
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]